# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen2-VL model.""" from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from transformers.configuration_utils import PretrainedConfig, layer_type_validation from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import ( ModelOutput, ) from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLVisionConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLConfig, ) from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLAttention, Qwen2RMSNorm, Qwen2_5_VLRotaryEmbedding, ) from DCMoE import UniMoEAudioSparseMoeBlock from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel logger = logging.get_logger(__name__) FAST_INIT = True class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig): model_type = "qwen2_5_vl_moe_text" def __init__( self, mlp_dynamic_expert_num=4, mlp_dynamic_null_expert_num=0, mlp_dynamic_top_p=0.7, mlp_dynamic_top_k=2, mlp_fixed_expert_num=2, dynamic_intermediate_size=8960, shared_intermediate_size=8960, ignore_differentiable_router=False, enable_expert_tensor_parallelism: bool = False, ep_size=1, fixed_ep_size=1, router_jitter_noise=0.01, input_jitter_noise=0.01, token_drop=False, drop_policy: str = "probs", min_capacity: int = 8, capacity_factor: float = 1.0, fp32_gate=True, avg_hidden_states_last=False, drop_token_num_print=True, **kwargs, ): super().__init__(**kwargs) self.mlp_dynamic_expert_num = mlp_dynamic_expert_num self.mlp_dynamic_top_p = mlp_dynamic_top_p self.mlp_dynamic_top_k = mlp_dynamic_top_k self.mlp_fixed_expert_num = mlp_fixed_expert_num self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num self.dynamic_intermediate_size = dynamic_intermediate_size self.shared_intermediate_size = shared_intermediate_size self.ignore_differentiable_router = ignore_differentiable_router self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism self.ep_size = ep_size self.fixed_ep_size = fixed_ep_size self.input_jitter_noise = input_jitter_noise self.router_jitter_noise = router_jitter_noise self.token_drop = token_drop self.drop_policy = drop_policy self.min_capacity = min_capacity self.capacity_factor = capacity_factor self.fp32_gate = fp32_gate self.avg_hidden_states_last = avg_hidden_states_last self.drop_token_num_print = drop_token_num_print class UniMoEAudioConfig(PretrainedConfig): model_type = "uni_audio_rvq_qwen2_5vl_moe" sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig} keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, text_config=None, vision_config=None, image_token_id=151655, video_token_id=151656, codec_vocab_size=1028, codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15], codec_channels=9, codec_eos_value=1024, codec_pad_value=1025, codec_bos_value=1026, codec_placeholder_value=None, **kwargs, ): if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif text_config is None: self.text_config = self.sub_configs["text_config"](**kwargs) self.image_token_id = image_token_id self.video_token_id = video_token_id self.codec_vocab_size = codec_vocab_size self.codec_delay_pattern = codec_delay_pattern self.codec_channels = codec_channels self.codec_eos_value = codec_eos_value self.codec_pad_value = codec_pad_value self.codec_bos_value = codec_bos_value self.codec_placeholder_value = codec_placeholder_value super().__init__(**kwargs) @dataclass class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None all_router_logits: Tuple = None all_router_top_k: Tuple = None all_router_expert_mask: Tuple = None all_router_weight: Tuple = None aux_balance_loss: torch.FloatTensor = None @dataclass class BaseModelOutputWithPast(ModelOutput): last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None all_router_logits: Tuple = None all_router_top_k: Tuple = None all_router_weight: Tuple = None all_router_expert_mask: Tuple = None all_aux_loss: Tuple = None class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size if config.use_sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) self.self_attn = Qwen2_5_VLAttention(config, layer_idx) self.mlp = UniMoEAudioSparseMoeBlock(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, padding_token_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits_and_topk: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if output_router_logits_and_topk: outputs += (router_logits,) outputs += (router_top_k,) outputs += (router_expert_mask,) outputs += (router_weight,) outputs += (aux_loss,) return outputs class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel): config_class = UniMoEAudioConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if FAST_INIT: if isinstance(module, UniMoEAudioSparseMoeBlock): module.gate.weight.data.normal_(mean=0.0, std=std) if module.gate.bias is not None: module.gate.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() else: if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Qwen2RMSNorm): module.weight.data.fill_(1.0) class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel): config_class = Qwen2_5_VLMoETextConfig def __init__(self, config: Qwen2_5_VLMoETextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, padding_token_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits_and_topk: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None and not torch.jit.is_tracing(): past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), } if self.has_sliding_layers: causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits_and_topk else None all_router_top_k = () if output_router_logits_and_topk else None all_router_expert_mask = () all_router_weight = () all_aux_loss = () next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], padding_token_mask=padding_token_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits_and_topk=output_router_logits_and_topk, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits_and_topk: all_router_logits += (layer_outputs[-5],) all_router_top_k += (layer_outputs[-4],) all_router_expert_mask += (layer_outputs[-3],) all_router_weight += (layer_outputs[-2],) all_aux_loss += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, past_key_values, all_hidden_states, all_self_attns, all_router_logits, all_router_top_k, all_router_expert_mask, all_router_weight, all_aux_loss] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, all_router_logits=all_router_logits, all_router_top_k=all_router_top_k, all_router_expert_mask=all_router_expert_mask, all_router_weight=all_router_weight, all_aux_loss=all_aux_loss, ) class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel): base_model_prefix = "" _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] config_class = UniMoEAudioConfig _checkpoint_conversion_mapping = { "^visual": "visual", r"^model(?!\.(language_model|visual))": "language_model", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation) self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config) self.rope_deltas = None self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.num_channels = config.codec_channels self.codec_vocab_size = config.codec_vocab_size self.codec_embed_tokens = nn.ModuleList( [nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)]) self.codec_placeholder_value = config.codec_placeholder_value self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False) self.post_init() @property def cur_aux_weight(self): if self.training_steps >= self.l_aux_weight_decay_steps: return self.min_l_aux_weight return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) if second_per_grid_ts is not None: second_per_grid_t = second_per_grid_ts[video_index] else: second_per_grid_t = 1.0 video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) second_per_grid_t = torch.as_tensor( second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device ) time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) .expand(3, input_ids.shape[0], -1) ) mrope_position_deltas = torch.zeros( [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, ) return position_ids, mrope_position_deltas def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None): pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(video_embeds, split_sizes) return video_embeds def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) return image_embeds def codec_embedding(self, codec_input_ids): x = None for i in range(self.num_channels): channel_tokens = codec_input_ids[..., i] channel_embed = self.codec_embed_tokens[i](channel_tokens) x = channel_embed if x is None else x + channel_embed return x def calculate_input_embedding(self, input_ids, codec_input_ids): inputs_embeds = self.language_model.embed_tokens(input_ids) if codec_input_ids is not None: codec_input_embeds = self.codec_embedding(codec_input_ids) codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds) return inputs_embeds @can_return_tuple def forward( self, input_ids: torch.LongTensor = None, codec_input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, codec_labels: Optional[torch.LongTensor] = None, padding_token_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits_and_topk: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]: return_dict = True output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if inputs_embeds is None: inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = torch.cat(image_embeds, dim=0) if input_ids is None: image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) image_mask = image_mask.all(-1) else: image_mask = input_ids == self.config.image_token_id n_image_tokens = (image_mask).sum() image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = torch.cat(video_embeds, dim=0) if input_ids is None: video_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) ) video_mask = video_mask.all(-1) else: video_mask = input_ids == self.config.video_token_id n_video_tokens = (video_mask).sum() n_video_features = video_embeds.shape[0] video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: attention_mask_tensor = ( attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] ) if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = (1.0 - attention_mask_tensor).int() prefill_compiled_stage = is_torchdynamo_compiling() and ( (input_ids is not None and input_ids.shape[1] != 1) or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) ) prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( (cache_position is not None and cache_position[0] == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0) ) if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask_tensor, ) self.rope_deltas = rope_deltas else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) if padding_token_mask is None: padding_token_mask = attention_mask.bool() outputs = self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, padding_token_mask=padding_token_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits_and_topk=output_router_logits_and_topk, return_dict=return_dict, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states).float() codec_logits = self.codec_head(hidden_states).float() codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size)) loss = None if labels is not None: all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1] all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0)) aux_loss = self.cur_aux_weight * all_aux_loss self.training_steps += 1 codec_loss = None if codec_labels is not None: for i in range(self.num_channels): channel_logits = codec_logits[:, :, i].float() channel_labels = codec_labels[:, :, i] shift_channel_logits = channel_logits[..., :-1, :].contiguous() shift_channel_labels = channel_labels[..., 1:].contiguous() if i!= 0 and (shift_channel_labels != -100).sum() == 0: continue loss_fct = CrossEntropyLoss() shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size) shift_channel_labels = shift_channel_labels.view(-1) shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device) channel_loss = loss_fct(shift_channel_logits, shift_channel_labels) codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss loss = codec_loss + aux_loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MoEQwen2_5VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, all_router_logits=outputs.all_router_logits, all_router_top_k=outputs.all_router_top_k, all_router_expert_mask=outputs.all_router_expert_mask, all_router_weight=outputs.all_router_weight, aux_balance_loss=all_aux_loss, ) @staticmethod def _sample_next_token( logits_BCxV: torch.Tensor, temperature: float, top_p: float, top_k: int, audio_eos_value: int, ) -> torch.Tensor: if temperature == 0.0: return torch.argmax(logits_BCxV, dim=-1) logits_BCxV = logits_BCxV / temperature if audio_eos_value is not None and audio_eos_value >= 0: top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1) eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool) mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf) if top_k is not None: _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1) mask = torch.ones_like(logits_BCxV, dtype=torch.bool) mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False) logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) if top_p < 1.0: probs_BCxV = torch.softmax(logits_BCxV, dim=-1) sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1) sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0]) indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) sampled_indices_C = sampled_indices_BC.squeeze(-1) return sampled_indices_C def _decoder_step( self, tokens_Bx1xC: torch.Tensor, model_kwargs, cfg_scale: float, neg_input_size: int, temperature: float, top_p: float, top_k: int, do_sample=True, eos_prob_mul_factor=1.0, labels_Bx1xC=None, use_cache=True, enable_eos=True, ) -> torch.Tensor: B = tokens_Bx1xC.shape[0] audio_eos_value = self.config.codec_eos_value attention_mask = model_kwargs["attention_mask"] cache_position = model_kwargs["cache_position"] past_key_values = model_kwargs["past_key_values"] input_ids = model_kwargs["input_ids"] codec_input_ids = model_kwargs["codec_input_ids"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :] position_ids = position_ids.clone(memory_format=torch.contiguous_format) tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0) codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone() input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1) if use_cache: codec_input_embeds = self.codec_embedding(tokens_Bx1xC) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=codec_input_embeds, use_cache=True, output_attentions=False, output_hidden_states=False, return_dict=True, cache_position=cache_position, ) else: batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels) inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=attention_mask.long().cumsum(-1) - 1, past_key_values=None, inputs_embeds=inputs_embeds, use_cache=True, output_attentions=False, output_hidden_states=False, return_dict=True, cache_position=None, ) last_hidden_state = outputs.last_hidden_state codec_logits = self.codec_head(last_hidden_state).float() codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size)) model_kwargs["past_key_values"] = outputs.past_key_values attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 model_kwargs["input_ids"] = input_ids model_kwargs["codec_input_ids"] = codec_input_ids logits_Bx1xCxV = codec_logits[: , -1:].clone() logits_last_2BxCxV = logits_Bx1xCxV[:, -1] logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:]) if cfg_scale is not None: cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] logits_BxCxV = cond_logits_BxCxV for ni in range(neg_input_size - 1): uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] # Shape [B, C, V] cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV) else: logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V] if enable_eos: logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like( logits_BxCxV[:, :, audio_eos_value + 1 :], fill_value=-torch.inf, ) logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like( logits_BxCxV[:, 1:, audio_eos_value:], fill_value=-torch.inf, ) logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device) else: logits_BxCxV[:, :, audio_eos_value:] = torch.full_like( logits_BxCxV[:, :, audio_eos_value:], fill_value=-torch.inf, ) flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1) if do_sample: pred_BC = self._sample_next_token( flat_logits_BCxV.float(), temperature=temperature, top_p=top_p, top_k=top_k, audio_eos_value=audio_eos_value, ) else: pred_BC = torch.argmax(flat_logits_BCxV, dim=1) pred_BxC = pred_BC.view(B, self.num_channels) return pred_BxC, model_kwargs def generate( self, input_ids, attention_mask, dec_output, max_tokens, min_tokens=None, codec_input_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, neg_input_size = 2, cfg_scale = 3.0, temperature: float = 1.2, top_p: float = 0.95, cfg_filter_top_k: int = 45, eos_prob_mul_factor: float = 0.8, do_sample: bool = True, debug_guidance_step: int = 0, use_cache=True, ): if codec_input_ids is not None: assert use_cache batch_size = input_ids.shape[0] // neg_input_size audio_eos_value = self.config.codec_eos_value audio_pad_value = self.config.codec_pad_value delay_pattern = self.config.codec_delay_pattern max_delay_pattern = max(delay_pattern) delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long) dec_step = min(dec_output.prefill_steps) - 1 eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device) eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) bos_over = False model_kwargs = dict(attention_mask=attention_mask, use_cache=True) model_kwargs["past_key_values"] = DynamicCache() model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 attention_mask = model_kwargs["attention_mask"] past_key_values = model_kwargs["past_key_values"] position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device) inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, use_cache=True, output_attentions=False, output_hidden_states=False, return_dict=True, cache_position=cache_position, ) model_kwargs["input_ids"] = input_ids model_kwargs["codec_input_ids"] = None model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100 labels_Bx1xC = dec_output.get_labels_at(0) if labels_Bx1xC is not None: model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels) assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0 labels_Bx1xC = torch.full_like(labels_Bx1xC, -100) model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1) model_kwargs["past_key_values"] = outputs.past_key_values attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 while dec_step < max_tokens: if (eos_countdown_Bx == 0).all(): break current_step_idx = dec_step + 1 tokens_Bx1xC = dec_output.get_tokens_at(dec_step) labels_Bx1xC = dec_output.get_labels_at(dec_step + 1) pred_BxC, model_kwargs = self._decoder_step( tokens_Bx1xC=tokens_Bx1xC, model_kwargs=model_kwargs, cfg_scale=cfg_scale, neg_input_size=neg_input_size, temperature=temperature, top_p=top_p, top_k=cfg_filter_top_k, do_sample=do_sample, eos_prob_mul_factor=eos_prob_mul_factor, labels_Bx1xC=labels_Bx1xC, use_cache=use_cache, enable_eos=(min_tokens is None or dec_step >= min_tokens), ) if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1): pred_BxC = labels_Bx1xC[:, 0] active_mask_Bx = eos_countdown_Bx != 0 eos_trigger_Bx = torch.zeros_like(active_mask_Bx) if active_mask_Bx.any(): is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value) is_max_len = current_step_idx >= max_tokens - max_delay_pattern eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len eos_detected_Bx |= eos_trigger_Bx start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0) if start_countdown_mask_Bx.any(): eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern finished_step_Bx[start_countdown_mask_Bx] = current_step_idx padding_mask_Bx = eos_countdown_Bx > 0 if padding_mask_Bx.any(): pred_active_BxC = pred_BxC[padding_mask_Bx].clone() countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx] step_after_eos_Bx = max_delay_pattern - countdown_active_Bx step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1) delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0) eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_ pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_ pred_active_BxC[eos_mask_NxC] = audio_eos_value pred_active_BxC[pad_mask_NxC] = audio_pad_value pred_BxC[padding_mask_Bx] = pred_active_BxC eos_countdown_Bx[padding_mask_Bx] -= 1 if not bos_over: bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps) dec_output.update_one(pred_BxC, current_step_idx, not bos_over) dec_step += 1 final_step = dec_step + 1 finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device) lengths_Bx = finished_step_Bx - prefill_steps_tensor lengths_Bx = torch.clamp(lengths_Bx, min=0) max_len = lengths_Bx.max().item() + max_delay_pattern if max_len > 0: num_channels = self.num_channels generated_codes = torch.full( (batch_size, max_len, num_channels), fill_value=audio_pad_value, dtype=torch.long, device=self.device, ) for i in range(batch_size): start_step = dec_output.prefill_steps[i] actual_len = lengths_Bx[i].item() + max_delay_pattern if actual_len > 0: tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :] generated_codes[i, :actual_len, :] = tokens_to_copy return generated_codes, lengths_Bx else: print("Warning: Nothing generated for any sequence in the batch.") return None, None # AutoConfig.register("qwen2_5_vl_moe_text", Qwen2_5_VLMoETextConfig) # AutoModelForCausalLM.register(Qwen2_5_VLMoETextConfig, Qwen2_5_VLMoETextModel) # AutoConfig.register("uni_audio_rvq_qwen2_5vl_moe", UniMoEAudioConfig) # AutoModelForCausalLM.register(UniMoEAudioConfig, UniMoEAudio)