from __future__ import annotations import contextlib import io import os from pathlib import Path from typing import List, Tuple import gradio as gr import torch import spaces from dia2 import Dia2, GenerationConfig, SamplingConfig DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B") MAX_TURNS = 10 INITIAL_TURNS = 2 _dia: Dia2 | None = None def _get_dia() -> Dia2: global _dia if _dia is None: _dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16") return _dia def _concat_script(turn_count: int, turn_values: List[str]) -> str: lines: List[str] = [] for idx in range(min(turn_count, len(turn_values))): text = (turn_values[idx] or "").strip() if not text: continue speaker = "[S1]" if idx % 2 == 0 else "[S2]" lines.append(f"{speaker} {text}") return "\n".join(lines) EXAMPLES: dict[str, dict[str, List[str] | str | None]] = { "Intro": { "turns": [ "Hello Dia2 fans! Today we're unveiling the new open TTS model.", "Sounds exciting. Can you show a sample right now?", "Absolutely. (laughs) Just press generate.", ], "voice_s1": "example_prefix1.wav", "voice_s2": "example_prefix2.wav", }, "Customer Support": { "turns": [ "Thanks for calling. How can I help you today?", "My parcel never arrived and it's been two weeks.", "I'm sorry about that. Let me check your tracking number.", "Appreciate it. I really need that package soon.", ], "voice_s1": "example_prefix1.wav", "voice_s2": "example_prefix2.wav", }, } def _apply_turn_visibility(count: int) -> List[gr.Update]: return [gr.update(visible=i < count) for i in range(MAX_TURNS)] def _add_turn(count: int): count = min(count + 1, MAX_TURNS) return (count, *_apply_turn_visibility(count)) def _remove_turn(count: int): count = max(1, count - 1) return (count, *_apply_turn_visibility(count)) def _load_example(name: str, count: int): data = EXAMPLES.get(name) if not data: return (count, *_apply_turn_visibility(count), None, None) turns = data.get("turns", []) voice_s1_path = data.get("voice_s1") voice_s2_path = data.get("voice_s2") new_count = min(len(turns), MAX_TURNS) updates: List[gr.Update] = [] for idx in range(MAX_TURNS): if idx < new_count: updates.append(gr.update(value=turns[idx], visible=True)) else: updates.append(gr.update(value="", visible=idx < INITIAL_TURNS)) return (new_count, *updates, voice_s1_path, voice_s2_path) def _prepare_prefix(file_path: str | None) -> str | None: if not file_path: return None path = Path(file_path) if not path.exists(): return None return str(path) @spaces.GPU(duration=100) def generate_audio( turn_count: int, *inputs, ): turn_values = list(inputs[:MAX_TURNS]) voice_s1 = inputs[MAX_TURNS] voice_s2 = inputs[MAX_TURNS + 1] cfg_scale = float(inputs[MAX_TURNS + 2]) text_temperature = float(inputs[MAX_TURNS + 3]) audio_temperature = float(inputs[MAX_TURNS + 4]) text_top_k = int(inputs[MAX_TURNS + 5]) audio_top_k = int(inputs[MAX_TURNS + 6]) include_prefix = bool(inputs[MAX_TURNS + 7]) script = _concat_script(turn_count, turn_values) if not script.strip(): raise gr.Error("Please enter at least one non-empty speaker turn.") dia = _get_dia() config = GenerationConfig( cfg_scale=cfg_scale, text=SamplingConfig(temperature=text_temperature, top_k=text_top_k), audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k), use_cuda_graph=True, ) kwargs = { "prefix_speaker_1": _prepare_prefix(voice_s1), "prefix_speaker_2": _prepare_prefix(voice_s2), "include_prefix": include_prefix, } buffer = io.StringIO() with contextlib.redirect_stdout(buffer): result = dia.generate( script, config=config, output_wav=None, verbose=True, **kwargs, ) waveform = result.waveform.detach().cpu().numpy() sample_rate = result.sample_rate timestamps = result.timestamps log_text = buffer.getvalue().strip() table = [[w, round(t, 3)] for w, t in timestamps] return (sample_rate, waveform), table, log_text or "Generation finished." def build_interface() -> gr.Blocks: with gr.Blocks( title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}" ) as demo: gr.Markdown( """## Dia2 — Open TTS Model Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default).""" ) turn_state = gr.State(INITIAL_TURNS) with gr.Row(equal_height=True): example_dropdown = gr.Dropdown( choices=["(select example)"] + list(EXAMPLES.keys()), label="Examples", value="(select example)", ) with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Script") controls = [] for idx in range(MAX_TURNS): speaker = "[S1]" if idx % 2 == 0 else "[S2]" box = gr.Textbox( label=f"{speaker} turn {idx + 1}", lines=2, elem_classes=["compact-turn"], placeholder=f"Enter dialogue for {speaker}…", visible=idx < INITIAL_TURNS, ) controls.append(box) with gr.Row(): add_btn = gr.Button("Add Turn") remove_btn = gr.Button("Remove Turn") with gr.Group(): gr.Markdown("### Voice Prompts") with gr.Row(): voice_s1 = gr.File( label="[S1] voice (wav/mp3)", type="filepath" ) voice_s2 = gr.File( label="[S2] voice (wav/mp3)", type="filepath" ) with gr.Group(): gr.Markdown("### Sampling") cfg_scale = gr.Slider( 1.0, 8.0, value=6.0, step=0.1, label="CFG Scale" ) with gr.Group(): gr.Markdown("#### Text Sampling") text_temperature = gr.Slider( 0.1, 1.5, value=0.6, step=0.05, label="Text Temperature" ) text_top_k = gr.Slider( 1, 200, value=50, step=1, label="Text Top-K" ) with gr.Group(): gr.Markdown("#### Audio Sampling") audio_temperature = gr.Slider( 0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature" ) audio_top_k = gr.Slider( 1, 200, value=50, step=1, label="Audio Top-K" ) include_prefix = gr.Checkbox( label="Keep prefix audio in output", value=False ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): gr.Markdown("### Output") audio_out = gr.Audio(label="Waveform", interactive=False) timestamps = gr.Dataframe( headers=["word", "seconds"], label="Timestamps" ) log_box = gr.Textbox(label="Logs", lines=8) add_btn.click( lambda c: _add_turn(c), inputs=turn_state, outputs=[turn_state, *controls], ) remove_btn.click( lambda c: _remove_turn(c), inputs=turn_state, outputs=[turn_state, *controls], ) example_dropdown.change( lambda name, c: _load_example(name, c), inputs=[example_dropdown, turn_state], outputs=[turn_state, *controls, voice_s1, voice_s2], ) generate_btn.click( generate_audio, inputs=[ turn_state, *controls, voice_s1, voice_s2, cfg_scale, text_temperature, audio_temperature, text_top_k, audio_top_k, include_prefix, ], outputs=[audio_out, timestamps, log_box], ) return demo if __name__ == "__main__": app = build_interface() app.queue(default_concurrency_limit=1) app.launch(share=True)