| | import asyncio |
| | import functools |
| | import logging |
| | import random |
| | import time |
| | import uuid |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Optional, List, Tuple, Union, AsyncGenerator, Dict, Any |
| | from concurrent.futures import ThreadPoolExecutor |
| |
|
| | import librosa |
| | import torch |
| | import numpy as np |
| | import torchaudio |
| | import sounddevice as sd |
| | import io |
| | from torch import nn |
| | from IPython.display import Audio, display |
| |
|
| | from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt, RequestOutput |
| | from vllm.multimodal import MultiModalDataDict |
| | from vllm.utils import Counter |
| |
|
| | from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder |
| |
|
| | from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder |
| | from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler |
| |
|
| | from .xtts2_config import XTTSConfig, XTTSGPTConfig |
| | from .tokenizer import XTTSTokenizerFast |
| |
|
| | from ..xtts2_gpt.xtts2_gpt_modeling import LearnedPositionEmbeddings |
| |
|
| |
|
| | def wav_to_mel_cloning( |
| | wav, |
| | mel_norms_file="../experiments/clips_mel_norms.pth", |
| | mel_norms=None, |
| | device=torch.device("cpu"), |
| | n_fft=4096, |
| | hop_length=1024, |
| | win_length=4096, |
| | power=2, |
| | normalized=False, |
| | sample_rate=22050, |
| | f_min=0, |
| | f_max=8000, |
| | n_mels=80, |
| | ): |
| | mel_stft = torchaudio.transforms.MelSpectrogram( |
| | n_fft=n_fft, |
| | hop_length=hop_length, |
| | win_length=win_length, |
| | power=power, |
| | normalized=normalized, |
| | sample_rate=sample_rate, |
| | f_min=f_min, |
| | f_max=f_max, |
| | n_mels=n_mels, |
| | norm="slaney", |
| | ).to(device) |
| | wav = wav.to(device) |
| | mel = mel_stft(wav) |
| | mel = torch.log(torch.clamp(mel, min=1e-5)) |
| | if mel_norms is None: |
| | mel_norms = torch.load(mel_norms_file, map_location=device) |
| | mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) |
| | return mel |
| |
|
| |
|
| | def load_audio(audiopath, sampling_rate): |
| | audio, lsr = torchaudio.load(audiopath) |
| |
|
| | |
| | if audio.size(0) != 1: |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| |
|
| | if lsr != sampling_rate: |
| | audio = torchaudio.functional.resample(audio, lsr, sampling_rate) |
| |
|
| | |
| | audio.clip_(-1, 1) |
| | return audio |
| |
|
| |
|
| | @dataclass |
| | class XTTSRequest: |
| | """Container for XTTS inference request data""" |
| | request_id: str |
| | text: Union[AsyncGenerator[str, None], str] |
| | language: str |
| | speaker_file: str |
| | generate_every_n_chars: Optional[int] = None |
| | temperature: float = 0.75 |
| | top_p: float = 0.85 |
| | top_k: int = 50 |
| | repetition_penalty: float = 5.0 |
| | length_penalty: float = 1.0 |
| | do_sample: bool = True |
| | max_ref_length: int = 60 |
| | gpt_cond_len: int = 30 |
| | gpt_cond_chunk_len: int = 4 |
| |
|
| |
|
| | import threading |
| |
|
| | class HiddenStatesCollector: |
| | def __init__(self): |
| | self.outputs = {} |
| | self.lock = threading.Lock() |
| |
|
| | def __call__(self, outputs: Optional[torch.Tensor], request_id: str): |
| | """Save outputs for a specific request""" |
| | with self.lock: |
| | if request_id not in self.outputs: |
| | self.outputs[request_id] = [] |
| | self.outputs[request_id].append(outputs) |
| |
|
| | def get_hidden_states(self, request_id) -> Optional[torch.Tensor]: |
| | with self.lock: |
| | outputs = self.outputs.pop(request_id, None) |
| | if outputs is not None: |
| | outputs = torch.cat(outputs, dim=0) |
| | return outputs |
| |
|
| | def bind_to_request(self, request_id: str): |
| | def bound_collector(outputs: Optional[torch.Tensor], _request_id: str = None): |
| | self(outputs, request_id) |
| | return bound_collector |
| |
|
| | class ExtendedSamplingParams(SamplingParams, kw_only=True): |
| | """Extended sampling parameters that allows additional fields while maintaining compatibility with SamplingParams. |
| | |
| | This class inherits from SamplingParams and allows adding new required fields |
| | without conflicting with the base class's optional fields ordering. |
| | """ |
| | hidden_state_collector: HiddenStatesCollector |
| |
|
| |
|
| | class LogitsRepetitionPenalizer: |
| | """A logits processor that applies repetition penalty to prevent repetitive text generation.""" |
| |
|
| | def __init__(self, repetition_penalty: float): |
| | if repetition_penalty < 0: |
| | raise ValueError("Repetition penalty must be non-negative") |
| | self.repetition_penalty = repetition_penalty |
| |
|
| | def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: |
| | """Apply repetition penalty to the logits based on previous tokens.""" |
| | |
| | if self.repetition_penalty == 1.0 or not token_ids: |
| | return logits |
| |
|
| | |
| | repeated_tokens = torch.tensor(token_ids, |
| | device=logits.device, |
| | dtype=torch.long) |
| |
|
| | |
| | repeated_logits = logits[repeated_tokens] |
| |
|
| | |
| | repeated_logits = torch.where( |
| | repeated_logits > 0, |
| | repeated_logits / self.repetition_penalty, |
| | repeated_logits * self.repetition_penalty |
| | ) |
| |
|
| | |
| | logits[repeated_tokens] = repeated_logits |
| |
|
| | return logits |
| |
|
| |
|
| | @dataclass |
| | class XTTSOutput: |
| | """Container for XTTS inference output with integrated audio utilities""" |
| | request_id: str |
| | wav: np.ndarray |
| | sample_rate: int = 24000 |
| |
|
| | def to_tensor(self) -> torch.Tensor: |
| | """Convert numpy array to torch tensor""" |
| | if isinstance(self.wav, np.ndarray): |
| | return torch.from_numpy(self.wav) |
| | return self.wav |
| |
|
| | def to_bytes(self, format: str = 'wav', sample_width: int = 2) -> bytes: |
| | """Convert audio to bytes format. |
| | |
| | Args: |
| | format: Output format ('wav' or 'raw') |
| | sample_width: Bit depth (1, 2, or 4 bytes per sample) |
| | |
| | Returns: |
| | Audio data as bytes |
| | """ |
| | |
| | wav_tensor = self.to_tensor() |
| |
|
| | |
| | if wav_tensor.dim() == 1: |
| | wav_tensor = wav_tensor.unsqueeze(0) |
| |
|
| | |
| | wav_tensor = torch.clamp(wav_tensor, -1.0, 1.0) |
| |
|
| | if format == 'wav': |
| | buffer = io.BytesIO() |
| | torchaudio.save( |
| | buffer, |
| | wav_tensor, |
| | self.sample_rate, |
| | format="wav", |
| | encoding="PCM_S" if sample_width == 2 else "PCM_F", |
| | bits_per_sample=sample_width * 8 |
| | ) |
| | return buffer.getvalue() |
| |
|
| | elif format == 'raw': |
| | |
| | if sample_width == 2: |
| | wav_tensor = (wav_tensor * 32767).to(torch.int16) |
| | elif sample_width == 4: |
| | wav_tensor = (wav_tensor * 2147483647).to(torch.int32) |
| | else: |
| | wav_tensor = (wav_tensor * 127).to(torch.int8) |
| | return wav_tensor.cpu().numpy().tobytes() |
| |
|
| | else: |
| | raise ValueError(f"Unsupported format: {format}") |
| |
|
| | def save(self, |
| | filename: Union[str, Path], |
| | sample_rate: Optional[int] = None, |
| | format: Optional[str] = None) -> None: |
| | """Save audio to file. |
| | |
| | Args: |
| | filename: Output filename |
| | sample_rate: Optional new sample rate for resampling |
| | format: Optional format override (default: inferred from extension) |
| | """ |
| | wav_tensor = self.to_tensor() |
| | if wav_tensor.dim() == 1: |
| | wav_tensor = wav_tensor.unsqueeze(0) |
| |
|
| | |
| | if sample_rate and sample_rate != self.sample_rate: |
| | wav_tensor = torchaudio.functional.resample( |
| | wav_tensor, |
| | orig_freq=self.sample_rate, |
| | new_freq=sample_rate |
| | ) |
| | else: |
| | sample_rate = self.sample_rate |
| |
|
| | torchaudio.save( |
| | filename, |
| | wav_tensor, |
| | sample_rate, |
| | format=format |
| | ) |
| |
|
| | def resample(self, new_sample_rate: int) -> 'XTTSOutput': |
| | """Create new XTTSOutput with resampled audio. |
| | |
| | Args: |
| | new_sample_rate: Target sample rate |
| | |
| | Returns: |
| | New XTTSOutput instance with resampled audio |
| | """ |
| | wav_tensor = self.to_tensor() |
| | if wav_tensor.dim() == 1: |
| | wav_tensor = wav_tensor.unsqueeze(0) |
| |
|
| | resampled = torchaudio.functional.resample( |
| | wav_tensor, |
| | orig_freq=self.sample_rate, |
| | new_freq=new_sample_rate |
| | ) |
| |
|
| | return XTTSOutput( |
| | request_id=self.request_id, |
| | wav=resampled.squeeze().numpy(), |
| | sample_rate=new_sample_rate |
| | ) |
| |
|
| | def get_info(self) -> Tuple[int, int, float]: |
| | """Get audio information. |
| | |
| | Returns: |
| | Tuple of (number of samples, sample rate, duration in seconds) |
| | """ |
| | n_samples = len(self.wav) |
| | duration = n_samples / self.sample_rate |
| | return n_samples, self.sample_rate, duration |
| |
|
| | @classmethod |
| | def from_tensor(cls, request_id: str, tensor: torch.Tensor, sample_rate: int = 24000) -> 'XTTSOutput': |
| | """Create XTTSOutput from torch tensor. |
| | |
| | Args: |
| | request_id: Request identifier |
| | tensor: Audio tensor |
| | sample_rate: Sample rate of the audio |
| | |
| | Returns: |
| | New XTTSOutput instance |
| | """ |
| | return cls( |
| | request_id=request_id, |
| | wav=tensor.squeeze().cpu().numpy(), |
| | sample_rate=sample_rate |
| | ) |
| |
|
| | @classmethod |
| | def from_file(cls, request_id: str, filename: Union[str, Path]) -> 'XTTSOutput': |
| | """Create XTTSOutput from audio file. |
| | |
| | Args: |
| | request_id: Request identifier |
| | filename: Path to audio file |
| | |
| | Returns: |
| | New XTTSOutput instance |
| | """ |
| | wav_tensor, sample_rate = torchaudio.load(filename) |
| | return cls.from_tensor(request_id, wav_tensor, sample_rate) |
| |
|
| | def play(self) -> None: |
| | """Play the audio through the default sound device. |
| | For use in regular Python scripts/applications.""" |
| | |
| | if isinstance(self.wav, torch.Tensor): |
| | audio_data = self.wav.cpu().numpy() |
| | else: |
| | audio_data = self.wav |
| |
|
| | |
| | if audio_data.dtype != np.float32: |
| | audio_data = audio_data.astype(np.float32) |
| | audio_data = np.clip(audio_data, -1.0, 1.0) |
| |
|
| | |
| | sd.play(audio_data, self.sample_rate) |
| | sd.wait() |
| |
|
| | def display(self) -> Optional[Audio]: |
| | """Display audio player in Jupyter notebook. |
| | Returns Audio widget if in notebook, None otherwise.""" |
| | try: |
| | |
| | audio_bytes = self.to_bytes(format='wav') |
| |
|
| | |
| | audio_widget = Audio(audio_bytes, rate=self.sample_rate, autoplay=False) |
| | display(audio_widget) |
| | return audio_widget |
| | except Exception as e: |
| | print(f"Could not display audio widget: {str(e)}") |
| | print("Try using .play() method instead") |
| | return None |
| |
|
| | def preview(self) -> None: |
| | """Smart play method that chooses appropriate playback method.""" |
| | try: |
| | |
| | if self.display() is None: |
| | |
| | self.play() |
| | except Exception as e: |
| | print(f"Error playing audio: {str(e)}") |
| |
|
| |
|
| | class Xtts(nn.Module): |
| | """Async XTTS model implementation using VLLM's AsyncEngine.""" |
| |
|
| | def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs): |
| | super().__init__() |
| |
|
| | self.hifi_config = hifi_config |
| | self.gpt_config = gpt_config |
| | self.mel_bos_token_id = gpt_config.start_audio_token |
| | self.mel_eos_token_id = gpt_config.stop_audio_token |
| | self.tp = tensor_parallel_size |
| | self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt") |
| | self.request_counter = Counter() |
| | self.executor = ThreadPoolExecutor(max_workers=4) |
| | self.hidden_states_collector = HiddenStatesCollector() |
| |
|
| | |
| | self.register_buffer("mel_stats", torch.ones(80)) |
| |
|
| | |
| | self.conditioning_encoder = ConditioningEncoder( |
| | gpt_config.audio_config.mel_channels, |
| | gpt_config.hidden_size, |
| | num_attn_heads=gpt_config.num_attention_heads |
| | ) |
| |
|
| | self.text_embedding = nn.Embedding( |
| | gpt_config.number_text_tokens, |
| | gpt_config.hidden_size |
| | ) |
| |
|
| | self.text_pos_embedding = ( |
| | LearnedPositionEmbeddings( |
| | gpt_config.max_text_tokens + 2, |
| | gpt_config.hidden_size, |
| | supports_pp=False |
| | ) |
| | if gpt_config.max_audio_tokens != -1 |
| | else functools.partial(gpt_config.null_position_embeddings, dim=gpt_config.hidden_size) |
| | ) |
| |
|
| | if gpt_config.use_perceiver_resampler: |
| | self.conditioning_perceiver = PerceiverResampler( |
| | dim=gpt_config.hidden_size, |
| | depth=2, |
| | dim_context=gpt_config.hidden_size, |
| | num_latents=32, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | use_flash_attn=False, |
| | ) |
| |
|
| | |
| | self.hifigan_decoder = HifiDecoder( |
| | input_sample_rate=self.hifi_config.input_sample_rate, |
| | output_sample_rate=self.hifi_config.output_sample_rate, |
| | output_hop_length=self.hifi_config.output_hop_length, |
| | ar_mel_length_compression=self.hifi_config.gpt_code_stride_len, |
| | decoder_input_dim=self.hifi_config.decoder_input_dim, |
| | d_vector_dim=self.hifi_config.d_vector_dim, |
| | cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer, |
| | ) |
| |
|
| | |
| | self.text_head = nn.Linear(gpt_config.hidden_size, gpt_config.number_text_tokens, bias=True) |
| | self.final_norm = nn.LayerNorm(gpt_config.hidden_size, eps=1e-5, bias=True) |
| |
|
| | |
| | self.init_vllm_engine() |
| |
|
| | |
| | self.max_concurrency = 10 |
| | self.semaphore = asyncio.BoundedSemaphore(self.max_concurrency) |
| |
|
| | def half(self): |
| | |
| | return |
| |
|
| | def to(self, *args, **kwargs): |
| | |
| | dtype = kwargs.get('dtype', None) |
| | if dtype == torch.float16 or dtype == torch.bfloat16: |
| | kwargs['dtype'] = torch.float32 |
| | elif len(args) > 0 and (args[0] == torch.float16 or args[0] == torch.bfloat16): |
| | args = list(args) |
| | args[0] = torch.float32 |
| | args = tuple(args) |
| | return super().to(*args, **kwargs) |
| |
|
| | @property |
| | def device(self): |
| | """Get the current device of the model.""" |
| | return next(self.parameters()).device |
| |
|
| | @property |
| | def dtype(self): |
| | """Get the current dtype of the model.""" |
| | return next(self.parameters()).dtype |
| |
|
| | @staticmethod |
| | def get_memory_percentage(memory: int) -> float: |
| | """Get memory percentage.""" |
| | total_memory = torch.cuda.get_device_properties(0).total_memory |
| | reserved_memory = torch.cuda.memory_reserved(0) |
| | allocated_memory = torch.cuda.memory_allocated(0) |
| | available_memory = total_memory - reserved_memory - allocated_memory |
| | return memory / available_memory |
| |
|
| | def init_vllm_engine(self): |
| | """Initialize models with AsyncVLLMEngine.""" |
| | engine_args = AsyncEngineArgs( |
| | model="AstraMindAI/xtts2-gpt", |
| | tensor_parallel_size=self.tp, |
| | dtype="auto", |
| | disable_log_stats=True, |
| | max_model_len=self.gpt_config.max_text_tokens + self.gpt_config.max_audio_tokens, |
| | gpu_memory_utilization=self.get_memory_percentage(3 * 1024 ** 3), |
| | trust_remote_code=True, |
| | enforce_eager=True, |
| | limit_mm_per_prompt={"audio": 1}, |
| | max_num_batched_tokens=7296, |
| | ) |
| |
|
| | self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: str, |
| | torch_dtype: torch.dtype = torch.float32, |
| | device_map: Optional[str] = "auto", |
| | tensor_parallel_size: int = 1, |
| | **kwargs, |
| | ) -> "Xtts": |
| | """Load pretrained XTTS model from HuggingFace Hub.""" |
| | from huggingface_hub import hf_hub_download |
| | import json |
| | import os |
| |
|
| | |
| | if not os.path.exists(pretrained_model_name_or_path): |
| | config_file = hf_hub_download( |
| | repo_id=pretrained_model_name_or_path, |
| | filename="config.json" |
| | ) |
| | with open(config_file, 'r') as f: |
| | config = json.load(f) |
| |
|
| | else: |
| | |
| | with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f: |
| | config = json.load(f) |
| |
|
| | |
| | gpt_config = XTTSGPTConfig(**config['gpt_config']) |
| | hifi_config = XTTSConfig(**config) |
| |
|
| | |
| | model = cls( |
| | hifi_config=hifi_config, |
| | gpt_config=gpt_config, |
| | tensor_parallel_size=tensor_parallel_size, |
| | **kwargs |
| | ) |
| |
|
| | |
| | if not os.path.exists(pretrained_model_name_or_path): |
| | hifigan_weights = hf_hub_download( |
| | repo_id=pretrained_model_name_or_path, |
| | filename="xtts-v2.safetensors" |
| | ) |
| | else: |
| | hifigan_weights = os.path.join(pretrained_model_name_or_path, "xtts-v2.safetensors") |
| |
|
| | import safetensors.torch |
| |
|
| | |
| | hifigan_state = safetensors.torch.load_file(hifigan_weights) |
| | model.load_state_dict(hifigan_state) |
| |
|
| | |
| | model.config = config |
| |
|
| | |
| | model = model.to(torch_dtype) |
| | model = model.to('cuda') |
| |
|
| | return model |
| |
|
| | @staticmethod |
| | def load_audio(audio_path: Union[str, Path], sampling_rate: int = 22050) -> torch.Tensor: |
| | audio, lsr = torchaudio.load(audio_path) |
| |
|
| | |
| | if audio.size(0) != 1: |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| |
|
| | if lsr != sampling_rate: |
| | audio = torchaudio.functional.resample(audio, lsr, sampling_rate) |
| |
|
| | |
| | audio.clip_(-1, 1) |
| | return audio |
| |
|
| | @torch.inference_mode() |
| | def get_speaker_embedding(self, audio, sr): |
| | audio_16k = torchaudio.functional.resample(audio, sr, 16000) |
| | return ( |
| | self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) |
| | .unsqueeze(-1) |
| | .to(self.device) |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): |
| | """Compute the conditioning latents for the GPT model from the given audio.""" |
| | if sr != 22050: |
| | audio = torchaudio.functional.resample(audio, sr, 22050) |
| | if length > 0: |
| | audio = audio[:, : 22050 * length] |
| | if self.gpt_config.use_perceiver_resampler: |
| | style_embs = [] |
| | for i in range(0, audio.shape[1], 22050 * chunk_length): |
| | audio_chunk = audio[:, i: i + 22050 * chunk_length] |
| |
|
| | |
| | if audio_chunk.size(-1) < 22050 * 0.33: |
| | continue |
| |
|
| | mel_chunk = wav_to_mel_cloning( |
| | audio_chunk, |
| | mel_norms=self.mel_stats.cpu(), |
| | n_fft=2048, |
| | hop_length=256, |
| | win_length=1024, |
| | power=2, |
| | normalized=False, |
| | sample_rate=22050, |
| | f_min=0, |
| | f_max=8000, |
| | n_mels=80, |
| | ) |
| | style_emb = self.get_style_emb(mel_chunk.to(self.device), None) |
| | style_embs.append(style_emb) |
| |
|
| | |
| | cond_latent = torch.stack(style_embs).mean(dim=0) |
| | else: |
| | mel = wav_to_mel_cloning( |
| | audio, |
| | mel_norms=self.mel_stats.cpu(), |
| | n_fft=4096, |
| | hop_length=1024, |
| | win_length=4096, |
| | power=2, |
| | normalized=False, |
| | sample_rate=22050, |
| | f_min=0, |
| | f_max=8000, |
| | n_mels=80, |
| | ) |
| | cond_latent = self.get_style_emb(mel.to(self.device)) |
| | return cond_latent.transpose(1, 2) |
| |
|
| | @torch.inference_mode() |
| | def get_conditioning_latents( |
| | self, |
| | audio_path, |
| | max_ref_length=30, |
| | gpt_cond_len=6, |
| | gpt_cond_chunk_len=6, |
| | librosa_trim_db=None, |
| | sound_norm_refs=False, |
| | load_sr=22050, |
| | ): |
| | """Get the conditioning latents for the GPT model from the given audio.""" |
| | |
| | assert isinstance(audio_path, str) or isinstance(audio_path, list), "audio_path must be a string or a list." |
| |
|
| | if not isinstance(audio_path, list): |
| | audio_paths = [audio_path] |
| | else: |
| | audio_paths = audio_path |
| |
|
| | speaker_embeddings = [] |
| | audios = [] |
| | for file_path in audio_paths: |
| | audio = load_audio(file_path, load_sr) |
| | audio = audio[:, : load_sr * max_ref_length].to(self.device).to(self.dtype) |
| | if sound_norm_refs: |
| | audio = (audio / torch.abs(audio).max()) * 0.75 |
| | if librosa_trim_db is not None: |
| | audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] |
| |
|
| | |
| | speaker_embedding = self.get_speaker_embedding(audio, load_sr) |
| | speaker_embeddings.append(speaker_embedding) |
| |
|
| | audios.append(audio) |
| |
|
| | |
| | full_audio = torch.cat(audios, dim=-1) |
| | gpt_cond_latents = self.get_gpt_cond_latents( |
| | full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len |
| | ) |
| |
|
| | speaker_embedding = torch.stack(speaker_embeddings) |
| | speaker_embedding = speaker_embedding.mean(dim=0) |
| |
|
| | return gpt_cond_latents, speaker_embedding |
| |
|
| | def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor: |
| | """Get conditioning embeddings from mel spectrograms.""" |
| | if not return_latent: |
| | if cond_input.ndim == 4: |
| | cond_input = cond_input.squeeze(1) |
| | conds = self.conditioning_encoder(cond_input) |
| |
|
| | if hasattr(self, 'conditioning_perceiver'): |
| | conds = self.conditioning_perceiver( |
| | conds.permute(0, 2, 1) |
| | ).transpose(1, 2) |
| | else: |
| | conds = cond_input.unsqueeze(1) |
| | return conds |
| |
|
| | async def prepare_text_tokens_async(self, text: str, language: str, split_text=False) \ |
| | -> Tuple[List[Union[int, List[int]]], List[torch.Tensor]]: |
| | """Prepare text tokens for the given text and language.""" |
| |
|
| | async def elaborate_tokens(text_tokens: List[int]) -> torch.Tensor: |
| | text_tokens.insert(0, self.tokenizer.bos_token_id) |
| | text_tokens.append(self.tokenizer.eos_token_id) |
| | return torch.tensor(text_tokens).unsqueeze(0).to(self.text_embedding.weight.device) |
| |
|
| | async def embed_tokens(text_tokens: Union[torch.Tensor, List[torch.Tensor]]) -> List[torch.Tensor]: |
| | embeds = [] |
| | if isinstance(text_tokens, list): |
| | for list_element in text_tokens: |
| | embeds.append(self.text_embedding(list_element) + self.text_pos_embedding(list_element)) |
| | else: |
| | embeds.append(self.text_embedding(text_tokens) + self.text_pos_embedding(text_tokens)) |
| | return embeds |
| |
|
| | fake_tokens_for_audio_generation = [] |
| | if split_text: |
| | text_tokens = self.tokenizer.batch_encode_with_split(text, lang=[language]) |
| | for idx, text_token in enumerate(text_tokens): |
| | text_tokens[idx] = await elaborate_tokens(text_token) |
| | fake_tokens_for_audio_generation.append([1] * len(text_token)) |
| | else: |
| | text_tokens = self.tokenizer.batch_encode(text, lang=[language]) |
| | text_tokens = await elaborate_tokens(text_tokens) |
| | fake_tokens_for_audio_generation = [1] * len(text_tokens) |
| | return fake_tokens_for_audio_generation, await embed_tokens(text_tokens) |
| |
|
| | async def prepare_inputs_async(self, text: str, language: str, speaker_file: Union[str, Path], |
| | max_ref_length: int, gpt_cond_len: int, gpt_cond_chunk_len: int, split_text: bool) \ |
| | -> Tuple[List[List[int]], List[torch.Tensor], torch.Tensor]: |
| | """Prepare input text with conditioning tokens. Return combined conditioning latents""" |
| | |
| | text_tokens, text_embeddings = await self.prepare_text_tokens_async(text, language, split_text) |
| |
|
| | |
| | gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async( |
| | speaker_file, |
| | max_ref_length, |
| | gpt_cond_len, |
| | gpt_cond_chunk_len |
| | ) |
| |
|
| | cond_latents = [] |
| | for text_embedding in text_embeddings: |
| | |
| | cond_latents.append((torch.cat([gpt_cond_latent, text_embedding], dim=1).squeeze(0) |
| | .to(self.llm_engine.engine.model_config.dtype))) |
| |
|
| | return text_tokens, cond_latents, speaker_embeddings |
| |
|
| | async def get_conditioning_latents_async( |
| | self, |
| | audio_path, |
| | max_ref_length=30, |
| | gpt_cond_len=6, |
| | gpt_cond_chunk_len=6, |
| | librosa_trim_db=None, |
| | sound_norm_refs=False, |
| | load_sr=22050, |
| | ): |
| | """Async version of get_conditioning_latents with concurrency control.""" |
| | async with self.semaphore: |
| | |
| | result = await asyncio.get_event_loop().run_in_executor( |
| | None, |
| | functools.partial(self.get_conditioning_latents, |
| | audio_path, |
| | max_ref_length, |
| | gpt_cond_len, |
| | gpt_cond_chunk_len, |
| | librosa_trim_db, |
| | sound_norm_refs, |
| | load_sr) |
| | ) |
| | return result |
| |
|
| | async def get_model_logits(self, token_ids: List[int], conditioning: MultiModalDataDict) -> torch.Tensor: |
| | """Get model logits for a specific request""" |
| | request_id = uuid.uuid4().hex |
| |
|
| | |
| | token_ids = [self.mel_bos_token_id] + token_ids + [self.mel_eos_token_id] * 5 |
| |
|
| | engine_inputs = TokensPrompt(prompt_token_ids=token_ids) |
| | engine_inputs["multi_modal_data"] = conditioning |
| |
|
| | |
| | bound_collector = self.hidden_states_collector.bind_to_request(request_id) |
| |
|
| | |
| | sampling_params = ExtendedSamplingParams( |
| | detokenize=False, |
| | max_tokens=1, |
| | hidden_state_collector=bound_collector, |
| | ) |
| |
|
| | |
| | generator = self.llm_engine.generate( |
| | prompt=engine_inputs, |
| | sampling_params=sampling_params, |
| | request_id=request_id |
| | ) |
| |
|
| | |
| | try: |
| | async def consume_generator(): |
| | async for _ in generator: |
| | pass |
| |
|
| | await asyncio.wait_for(consume_generator(), timeout=300) |
| | except asyncio.TimeoutError: |
| | raise RuntimeError("Timeout while generating logits") |
| |
|
| | |
| | hidden_states = self.hidden_states_collector.get_hidden_states(request_id) |
| |
|
| | if hidden_states is None: |
| | raise RuntimeError(f"No hidden states collected for request {request_id}") |
| |
|
| | return hidden_states[-len(token_ids):, ...].unsqueeze(0).to(self.device).to(self.dtype) |
| |
|
| |
|
| | async def process_tokens_to_speech( |
| | self, |
| | generators: List[AsyncGenerator[RequestOutput, None]], |
| | speaker_embeddings: torch.Tensor, |
| | multimodal_data: List[torch.Tensor], |
| | chunk_size: int = 20, |
| | ) -> AsyncGenerator[XTTSOutput, None]: |
| | """ |
| | Process multiple token generators concurrently and emit results sequentially. |
| | Uses a queue-based approach to handle multiple generators reliably. |
| | """ |
| | |
| | queues = [asyncio.Queue() for _ in generators] |
| |
|
| | |
| | tasks = [] |
| | for i, generator in enumerate(generators): |
| | task = asyncio.create_task( |
| | self._process_single_generator( |
| | generator, |
| | queues[i], |
| | speaker_embeddings, |
| | multimodal_data[i], |
| | chunk_size |
| | ) |
| | ) |
| | tasks.append(task) |
| |
|
| | try: |
| | |
| | for i, queue in enumerate(queues): |
| | while True: |
| | result = await queue.get() |
| | if result is None: |
| | |
| | break |
| | else: |
| | yield result |
| |
|
| | finally: |
| | |
| | for task in tasks: |
| | if not task.done(): |
| | task.cancel() |
| | await asyncio.gather(*tasks, return_exceptions=True) |
| |
|
| | async def _process_single_generator( |
| | self, |
| | generator: AsyncGenerator[RequestOutput, None], |
| | queue: asyncio.Queue, |
| | speaker_embeddings: torch.Tensor, |
| | gpt_embed_input: torch.Tensor, |
| | chunk_size: int |
| | ) -> None: |
| | """Process a single generator and put results in its queue.""" |
| | try: |
| | last_decoded_token = 0 |
| | accumulated_tokens = [] |
| |
|
| | async for output in generator: |
| | |
| | new_tokens = output.outputs[0].token_ids[last_decoded_token:] |
| | accumulated_tokens.extend(new_tokens) |
| | last_decoded_token = len(accumulated_tokens) |
| |
|
| | |
| | if output.finished: |
| | |
| | hidden_states = await self.get_model_logits( |
| | accumulated_tokens, |
| | { |
| | "audio": { |
| | 'embeds': gpt_embed_input, |
| | "is_logits_only_mode": True |
| | } |
| | } |
| | ) |
| |
|
| | |
| | wav = await asyncio.get_event_loop().run_in_executor( |
| | self.executor, |
| | lambda: self.hifigan_decoder.inference( |
| | hidden_states, |
| | g=speaker_embeddings |
| | ).cpu().numpy().squeeze() |
| | ) |
| |
|
| | |
| | await queue.put(XTTSOutput( |
| | request_id=output.request_id, |
| | wav=wav |
| | )) |
| |
|
| | |
| | accumulated_tokens = [] |
| |
|
| | if output.finished: |
| | break |
| |
|
| | except Exception as e: |
| | logging.error(f"Error in generator processing: {e}") |
| | finally: |
| | |
| | await queue.put(None) |
| |
|
| | async def generate_speech_async_from_streaming_source(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]: |
| | """Generate speech for streaming source of text, making a streaming source of audio tokens and then decoding |
| | and returning a streaming audio response.""" |
| | assert isinstance(request.text, AsyncGenerator), "Text must be an AsyncGenerator for streaming source." |
| | |
| | gpt_cond_latent, speaker_embeddings = await self.get_conditioning_latents_async( |
| | request.speaker_file, |
| | request.max_ref_length, |
| | request.gpt_cond_len, |
| | request.gpt_cond_chunk_len |
| | ) |
| | sampling_params = SamplingParams( |
| | temperature=request.temperature, |
| | top_p=request.top_p, |
| | detokenize=False, |
| | top_k=request.top_k, |
| | logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)], |
| | repetition_penalty=1.0, |
| | max_tokens=self.gpt_config.gpt_max_audio_tokens, |
| | ignore_eos=True, |
| | stop_token_ids=[self.mel_eos_token_id], |
| | ) |
| |
|
| | accumulated_text = "" |
| | async for text in request.text: |
| | text = text.strip() |
| | accumulated_text += text |
| |
|
| | if len(accumulated_text) > request.generate_every_n_chars: |
| | tokens, embeddings = await self.prepare_text_tokens_async(accumulated_text, request.language) |
| | gpt_embed_input = [torch.cat([gpt_cond_latent, embeddings[0]], dim=0)] |
| |
|
| | engine_inputs = TokensPrompt(prompt_token_ids=tokens) |
| | if gpt_embed_input is not None: |
| | engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_input, "is_logits_only_mode": False}} |
| | token_generator = [self.llm_engine.generate( |
| | prompt=engine_inputs, |
| | sampling_params=sampling_params, |
| | request_id=request.request_id, |
| | )] |
| | |
| | async for output in self.process_tokens_to_speech( |
| | token_generator, |
| | speaker_embeddings, |
| | gpt_embed_input, |
| | chunk_size=50 |
| | ): |
| | yield output |
| |
|
| | accumulated_text = "" |
| |
|
| | async def generate_speech_from_text_async(self, request: XTTSRequest) -> AsyncGenerator[XTTSOutput, None]: |
| | """Generate speech for a single request asynchronously.""" |
| | |
| | tokens_list, gpt_embed_inputs, speaker_embeddings = await self.prepare_inputs_async( |
| | request.text, |
| | request.language, |
| | request.speaker_file, |
| | request.max_ref_length, |
| | request.gpt_cond_len, |
| | request.gpt_cond_chunk_len, |
| | split_text=True |
| | ) |
| |
|
| | |
| | generators = [] |
| | for seq_index, sequence in enumerate(tokens_list): |
| | sampling_params = SamplingParams( |
| | temperature=request.temperature, |
| | top_p=request.top_p, |
| | detokenize=False, |
| | top_k=request.top_k, |
| | logits_processors=[LogitsRepetitionPenalizer(request.repetition_penalty)], |
| | repetition_penalty=1.0, |
| | max_tokens=self.gpt_config.gpt_max_audio_tokens, |
| | ignore_eos=True, |
| | stop_token_ids=[self.mel_eos_token_id], |
| | ) |
| |
|
| | engine_inputs = TokensPrompt(prompt_token_ids=sequence) |
| | if gpt_embed_inputs is not None: |
| | engine_inputs["multi_modal_data"] = {"audio": {"embeds": gpt_embed_inputs[seq_index], "is_logits_only_mode": False}} |
| |
|
| | |
| | token_generator = self.llm_engine.generate( |
| | prompt=engine_inputs, |
| | sampling_params=sampling_params, |
| | request_id=f"{request.request_id}_{seq_index}", |
| | ) |
| | generators.append(token_generator) |
| |
|
| | |
| | async for output in self.process_tokens_to_speech( |
| | generators, |
| | speaker_embeddings, |
| | gpt_embed_inputs, |
| | chunk_size=50 |
| | ): |
| | yield output |
| |
|
| | def generate_speech_from_text(self, request: XTTSRequest) -> List[XTTSOutput]: |
| | """ |
| | Synchronous wrapper for generate_speech_from_text_async. |
| | |
| | Args: |
| | request: XTTSRequest object containing generation parameters |
| | |
| | Returns: |
| | List of XTTSOutput containing the generated speech segments |
| | """ |
| |
|
| | async def _collect_outputs(): |
| | outputs = [] |
| | async for output in self.generate_speech_from_text_async(request): |
| | outputs.append(output) |
| | return outputs |
| |
|
| | |
| | import asyncio |
| |
|
| | |
| | try: |
| | loop = asyncio.get_event_loop() |
| | except RuntimeError: |
| | loop = asyncio.new_event_loop() |
| | asyncio.set_event_loop(loop) |
| |
|
| | if loop.is_running(): |
| | |
| | new_loop = asyncio.new_event_loop() |
| | results = new_loop.run_until_complete(_collect_outputs()) |
| | new_loop.close() |
| | else: |
| | results = loop.run_until_complete(_collect_outputs()) |
| |
|
| | return results |
| |
|