feat: use token2wav in non-streaming tts
#7
by
airlsyn - opened
- modeling_minicpmo.py +31 -51
modeling_minicpmo.py
CHANGED
|
@@ -252,33 +252,19 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 252 |
assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
|
| 253 |
return model_dir
|
| 254 |
|
| 255 |
-
def init_tts(self,
|
| 256 |
-
if
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
|
| 260 |
-
|
| 261 |
-
try:
|
| 262 |
-
from stepaudio2 import Token2wav
|
| 263 |
-
except ImportError:
|
| 264 |
-
raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
|
| 265 |
-
|
| 266 |
-
model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
|
| 267 |
-
self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
|
| 268 |
-
return self.tts.audio_tokenizer
|
| 269 |
-
else:
|
| 270 |
-
if self.config.tts_config.audio_tokenizer_type != "s3tokenizer":
|
| 271 |
-
logger.warning("audio tokenizer type is set to s3tokenizer")
|
| 272 |
-
self.tts.config.audio_tokenizer_type = "s3tokenizer"
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
|
| 283 |
def get_input_embeddings(self):
|
| 284 |
return self.llm.get_input_embeddings()
|
|
@@ -1336,27 +1322,25 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 1336 |
),
|
| 1337 |
)
|
| 1338 |
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
else:
|
| 1359 |
-
raise NotImplementedError
|
| 1360 |
|
| 1361 |
@torch.inference_mode()
|
| 1362 |
def init_token2wav_cache(self, prompt_speech_16k):
|
|
@@ -2511,11 +2495,7 @@ class MiniCPMODuplex:
|
|
| 2511 |
# Initialize TTS (same as __init__)
|
| 2512 |
enable_float16 = get_param("enable_float16")
|
| 2513 |
n_timesteps = get_param("n_timesteps")
|
| 2514 |
-
instance.model.init_tts(
|
| 2515 |
-
streaming=True,
|
| 2516 |
-
enable_float16=enable_float16,
|
| 2517 |
-
n_timesteps=n_timesteps,
|
| 2518 |
-
)
|
| 2519 |
|
| 2520 |
instance.break_event = threading.Event()
|
| 2521 |
instance.session_stop_event = threading.Event()
|
|
|
|
| 252 |
assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
|
| 253 |
return model_dir
|
| 254 |
|
| 255 |
+
def init_tts(self, model_dir=None, enable_float16=False, n_timesteps=10, **kwargs):
|
| 256 |
+
if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
|
| 257 |
+
logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
|
| 258 |
+
self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
try:
|
| 261 |
+
from stepaudio2 import Token2wav
|
| 262 |
+
except ImportError:
|
| 263 |
+
raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
|
| 264 |
|
| 265 |
+
model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
|
| 266 |
+
self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
|
| 267 |
+
return self.tts.audio_tokenizer
|
| 268 |
|
| 269 |
def get_input_embeddings(self):
|
| 270 |
return self.llm.get_input_embeddings()
|
|
|
|
| 1322 |
),
|
| 1323 |
)
|
| 1324 |
|
| 1325 |
+
import io
|
| 1326 |
+
|
| 1327 |
+
import soundfile as sf
|
| 1328 |
+
|
| 1329 |
+
generated_tokens = outputs.new_ids.squeeze(-1)
|
| 1330 |
+
reference_audio = audio_prompt
|
| 1331 |
+
prompt_wav_path = None
|
| 1332 |
+
if reference_audio is not None:
|
| 1333 |
+
logger.debug("use reference audio in data to generate waveform")
|
| 1334 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
|
| 1335 |
+
prompt_wav_path = tmp_wav.name
|
| 1336 |
+
sf.write(prompt_wav_path, reference_audio, 16000)
|
| 1337 |
+
wav_bytes = self.tts.audio_tokenizer(
|
| 1338 |
+
generated_tokens.squeeze(0).tolist(),
|
| 1339 |
+
prompt_wav_path,
|
| 1340 |
+
)
|
| 1341 |
+
# convert wav bytes back to tensor for caller compatibility
|
| 1342 |
+
waveform, sr = sf.read(io.BytesIO(wav_bytes))
|
| 1343 |
+
return torch.tensor(waveform, dtype=torch.float32)
|
|
|
|
|
|
|
| 1344 |
|
| 1345 |
@torch.inference_mode()
|
| 1346 |
def init_token2wav_cache(self, prompt_speech_16k):
|
|
|
|
| 2495 |
# Initialize TTS (same as __init__)
|
| 2496 |
enable_float16 = get_param("enable_float16")
|
| 2497 |
n_timesteps = get_param("n_timesteps")
|
| 2498 |
+
instance.model.init_tts(enable_float16=enable_float16, n_timesteps=n_timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2499 |
|
| 2500 |
instance.break_event = threading.Event()
|
| 2501 |
instance.session_stop_event = threading.Event()
|