import os from importlib.resources import files import numpy as np import tomli from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf import torch import time import sys sys.path.append('F5-TTS/src/f5_tts') from infer.utils_infer import ( cross_fade_duration, infer_process, load_model, load_vocoder, preprocess_ref_audio_text, speed ) from SmoothCache import SmoothCacheHelper import gradio as gr import numpy as np from functools import lru_cache from PIL import Image, ImageDraw try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func # Constants layer_names = ['attn', 'ff'] colors_rgb = [(255, 103, 35), (0, 210, 106)] # orange, green cell_size = 20 spacing = 2 n_layers = 2 # Presets presets = { "32 NFE, α=0.15": { 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], }, "32 NFE, α=0.25": { 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], }, "16 NFE, α=0.3": { 'attn': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], 'ff': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], }, "16 NFE, α=0.5": { 'attn': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], 'ff': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], } } default_preset = "32 NFE, α=0.15" # Global state cache_schedule = { 'attn': presets[default_preset]['attn'][:], 'ff': presets[default_preset]['ff'][:] } config = tomli.load(open(os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), "rb")) model = config.get("model", "F5TTS_v1_Base") ckpt_file = config.get("ckpt_file", "") vocab_file = config.get("vocab_file", "") model_cfg = OmegaConf.load( config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" if not ckpt_file: ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) if not vocab_file: vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt")) ema_model = load_model( model_cls, model_arc, ckpt_file, vocab_file=vocab_file ) vocoder = load_vocoder() @gpu_decorator def render_grid(schedule: dict) -> np.ndarray: n_steps = len(schedule['attn']) img = Image.new("RGB", (n_steps * (cell_size + spacing), n_layers * (cell_size + spacing)), "white") draw = ImageDraw.Draw(img) for row in range(n_layers): layer = layer_names[row] for col in range(n_steps): x0 = col * (cell_size + spacing) y0 = row * (cell_size + spacing) x1 = x0 + cell_size y1 = y0 + cell_size color = colors_rgb[row] if schedule[layer][col] == 1 else "white" draw.rectangle([x0, y0, x1, y1], fill=color, outline="black") return np.array(img) @gpu_decorator def apply_preset(preset_name): global cache_schedule if preset_name in presets: schedule = presets[preset_name] cache_schedule['attn'] = schedule['attn'][:] cache_schedule['ff'] = schedule['ff'][:] return render_grid(cache_schedule), len(cache_schedule['attn']) @gpu_decorator def toggle_cell(evt: gr.SelectData): global cache_schedule col = evt.index[0] // (cell_size + spacing) row = evt.index[1] // (cell_size + spacing) layer = layer_names[row] if col < len(cache_schedule[layer]): cache_schedule[layer][col] ^= 1 return render_grid(cache_schedule), "Custom" @gpu_decorator def reset_schedule(n_steps): global cache_schedule cache_schedule = { 'attn': [1] * n_steps, 'ff': [1] * n_steps } return render_grid(cache_schedule), "Custom" @gpu_decorator def update_nfe(nfe_value): return reset_schedule(nfe_value) @gpu_decorator def load_default(): return render_grid(cache_schedule), default_preset @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable @gpu_decorator def infer( ref_audio_orig, ref_text, gen_text, #model, #remove_silence, #seed, #cross_fade_duration=0.15, nfe_step=32, #speed=1, #show_info=gr.Info, ): global cache_schedule show_info=gr.Info if not ref_audio_orig: gr.Warning("Please provide reference audio.") return gr.update(), gr.update(), ref_text # Set inference seed # if seed < 0 or seed > 2**31 - 1: # gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.") seed = np.random.randint(0, 2**31 - 1) torch.manual_seed(seed) used_seed = seed if not gen_text.strip(): gr.Warning("Please enter text to generate or upload a text file.") return gr.update(), gr.update(), ref_text ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) start_time = time.time() final_wave, final_sample_rate, _ = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) process_time = time.time() - start_time cache_helper = SmoothCacheHelper( model=ema_model.transformer, block_classes=get_class("f5_tts.model.modules.DiTBlock"), components_to_wrap=['attn','ff'], schedule=cache_schedule ) cache_helper.enable() start_time = time.time() final_wave_cache, final_sample_rate_cache, _ = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) process_time_cache = time.time() - start_time cache_helper.disable() # Remove silence # if remove_silence: # with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: # temp_path = f.name # try: # sf.write(temp_path, final_wave, final_sample_rate) # remove_silence_for_generated_wav(f.name) # final_wave, _ = torchaudio.load(f.name) # finally: # os.unlink(temp_path) # final_wave = final_wave.squeeze().cpu().numpy() # Save the spectrogram # with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: # spectrogram_path = tmp_spectrogram.name # save_spectrogram(combined_spectrogram, spectrogram_path) return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), process_time, process_time_cache with gr.Blocks() as demo: gr.Markdown("## 🎛️ Cache Schedule Editor") ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") ref_text_input = gr.Textbox( label="Reference Text", #info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.", # lines=2, # scale=4, ) gen_text_input = gr.Textbox( label="Text to Generate", # lines=10, # max_lines=40, # scale=4, ) with gr.Row(): with gr.Column(scale=0): preset_dropdown = gr.Dropdown(choices=list(presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset) nfe_slider = gr.Slider(4, 64, value=len(cache_schedule['attn']), step=1, label="Number of Steps (NFE)") with gr.Column(scale=1): gr.Markdown("Click Grid to Customize Cache Schedule\n🟧 = Compute Attn Layer\n🟩 = Compute FFN Layer\n⬜ = Cached Layer") image = gr.Image(type="numpy", interactive=True, scale=1) #reset_btn = gr.Button("Reset to All Cached") #current_label = gr.Textbox(label="Current Preset", interactive=False) generate_btn = gr.Button("Synthesize", variant="primary") with gr.Row(): with gr.Group(): audio_output = gr.Audio(label="Synthesized Audio (No Cache)") process_time = gr.Textbox(label="⏱ Process Time", interactive=False) with gr.Group(): audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)") process_time_cache = gr.Textbox(label="⏱ Process Time", interactive=False) # Wire up logic preset_dropdown.change(fn=apply_preset, inputs=preset_dropdown, outputs=[image, nfe_slider]) #preset_dropdown.change(fn=lambda x: x, inputs=preset_dropdown, outputs=current_label) image.select(fn=toggle_cell, outputs=[image, preset_dropdown]) #reset_btn.click(fn=reset_schedule, inputs=nfe_slider, outputs=[image, preset_dropdown]) nfe_slider.change(fn=update_nfe, inputs=nfe_slider, outputs=[image, preset_dropdown]) generate_btn.click( infer, inputs=[ ref_audio_input, ref_text_input, gen_text_input, #remove_silence, #randomize_seed, #np.random.randint(0, 2**31 - 1), #cross_fade_duration_slider, nfe_slider, #speed_slider, ], outputs=[audio_output, audio_output_cache, process_time, process_time_cache], ) demo.load(fn=load_default, outputs=[image, preset_dropdown]) demo.launch()