|
|
from __future__ import annotations |
|
|
|
|
|
import contextlib |
|
|
import io |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import spaces |
|
|
|
|
|
from dia2 import Dia2, GenerationConfig, SamplingConfig |
|
|
|
|
|
DEFAULT_REPO = os.environ.get("DIA2_DEFAULT_REPO", "nari-labs/Dia2-2B") |
|
|
MAX_TURNS = 10 |
|
|
INITIAL_TURNS = 2 |
|
|
|
|
|
_dia: Dia2 | None = None |
|
|
|
|
|
|
|
|
def _get_dia() -> Dia2: |
|
|
global _dia |
|
|
if _dia is None: |
|
|
_dia = Dia2.from_repo(DEFAULT_REPO, device="cuda", dtype="bfloat16") |
|
|
return _dia |
|
|
|
|
|
|
|
|
def _concat_script(turn_count: int, turn_values: List[str]) -> str: |
|
|
lines: List[str] = [] |
|
|
for idx in range(min(turn_count, len(turn_values))): |
|
|
text = (turn_values[idx] or "").strip() |
|
|
if not text: |
|
|
continue |
|
|
speaker = "[S1]" if idx % 2 == 0 else "[S2]" |
|
|
lines.append(f"{speaker} {text}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
EXAMPLES: dict[str, dict[str, List[str] | str | None]] = { |
|
|
"Intro": { |
|
|
"turns": [ |
|
|
"Hello Dia2 fans! Today we're unveiling the new open TTS model.", |
|
|
"Sounds exciting. Can you show a sample right now?", |
|
|
"Absolutely. (laughs) Just press generate.", |
|
|
], |
|
|
"voice_s1": "example_prefix1.wav", |
|
|
"voice_s2": "example_prefix2.wav", |
|
|
}, |
|
|
"Customer Support": { |
|
|
"turns": [ |
|
|
"Thanks for calling. How can I help you today?", |
|
|
"My parcel never arrived and it's been two weeks.", |
|
|
"I'm sorry about that. Let me check your tracking number.", |
|
|
"Appreciate it. I really need that package soon.", |
|
|
], |
|
|
"voice_s1": "example_prefix1.wav", |
|
|
"voice_s2": "example_prefix2.wav", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def _apply_turn_visibility(count: int) -> List[gr.Update]: |
|
|
return [gr.update(visible=i < count) for i in range(MAX_TURNS)] |
|
|
|
|
|
|
|
|
def _add_turn(count: int): |
|
|
count = min(count + 1, MAX_TURNS) |
|
|
return (count, *_apply_turn_visibility(count)) |
|
|
|
|
|
|
|
|
def _remove_turn(count: int): |
|
|
count = max(1, count - 1) |
|
|
return (count, *_apply_turn_visibility(count)) |
|
|
|
|
|
|
|
|
def _load_example(name: str, count: int): |
|
|
data = EXAMPLES.get(name) |
|
|
if not data: |
|
|
return (count, *_apply_turn_visibility(count), None, None) |
|
|
turns = data.get("turns", []) |
|
|
voice_s1_path = data.get("voice_s1") |
|
|
voice_s2_path = data.get("voice_s2") |
|
|
new_count = min(len(turns), MAX_TURNS) |
|
|
updates: List[gr.Update] = [] |
|
|
for idx in range(MAX_TURNS): |
|
|
if idx < new_count: |
|
|
updates.append(gr.update(value=turns[idx], visible=True)) |
|
|
else: |
|
|
updates.append(gr.update(value="", visible=idx < INITIAL_TURNS)) |
|
|
return (new_count, *updates, voice_s1_path, voice_s2_path) |
|
|
|
|
|
|
|
|
def _prepare_prefix(file_path: str | None) -> str | None: |
|
|
if not file_path: |
|
|
return None |
|
|
path = Path(file_path) |
|
|
if not path.exists(): |
|
|
return None |
|
|
return str(path) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=100) |
|
|
def generate_audio( |
|
|
turn_count: int, |
|
|
*inputs, |
|
|
): |
|
|
turn_values = list(inputs[:MAX_TURNS]) |
|
|
voice_s1 = inputs[MAX_TURNS] |
|
|
voice_s2 = inputs[MAX_TURNS + 1] |
|
|
cfg_scale = float(inputs[MAX_TURNS + 2]) |
|
|
text_temperature = float(inputs[MAX_TURNS + 3]) |
|
|
audio_temperature = float(inputs[MAX_TURNS + 4]) |
|
|
text_top_k = int(inputs[MAX_TURNS + 5]) |
|
|
audio_top_k = int(inputs[MAX_TURNS + 6]) |
|
|
include_prefix = bool(inputs[MAX_TURNS + 7]) |
|
|
|
|
|
script = _concat_script(turn_count, turn_values) |
|
|
if not script.strip(): |
|
|
raise gr.Error("Please enter at least one non-empty speaker turn.") |
|
|
|
|
|
dia = _get_dia() |
|
|
config = GenerationConfig( |
|
|
cfg_scale=cfg_scale, |
|
|
text=SamplingConfig(temperature=text_temperature, top_k=text_top_k), |
|
|
audio=SamplingConfig(temperature=audio_temperature, top_k=audio_top_k), |
|
|
use_cuda_graph=True, |
|
|
) |
|
|
kwargs = { |
|
|
"prefix_speaker_1": _prepare_prefix(voice_s1), |
|
|
"prefix_speaker_2": _prepare_prefix(voice_s2), |
|
|
"include_prefix": include_prefix, |
|
|
} |
|
|
buffer = io.StringIO() |
|
|
with contextlib.redirect_stdout(buffer): |
|
|
result = dia.generate( |
|
|
script, |
|
|
config=config, |
|
|
output_wav=None, |
|
|
verbose=True, |
|
|
**kwargs, |
|
|
) |
|
|
waveform = result.waveform.detach().cpu().numpy() |
|
|
sample_rate = result.sample_rate |
|
|
timestamps = result.timestamps |
|
|
log_text = buffer.getvalue().strip() |
|
|
table = [[w, round(t, 3)] for w, t in timestamps] |
|
|
return (sample_rate, waveform), table, log_text or "Generation finished." |
|
|
|
|
|
|
|
|
def build_interface() -> gr.Blocks: |
|
|
with gr.Blocks( |
|
|
title="Dia2 TTS", css=".compact-turn textarea {min-height: 60px}" |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
"""## Dia2 — Open TTS Model |
|
|
Compose dialogue, attach optional voice prompts, and generate audio (CUDA graphs enabled by default).""" |
|
|
) |
|
|
turn_state = gr.State(INITIAL_TURNS) |
|
|
with gr.Row(equal_height=True): |
|
|
example_dropdown = gr.Dropdown( |
|
|
choices=["(select example)"] + list(EXAMPLES.keys()), |
|
|
label="Examples", |
|
|
value="(select example)", |
|
|
) |
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### Script") |
|
|
controls = [] |
|
|
for idx in range(MAX_TURNS): |
|
|
speaker = "[S1]" if idx % 2 == 0 else "[S2]" |
|
|
box = gr.Textbox( |
|
|
label=f"{speaker} turn {idx + 1}", |
|
|
lines=2, |
|
|
elem_classes=["compact-turn"], |
|
|
placeholder=f"Enter dialogue for {speaker}…", |
|
|
visible=idx < INITIAL_TURNS, |
|
|
) |
|
|
controls.append(box) |
|
|
with gr.Row(): |
|
|
add_btn = gr.Button("Add Turn") |
|
|
remove_btn = gr.Button("Remove Turn") |
|
|
with gr.Group(): |
|
|
gr.Markdown("### Voice Prompts") |
|
|
with gr.Row(): |
|
|
voice_s1 = gr.File( |
|
|
label="[S1] voice (wav/mp3)", type="filepath" |
|
|
) |
|
|
voice_s2 = gr.File( |
|
|
label="[S2] voice (wav/mp3)", type="filepath" |
|
|
) |
|
|
with gr.Group(): |
|
|
gr.Markdown("### Sampling") |
|
|
cfg_scale = gr.Slider( |
|
|
1.0, 8.0, value=6.0, step=0.1, label="CFG Scale" |
|
|
) |
|
|
with gr.Group(): |
|
|
gr.Markdown("#### Text Sampling") |
|
|
text_temperature = gr.Slider( |
|
|
0.1, 1.5, value=0.6, step=0.05, label="Text Temperature" |
|
|
) |
|
|
text_top_k = gr.Slider( |
|
|
1, 200, value=50, step=1, label="Text Top-K" |
|
|
) |
|
|
with gr.Group(): |
|
|
gr.Markdown("#### Audio Sampling") |
|
|
audio_temperature = gr.Slider( |
|
|
0.1, 1.5, value=0.8, step=0.05, label="Audio Temperature" |
|
|
) |
|
|
audio_top_k = gr.Slider( |
|
|
1, 200, value=50, step=1, label="Audio Top-K" |
|
|
) |
|
|
include_prefix = gr.Checkbox( |
|
|
label="Keep prefix audio in output", value=False |
|
|
) |
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Output") |
|
|
audio_out = gr.Audio(label="Waveform", interactive=False) |
|
|
timestamps = gr.Dataframe( |
|
|
headers=["word", "seconds"], label="Timestamps" |
|
|
) |
|
|
log_box = gr.Textbox(label="Logs", lines=8) |
|
|
|
|
|
add_btn.click( |
|
|
lambda c: _add_turn(c), |
|
|
inputs=turn_state, |
|
|
outputs=[turn_state, *controls], |
|
|
) |
|
|
remove_btn.click( |
|
|
lambda c: _remove_turn(c), |
|
|
inputs=turn_state, |
|
|
outputs=[turn_state, *controls], |
|
|
) |
|
|
example_dropdown.change( |
|
|
lambda name, c: _load_example(name, c), |
|
|
inputs=[example_dropdown, turn_state], |
|
|
outputs=[turn_state, *controls, voice_s1, voice_s2], |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
generate_audio, |
|
|
inputs=[ |
|
|
turn_state, |
|
|
*controls, |
|
|
voice_s1, |
|
|
voice_s2, |
|
|
cfg_scale, |
|
|
text_temperature, |
|
|
audio_temperature, |
|
|
text_top_k, |
|
|
audio_top_k, |
|
|
include_prefix, |
|
|
], |
|
|
outputs=[audio_out, timestamps, log_box], |
|
|
) |
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = build_interface() |
|
|
app.queue(default_concurrency_limit=1) |
|
|
app.launch(share=True) |
|
|
|