playmak3r commited on
Commit
c476e57
·
1 Parent(s): 7a62e14

Fix output return type; install missing libs; Update DEFAULT_REF_TEXT; improve error handling in infer function

Browse files
Files changed (2) hide show
  1. app.py +13 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -15,7 +15,7 @@ logging.basicConfig(
15
  DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
16
  DEFAULT_GEN_TEXT = "Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible."
17
  SAMPLES_PATH = os.path.join(os.getcwd(), "samples")
18
- DEFAULT_REF_TEXT = ""
19
 
20
  model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626")
21
 
@@ -35,7 +35,7 @@ def transcribe(file_path: str):
35
 
36
  def infer(
37
  gen_text: str,
38
- ref_text: str = "",
39
  ref_audio_path: str = DEFAULT_REF_PATH,
40
  ) -> tuple[int, np.ndarray]:
41
  """
@@ -45,12 +45,12 @@ def infer(
45
  ref_text (str): The text corresponding to the reference audio.
46
  ref_audio_path (str): The file path to the reference audio.
47
  Returns:
48
- tuple [int, np.ndarray]: A tuple containing the sample rate (24000) and the generated audio waveform as a numpy array.
49
  """
50
 
51
  if gen_text is None or not len(gen_text):
52
- raise Exception("Please insert the new text to synthesize.")
53
- #if ref_audio_path != DEFAULT_REF_PATH and ref_text == DEFAULT_REF_TEXT: ref_text = ""
54
  if not len(ref_text):
55
  ref_text = transcribe(ref_audio_path)
56
 
@@ -58,6 +58,7 @@ def infer(
58
  gr.Info("Starting inference request!")
59
  gr.Info("Encoding reference...")
60
 
 
61
  output = model.generate(
62
  ref_text + gen_text,
63
  audio_prompt=ref_audio_path,
@@ -69,6 +70,11 @@ def infer(
69
  cfg_filter_top_k=50,
70
  )
71
 
 
 
 
 
 
72
  return (44100, output)
73
 
74
 
@@ -76,7 +82,7 @@ demo = gr.Interface(
76
  fn=infer,
77
  inputs=[
78
  gr.Textbox(label="Text to Generate", value=DEFAULT_GEN_TEXT),
79
- gr.Textbox(label="Reference Text (Optional)"),
80
  gr.Audio(type="filepath", label="Reference Audio", value=DEFAULT_REF_PATH),
81
  ],
82
  outputs=gr.Audio(type="numpy", label="Generated Speech"),
@@ -85,4 +91,4 @@ demo = gr.Interface(
85
  )
86
 
87
  if __name__ == "__main__":
88
- demo.queue(max_size=10).launch(allowed_paths=[SAMPLES_PATH], mcp_server=True, inbrowser=True)
 
15
  DEFAULT_REF_PATH = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
16
  DEFAULT_GEN_TEXT = "Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible."
17
  SAMPLES_PATH = os.path.join(os.getcwd(), "samples")
18
+ DEFAULT_REF_TEXT = "That place in the distance, it's huge and dedicated to Lady Shah. It can only mean one thing. I have a hidden place close to the cloister where night orchids bloom."
19
 
20
  model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626")
21
 
 
35
 
36
  def infer(
37
  gen_text: str,
38
+ ref_text: str = DEFAULT_REF_TEXT,
39
  ref_audio_path: str = DEFAULT_REF_PATH,
40
  ) -> tuple[int, np.ndarray]:
41
  """
 
45
  ref_text (str): The text corresponding to the reference audio.
46
  ref_audio_path (str): The file path to the reference audio.
47
  Returns:
48
+ tuple [int, np.ndarray]: A tuple containing the sample rate (44100) and the generated audio waveform as a numpy array.
49
  """
50
 
51
  if gen_text is None or not len(gen_text):
52
+ raise ValueError("Please insert the new text to synthesize.")
53
+ if ref_audio_path != DEFAULT_REF_PATH and ref_text == DEFAULT_REF_TEXT: ref_text = ""
54
  if not len(ref_text):
55
  ref_text = transcribe(ref_audio_path)
56
 
 
58
  gr.Info("Starting inference request!")
59
  gr.Info("Encoding reference...")
60
 
61
+ # ndarray[Unknown, Unknown] | list[ndarray[Unknown, Unknown]]
62
  output = model.generate(
63
  ref_text + gen_text,
64
  audio_prompt=ref_audio_path,
 
70
  cfg_filter_top_k=50,
71
  )
72
 
73
+ if isinstance(output, list):
74
+ output = np.concatenate(output, axis=-1) # Junta os pedaços de áudio
75
+ elif not isinstance(output, np.ndarray):
76
+ output = np.array(output, dtype=np.float32)
77
+
78
  return (44100, output)
79
 
80
 
 
82
  fn=infer,
83
  inputs=[
84
  gr.Textbox(label="Text to Generate", value=DEFAULT_GEN_TEXT),
85
+ gr.Textbox(label="Reference Text (Optional)", value=DEFAULT_REF_TEXT),
86
  gr.Audio(type="filepath", label="Reference Audio", value=DEFAULT_REF_PATH),
87
  ],
88
  outputs=gr.Audio(type="numpy", label="Generated Speech"),
 
91
  )
92
 
93
  if __name__ == "__main__":
94
+ demo.queue(max_size=10).launch(allowed_paths=[SAMPLES_PATH], mcp_server=False, inbrowser=True)
requirements.txt CHANGED
@@ -7,4 +7,5 @@ soundfile>=0.13.1
7
  torchaudio>=2.0.0
8
  torch>=2.0.0
9
  gradio-dialogue>=0.0.4
10
- groq
 
 
7
  torchaudio>=2.0.0
8
  torch>=2.0.0
9
  gradio-dialogue>=0.0.4
10
+ groq
11
+ torchcodec