VibeVoice-Realtime-0.5B / vibevoice /modular /modeling_vibevoice_streaming.py
akhaliq's picture
akhaliq HF Staff
Upload 37 files
26e0cd3 verified
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class BinaryClassifier(nn.Module):
def __init__(self, hidden_size):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoiceStreamingPreTrainedModel(PreTrainedModel):
config_class = VibeVoiceStreamingConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceStreamingModel(VibeVoiceStreamingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling.
# The lower Transformer layers are only used for encoding text, while the upper Transformer layers are used for encoding text and generating speech.
# To keep the code clean, we constructs two language models.
# The final norm layer of the first language_model is set to identity and will not be used in inference.
lm_config = copy.deepcopy(config.decoder_config)
lm_backbone_num_hidden_layers = getattr(lm_config, 'num_hidden_layers', 24) - config.tts_backbone_num_hidden_layers
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers
self.language_model = AutoModel.from_config(lm_config)
self.language_model.norm = nn.Identity()
# We only need the Transformer layers here. Note that embed_tokens in tts_language_model is unused
tts_lm_config = copy.deepcopy(lm_config)
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers
self.tts_language_model = AutoModel.from_config(tts_lm_config)
# Marks the text that needs to be spoken by the TTS model.
self.tts_input_types = nn.Embedding(num_embeddings=2, embedding_dim=config.decoder_config.hidden_size)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
def forward(self, *args, **kwargs):
"""
Intentionally not implemented.
This streaming model is split into two explicit submodules:
- `language_model` for plain text processing (lower layers).
- `tts_language_model` for TTS-related upper layers.
We deliberately avoid a unified `forward` to prevent accidental calls
that mix responsibilities.
To use the model:
- Call `self.language_model(...)` for text embeddings / hidden states.
- Call `self.tts_language_model(...)` for the TTS portion.
- Use the dedicated inference class for combined generation logic.
"""
raise RuntimeError(
"VibeVoiceStreamingModel.forward is intentionally disabled. "
"Use `model.language_model(...)` or `model.tts_language_model(...)` instead."
)
AutoModel.register(VibeVoiceStreamingConfig, VibeVoiceStreamingModel)
__all__ = [
"VibeVoiceStreamingPreTrainedModel",
"VibeVoiceStreamingModel",
]