Dia-1.6B / app.py
playmak3r's picture
Fix output return type; install missing libs; Update DEFAULT_REF_TEXT; improve error handling in infer function
c476e57
raw
history blame
3.64 kB
import os, logging, sys
import gradio as gr
import numpy as np
from groq import Groq
from dia.model import Dia
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
stream=sys.stdout
)
DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
DEFAULT_GEN_TEXT = "Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible."
SAMPLES_PATH = os.path.join(os.getcwd(), "samples")
DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom."
model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626")
def transcribe(file_path: str):
client = Groq()
with open(file_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(file_path, file.read()),
model="whisper-large-v3-turbo",
temperature=0,
response_format="verbose_json",
)
if len(transcription.text) <= 0: logging.warn("Error while transcripting the reference audio.")
else: logging.info(f"Transcribed: {transcription.text}")
return transcription.text
def infer(
gen_text: str,
ref_text: str = DEFAULT_REF_TEXT,
ref_audio_path: str = DEFAULT_REF_PATH,
) -> tuple[int, np.ndarray]:
"""
Generates speech using NeuTTS-Air given a reference audio and text, and new text to synthesize.
Args:
gen_text (str): The new text to synthesize.
ref_text (str): The text corresponding to the reference audio.
ref_audio_path (str): The file path to the reference audio.
Returns:
tuple [int, np.ndarray]: A tuple containing the sample rate (44100) and the generated audio waveform as a numpy array.
"""
if gen_text is None or not len(gen_text):
raise ValueError("Please insert the new text to synthesize.")
if ref_audio_path != DEFAULT_REF_PATH and ref_text == DEFAULT_REF_TEXT: ref_text = ""
if not len(ref_text):
ref_text = transcribe(ref_audio_path)
logging.info(f"Using reference: {ref_audio_path}")
gr.Info("Starting inference request!")
gr.Info("Encoding reference...")
# ndarray[Unknown, Unknown] | list[ndarray[Unknown, Unknown]]
output = model.generate(
ref_text + gen_text,
audio_prompt=ref_audio_path,
use_torch_compile=False,
verbose=True,
cfg_scale=4.0,
temperature=1.8,
top_p=0.90,
cfg_filter_top_k=50,
)
if isinstance(output, list):
output = np.concatenate(output, axis=-1) # Junta os pedaços de áudio
elif not isinstance(output, np.ndarray):
output = np.array(output, dtype=np.float32)
return (44100, output)
demo = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(label="Text to Generate", value=DEFAULT_GEN_TEXT),
gr.Textbox(label="Reference Text (Optional)", value=DEFAULT_REF_TEXT),
gr.Audio(type="filepath", label="Reference Audio", value=DEFAULT_REF_PATH),
],
outputs=gr.Audio(type="numpy", label="Generated Speech"),
title="NeuTTS-Air☁️",
description="Upload a reference audio sample, provide the reference text, and enter new text to synthesize."
)
if __name__ == "__main__":
demo.queue(max_size=10).launch(allowed_paths=[SAMPLES_PATH], mcp_server=False, inbrowser=True)