import gradio as gr import torch from inference import ( load_model_from_hf, load_fish_ae_from_hf, load_pca_state_from_hf, load_audio, sample_pipeline, ae_encode, ae_decode, ) from samplers import sample_euler_cfg_any, GuidanceMode # ========================= # Global model load # ========================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = load_model_from_hf(device=DEVICE) fish_ae = load_fish_ae_from_hf(device=DEVICE) pca_state = load_pca_state_from_hf(device=DEVICE) # ========================= # Sampler wrapper # ========================= def sampler( model, speaker_latent, speaker_mask, text_input_ids, text_mask, seed, *, force_speaker_kv: bool = False, ): return sample_euler_cfg_any( model=model, speaker_latent=speaker_latent, speaker_mask=speaker_mask, text_input_ids=text_input_ids, text_mask=text_mask, rng_seed=seed, guidance_mode=GuidanceMode.JOINT, num_steps=30, cfg_scale_text=3.0, cfg_scale_speaker=None, cfg_min_t=0.0, cfg_max_t=1.0, truncation_factor=None, rescale_k=None, rescale_sigma=None, speaker_k_scale=1.5 if force_speaker_kv else None, speaker_k_min_t=0.6, speaker_k_max_layers=24, apg_eta_text=None, apg_eta_speaker=None, apg_momentum_text=None, apg_momentum_speaker=None, apg_norm_text=None, apg_norm_speaker=None, ) # ========================= # Voice Conversion # ========================= @torch.inference_mode() def voice_conversion_pipeline( src_audio, tgt_audio, seed, force_speaker_kv, ): # encode source (content) z_src = ae_encode(fish_ae, pca_state, src_audio.to(fish_ae.dtype)) # get speaker latent from target speaker_latent, speaker_mask = sample_pipeline.__globals__[ "get_speaker_latent_and_mask" ]( fish_ae, pca_state, tgt_audio.to(fish_ae.dtype), ) speaker_latent = speaker_latent.to(DEVICE) speaker_mask = speaker_mask.to(DEVICE) # dummy text (BOS only) text_input_ids = torch.zeros((1, 1), dtype=torch.int32, device=DEVICE) text_mask = torch.ones((1, 1), dtype=torch.bool, device=DEVICE) # sample z_out = sampler( model, speaker_latent, speaker_mask, text_input_ids, text_mask, seed, force_speaker_kv=force_speaker_kv, ) return ae_decode(fish_ae, pca_state, z_out) # ========================= # Gradio callback # ========================= def generate_audio( mode, text_prompt, vc_source_audio, speaker_audio, seed, force_speaker_kv, ): seed = int(seed) if mode == "Voice Conversion": if vc_source_audio is None or speaker_audio is None: raise gr.Error("VC requires BOTH source and target audio") src = load_audio(vc_source_audio).unsqueeze(1) tgt = load_audio(speaker_audio) audio = voice_conversion_pipeline( src, tgt, seed, force_speaker_kv, ) else: speaker = load_audio(speaker_audio) if speaker_audio else None audio = sample_pipeline( model, fish_ae, pca_state, sampler, text_prompt, speaker, seed, ) return audio[0].cpu().numpy(), 44100 # ========================= # UI # ========================= with gr.Blocks() as demo: gr.Markdown("# 🗣️ Echo-DiT TTS / Voice Conversion") mode = gr.Radio( ["Text to Speech", "Voice Conversion"], value="Text to Speech", label="Mode", ) text_prompt = gr.Textbox( label="Text Prompt (TTS only)", lines=4, ) vc_source_audio = gr.Audio( label="Source Audio (VC)", type="filepath", ) speaker_audio = gr.Audio( label="Speaker / Target Audio", type="filepath", ) seed = gr.Number( value=0, precision=0, label="Seed", ) force_speaker_kv = gr.Checkbox( label="Force Speaker KV (strong timbre lock)", value=False, ) generate = gr.Button("Generate") output_audio = gr.Audio(label="Output", type="numpy") generate.click( generate_audio, inputs=[ mode, text_prompt, vc_source_audio, speaker_audio, seed, force_speaker_kv, ], outputs=output_audio, ) demo.launch()