feat: use token2wav in non-streaming tts

#7
by airlsyn - opened
Files changed (1) hide show
  1. modeling_minicpmo.py +31 -51
modeling_minicpmo.py CHANGED
@@ -252,33 +252,19 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
252
  assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
253
  return model_dir
254
 
255
- def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_timesteps=10):
256
- if streaming:
257
- if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
258
- logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
259
- self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
260
-
261
- try:
262
- from stepaudio2 import Token2wav
263
- except ImportError:
264
- raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
265
-
266
- model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
267
- self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
268
- return self.tts.audio_tokenizer
269
- else:
270
- if self.config.tts_config.audio_tokenizer_type != "s3tokenizer":
271
- logger.warning("audio tokenizer type is set to s3tokenizer")
272
- self.tts.config.audio_tokenizer_type = "s3tokenizer"
273
 
274
- try:
275
- from cosyvoice.cli.cosyvoice import CosyVoice2
276
- except ImportError:
277
- raise ImportError("Please install CosyVoice via: pip install minicpmo-utils[all]")
278
 
279
- model_dir = self._ensure_asset_dir("assets/CosyVoice2-0.5B", model_dir)
280
- self.tts.audio_tokenizer = CosyVoice2(model_dir=model_dir, load_jit=False, load_trt=False, fp16=False)
281
- return self.tts.audio_tokenizer
282
 
283
  def get_input_embeddings(self):
284
  return self.llm.get_input_embeddings()
@@ -1336,27 +1322,25 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
1336
  ),
1337
  )
1338
 
1339
- if self.tts.config.audio_tokenizer_type == "s3tokenizer":
1340
- generated_tokens = outputs.new_ids.squeeze(-1)
1341
- reference_audio = audio_prompt
1342
- prompt_speech_16k = None
1343
- if reference_audio is not None:
1344
- logger.debug("use reference audio in data to generate waveform")
1345
- prompt_speech_16k = torch.tensor(reference_audio).unsqueeze(0)
1346
-
1347
- for i, j in enumerate(
1348
- self.tts.audio_tokenizer.token2wav(
1349
- speech_token=generated_tokens,
1350
- speech_token_len=torch.tensor([generated_tokens.shape[1]], device=generated_tokens.device),
1351
- prompt_speech_16k=prompt_speech_16k,
1352
- stream=False,
1353
- )
1354
- ):
1355
- waveform_pred = j["tts_speech"]
1356
- waveform_sample_rate = self.tts.audio_tokenizer.sample_rate # 24000 here, not 16000 input.
1357
- return waveform_pred[0]
1358
- else:
1359
- raise NotImplementedError
1360
 
1361
  @torch.inference_mode()
1362
  def init_token2wav_cache(self, prompt_speech_16k):
@@ -2511,11 +2495,7 @@ class MiniCPMODuplex:
2511
  # Initialize TTS (same as __init__)
2512
  enable_float16 = get_param("enable_float16")
2513
  n_timesteps = get_param("n_timesteps")
2514
- instance.model.init_tts(
2515
- streaming=True,
2516
- enable_float16=enable_float16,
2517
- n_timesteps=n_timesteps,
2518
- )
2519
 
2520
  instance.break_event = threading.Event()
2521
  instance.session_stop_event = threading.Event()
 
252
  assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
253
  return model_dir
254
 
255
+ def init_tts(self, model_dir=None, enable_float16=False, n_timesteps=10, **kwargs):
256
+ if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
257
+ logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
258
+ self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ try:
261
+ from stepaudio2 import Token2wav
262
+ except ImportError:
263
+ raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
264
 
265
+ model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
266
+ self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
267
+ return self.tts.audio_tokenizer
268
 
269
  def get_input_embeddings(self):
270
  return self.llm.get_input_embeddings()
 
1322
  ),
1323
  )
1324
 
1325
+ import io
1326
+
1327
+ import soundfile as sf
1328
+
1329
+ generated_tokens = outputs.new_ids.squeeze(-1)
1330
+ reference_audio = audio_prompt
1331
+ prompt_wav_path = None
1332
+ if reference_audio is not None:
1333
+ logger.debug("use reference audio in data to generate waveform")
1334
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
1335
+ prompt_wav_path = tmp_wav.name
1336
+ sf.write(prompt_wav_path, reference_audio, 16000)
1337
+ wav_bytes = self.tts.audio_tokenizer(
1338
+ generated_tokens.squeeze(0).tolist(),
1339
+ prompt_wav_path,
1340
+ )
1341
+ # convert wav bytes back to tensor for caller compatibility
1342
+ waveform, sr = sf.read(io.BytesIO(wav_bytes))
1343
+ return torch.tensor(waveform, dtype=torch.float32)
 
 
1344
 
1345
  @torch.inference_mode()
1346
  def init_token2wav_cache(self, prompt_speech_16k):
 
2495
  # Initialize TTS (same as __init__)
2496
  enable_float16 = get_param("enable_float16")
2497
  n_timesteps = get_param("n_timesteps")
2498
+ instance.model.init_tts(enable_float16=enable_float16, n_timesteps=n_timesteps)
 
 
 
 
2499
 
2500
  instance.break_event = threading.Event()
2501
  instance.session_stop_event = threading.Event()