GaboChoropan's picture
Update app.py
c9d5eb3 verified
import gradio as gr
import torch
from inference import (
load_model_from_hf,
load_fish_ae_from_hf,
load_pca_state_from_hf,
load_audio,
sample_pipeline,
ae_encode,
ae_decode,
)
from samplers import sample_euler_cfg_any, GuidanceMode
# =========================
# Global model load
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_from_hf(device=DEVICE)
fish_ae = load_fish_ae_from_hf(device=DEVICE)
pca_state = load_pca_state_from_hf(device=DEVICE)
# =========================
# Sampler wrapper
# =========================
def sampler(
model,
speaker_latent,
speaker_mask,
text_input_ids,
text_mask,
seed,
*,
force_speaker_kv: bool = False,
):
return sample_euler_cfg_any(
model=model,
speaker_latent=speaker_latent,
speaker_mask=speaker_mask,
text_input_ids=text_input_ids,
text_mask=text_mask,
rng_seed=seed,
guidance_mode=GuidanceMode.JOINT,
num_steps=30,
cfg_scale_text=3.0,
cfg_scale_speaker=None,
cfg_min_t=0.0,
cfg_max_t=1.0,
truncation_factor=None,
rescale_k=None,
rescale_sigma=None,
speaker_k_scale=1.5 if force_speaker_kv else None,
speaker_k_min_t=0.6,
speaker_k_max_layers=24,
apg_eta_text=None,
apg_eta_speaker=None,
apg_momentum_text=None,
apg_momentum_speaker=None,
apg_norm_text=None,
apg_norm_speaker=None,
)
# =========================
# Voice Conversion
# =========================
@torch.inference_mode()
def voice_conversion_pipeline(
src_audio,
tgt_audio,
seed,
force_speaker_kv,
):
# encode source (content)
z_src = ae_encode(fish_ae, pca_state, src_audio.to(fish_ae.dtype))
# get speaker latent from target
speaker_latent, speaker_mask = sample_pipeline.__globals__[
"get_speaker_latent_and_mask"
](
fish_ae,
pca_state,
tgt_audio.to(fish_ae.dtype),
)
speaker_latent = speaker_latent.to(DEVICE)
speaker_mask = speaker_mask.to(DEVICE)
# dummy text (BOS only)
text_input_ids = torch.zeros((1, 1), dtype=torch.int32, device=DEVICE)
text_mask = torch.ones((1, 1), dtype=torch.bool, device=DEVICE)
# sample
z_out = sampler(
model,
speaker_latent,
speaker_mask,
text_input_ids,
text_mask,
seed,
force_speaker_kv=force_speaker_kv,
)
return ae_decode(fish_ae, pca_state, z_out)
# =========================
# Gradio callback
# =========================
def generate_audio(
mode,
text_prompt,
vc_source_audio,
speaker_audio,
seed,
force_speaker_kv,
):
seed = int(seed)
if mode == "Voice Conversion":
if vc_source_audio is None or speaker_audio is None:
raise gr.Error("VC requires BOTH source and target audio")
src = load_audio(vc_source_audio).unsqueeze(1)
tgt = load_audio(speaker_audio)
audio = voice_conversion_pipeline(
src,
tgt,
seed,
force_speaker_kv,
)
else:
speaker = load_audio(speaker_audio) if speaker_audio else None
audio = sample_pipeline(
model,
fish_ae,
pca_state,
sampler,
text_prompt,
speaker,
seed,
)
return audio[0].cpu().numpy(), 44100
# =========================
# UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# 🗣️ Echo-DiT TTS / Voice Conversion")
mode = gr.Radio(
["Text to Speech", "Voice Conversion"],
value="Text to Speech",
label="Mode",
)
text_prompt = gr.Textbox(
label="Text Prompt (TTS only)",
lines=4,
)
vc_source_audio = gr.Audio(
label="Source Audio (VC)",
type="filepath",
)
speaker_audio = gr.Audio(
label="Speaker / Target Audio",
type="filepath",
)
seed = gr.Number(
value=0,
precision=0,
label="Seed",
)
force_speaker_kv = gr.Checkbox(
label="Force Speaker KV (strong timbre lock)",
value=False,
)
generate = gr.Button("Generate")
output_audio = gr.Audio(label="Output", type="numpy")
generate.click(
generate_audio,
inputs=[
mode,
text_prompt,
vc_source_audio,
speaker_audio,
seed,
force_speaker_kv,
],
outputs=output_audio,
)
demo.launch()