|
|
import os, logging, sys |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from groq import Groq |
|
|
from dia.model import Dia |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
|
stream=sys.stdout |
|
|
) |
|
|
|
|
|
DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" |
|
|
DEFAULT_GEN_TEXT = "Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible." |
|
|
SAMPLES_PATH = os.path.join(os.getcwd(), "samples") |
|
|
DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom." |
|
|
|
|
|
model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626") |
|
|
|
|
|
def transcribe(file_path: str): |
|
|
client = Groq() |
|
|
with open(file_path, "rb") as file: |
|
|
transcription = client.audio.transcriptions.create( |
|
|
file=(file_path, file.read()), |
|
|
model="whisper-large-v3-turbo", |
|
|
temperature=0, |
|
|
response_format="verbose_json", |
|
|
) |
|
|
|
|
|
if len(transcription.text) <= 0: logging.warn("Error while transcripting the reference audio.") |
|
|
else: logging.info(f"Transcribed: {transcription.text}") |
|
|
return transcription.text |
|
|
|
|
|
def infer( |
|
|
gen_text: str, |
|
|
ref_text: str = DEFAULT_REF_TEXT, |
|
|
ref_audio_path: str = DEFAULT_REF_PATH, |
|
|
) -> tuple[int, np.ndarray]: |
|
|
""" |
|
|
Generates speech using NeuTTS-Air given a reference audio and text, and new text to synthesize. |
|
|
Args: |
|
|
gen_text (str): The new text to synthesize. |
|
|
ref_text (str): The text corresponding to the reference audio. |
|
|
ref_audio_path (str): The file path to the reference audio. |
|
|
Returns: |
|
|
tuple [int, np.ndarray]: A tuple containing the sample rate (44100) and the generated audio waveform as a numpy array. |
|
|
""" |
|
|
|
|
|
if gen_text is None or not len(gen_text): |
|
|
raise ValueError("Please insert the new text to synthesize.") |
|
|
if ref_audio_path != DEFAULT_REF_PATH and ref_text == DEFAULT_REF_TEXT: ref_text = "" |
|
|
if not len(ref_text): |
|
|
ref_text = transcribe(ref_audio_path) |
|
|
|
|
|
logging.info(f"Using reference: {ref_audio_path}") |
|
|
gr.Info("Starting inference request!") |
|
|
gr.Info("Encoding reference...") |
|
|
|
|
|
|
|
|
output = model.generate( |
|
|
ref_text + gen_text, |
|
|
audio_prompt=ref_audio_path, |
|
|
use_torch_compile=False, |
|
|
verbose=True, |
|
|
cfg_scale=4.0, |
|
|
temperature=1.8, |
|
|
top_p=0.90, |
|
|
cfg_filter_top_k=50, |
|
|
) |
|
|
|
|
|
if isinstance(output, list): |
|
|
output = np.concatenate(output, axis=-1) |
|
|
elif not isinstance(output, np.ndarray): |
|
|
output = np.array(output, dtype=np.float32) |
|
|
|
|
|
return (44100, output) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=infer, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Text to Generate", value=DEFAULT_GEN_TEXT), |
|
|
gr.Textbox(label="Reference Text (Optional)", value=DEFAULT_REF_TEXT), |
|
|
gr.Audio(type="filepath", label="Reference Audio", value=DEFAULT_REF_PATH), |
|
|
], |
|
|
outputs=gr.Audio(type="numpy", label="Generated Speech"), |
|
|
title="NeuTTS-Air☁️", |
|
|
description="Upload a reference audio sample, provide the reference text, and enter new text to synthesize." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10).launch(allowed_paths=[SAMPLES_PATH], mcp_server=False, inbrowser=True) |