| | import functools |
| | import math |
| | from array import array |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from typing import List, Optional, Union, Iterable, Tuple, Mapping |
| |
|
| | from transformers import PretrainedConfig |
| | from vllm.attention import AttentionMetadata, Attention |
| | from vllm.config import CacheConfig, MultiModalConfig |
| | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
| | from vllm.inputs import InputContext, INPUT_REGISTRY |
| | from vllm.model_executor.layers.activation import get_act_fn |
| | from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear |
| | from vllm.model_executor.layers.quantization import QuantizationConfig |
| | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
| | from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding |
| | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| | from vllm.model_executor.sampling_metadata import SamplingMetadata |
| | from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs |
| | from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE |
| | from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP |
| |
|
| | from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder |
| | from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler |
| |
|
| | from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings |
| |
|
| | |
| | _AUDIO_PLACEHOLDER_TOKEN = 8192 |
| | _AUDIO_TOKENS_PER_SECOND = 6.25 |
| | _CODE_STRIDE_LEN = 1024 |
| |
|
| | class GPT2Attention(nn.Module): |
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | total_num_heads = config.num_attention_heads |
| | self.hidden_size = config.hidden_size |
| | tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| | assert total_num_heads % tensor_model_parallel_world_size == 0 |
| | self.num_heads = total_num_heads // tensor_model_parallel_world_size |
| | self.head_dim = self.hidden_size // total_num_heads |
| | self.scale = self.head_dim**-0.5 |
| |
|
| | self.c_attn = QKVParallelLinear( |
| | self.hidden_size, |
| | self.head_dim, |
| | total_num_heads, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.c_attn", |
| | ) |
| | self.c_proj = RowParallelLinear( |
| | self.hidden_size, |
| | self.hidden_size, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.c_proj", |
| | ) |
| | self.attn = Attention( |
| | self.num_heads, |
| | self.head_dim, |
| | scale=self.scale, |
| | cache_config=cache_config, |
| | quant_config=quant_config |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | kv_cache: torch.Tensor, |
| | attn_metadata: AttentionMetadata, |
| | ) -> torch.Tensor: |
| | qkv, _ = self.c_attn(hidden_states) |
| | q, k, v = qkv.chunk(chunks=3, dim=-1) |
| | attn_output = self.attn(q, k, v, kv_cache, attn_metadata) |
| | attn_output, _ = self.c_proj(attn_output) |
| | return attn_output |
| |
|
| |
|
| | class GPT2MLP(nn.Module): |
| | def __init__( |
| | self, |
| | intermediate_size: int, |
| | config: PretrainedConfig, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | hidden_size = config.hidden_size |
| |
|
| | self.c_fc = ColumnParallelLinear( |
| | hidden_size, |
| | intermediate_size, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.c_fc", |
| | ) |
| | self.c_proj = RowParallelLinear( |
| | intermediate_size, |
| | hidden_size, |
| | bias=True, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.c_proj", |
| | ) |
| | self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | hidden_states, _ = self.c_fc(hidden_states) |
| | hidden_states = self.act(hidden_states) |
| | hidden_states, _ = self.c_proj(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class GPT2Block(nn.Module): |
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | hidden_size = config.hidden_size |
| | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size |
| |
|
| | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| | self.attn = GPT2Attention( |
| | config, |
| | cache_config, |
| | quant_config, |
| | prefix=f"{prefix}.attn" |
| | ) |
| | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| | self.mlp = GPT2MLP( |
| | inner_dim, |
| | config, |
| | quant_config, |
| | prefix=f"{prefix}.mlp" |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | kv_cache: torch.Tensor, |
| | attn_metadata: AttentionMetadata, |
| | ) -> torch.Tensor: |
| | residual = hidden_states |
| | hidden_states = self.ln_1(hidden_states) |
| | attn_output = self.attn( |
| | hidden_states=hidden_states, |
| | kv_cache=kv_cache, |
| | attn_metadata=attn_metadata, |
| | ) |
| | hidden_states = attn_output + residual |
| |
|
| | residual = hidden_states |
| | hidden_states = self.ln_2(hidden_states) |
| | feed_forward_hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + feed_forward_hidden_states |
| | return hidden_states |
| |
|
| |
|
| |
|
| | def get_xtts_max_audio_tokens(ctx: InputContext) -> int: |
| | """Calculate maximum audio tokens based on text context and audio duration.""" |
| | |
| | return 608 |
| |
|
| |
|
| | def dummy_seq_data_for_xtts( |
| | ctx: InputContext, |
| | seq_len: int, |
| | audio_count: int, |
| | ) -> SequenceData: |
| | """Create dummy sequence data for XTTS profiling.""" |
| | |
| | audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) |
| | audio_placeholder = array( |
| | VLLM_TOKEN_ID_ARRAY_TYPE, |
| | [_AUDIO_PLACEHOLDER_TOKEN] |
| | ) * audio_len_tokens |
| |
|
| | |
| | audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count |
| |
|
| | |
| | other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) |
| |
|
| | return SequenceData(audio_token_ids + other_token_ids) |
| |
|
| |
|
| | def dummy_conditioning_for_xtts( |
| | ctx: InputContext, |
| | audio_count: int, |
| | ) -> dict: |
| | """Create dummy conditioning data for XTTS.""" |
| | return { |
| | "audio": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)] |
| | } |
| |
|
| |
|
| | def dummy_data_for_xtts( |
| | ctx: InputContext, |
| | seq_len: int, |
| | mm_counts: Mapping[str, int], |
| | ) -> Tuple[SequenceData, dict]: |
| | """Create complete dummy data for XTTS profiling.""" |
| | audio_count = mm_counts["audio"] |
| | seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count) |
| | cond_data = dummy_conditioning_for_xtts(ctx, audio_count) |
| | return (seq_data, cond_data) |
| |
|
| |
|
| | def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs: |
| | """Map input data to XTTS format.""" |
| | if not isinstance(data, list): |
| | data = [data] |
| |
|
| | |
| | for audio_input in data: |
| | if not isinstance(audio_input, tuple): |
| | raise NotImplementedError(f"Unsupported data type: {type(audio_input)}") |
| |
|
| | return MultiModalInputs({"cond_latents": data}) |
| |
|
| |
|
| |
|
| | @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts) |
| | @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens) |
| | @INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts) |
| | class XttsGPT(nn.Module, SupportsMultiModal, SupportsPP): |
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | multimodal_config: MultiModalConfig, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional["QuantizationConfig"] = None, |
| | ): |
| | super().__init__() |
| | self.config = config |
| | self.quant_config = quant_config |
| |
|
| | |
| | self.conditioning_encoder = ConditioningEncoder( |
| | config.audio_config.mel_channels, |
| | config.hidden_size, |
| | num_attn_heads=config.num_attention_heads |
| | ) |
| |
|
| | if config.use_perceiver_resampler: |
| | self.conditioning_perceiver = PerceiverResampler( |
| | dim=config.hidden_size, |
| | depth=2, |
| | dim_context=config.hidden_size, |
| | num_latents=32, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | use_flash_attn=False, |
| | ) |
| |
|
| | |
| | self.gpt = XttsGPT2Model( |
| | config, |
| | cache_config, |
| | quant_config, |
| | prefix="gpt" |
| | ) |
| |
|
| | |
| | self.text_head = ColumnParallelLinear( |
| | config.hidden_size, |
| | config.vocab_size, |
| | bias=False, |
| | quant_config=quant_config, |
| | prefix="text_head" |
| | ) |
| |
|
| | self.mel_head = ColumnParallelLinear( |
| | config.hidden_size, |
| | config.num_audio_tokens, |
| | bias=False, |
| | quant_config=quant_config, |
| | prefix="mel_head" |
| | ) |
| |
|
| | self.sampler = Sampler() |
| |
|
| | def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor: |
| | """Get conditioning embeddings from mel spectrograms.""" |
| | if not return_latent: |
| | if cond_input.ndim == 4: |
| | cond_input = cond_input.squeeze(1) |
| | conds = self.conditioning_encoder(cond_input) |
| |
|
| | if hasattr(self, 'conditioning_perceiver'): |
| | conds = self.conditioning_perceiver( |
| | conds.permute(0, 2, 1) |
| | ).transpose(1, 2) |
| | else: |
| | conds = cond_input.unsqueeze(1) |
| | return conds |
| |
|
| | def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], |
| | attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, |
| | cond_latents: Optional[torch.Tensor] = None ) -> torch.Tensor: |
| | """Forward pass following VLLM pattern.""" |
| | if cond_latents is not None: |
| | |
| | input_embeds = self.gpt.get_input_embeddings()(input_ids) |
| | combined_embeds = torch.cat([cond_latents, input_embeds], dim=1) |
| | hidden_states = self.gpt( |
| | inputs_embeds=combined_embeds, |
| | positions=positions, |
| | kv_caches=kv_caches, |
| | attn_metadata=attn_metadata, |
| | intermediate_tensors=intermediate_tensors, |
| | ) |
| | else: |
| | hidden_states = self.gpt( |
| | input_ids=input_ids, |
| | positions=positions, |
| | kv_caches=kv_caches, |
| | attn_metadata=attn_metadata, |
| | intermediate_tensors=intermediate_tensors, |
| | ) |
| | return hidden_states |
| |
|
| | def compute_logits( |
| | self, |
| | hidden_states: torch.Tensor, |
| | sampling_metadata: SamplingMetadata, |
| | ) -> torch.Tensor: |
| | """Compute output logits.""" |
| | text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices]) |
| | mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices]) |
| | return torch.cat([text_logits, mel_logits], dim=1) |
| |
|
| |
|
| | def sample( |
| | self, |
| | logits: torch.Tensor, |
| | sampling_metadata: SamplingMetadata, |
| | ) -> Optional[SamplerOutput]: |
| | """Sample next tokens using VLLM sampler.""" |
| | return self.sampler(logits, sampling_metadata) |
| |
|
| | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| | """Load weights following VLLM pattern.""" |
| | params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| |
|
| | for name, loaded_weight in weights: |
| | if name not in params_dict: |
| | continue |
| |
|
| | param = params_dict[name] |
| | if "c_attn" in name or "c_proj" in name or "c_fc" in name: |
| | if name.endswith(".weight"): |
| | loaded_weight = loaded_weight.t() |
| |
|
| | weight_loader = getattr(param, "weight_loader", default_weight_loader) |
| | weight_loader(param, loaded_weight) |
| |
|
| |
|
| | class XttsGPT2Model(nn.Module): |
| | """VLLM-style implementation of GPT2 core architecture.""" |
| |
|
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.text_embedding = VocabParallelEmbedding( |
| | config.number_text_tokens, |
| | config.hidden_size |
| | ) |
| | self.mel_embedding = VocabParallelEmbedding( |
| | config.num_audio_tokens, |
| | config.hidden_size |
| | ) |
| |
|
| | self.text_pos_embedding = ( |
| | LearnedPositionEmbeddings( |
| | config.max_text_tokens + 2, |
| | config.hidden_size |
| | ) |
| | if config.max_audio_tokens != -1 |
| | else functools.partial(config.null_position_embeddings, dim=config.hidden_size) |
| | ) |
| |
|
| | self.mel_pos_embedding = ( |
| | LearnedPositionEmbeddings( |
| | config.max_audio_tokens + 3, |
| | config.hidden_size |
| | ) |
| | if config.max_audio_tokens != -1 |
| | else functools.partial(config.null_position_embeddings, dim=config.hidden_size) |
| | ) |
| |
|
| | self.h = nn.ModuleList([ |
| | GPT2Block( |
| | config, |
| | cache_config, |
| | quant_config, |
| | prefix=f"{prefix}.h.{i}" |
| | ) for i in range(config.num_hidden_layers) |
| | ]) |
| |
|
| | self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| |
|
| | def get_input_embeddings(self): |
| | return self.text_embedding |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | positions: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | kv_caches: List[torch.Tensor] = None, |
| | attn_metadata: AttentionMetadata = None, |
| | intermediate_tensors: Optional[IntermediateTensors] = None, |
| | ) -> Union[torch.Tensor, IntermediateTensors]: |
| | if get_pp_group().is_first_rank: |
| | if inputs_embeds is None: |
| | inputs_embeds = self.text_embedding(input_ids) |
| | hidden_states = inputs_embeds |
| |
|
| | if positions is not None: |
| | position_embeds = self.text_pos_embedding(positions) |
| | hidden_states = hidden_states + position_embeds |
| | else: |
| | assert intermediate_tensors is not None |
| | hidden_states = intermediate_tensors["hidden_states"] |
| |
|
| | for i, block in enumerate(self.h): |
| | hidden_states = block( |
| | hidden_states, |
| | kv_caches[i], |
| | attn_metadata |
| | ) |
| |
|
| | if not get_pp_group().is_last_rank: |
| | return IntermediateTensors({"hidden_states": hidden_states}) |
| |
|
| | hidden_states = self.ln_f(hidden_states) |
| | return hidden_states |