diff --git "a/modeling.py" "b/modeling.py" --- "a/modeling.py" +++ "b/modeling.py" @@ -1,1182 +1,1180 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2-VL model.""" - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.generation import GenerationMixin -from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_layers import GradientCheckpointingLayer -from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from transformers.configuration_utils import PretrainedConfig, layer_type_validation - -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.modeling_outputs import ( - ModelOutput, -) -from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLVisionConfig, - Qwen2_5_VLTextConfig, - Qwen2_5_VLConfig, -) -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLAttention, - Qwen2RMSNorm, - Qwen2_5_VLRotaryEmbedding, -) -from DCMoE import UniMoEAudioSparseMoeBlock -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel - -logger = logging.get_logger(__name__) - -FAST_INIT = True -if FAST_INIT: - logger.warning(f"using FAST initial for Grin Qwen2_vl !!!") - -class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig): - model_type = "qwen2_5_vl_moe_text" - - def __init__( - self, - mlp_dynamic_expert_num=4, - mlp_dynamic_null_expert_num=0, - mlp_dynamic_top_p=0.7, - mlp_dynamic_top_k=2, - mlp_fixed_expert_num=2, - dynamic_intermediate_size=8960, - shared_intermediate_size=8960, - ignore_differentiable_router=False, - enable_expert_tensor_parallelism: bool = False, - ep_size=1, - fixed_ep_size=1, - router_jitter_noise=0.01, - input_jitter_noise=0.01, - token_drop=False, - drop_policy: str = "probs", - min_capacity: int = 8, - capacity_factor: float = 1.0, - fp32_gate=True, - avg_hidden_states_last=False, - drop_token_num_print=True, - **kwargs, - ): - - super().__init__(**kwargs) - self.mlp_dynamic_expert_num = mlp_dynamic_expert_num - self.mlp_dynamic_top_p = mlp_dynamic_top_p - self.mlp_dynamic_top_k = mlp_dynamic_top_k - self.mlp_fixed_expert_num = mlp_fixed_expert_num - self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num - self.dynamic_intermediate_size = dynamic_intermediate_size - self.shared_intermediate_size = shared_intermediate_size - self.ignore_differentiable_router = ignore_differentiable_router - self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism - self.ep_size = ep_size - self.fixed_ep_size = fixed_ep_size - self.input_jitter_noise = input_jitter_noise - self.router_jitter_noise = router_jitter_noise - self.token_drop = token_drop - self.drop_policy = drop_policy - self.min_capacity = min_capacity - self.capacity_factor = capacity_factor - self.fp32_gate = fp32_gate - self.avg_hidden_states_last = avg_hidden_states_last - self.drop_token_num_print = drop_token_num_print - -class UniMoEAudioConfig(PretrainedConfig): - model_type = "uni_audio_rvq_qwen2_5vl_moe" - sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig} - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - text_config=None, - vision_config=None, - image_token_id=151655, - video_token_id=151656, - codec_vocab_size=1028, - codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15], - codec_channels=9, - codec_eos_value=1024, - codec_pad_value=1025, - codec_bos_value=1026, - codec_placeholder_value=None, - **kwargs, - ): - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - - if isinstance(text_config, dict): - self.text_config = self.sub_configs["text_config"](**text_config) - elif text_config is None: - self.text_config = self.sub_configs["text_config"](**kwargs) - - self.image_token_id = image_token_id - self.video_token_id = video_token_id - self.codec_vocab_size = codec_vocab_size - self.codec_delay_pattern = codec_delay_pattern - self.codec_channels = codec_channels - self.codec_eos_value = codec_eos_value - self.codec_pad_value = codec_pad_value - self.codec_bos_value = codec_bos_value - self.codec_placeholder_value = codec_placeholder_value - - super().__init__(**kwargs) - -@dataclass -class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None - all_router_logits: Tuple = None - all_router_top_k: Tuple = None - all_router_expert_mask: Tuple = None - all_router_weight: Tuple = None - aux_balance_loss: torch.FloatTensor = None - - -@dataclass -class BaseModelOutputWithPast(ModelOutput): - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - all_router_logits: Tuple = None - all_router_top_k: Tuple = None - all_router_weight: Tuple = None - all_router_expert_mask: Tuple = None - all_aux_loss: Tuple = None - - -class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - - self.self_attn = Qwen2_5_VLAttention(config, layer_idx) - self.mlp = UniMoEAudioSparseMoeBlock(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - padding_token_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits_and_topk: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if output_router_logits_and_topk: - outputs += (router_logits,) - outputs += (router_top_k,) - outputs += (router_expert_mask,) - outputs += (router_weight,) - outputs += (aux_loss,) - - return outputs - - -class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel): - config_class = UniMoEAudioConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if FAST_INIT: - if isinstance(module, UniMoEAudioSparseMoeBlock): - module.gate.weight.data.normal_(mean=0.0, std=std) - if module.gate.bias is not None: - module.gate.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - else: - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - - -class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel): - config_class = Qwen2_5_VLMoETextConfig - def __init__(self, config: Qwen2_5_VLMoETextConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) - self.has_sliding_layers = "sliding_attention" in self.config.layer_types - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - padding_token_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits_and_topk: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None and not torch.jit.is_tracing(): - past_key_values = DynamicCache() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.dim() == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - if not isinstance(causal_mask_mapping := attention_mask, dict): - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - } - if self.has_sliding_layers: - causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) - - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits_and_topk else None - all_router_top_k = () if output_router_logits_and_topk else None - all_router_expert_mask = () - all_router_weight = () - all_aux_loss = () - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - padding_token_mask=padding_token_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits_and_topk=output_router_logits_and_topk, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits_and_topk: - all_router_logits += (layer_outputs[-5],) - all_router_top_k += (layer_outputs[-4],) - all_router_expert_mask += (layer_outputs[-3],) - all_router_weight += (layer_outputs[-2],) - all_aux_loss += (layer_outputs[-1],) - - hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if not return_dict: - return tuple( - v for v in [ - hidden_states, - past_key_values, - all_hidden_states, - all_self_attns, - all_router_logits, - all_router_top_k, - all_router_expert_mask, - all_router_weight, - all_aux_loss] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - all_router_logits=all_router_logits, - all_router_top_k=all_router_top_k, - all_router_expert_mask=all_router_expert_mask, - all_router_weight=all_router_weight, - all_aux_loss=all_aux_loss, - ) - - -class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel): - base_model_prefix = "" - _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] - config_class = UniMoEAudioConfig - _checkpoint_conversion_mapping = { - "^visual": "visual", - r"^model(?!\.(language_model|visual))": "language_model", - } - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation) - self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config) - self.rope_deltas = None - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.num_channels = config.codec_channels - self.codec_vocab_size = config.codec_vocab_size - self.codec_embed_tokens = nn.ModuleList( - [nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)]) - self.codec_placeholder_value = config.codec_placeholder_value - self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False) - self.post_init() - - @property - def cur_aux_weight(self): - if self.training_steps >= self.l_aux_weight_decay_steps: - return self.min_l_aux_weight - return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - - def get_rope_index( - self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - spatial_merge_size = self.config.vision_config.spatial_merge_size - image_token_id = self.config.image_token_id - video_token_id = self.config.video_token_id - vision_start_token_id = self.config.vision_start_token_id - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - second_per_grid_t = 0 - image_index += 1 - remain_images -= 1 - ed = ed_image - - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - range_tensor = torch.arange(llm_grid_t).view(-1, 1) - expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) - second_per_grid_t = torch.as_tensor( - second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device - ) - - time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second - - time_tensor_long = time_tensor.long() - t_index = time_tensor_long.flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) - - return position_ids, mrope_position_deltas - - def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None): - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() - video_embeds = torch.split(video_embeds, split_sizes) - return video_embeds - - def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() - image_embeds = torch.split(image_embeds, split_sizes) - return image_embeds - - - def codec_embedding(self, codec_input_ids): - x = None - for i in range(self.num_channels): - channel_tokens = codec_input_ids[..., i] - channel_embed = self.codec_embed_tokens[i](channel_tokens) - x = channel_embed if x is None else x + channel_embed - return x - - def calculate_input_embedding(self, input_ids, codec_input_ids): - inputs_embeds = self.language_model.embed_tokens(input_ids) - if codec_input_ids is not None: - codec_input_embeds = self.codec_embedding(codec_input_ids) - - codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds) - return inputs_embeds - - @can_return_tuple - def forward( - self, - input_ids: torch.LongTensor = None, - codec_input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - codec_labels: Optional[torch.LongTensor] = None, - padding_token_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits_and_topk: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - **kwargs, - - ) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]: - return_dict = True - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if inputs_embeds is None: - inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) - - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (image_mask).sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.video_token_id - - n_video_tokens = (video_mask).sum() - n_video_features = video_embeds.shape[0] - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - if position_ids is None: - attention_mask_tensor = ( - attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] - ) - if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: - attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) - attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min - attention_mask_tensor = (1.0 - attention_mask_tensor).int() - prefill_compiled_stage = is_torchdynamo_compiling() and ( - (input_ids is not None and input_ids.shape[1] != 1) - or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) - ) - prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( - (cache_position is not None and cache_position[0] == 0) - or (past_key_values is None or past_key_values.get_seq_length() == 0) - ) - if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - attention_mask=attention_mask_tensor, - ) - self.rope_deltas = rope_deltas - - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - if padding_token_mask is None: - padding_token_mask = attention_mask.bool() - - outputs = self.language_model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - padding_token_mask=padding_token_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits_and_topk=output_router_logits_and_topk, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states).float() - codec_logits = self.codec_head(hidden_states).float() - codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size)) - - loss = None - if labels is not None: - - all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1] - all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0)) - aux_loss = self.cur_aux_weight * all_aux_loss - self.training_steps += 1 - codec_loss = None - - if codec_labels is not None: - for i in range(self.num_channels): - channel_logits = codec_logits[:, :, i].float() - channel_labels = codec_labels[:, :, i] - shift_channel_logits = channel_logits[..., :-1, :].contiguous() - shift_channel_labels = channel_labels[..., 1:].contiguous() - - if i!= 0 and (shift_channel_labels != -100).sum() == 0: - continue - - loss_fct = CrossEntropyLoss() - shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size) - shift_channel_labels = shift_channel_labels.view(-1) - shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device) - channel_loss = loss_fct(shift_channel_logits, shift_channel_labels) - codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss - - loss = codec_loss + aux_loss - - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return MoEQwen2_5VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - all_router_logits=outputs.all_router_logits, - all_router_top_k=outputs.all_router_top_k, - all_router_expert_mask=outputs.all_router_expert_mask, - all_router_weight=outputs.all_router_weight, - aux_balance_loss=all_aux_loss, - ) - - @staticmethod - def _sample_next_token( - logits_BCxV: torch.Tensor, - temperature: float, - top_p: float, - top_k: int, - audio_eos_value: int, - ) -> torch.Tensor: - if temperature == 0.0: - return torch.argmax(logits_BCxV, dim=-1) - - logits_BCxV = logits_BCxV / temperature - - if audio_eos_value is not None and audio_eos_value >= 0: - top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1) - eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value - mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool) - mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True - logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf) - - if top_k is not None: - _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1) - mask = torch.ones_like(logits_BCxV, dtype=torch.bool) - mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False) - logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) - - if top_p < 1.0: - probs_BCxV = torch.softmax(logits_BCxV, dim=-1) - sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) - cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) - - sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p - sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1) - sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0]) - - indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) - indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) - logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) - - final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) - - sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) - sampled_indices_C = sampled_indices_BC.squeeze(-1) - return sampled_indices_C - - def _decoder_step( - self, - tokens_Bx1xC: torch.Tensor, - model_kwargs, - cfg_scale: float, - neg_input_size: int, - temperature: float, - top_p: float, - top_k: int, - do_sample=True, - eos_prob_mul_factor=1.0, - labels_Bx1xC=None, - use_cache=True, - enable_eos=True, - ) -> torch.Tensor: - B = tokens_Bx1xC.shape[0] - audio_eos_value = self.config.codec_eos_value - attention_mask = model_kwargs["attention_mask"] - cache_position = model_kwargs["cache_position"] - past_key_values = model_kwargs["past_key_values"] - input_ids = model_kwargs["input_ids"] - codec_input_ids = model_kwargs["codec_input_ids"] - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :] - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0) - codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone() - input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1) - - if use_cache: - codec_input_embeds = self.codec_embedding(tokens_Bx1xC) - outputs = self.language_model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=codec_input_embeds, - use_cache=True, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - cache_position=cache_position, - ) - - else: - batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels) - - inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids) - outputs = self.language_model( - input_ids=None, - attention_mask=attention_mask, - position_ids=attention_mask.long().cumsum(-1) - 1, - past_key_values=None, - inputs_embeds=inputs_embeds, - use_cache=True, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - cache_position=None, - ) - - last_hidden_state = outputs.last_hidden_state - codec_logits = self.codec_head(last_hidden_state).float() - codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size)) - model_kwargs["past_key_values"] = outputs.past_key_values - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 - model_kwargs["input_ids"] = input_ids - model_kwargs["codec_input_ids"] = codec_input_ids - - logits_Bx1xCxV = codec_logits[: , -1:].clone() - logits_last_2BxCxV = logits_Bx1xCxV[:, -1] - logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:]) - if cfg_scale is not None: - cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] - logits_BxCxV = cond_logits_BxCxV - for ni in range(neg_input_size - 1): - uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] # Shape [B, C, V] - cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale - logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV) - else: - logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] - - if enable_eos: - logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like( - logits_BxCxV[:, :, audio_eos_value + 1 :], - fill_value=-torch.inf, - ) - logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like( - logits_BxCxV[:, 1:, audio_eos_value:], - fill_value=-torch.inf, - ) - logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device) - - else: - logits_BxCxV[:, :, audio_eos_value:] = torch.full_like( - logits_BxCxV[:, :, audio_eos_value:], - fill_value=-torch.inf, - ) - - - flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1) - if do_sample: - pred_BC = self._sample_next_token( - flat_logits_BCxV.float(), - temperature=temperature, - top_p=top_p, - top_k=top_k, - audio_eos_value=audio_eos_value, - ) - else: - pred_BC = torch.argmax(flat_logits_BCxV, dim=1) - - pred_BxC = pred_BC.view(B, self.num_channels) - - return pred_BxC, model_kwargs - - def generate( - self, - input_ids, - attention_mask, - dec_output, - max_tokens, - min_tokens=None, - codec_input_ids: Optional[torch.Tensor] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - neg_input_size = 2, - cfg_scale = 3.0, - temperature: float = 1.2, - top_p: float = 0.95, - cfg_filter_top_k: int = 45, - eos_prob_mul_factor: float = 0.8, - do_sample: bool = True, - debug_guidance_step: int = 0, - use_cache=True, - ): - if codec_input_ids is not None: - assert use_cache - batch_size = input_ids.shape[0] // neg_input_size - audio_eos_value = self.config.codec_eos_value - audio_pad_value = self.config.codec_pad_value - delay_pattern = self.config.codec_delay_pattern - max_delay_pattern = max(delay_pattern) - delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long) - - dec_step = min(dec_output.prefill_steps) - 1 - - eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device) - eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) - finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) - - bos_over = False - model_kwargs = dict(attention_mask=attention_mask, use_cache=True) - model_kwargs["past_key_values"] = DynamicCache() - model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 - attention_mask = model_kwargs["attention_mask"] - past_key_values = model_kwargs["past_key_values"] - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device) - inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) - outputs = self.language_model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - use_cache=True, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - cache_position=cache_position, - ) - - model_kwargs["input_ids"] = input_ids - model_kwargs["codec_input_ids"] = None - model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100 - labels_Bx1xC = dec_output.get_labels_at(0) - if labels_Bx1xC is not None: - model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels) - assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0 - labels_Bx1xC = torch.full_like(labels_Bx1xC, -100) - model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1) - model_kwargs["past_key_values"] = outputs.past_key_values - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 - - while dec_step < max_tokens: - if (eos_countdown_Bx == 0).all(): - break - - current_step_idx = dec_step + 1 - tokens_Bx1xC = dec_output.get_tokens_at(dec_step) - labels_Bx1xC = dec_output.get_labels_at(dec_step + 1) - - pred_BxC, model_kwargs = self._decoder_step( - tokens_Bx1xC=tokens_Bx1xC, - model_kwargs=model_kwargs, - cfg_scale=cfg_scale, - neg_input_size=neg_input_size, - temperature=temperature, - top_p=top_p, - top_k=cfg_filter_top_k, - do_sample=do_sample, - eos_prob_mul_factor=eos_prob_mul_factor, - labels_Bx1xC=labels_Bx1xC, - use_cache=use_cache, - enable_eos=(min_tokens is None or dec_step >= min_tokens), - ) - if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1): - pred_BxC = labels_Bx1xC[:, 0] - - active_mask_Bx = eos_countdown_Bx != 0 - eos_trigger_Bx = torch.zeros_like(active_mask_Bx) - if active_mask_Bx.any(): - is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value) - is_max_len = current_step_idx >= max_tokens - max_delay_pattern - eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len - eos_detected_Bx |= eos_trigger_Bx - start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0) - if start_countdown_mask_Bx.any(): - eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern - finished_step_Bx[start_countdown_mask_Bx] = current_step_idx - - padding_mask_Bx = eos_countdown_Bx > 0 - if padding_mask_Bx.any(): - pred_active_BxC = pred_BxC[padding_mask_Bx].clone() - countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx] - step_after_eos_Bx = max_delay_pattern - countdown_active_Bx - step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1) - delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0) - eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_ - pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_ - pred_active_BxC[eos_mask_NxC] = audio_eos_value - pred_active_BxC[pad_mask_NxC] = audio_pad_value - pred_BxC[padding_mask_Bx] = pred_active_BxC - eos_countdown_Bx[padding_mask_Bx] -= 1 - - if not bos_over: - bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps) - - dec_output.update_one(pred_BxC, current_step_idx, not bos_over) - dec_step += 1 - - final_step = dec_step + 1 - finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern - prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device) - lengths_Bx = finished_step_Bx - prefill_steps_tensor - lengths_Bx = torch.clamp(lengths_Bx, min=0) - max_len = lengths_Bx.max().item() + max_delay_pattern - - if max_len > 0: - num_channels = self.num_channels - generated_codes = torch.full( - (batch_size, max_len, num_channels), - fill_value=audio_pad_value, - dtype=torch.long, - device=self.device, - ) - - for i in range(batch_size): - start_step = dec_output.prefill_steps[i] - actual_len = lengths_Bx[i].item() + max_delay_pattern - if actual_len > 0: - tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :] - generated_codes[i, :actual_len, :] = tokens_to_copy - - return generated_codes, lengths_Bx - else: - print("Warning: Nothing generated for any sequence in the batch.") - return None, None - -# AutoConfig.register("qwen2_5_vl_moe_text", Qwen2_5_VLMoETextConfig) -# AutoModelForCausalLM.register(Qwen2_5_VLMoETextConfig, Qwen2_5_VLMoETextModel) - -# AutoConfig.register("uni_audio_rvq_qwen2_5vl_moe", UniMoEAudioConfig) -# AutoModelForCausalLM.register(UniMoEAudioConfig, UniMoEAudio) +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2-VL model.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from transformers.configuration_utils import PretrainedConfig, layer_type_validation + +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.modeling_outputs import ( + ModelOutput, +) +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLVisionConfig, + Qwen2_5_VLTextConfig, + Qwen2_5_VLConfig, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLAttention, + Qwen2RMSNorm, + Qwen2_5_VLRotaryEmbedding, +) +from DCMoE import UniMoEAudioSparseMoeBlock +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel + +logger = logging.get_logger(__name__) + +FAST_INIT = True + +class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig): + model_type = "qwen2_5_vl_moe_text" + + def __init__( + self, + mlp_dynamic_expert_num=4, + mlp_dynamic_null_expert_num=0, + mlp_dynamic_top_p=0.7, + mlp_dynamic_top_k=2, + mlp_fixed_expert_num=2, + dynamic_intermediate_size=8960, + shared_intermediate_size=8960, + ignore_differentiable_router=False, + enable_expert_tensor_parallelism: bool = False, + ep_size=1, + fixed_ep_size=1, + router_jitter_noise=0.01, + input_jitter_noise=0.01, + token_drop=False, + drop_policy: str = "probs", + min_capacity: int = 8, + capacity_factor: float = 1.0, + fp32_gate=True, + avg_hidden_states_last=False, + drop_token_num_print=True, + **kwargs, + ): + + super().__init__(**kwargs) + self.mlp_dynamic_expert_num = mlp_dynamic_expert_num + self.mlp_dynamic_top_p = mlp_dynamic_top_p + self.mlp_dynamic_top_k = mlp_dynamic_top_k + self.mlp_fixed_expert_num = mlp_fixed_expert_num + self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num + self.dynamic_intermediate_size = dynamic_intermediate_size + self.shared_intermediate_size = shared_intermediate_size + self.ignore_differentiable_router = ignore_differentiable_router + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.ep_size = ep_size + self.fixed_ep_size = fixed_ep_size + self.input_jitter_noise = input_jitter_noise + self.router_jitter_noise = router_jitter_noise + self.token_drop = token_drop + self.drop_policy = drop_policy + self.min_capacity = min_capacity + self.capacity_factor = capacity_factor + self.fp32_gate = fp32_gate + self.avg_hidden_states_last = avg_hidden_states_last + self.drop_token_num_print = drop_token_num_print + +class UniMoEAudioConfig(PretrainedConfig): + model_type = "uni_audio_rvq_qwen2_5vl_moe" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=151655, + video_token_id=151656, + codec_vocab_size=1028, + codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15], + codec_channels=9, + codec_eos_value=1024, + codec_pad_value=1025, + codec_bos_value=1026, + codec_placeholder_value=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.codec_vocab_size = codec_vocab_size + self.codec_delay_pattern = codec_delay_pattern + self.codec_channels = codec_channels + self.codec_eos_value = codec_eos_value + self.codec_pad_value = codec_pad_value + self.codec_bos_value = codec_bos_value + self.codec_placeholder_value = codec_placeholder_value + + super().__init__(**kwargs) + +@dataclass +class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + all_router_logits: Tuple = None + all_router_top_k: Tuple = None + all_router_expert_mask: Tuple = None + all_router_weight: Tuple = None + aux_balance_loss: torch.FloatTensor = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutput): + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + all_router_logits: Tuple = None + all_router_top_k: Tuple = None + all_router_weight: Tuple = None + all_router_expert_mask: Tuple = None + all_aux_loss: Tuple = None + + +class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + self.mlp = UniMoEAudioSparseMoeBlock(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + padding_token_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits_and_topk: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits_and_topk: + outputs += (router_logits,) + outputs += (router_top_k,) + outputs += (router_expert_mask,) + outputs += (router_weight,) + outputs += (aux_loss,) + + return outputs + + +class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel): + config_class = UniMoEAudioConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_flash_attn_3 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if FAST_INIT: + if isinstance(module, UniMoEAudioSparseMoeBlock): + module.gate.weight.data.normal_(mean=0.0, std=std) + if module.gate.bias is not None: + module.gate.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + else: + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel): + config_class = Qwen2_5_VLMoETextConfig + def __init__(self, config: Qwen2_5_VLMoETextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_token_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits_and_topk: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits_and_topk else None + all_router_top_k = () if output_router_logits_and_topk else None + all_router_expert_mask = () + all_router_weight = () + all_aux_loss = () + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + padding_token_mask=padding_token_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits_and_topk=output_router_logits_and_topk, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits_and_topk: + all_router_logits += (layer_outputs[-5],) + all_router_top_k += (layer_outputs[-4],) + all_router_expert_mask += (layer_outputs[-3],) + all_router_weight += (layer_outputs[-2],) + all_aux_loss += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + all_router_logits, + all_router_top_k, + all_router_expert_mask, + all_router_weight, + all_aux_loss] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + all_router_logits=all_router_logits, + all_router_top_k=all_router_top_k, + all_router_expert_mask=all_router_expert_mask, + all_router_weight=all_router_weight, + all_aux_loss=all_aux_loss, + ) + + +class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel): + base_model_prefix = "" + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + config_class = UniMoEAudioConfig + _checkpoint_conversion_mapping = { + "^visual": "visual", + r"^model(?!\.(language_model|visual))": "language_model", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation) + self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config) + self.rope_deltas = None + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.num_channels = config.codec_channels + self.codec_vocab_size = config.codec_vocab_size + self.codec_embed_tokens = nn.ModuleList( + [nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)]) + self.codec_placeholder_value = config.codec_placeholder_value + self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False) + self.post_init() + + @property + def cur_aux_weight(self): + if self.training_steps >= self.l_aux_weight_decay_steps: + return self.min_l_aux_weight + return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + second_per_grid_t = torch.as_tensor( + second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device + ) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None): + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + + def codec_embedding(self, codec_input_ids): + x = None + for i in range(self.num_channels): + channel_tokens = codec_input_ids[..., i] + channel_embed = self.codec_embed_tokens[i](channel_tokens) + x = channel_embed if x is None else x + channel_embed + return x + + def calculate_input_embedding(self, input_ids, codec_input_ids): + inputs_embeds = self.language_model.embed_tokens(input_ids) + if codec_input_ids is not None: + codec_input_embeds = self.codec_embedding(codec_input_ids) + + codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds) + return inputs_embeds + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + codec_input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + codec_labels: Optional[torch.LongTensor] = None, + padding_token_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits_and_topk: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + + ) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]: + return_dict = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if inputs_embeds is None: + inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) + + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (image_mask).sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) + + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() + n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if padding_token_mask is None: + padding_token_mask = attention_mask.bool() + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + padding_token_mask=padding_token_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits_and_topk=output_router_logits_and_topk, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states).float() + codec_logits = self.codec_head(hidden_states).float() + codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size)) + + loss = None + if labels is not None: + + all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1] + all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0)) + aux_loss = self.cur_aux_weight * all_aux_loss + self.training_steps += 1 + codec_loss = None + + if codec_labels is not None: + for i in range(self.num_channels): + channel_logits = codec_logits[:, :, i].float() + channel_labels = codec_labels[:, :, i] + shift_channel_logits = channel_logits[..., :-1, :].contiguous() + shift_channel_labels = channel_labels[..., 1:].contiguous() + + if i!= 0 and (shift_channel_labels != -100).sum() == 0: + continue + + loss_fct = CrossEntropyLoss() + shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size) + shift_channel_labels = shift_channel_labels.view(-1) + shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device) + channel_loss = loss_fct(shift_channel_logits, shift_channel_labels) + codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss + + loss = codec_loss + aux_loss + + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoEQwen2_5VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + all_router_logits=outputs.all_router_logits, + all_router_top_k=outputs.all_router_top_k, + all_router_expert_mask=outputs.all_router_expert_mask, + all_router_weight=outputs.all_router_weight, + aux_balance_loss=all_aux_loss, + ) + + @staticmethod + def _sample_next_token( + logits_BCxV: torch.Tensor, + temperature: float, + top_p: float, + top_k: int, + audio_eos_value: int, + ) -> torch.Tensor: + if temperature == 0.0: + return torch.argmax(logits_BCxV, dim=-1) + + logits_BCxV = logits_BCxV / temperature + + if audio_eos_value is not None and audio_eos_value >= 0: + top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1) + eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value + mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool) + mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True + logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf) + + if top_k is not None: + _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1) + mask = torch.ones_like(logits_BCxV, dtype=torch.bool) + mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False) + logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) + + if top_p < 1.0: + probs_BCxV = torch.softmax(logits_BCxV, dim=-1) + sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) + cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) + + sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p + sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1) + sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0]) + + indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) + indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) + logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) + + final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) + + sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) + sampled_indices_C = sampled_indices_BC.squeeze(-1) + return sampled_indices_C + + def _decoder_step( + self, + tokens_Bx1xC: torch.Tensor, + model_kwargs, + cfg_scale: float, + neg_input_size: int, + temperature: float, + top_p: float, + top_k: int, + do_sample=True, + eos_prob_mul_factor=1.0, + labels_Bx1xC=None, + use_cache=True, + enable_eos=True, + ) -> torch.Tensor: + B = tokens_Bx1xC.shape[0] + audio_eos_value = self.config.codec_eos_value + attention_mask = model_kwargs["attention_mask"] + cache_position = model_kwargs["cache_position"] + past_key_values = model_kwargs["past_key_values"] + input_ids = model_kwargs["input_ids"] + codec_input_ids = model_kwargs["codec_input_ids"] + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :] + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0) + codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone() + input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1) + + if use_cache: + codec_input_embeds = self.codec_embedding(tokens_Bx1xC) + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=codec_input_embeds, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + cache_position=cache_position, + ) + + else: + batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels) + + inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids) + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=attention_mask.long().cumsum(-1) - 1, + past_key_values=None, + inputs_embeds=inputs_embeds, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + cache_position=None, + ) + + last_hidden_state = outputs.last_hidden_state + codec_logits = self.codec_head(last_hidden_state).float() + codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size)) + model_kwargs["past_key_values"] = outputs.past_key_values + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + model_kwargs["input_ids"] = input_ids + model_kwargs["codec_input_ids"] = codec_input_ids + + logits_Bx1xCxV = codec_logits[: , -1:].clone() + logits_last_2BxCxV = logits_Bx1xCxV[:, -1] + logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:]) + if cfg_scale is not None: + cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] + logits_BxCxV = cond_logits_BxCxV + for ni in range(neg_input_size - 1): + uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] # Shape [B, C, V] + cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale + logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV) + else: + logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] + + if enable_eos: + logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like( + logits_BxCxV[:, :, audio_eos_value + 1 :], + fill_value=-torch.inf, + ) + logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like( + logits_BxCxV[:, 1:, audio_eos_value:], + fill_value=-torch.inf, + ) + logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device) + + else: + logits_BxCxV[:, :, audio_eos_value:] = torch.full_like( + logits_BxCxV[:, :, audio_eos_value:], + fill_value=-torch.inf, + ) + + + flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1) + if do_sample: + pred_BC = self._sample_next_token( + flat_logits_BCxV.float(), + temperature=temperature, + top_p=top_p, + top_k=top_k, + audio_eos_value=audio_eos_value, + ) + else: + pred_BC = torch.argmax(flat_logits_BCxV, dim=1) + + pred_BxC = pred_BC.view(B, self.num_channels) + + return pred_BxC, model_kwargs + + def generate( + self, + input_ids, + attention_mask, + dec_output, + max_tokens, + min_tokens=None, + codec_input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + neg_input_size = 2, + cfg_scale = 3.0, + temperature: float = 1.2, + top_p: float = 0.95, + cfg_filter_top_k: int = 45, + eos_prob_mul_factor: float = 0.8, + do_sample: bool = True, + debug_guidance_step: int = 0, + use_cache=True, + ): + if codec_input_ids is not None: + assert use_cache + batch_size = input_ids.shape[0] // neg_input_size + audio_eos_value = self.config.codec_eos_value + audio_pad_value = self.config.codec_pad_value + delay_pattern = self.config.codec_delay_pattern + max_delay_pattern = max(delay_pattern) + delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long) + + dec_step = min(dec_output.prefill_steps) - 1 + + eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device) + eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) + finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) + + bos_over = False + model_kwargs = dict(attention_mask=attention_mask, use_cache=True) + model_kwargs["past_key_values"] = DynamicCache() + model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + attention_mask = model_kwargs["attention_mask"] + past_key_values = model_kwargs["past_key_values"] + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device) + inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + cache_position=cache_position, + ) + + model_kwargs["input_ids"] = input_ids + model_kwargs["codec_input_ids"] = None + model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100 + labels_Bx1xC = dec_output.get_labels_at(0) + if labels_Bx1xC is not None: + model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels) + assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0 + labels_Bx1xC = torch.full_like(labels_Bx1xC, -100) + model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1) + model_kwargs["past_key_values"] = outputs.past_key_values + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + + while dec_step < max_tokens: + if (eos_countdown_Bx == 0).all(): + break + + current_step_idx = dec_step + 1 + tokens_Bx1xC = dec_output.get_tokens_at(dec_step) + labels_Bx1xC = dec_output.get_labels_at(dec_step + 1) + + pred_BxC, model_kwargs = self._decoder_step( + tokens_Bx1xC=tokens_Bx1xC, + model_kwargs=model_kwargs, + cfg_scale=cfg_scale, + neg_input_size=neg_input_size, + temperature=temperature, + top_p=top_p, + top_k=cfg_filter_top_k, + do_sample=do_sample, + eos_prob_mul_factor=eos_prob_mul_factor, + labels_Bx1xC=labels_Bx1xC, + use_cache=use_cache, + enable_eos=(min_tokens is None or dec_step >= min_tokens), + ) + if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1): + pred_BxC = labels_Bx1xC[:, 0] + + active_mask_Bx = eos_countdown_Bx != 0 + eos_trigger_Bx = torch.zeros_like(active_mask_Bx) + if active_mask_Bx.any(): + is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value) + is_max_len = current_step_idx >= max_tokens - max_delay_pattern + eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len + eos_detected_Bx |= eos_trigger_Bx + start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0) + if start_countdown_mask_Bx.any(): + eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern + finished_step_Bx[start_countdown_mask_Bx] = current_step_idx + + padding_mask_Bx = eos_countdown_Bx > 0 + if padding_mask_Bx.any(): + pred_active_BxC = pred_BxC[padding_mask_Bx].clone() + countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx] + step_after_eos_Bx = max_delay_pattern - countdown_active_Bx + step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1) + delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0) + eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_ + pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_ + pred_active_BxC[eos_mask_NxC] = audio_eos_value + pred_active_BxC[pad_mask_NxC] = audio_pad_value + pred_BxC[padding_mask_Bx] = pred_active_BxC + eos_countdown_Bx[padding_mask_Bx] -= 1 + + if not bos_over: + bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps) + + dec_output.update_one(pred_BxC, current_step_idx, not bos_over) + dec_step += 1 + + final_step = dec_step + 1 + finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern + prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device) + lengths_Bx = finished_step_Bx - prefill_steps_tensor + lengths_Bx = torch.clamp(lengths_Bx, min=0) + max_len = lengths_Bx.max().item() + max_delay_pattern + + if max_len > 0: + num_channels = self.num_channels + generated_codes = torch.full( + (batch_size, max_len, num_channels), + fill_value=audio_pad_value, + dtype=torch.long, + device=self.device, + ) + + for i in range(batch_size): + start_step = dec_output.prefill_steps[i] + actual_len = lengths_Bx[i].item() + max_delay_pattern + if actual_len > 0: + tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :] + generated_codes[i, :actual_len, :] = tokens_to_copy + + return generated_codes, lengths_Bx + else: + print("Warning: Nothing generated for any sequence in the batch.") + return None, None + +# AutoConfig.register("qwen2_5_vl_moe_text", Qwen2_5_VLMoETextConfig) +# AutoModelForCausalLM.register(Qwen2_5_VLMoETextConfig, Qwen2_5_VLMoETextModel) + +# AutoConfig.register("uni_audio_rvq_qwen2_5vl_moe", UniMoEAudioConfig) +# AutoModelForCausalLM.register(UniMoEAudioConfig, UniMoEAudio)