Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple, Union, Callable | |
| from tqdm import tqdm | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from transformers.models.auto import AutoModel, AutoModelForCausalLM | |
| from transformers.activations import ACT2FN | |
| from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput | |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm | |
| from transformers import modeling_utils | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs | |
| from transformers.utils import logging | |
| from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead | |
| from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler | |
| from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig | |
| logger = logging.get_logger(__name__) | |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: | |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] | |
| class BinaryClassifier(nn.Module): | |
| def __init__(self, hidden_size): | |
| super(BinaryClassifier, self).__init__() | |
| self.fc1 = nn.Linear(hidden_size, hidden_size) | |
| self.fc2 = nn.Linear(hidden_size, 1) | |
| def forward(self, x): | |
| x = torch.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| class SpeechConnector(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| super().__init__() | |
| self.fc1 = nn.Linear(input_dim, output_dim) | |
| self.norm = LlamaRMSNorm(output_dim, eps=1e-6) | |
| self.fc2 = nn.Linear(output_dim, output_dim) | |
| def forward(self, features, **kwargs): | |
| x = self.fc1(features) | |
| x = self.norm(x) | |
| x = self.fc2(x) | |
| return x | |
| # @auto_docstring | |
| class VibeVoiceStreamingPreTrainedModel(PreTrainedModel): | |
| config_class = VibeVoiceStreamingConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_cache_class = True | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_quantized_cache = True | |
| _supports_static_cache = True | |
| _supports_attention_backend = True | |
| def _init_weights(self, module): | |
| if isinstance(module, VibeVoiceDiffusionHead): | |
| module.initialize_weights() | |
| return | |
| # Use the language model's initializer_range if available | |
| if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): | |
| std = self.config.language_model_config.initializer_range | |
| elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): | |
| std = self.config.decoder_config.initializer_range | |
| else: | |
| std = 0.02 # Default value | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.weight.data.fill_(1.0) | |
| module.bias.data.zero_() | |
| # @auto_docstring | |
| class VibeVoiceStreamingModel(VibeVoiceStreamingPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: | |
| if isinstance(config.torch_dtype, str): | |
| dtype = getattr(torch, config.torch_dtype) | |
| else: | |
| dtype = config.torch_dtype | |
| else: | |
| dtype = torch.float32 | |
| # Initialize Qwen2 model for language modeling. | |
| # The lower Transformer layers are only used for encoding text, while the upper Transformer layers are used for encoding text and generating speech. | |
| # To keep the code clean, we constructs two language models. | |
| # The final norm layer of the first language_model is set to identity and will not be used in inference. | |
| lm_config = copy.deepcopy(config.decoder_config) | |
| lm_backbone_num_hidden_layers = getattr(lm_config, 'num_hidden_layers', 24) - config.tts_backbone_num_hidden_layers | |
| lm_config.num_hidden_layers = lm_backbone_num_hidden_layers | |
| self.language_model = AutoModel.from_config(lm_config) | |
| self.language_model.norm = nn.Identity() | |
| # We only need the Transformer layers here. Note that embed_tokens in tts_language_model is unused | |
| tts_lm_config = copy.deepcopy(lm_config) | |
| tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers | |
| self.tts_language_model = AutoModel.from_config(tts_lm_config) | |
| # Marks the text that needs to be spoken by the TTS model. | |
| self.tts_input_types = nn.Embedding(num_embeddings=2, embedding_dim=config.decoder_config.hidden_size) | |
| # Initialize speech components if needed | |
| self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) | |
| self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) | |
| # Register scaling factors as buffers - use 1D tensors for FSDP compatibility | |
| self.register_buffer('speech_scaling_factor', torch.tensor(float('nan'))) | |
| self.register_buffer('speech_bias_factor', torch.tensor(float('nan'))) | |
| # Initialize prediction head for speech generation | |
| self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype) | |
| # Initialize noise scheduler | |
| self.noise_scheduler = DPMSolverMultistepScheduler( | |
| num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, | |
| beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, | |
| prediction_type=config.diffusion_head_config.prediction_type | |
| ) | |
| def get_input_embeddings(self): | |
| if hasattr(self.language_model, 'embed_tokens'): | |
| # If the language model has an embed_tokens attribute, return it | |
| return self.language_model.embed_tokens | |
| for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed | |
| if attr.orig_name == 'embed_tokens.weight': | |
| return getattr(self.language_model, name) | |
| assert False, 'should not arrive here' | |
| def set_input_embeddings(self, value): | |
| self.language_model.embed_tokens = value | |
| def set_speech_tokenizers(self, acoustic_tokenizer=None): | |
| """Set the speech tokenizers used for encoding and decoding speech.""" | |
| self.acoustic_tokenizer = acoustic_tokenizer | |
| # Reset the encoder to evaluation mode | |
| if self.acoustic_tokenizer is not None: | |
| self.acoustic_tokenizer.eval() | |
| def forward(self, *args, **kwargs): | |
| """ | |
| Intentionally not implemented. | |
| This streaming model is split into two explicit submodules: | |
| - `language_model` for plain text processing (lower layers). | |
| - `tts_language_model` for TTS-related upper layers. | |
| We deliberately avoid a unified `forward` to prevent accidental calls | |
| that mix responsibilities. | |
| To use the model: | |
| - Call `self.language_model(...)` for text embeddings / hidden states. | |
| - Call `self.tts_language_model(...)` for the TTS portion. | |
| - Use the dedicated inference class for combined generation logic. | |
| """ | |
| raise RuntimeError( | |
| "VibeVoiceStreamingModel.forward is intentionally disabled. " | |
| "Use `model.language_model(...)` or `model.tts_language_model(...)` instead." | |
| ) | |
| AutoModel.register(VibeVoiceStreamingConfig, VibeVoiceStreamingModel) | |
| __all__ = [ | |
| "VibeVoiceStreamingPreTrainedModel", | |
| "VibeVoiceStreamingModel", | |
| ] |