Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from inference import ( | |
| load_model_from_hf, | |
| load_fish_ae_from_hf, | |
| load_pca_state_from_hf, | |
| load_audio, | |
| sample_pipeline, | |
| ae_encode, | |
| ae_decode, | |
| ) | |
| from samplers import sample_euler_cfg_any, GuidanceMode | |
| # ========================= | |
| # Global model load | |
| # ========================= | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = load_model_from_hf(device=DEVICE) | |
| fish_ae = load_fish_ae_from_hf(device=DEVICE) | |
| pca_state = load_pca_state_from_hf(device=DEVICE) | |
| # ========================= | |
| # Sampler wrapper | |
| # ========================= | |
| def sampler( | |
| model, | |
| speaker_latent, | |
| speaker_mask, | |
| text_input_ids, | |
| text_mask, | |
| seed, | |
| *, | |
| force_speaker_kv: bool = False, | |
| ): | |
| return sample_euler_cfg_any( | |
| model=model, | |
| speaker_latent=speaker_latent, | |
| speaker_mask=speaker_mask, | |
| text_input_ids=text_input_ids, | |
| text_mask=text_mask, | |
| rng_seed=seed, | |
| guidance_mode=GuidanceMode.JOINT, | |
| num_steps=30, | |
| cfg_scale_text=3.0, | |
| cfg_scale_speaker=None, | |
| cfg_min_t=0.0, | |
| cfg_max_t=1.0, | |
| truncation_factor=None, | |
| rescale_k=None, | |
| rescale_sigma=None, | |
| speaker_k_scale=1.5 if force_speaker_kv else None, | |
| speaker_k_min_t=0.6, | |
| speaker_k_max_layers=24, | |
| apg_eta_text=None, | |
| apg_eta_speaker=None, | |
| apg_momentum_text=None, | |
| apg_momentum_speaker=None, | |
| apg_norm_text=None, | |
| apg_norm_speaker=None, | |
| ) | |
| # ========================= | |
| # Voice Conversion | |
| # ========================= | |
| def voice_conversion_pipeline( | |
| src_audio, | |
| tgt_audio, | |
| seed, | |
| force_speaker_kv, | |
| ): | |
| # encode source (content) | |
| z_src = ae_encode(fish_ae, pca_state, src_audio.to(fish_ae.dtype)) | |
| # get speaker latent from target | |
| speaker_latent, speaker_mask = sample_pipeline.__globals__[ | |
| "get_speaker_latent_and_mask" | |
| ]( | |
| fish_ae, | |
| pca_state, | |
| tgt_audio.to(fish_ae.dtype), | |
| ) | |
| speaker_latent = speaker_latent.to(DEVICE) | |
| speaker_mask = speaker_mask.to(DEVICE) | |
| # dummy text (BOS only) | |
| text_input_ids = torch.zeros((1, 1), dtype=torch.int32, device=DEVICE) | |
| text_mask = torch.ones((1, 1), dtype=torch.bool, device=DEVICE) | |
| # sample | |
| z_out = sampler( | |
| model, | |
| speaker_latent, | |
| speaker_mask, | |
| text_input_ids, | |
| text_mask, | |
| seed, | |
| force_speaker_kv=force_speaker_kv, | |
| ) | |
| return ae_decode(fish_ae, pca_state, z_out) | |
| # ========================= | |
| # Gradio callback | |
| # ========================= | |
| def generate_audio( | |
| mode, | |
| text_prompt, | |
| vc_source_audio, | |
| speaker_audio, | |
| seed, | |
| force_speaker_kv, | |
| ): | |
| seed = int(seed) | |
| if mode == "Voice Conversion": | |
| if vc_source_audio is None or speaker_audio is None: | |
| raise gr.Error("VC requires BOTH source and target audio") | |
| src = load_audio(vc_source_audio).unsqueeze(1) | |
| tgt = load_audio(speaker_audio) | |
| audio = voice_conversion_pipeline( | |
| src, | |
| tgt, | |
| seed, | |
| force_speaker_kv, | |
| ) | |
| else: | |
| speaker = load_audio(speaker_audio) if speaker_audio else None | |
| audio = sample_pipeline( | |
| model, | |
| fish_ae, | |
| pca_state, | |
| sampler, | |
| text_prompt, | |
| speaker, | |
| seed, | |
| ) | |
| return audio[0].cpu().numpy(), 44100 | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🗣️ Echo-DiT TTS / Voice Conversion") | |
| mode = gr.Radio( | |
| ["Text to Speech", "Voice Conversion"], | |
| value="Text to Speech", | |
| label="Mode", | |
| ) | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt (TTS only)", | |
| lines=4, | |
| ) | |
| vc_source_audio = gr.Audio( | |
| label="Source Audio (VC)", | |
| type="filepath", | |
| ) | |
| speaker_audio = gr.Audio( | |
| label="Speaker / Target Audio", | |
| type="filepath", | |
| ) | |
| seed = gr.Number( | |
| value=0, | |
| precision=0, | |
| label="Seed", | |
| ) | |
| force_speaker_kv = gr.Checkbox( | |
| label="Force Speaker KV (strong timbre lock)", | |
| value=False, | |
| ) | |
| generate = gr.Button("Generate") | |
| output_audio = gr.Audio(label="Output", type="numpy") | |
| generate.click( | |
| generate_audio, | |
| inputs=[ | |
| mode, | |
| text_prompt, | |
| vc_source_audio, | |
| speaker_audio, | |
| seed, | |
| force_speaker_kv, | |
| ], | |
| outputs=output_audio, | |
| ) | |
| demo.launch() | |