Spaces:
Sleeping
Sleeping
| import os | |
| from importlib.resources import files | |
| import numpy as np | |
| import tomli | |
| from cached_path import cached_path | |
| from hydra.utils import get_class | |
| from omegaconf import OmegaConf | |
| import torch | |
| import time | |
| import sys | |
| sys.path.append('F5-TTS/src/f5_tts') | |
| from infer.utils_infer import ( | |
| cross_fade_duration, | |
| infer_process, | |
| load_model, | |
| load_vocoder, | |
| preprocess_ref_audio_text, | |
| speed | |
| ) | |
| from SmoothCache import SmoothCacheHelper | |
| import gradio as gr | |
| import numpy as np | |
| from functools import lru_cache | |
| from PIL import Image, ImageDraw | |
| try: | |
| import spaces | |
| USING_SPACES = True | |
| except ImportError: | |
| USING_SPACES = False | |
| def gpu_decorator(func): | |
| if USING_SPACES: | |
| return spaces.GPU(func) | |
| else: | |
| return func | |
| # Constants | |
| layer_names = ['attn', 'ff'] | |
| colors_rgb = [(255, 103, 35), (0, 210, 106)] # orange, green | |
| cell_size = 20 | |
| spacing = 2 | |
| n_layers = 2 | |
| # Presets | |
| presets = { | |
| "32 NFE, α=0.15": { | |
| 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], | |
| 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], | |
| }, | |
| "32 NFE, α=0.25": { | |
| 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], | |
| 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], | |
| }, | |
| "16 NFE, α=0.3": { | |
| 'attn': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], | |
| 'ff': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], | |
| }, | |
| "16 NFE, α=0.5": { | |
| 'attn': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], | |
| 'ff': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], | |
| } | |
| } | |
| default_preset = "32 NFE, α=0.15" | |
| # Global state | |
| cache_schedule = { | |
| 'attn': presets[default_preset]['attn'][:], | |
| 'ff': presets[default_preset]['ff'][:] | |
| } | |
| config = tomli.load(open(os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), "rb")) | |
| model = config.get("model", "F5TTS_v1_Base") | |
| ckpt_file = config.get("ckpt_file", "") | |
| vocab_file = config.get("vocab_file", "") | |
| model_cfg = OmegaConf.load( | |
| config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) | |
| ) | |
| model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") | |
| model_arc = model_cfg.model.arch | |
| repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" | |
| if not ckpt_file: | |
| ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) | |
| if not vocab_file: | |
| vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt")) | |
| ema_model = load_model( | |
| model_cls, model_arc, ckpt_file, vocab_file=vocab_file | |
| ) | |
| vocoder = load_vocoder() | |
| def render_grid(schedule: dict) -> np.ndarray: | |
| n_steps = len(schedule['attn']) | |
| img = Image.new("RGB", (n_steps * (cell_size + spacing), n_layers * (cell_size + spacing)), "white") | |
| draw = ImageDraw.Draw(img) | |
| for row in range(n_layers): | |
| layer = layer_names[row] | |
| for col in range(n_steps): | |
| x0 = col * (cell_size + spacing) | |
| y0 = row * (cell_size + spacing) | |
| x1 = x0 + cell_size | |
| y1 = y0 + cell_size | |
| color = colors_rgb[row] if schedule[layer][col] == 1 else "white" | |
| draw.rectangle([x0, y0, x1, y1], fill=color, outline="black") | |
| return np.array(img) | |
| def apply_preset(preset_name): | |
| global cache_schedule | |
| if preset_name in presets: | |
| schedule = presets[preset_name] | |
| cache_schedule['attn'] = schedule['attn'][:] | |
| cache_schedule['ff'] = schedule['ff'][:] | |
| return render_grid(cache_schedule), len(cache_schedule['attn']) | |
| def toggle_cell(evt: gr.SelectData): | |
| global cache_schedule | |
| col = evt.index[0] // (cell_size + spacing) | |
| row = evt.index[1] // (cell_size + spacing) | |
| layer = layer_names[row] | |
| if col < len(cache_schedule[layer]): | |
| cache_schedule[layer][col] ^= 1 | |
| return render_grid(cache_schedule), "Custom" | |
| def reset_schedule(n_steps): | |
| global cache_schedule | |
| cache_schedule = { | |
| 'attn': [1] * n_steps, | |
| 'ff': [1] * n_steps | |
| } | |
| return render_grid(cache_schedule), "Custom" | |
| def update_nfe(nfe_value): | |
| return reset_schedule(nfe_value) | |
| def load_default(): | |
| return render_grid(cache_schedule), default_preset | |
| # NOTE. need to ensure params of infer() hashable | |
| def infer( | |
| ref_audio_orig, | |
| ref_text, | |
| gen_text, | |
| #model, | |
| #remove_silence, | |
| #seed, | |
| #cross_fade_duration=0.15, | |
| nfe_step=32, | |
| #speed=1, | |
| #show_info=gr.Info, | |
| ): | |
| global cache_schedule | |
| show_info=gr.Info | |
| if not ref_audio_orig: | |
| gr.Warning("Please provide reference audio.") | |
| return gr.update(), gr.update(), ref_text | |
| # Set inference seed | |
| # if seed < 0 or seed > 2**31 - 1: | |
| # gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.") | |
| seed = np.random.randint(0, 2**31 - 1) | |
| torch.manual_seed(seed) | |
| used_seed = seed | |
| if not gen_text.strip(): | |
| gr.Warning("Please enter text to generate or upload a text file.") | |
| return gr.update(), gr.update(), ref_text | |
| ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) | |
| start_time = time.time() | |
| final_wave, final_sample_rate, _ = infer_process( | |
| ref_audio, | |
| ref_text, | |
| gen_text, | |
| ema_model, | |
| vocoder, | |
| cross_fade_duration=cross_fade_duration, | |
| nfe_step=nfe_step, | |
| speed=speed, | |
| show_info=show_info, | |
| progress=gr.Progress(), | |
| ) | |
| process_time = time.time() - start_time | |
| cache_helper = SmoothCacheHelper( | |
| model=ema_model.transformer, | |
| block_classes=get_class("f5_tts.model.modules.DiTBlock"), | |
| components_to_wrap=['attn','ff'], | |
| schedule=cache_schedule | |
| ) | |
| cache_helper.enable() | |
| start_time = time.time() | |
| final_wave_cache, final_sample_rate_cache, _ = infer_process( | |
| ref_audio, | |
| ref_text, | |
| gen_text, | |
| ema_model, | |
| vocoder, | |
| cross_fade_duration=cross_fade_duration, | |
| nfe_step=nfe_step, | |
| speed=speed, | |
| show_info=show_info, | |
| progress=gr.Progress(), | |
| ) | |
| process_time_cache = time.time() - start_time | |
| cache_helper.disable() | |
| # Remove silence | |
| # if remove_silence: | |
| # with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: | |
| # temp_path = f.name | |
| # try: | |
| # sf.write(temp_path, final_wave, final_sample_rate) | |
| # remove_silence_for_generated_wav(f.name) | |
| # final_wave, _ = torchaudio.load(f.name) | |
| # finally: | |
| # os.unlink(temp_path) | |
| # final_wave = final_wave.squeeze().cpu().numpy() | |
| # Save the spectrogram | |
| # with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: | |
| # spectrogram_path = tmp_spectrogram.name | |
| # save_spectrogram(combined_spectrogram, spectrogram_path) | |
| return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), process_time, process_time_cache | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎛️ Cache Schedule Editor") | |
| ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
| ref_text_input = gr.Textbox( | |
| label="Reference Text", | |
| #info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.", | |
| # lines=2, | |
| # scale=4, | |
| ) | |
| gen_text_input = gr.Textbox( | |
| label="Text to Generate", | |
| # lines=10, | |
| # max_lines=40, | |
| # scale=4, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=0): | |
| preset_dropdown = gr.Dropdown(choices=list(presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset) | |
| nfe_slider = gr.Slider(4, 64, value=len(cache_schedule['attn']), step=1, label="Number of Steps (NFE)") | |
| with gr.Column(scale=1): | |
| gr.Markdown("Click Grid to Customize Cache Schedule\n🟧 = Compute Attn Layer\n🟩 = Compute FFN Layer\n⬜ = Cached Layer") | |
| image = gr.Image(type="numpy", interactive=True, scale=1) | |
| #reset_btn = gr.Button("Reset to All Cached") | |
| #current_label = gr.Textbox(label="Current Preset", interactive=False) | |
| generate_btn = gr.Button("Synthesize", variant="primary") | |
| with gr.Row(): | |
| with gr.Group(): | |
| audio_output = gr.Audio(label="Synthesized Audio (No Cache)") | |
| process_time = gr.Textbox(label="⏱ Process Time", interactive=False) | |
| with gr.Group(): | |
| audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)") | |
| process_time_cache = gr.Textbox(label="⏱ Process Time", interactive=False) | |
| # Wire up logic | |
| preset_dropdown.change(fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider]) | |
| #preset_dropdown.change(fn=lambda x: x, inputs=preset_dropdown, outputs=current_label) | |
| image.select(fn=toggle_cell, outputs=[image, preset_dropdown]) | |
| #reset_btn.click(fn=reset_schedule, inputs=nfe_slider, outputs=[image, preset_dropdown]) | |
| nfe_slider.change(fn=update_nfe, inputs=nfe_slider, outputs=[image, preset_dropdown]) | |
| generate_btn.click( | |
| infer, | |
| inputs=[ | |
| ref_audio_input, | |
| ref_text_input, | |
| gen_text_input, | |
| #remove_silence, | |
| #randomize_seed, | |
| #np.random.randint(0, 2**31 - 1), | |
| #cross_fade_duration_slider, | |
| nfe_slider, | |
| #speed_slider, | |
| ], | |
| outputs=[audio_output, audio_output_cache, process_time, process_time_cache], | |
| ) | |
| demo.load(fn=load_default, outputs=[image, preset_dropdown]) | |
| demo.launch() | |