plm_ola_internvl / modeling_internvl_chat copy.py
jjw0126's picture
Upload folder using huggingface_hub
35904d7 verified
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import warnings
import os
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import transformers
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers import LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3MoeForCausalLM
from .configuration_internvl_chat import InternVLChatConfig
from .conversation import get_conv_template
from .modeling_intern_vit import InternVisionModel, has_flash_attn
from .speech_encoder import DualWrappedEncoder
from .speech_projector import EncoderProjectorConcat
logger = logging.get_logger(__name__)
# Speech related constants
IGNORE_INDEX = -100
SPEECH_TOKEN_INDEX = -200
DEFAULT_SPEECH_TOKEN = "<speech>"
# Image related constants
IMAGE_TOKEN_INDEX = -201
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
"""Tokenize prompt with speech tokens, similar to OLA's implementation"""
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def version_cmp(v1, v2, op='eq'):
import operator
from packaging import version
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
# SpeechProjector is now imported from speech_projector.py as EncoderProjectorConcat
def build_speech_encoder(audio_config):
if audio_config.speech_encoder_type is None:
return None
return DualWrappedEncoder(audio_config)
def build_speech_projector(audio_config, llm_hidden_size):
# Check if fallback speech projector path is specified in config
fallback_path = getattr(audio_config, 'speech_projector', None)
if fallback_path and os.path.exists(fallback_path):
print(f"Loading speech projector from fallback path: {fallback_path}")
# Load the pretrained weights first to determine the expected dimensions
import torch
try:
state_dict = torch.load(fallback_path, map_location='cpu')
# Handle the 'model.speech_projector.' prefix in state_dict keys
speech_projector_state_dict = {}
for key, value in state_dict.items():
if key.startswith('model.speech_projector.'):
# Remove the 'model.speech_projector.' prefix
new_key = key.replace('model.speech_projector.', '')
speech_projector_state_dict[new_key] = value
else:
# If no prefix, use the key as-is
speech_projector_state_dict[key] = value
# Determine the expected input dimensions from the loaded weights
linear1_weight_shape = speech_projector_state_dict.get('linear1.weight', None)
if linear1_weight_shape is not None:
expected_input_dim = linear1_weight_shape.shape[1] # [out_features, in_features]
print(f"Detected expected input dimension from weights: {expected_input_dim}")
# Calculate the encoder hidden size and ds_rate that match this
# expected_input_dim = encoder_hidden_size * ds_rate
# We know current encoder outputs 2048 dim (combined Whisper+BEATs)
current_encoder_dim = 2048
required_ds_rate = expected_input_dim // current_encoder_dim
if expected_input_dim == current_encoder_dim * required_ds_rate:
print(f"Using ds_rate={required_ds_rate} to match loaded weights")
ds_rate = required_ds_rate
encoder_hidden_size = current_encoder_dim
else:
print(f"Warning: Cannot perfectly match dimensions. Expected {expected_input_dim}, current encoder {current_encoder_dim}")
print(f"Will use closest match: ds_rate={required_ds_rate}")
ds_rate = max(1, required_ds_rate) # Ensure at least 1
encoder_hidden_size = current_encoder_dim
else:
print("Warning: Could not determine input dimensions from weights, using defaults")
ds_rate = 5
encoder_hidden_size = 2048
except Exception as e:
print(f"Warning: Failed to analyze speech projector weights: {e}")
print("Using default dimensions")
ds_rate = 5
encoder_hidden_size = 2048
# Create a config with the determined dimensions
class ConfigWrapper:
def __init__(self, llm_hidden_size, ds_rate, encoder_hidden_size):
self.speech_encoder_ds_rate = ds_rate
self.speech_encoder_hidden_size = encoder_hidden_size
self.hidden_size = llm_hidden_size
wrapper_config = ConfigWrapper(llm_hidden_size, ds_rate, encoder_hidden_size)
projector = EncoderProjectorConcat(wrapper_config)
# Load the weights
try:
projector.load_state_dict(speech_projector_state_dict, strict=False)
print(f"Successfully loaded speech projector weights from {fallback_path}")
except Exception as e:
print(f"Warning: Failed to load speech projector weights: {e}")
print("Using randomly initialized speech projector")
return projector
# If no fallback path or normal speech encoder is configured
if audio_config.speech_encoder_type is None:
return None
class ConfigWrapper:
def __init__(self, audio_config, llm_hidden_size):
self.speech_encoder_ds_rate = audio_config.speech_encoder_ds_rate
self.speech_encoder_hidden_size = audio_config.speech_encoder_hidden_size
self.hidden_size = llm_hidden_size # Note: EncoderProjectorConcat uses hidden_size
wrapper_config = ConfigWrapper(audio_config, llm_hidden_size)
return EncoderProjectorConcat(wrapper_config)
class InternVLChatModel(PreTrainedModel):
config_class = InternVLChatConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
_no_split_modules = [
"InternVisionModel",
"Qwen3DecoderLayer",
]
# support transformers 4.51.+
_tp_plan = ''
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
super().__init__(config)
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = 'plm_v'
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
use_flash_attn = use_flash_attn if has_flash_attn else False
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
logger.info(f'num_image_token: {self.num_image_token}')
logger.info(f'ps_version: {self.ps_version}')
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = InternVisionModel(config.vision_config)
if language_model is not None:
self.language_model = language_model
else:
architecture: str = config.llm_config.architectures[0]
if architecture == 'LlamaForCausalLM':
self.language_model = LlamaForCausalLM(config.llm_config)
elif architecture == 'Qwen2ForCausalLM':
self.language_model = Qwen2ForCausalLM(config.llm_config)
elif architecture == 'Qwen3MoeForCausalLM':
self.language_model = Qwen3MoeForCausalLM(config.llm_config)
elif architecture == 'Qwen3ForCausalLM':
self.language_model = Qwen3ForCausalLM(config.llm_config)
else:
raise NotImplementedError(f'{architecture} is not implemented.')
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
# Initialize speech encoder and projector
self.speech_encoder = build_speech_encoder(config.audio_config)
self.speech_projector = build_speech_projector(config.audio_config, llm_hidden_size)
# Add a dimension adjustment layer if needed
self.speech_dim_adapter = None
if self.speech_projector is not None and hasattr(self.speech_projector, 'encoder_dim'):
expected_encoder_dim = self.speech_projector.encoder_dim
# Our current encoder outputs 2048 (combined Whisper+BEATs)
actual_encoder_dim = 2048
if expected_encoder_dim != actual_encoder_dim:
print(f"Adding dimension adapter: {actual_encoder_dim} -> {expected_encoder_dim}")
self.speech_dim_adapter = nn.Linear(actual_encoder_dim, expected_encoder_dim)
self.img_context_token_id = None
self.speech_context_token_id = None
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
def get_speech_encoder(self):
return self.speech_encoder
def get_speech_projector(self):
return self.speech_projector
def encode_speech(self, speech, speech_lengths, speech_wav):
"""Encode speech similar to Ola's implementation"""
speech_encoder = self.get_speech_encoder()
if speech_encoder is None:
return None
# Process raw_wav to handle BEATs input format requirements
processed_raw_wav = speech_wav
if speech_wav is not None and speech_wav.dim() == 2:
processed_raw_wav = speech_wav
elif speech_wav is not None and isinstance(speech_wav, list):
processed_raw_wav = torch.stack(speech_wav, dim=0)
# Call speech encoder with processed raw_wav parameter
try:
encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=processed_raw_wav)
except Exception as e:
print(f"⚠️ BEATs processing failed: {e}")
print("🔄 Falling back to Whisper-only processing")
encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=None)
speech_lengths = (speech_lengths + 1) // 2
# Apply dimension adapter if needed
if self.speech_dim_adapter is not None:
encoder_outs = self.speech_dim_adapter(encoder_outs)
# Apply speech projector using config-based approach like your version
speech_projector_type = getattr(self.config.audio_config, 'speech_projector_type', 'linear')
speech_projector = self.get_speech_projector()
if speech_projector_type == "linear" and speech_projector is not None:
encoder_outs = speech_projector(encoder_outs)
# Note: speech_projector.k is the downsampling rate
if hasattr(speech_projector, 'k'):
speech_lengths = speech_lengths // speech_projector.k
elif speech_projector_type != "linear":
raise ValueError(f'Unknown speech projector: {speech_projector_type}')
return encoder_outs
def prepare_inputs_labels_for_speech_vision_text(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
speech, speech_lengths, speech_chunks, speech_wav, pixel_values, modalities, image_sizes=None, image_flags=None
):
"""Prepare inputs similar to Ola's implementation"""
speech_encoder = self.speech_encoder
if speech_encoder is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
# Encode speech
if speech is not None:
if not isinstance(speech, list):
if speech_chunks is not None:
speech = torch.split(speech, speech_chunks.tolist(), dim=0)
speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0)
speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0)
else:
speech = [speech]
speech_lengths = [speech_lengths]
speech_wav = [speech_wav]
speech_features = []
for idx in range(len(speech)):
speech_feat = self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])
if speech_feat is not None:
speech_features.append(speech_feat)
else:
speech_features = []
# Encode vision
if isinstance(modalities, str):
modalities = [modalities]
image_features = []
if pixel_values is not None:
if image_flags is not None:
image_flags = image_flags.squeeze(-1)
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
# Apply vision projector
for idx in range(len(modalities)):
img_feat = self.mlp1(vit_embeds[idx:idx+1])
image_features.append(img_feat.flatten(0, 1))
else:
vit_embeds = self.extract_feature(pixel_values)
for idx in range(vit_embeds.shape[0]):
img_feat = vit_embeds[idx:idx+1]
image_features.append(img_feat.flatten(0, 1))
# Save original values
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# Remove padding using attention_mask
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_speech_idx = 0
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
num_speech_images = num_images + num_speech
if num_speech_images == 0:
# No speech or image tokens
cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
if len(speech_features) > cur_speech_idx:
cur_speech_features = speech_features[cur_speech_idx]
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0)
else:
cur_input_embeds = cur_input_embeds_1
if len(image_features) > cur_image_idx:
cur_images_features = image_features[cur_image_idx]
cur_input_embeds = torch.cat([cur_input_embeds, cur_images_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_speech_idx += 1
cur_image_idx += 1
continue
# Handle speech and image tokens
speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_nospeech_image = []
cur_labels = labels[batch_idx]
cur_labels_nospeech_image = []
for i in range(len(speech_image_token_indices) - 1):
cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_nospeech_image]
cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_nospeech_image))
cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
# Process tokens in order, similar to OLA's approach
speech_idx_in_sequence = 0
image_idx_in_sequence = 0
for i in range(num_speech_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i])
cur_new_labels.append(cur_labels_nospeech_image[i])
if i < num_speech_images:
# Determine which token type comes next based on position
if i < len(speech_image_token_indices) - 1:
token_pos = speech_image_token_indices[i + 1]
token_type = cur_input_ids[token_pos].item()
if token_type == SPEECH_TOKEN_INDEX and len(speech_features) > cur_speech_idx:
cur_speech_features = speech_features[cur_speech_idx]
cur_speech_idx += 1
cur_new_input_embeds.append(cur_speech_features)
cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
elif token_type == IMAGE_TOKEN_INDEX and len(image_features) > cur_image_idx:
cur_images_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_images_features)
cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
# Handle missing modalities
if num_images == 0 and len(image_features) > cur_image_idx:
cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0)
cur_image_idx += 1
if num_speech == 0 and len(speech_features) > cur_speech_idx:
cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0)
cur_speech_idx += 1
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine and pad
max_len = max(x.shape[0] for x in new_input_embeds) if new_input_embeds else 0
batch_size = len(new_input_embeds)
if max_len > 0:
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
else:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
def forward(
self,
pixel_values: torch.FloatTensor = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# Speech related parameters
speech: Optional[torch.FloatTensor] = None,
speech_lengths: Optional[torch.LongTensor] = None,
speech_chunks: Optional[torch.LongTensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
modalities: Optional[List[str]] = ["image"],
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Use Ola-style multimodal processing if speech or complex multimodal input is provided
if speech is not None or (pixel_values is not None and speech_chunks is not None):
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_speech_vision_text(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
speech,
speech_lengths,
speech_chunks,
speech_wav,
pixel_values,
modalities,
image_sizes=None,
image_flags=image_flags
)
if inputs_embeds is not None:
input_embeds = inputs_embeds
else:
# Fallback to simple processing
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids_flat = input_ids.reshape(B * N)
# Process speech input if provided
if speech is not None and hasattr(self, 'speech_context_token_id') and self.speech_context_token_id is not None:
speech_features = self.encode_speech(speech, speech_lengths, speech_wav)
if speech_features is not None:
speech_selected = (input_ids_flat == self.speech_context_token_id)
if speech_selected.sum() > 0:
try:
input_embeds[speech_selected] = input_embeds[speech_selected] * 0.0 + speech_features.reshape(-1, C)[:speech_selected.sum()]
except Exception as e:
print(f'warning: {e}, speech processing fallback')
n_token = min(speech_selected.sum(), speech_features.size(0))
input_embeds[speech_selected][:n_token] = input_embeds[speech_selected][:n_token] * 0.0 + speech_features.reshape(-1, C)[:n_token]
# Process vision input if provided
if pixel_values is not None:
image_flags = image_flags.squeeze(-1)
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
selected = (input_ids_flat == self.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = min(selected.sum(), vit_embeds.size(0))
input_embeds[selected][:n_token] = input_embeds[selected][:n_token] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
else:
# Original simple processing for vision-only inputs
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids_flat = input_ids.reshape(B * N)
# Process speech input if provided
if speech is not None and hasattr(self, 'speech_context_token_id') and self.speech_context_token_id is not None:
speech_features = self.encode_speech(speech, speech_lengths, speech_wav)
if speech_features is not None:
speech_selected = (input_ids_flat == self.speech_context_token_id)
if speech_selected.sum() > 0:
try:
input_embeds[speech_selected] = input_embeds[speech_selected] * 0.0 + speech_features.reshape(-1, C)[:speech_selected.sum()]
except Exception as e:
print(f'warning: {e}, speech processing fallback')
n_token = min(speech_selected.sum(), speech_features.size(0))
input_embeds[speech_selected][:n_token] = input_embeds[speech_selected][:n_token] * 0.0 + speech_features.reshape(-1, C)[:n_token]
# Process vision input if provided
if pixel_values is not None:
image_flags = image_flags.squeeze(-1)
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
selected = (input_ids_flat == self.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = min(selected.sum(), vit_embeds.size(0))
input_embeds[selected][:n_token] = input_embeds[selected][:n_token] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
'which results in a transposed image.')
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None,
speech=None, speech_lengths=None, speech_wav=None, SPEECH_CONTEXT_TOKEN='<SPEECH_CONTEXT>'):
if history is not None or return_history:
print('Now multi-turn chat is not supported in batch_chat.')
raise NotImplementedError
if image_counts is not None:
num_patches_list = image_counts
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
# Set up speech context token
if speech is not None:
speech_context_token_id = tokenizer.convert_tokens_to_ids(SPEECH_CONTEXT_TOKEN)
self.speech_context_token_id = speech_context_token_id
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
if verbose and speech is not None:
speech_bs = speech.shape[0]
print(f'speech batch size: {speech_bs}')
queries = []
for idx, num_patches in enumerate(num_patches_list):
question = questions[idx]
if pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
if speech is not None and '<speech>' not in question:
question = '<speech>\n' + question
template = get_conv_template(self.template)
template.system_message = self.system_message
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
# Replace image tokens
if pixel_values is not None:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
queries.append(query)
tokenizer.padding_side = 'left'
# Use OLA-style tokenization for speech inputs
if speech is not None:
input_ids = []
for idx, query in enumerate(queries):
if '<speech>' in query:
# Use OLA-style tokenization directly
tokens = tokenizer_speech_token(query, tokenizer, return_tensors='pt')
else:
# Replace speech tokens with context tokens for non-speech queries
speech_len = speech_lengths[idx] if speech_lengths is not None else speech.shape[1]
num_downsampled_frames = speech_len // self.config.audio_config.speech_encoder_ds_rate
num_speech_tokens = num_downsampled_frames + 3
speech_tokens = SPEECH_CONTEXT_TOKEN * num_speech_tokens
processed_query = query.replace('<speech>', speech_tokens, 1)
tokens = tokenizer(processed_query, return_tensors='pt').input_ids.squeeze(0)
input_ids.append(tokens)
# Pad sequences
max_len = max(len(ids) for ids in input_ids)
padded_input_ids = []
attention_mask = []
for ids in input_ids:
pad_len = max_len - len(ids)
if pad_len > 0:
padded_ids = torch.cat([torch.full((pad_len,), tokenizer.pad_token_id, dtype=ids.dtype), ids])
mask = torch.cat([torch.zeros(pad_len, dtype=torch.bool), torch.ones(len(ids), dtype=torch.bool)])
else:
padded_ids = ids
mask = torch.ones(len(ids), dtype=torch.bool)
padded_input_ids.append(padded_ids)
attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids).to(self.device)
attention_mask = torch.stack(attention_mask).to(self.device)
else:
# Replace speech tokens with context tokens for non-OLA processing
processed_queries = []
for idx, query in enumerate(queries):
if speech is not None and '<speech>' in query:
speech_len = speech_lengths[idx] if speech_lengths is not None else speech.shape[1]
num_downsampled_frames = speech_len // self.config.audio_config.speech_encoder_ds_rate
num_speech_tokens = num_downsampled_frames + 3
speech_tokens = SPEECH_CONTEXT_TOKEN * num_speech_tokens
query = query.replace('<speech>', speech_tokens, 1)
processed_queries.append(query)
model_inputs = tokenizer(processed_queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
speech=speech,
speech_lengths=speech_lengths,
speech_chunks=None,
speech_wav=speech_wav if speech_wav is not None else speech, # Use speech_wav if provided, otherwise fallback to speech
modalities=["image"],
**generation_config
)
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
responses = [response.split(template.sep.strip())[0].strip() for response in responses]
return responses
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
verbose=False, speech=None, speech_lengths=None, speech_wav=None, SPEECH_CONTEXT_TOKEN='<SPEECH_CONTEXT>'):
if history is None and pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
if history is None and speech is not None and '<speech>' not in question:
question = '<speech>\n' + question
if num_patches_list is None:
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
# Set up speech context token
if speech is not None:
speech_context_token_id = tokenizer.convert_tokens_to_ids(SPEECH_CONTEXT_TOKEN)
self.speech_context_token_id = speech_context_token_id
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
history = [] if history is None else history
for (old_question, old_answer) in history:
template.append_message(template.roles[0], old_question)
template.append_message(template.roles[1], old_answer)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
if verbose and speech is not None:
speech_bs = speech.shape[0]
print(f'speech batch size: {speech_bs}')
# Replace image tokens
for num_patches in num_patches_list:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
# Use OLA-style tokenization for speech inputs
if speech is not None and '<speech>' in query:
# Use OLA-style tokenization directly with <speech> tokens
input_ids = tokenizer_speech_token(query, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bool).to(self.device)
else:
# Replace speech tokens with context tokens for non-OLA processing
if speech is not None:
speech_len = speech_lengths[0] if speech_lengths is not None else speech.shape[1]
# Account for downsampling and special tokens (begin, end, newline)
num_downsampled_frames = speech_len // self.config.audio_config.speech_encoder_ds_rate
# Add 3 for begin, end, and newline tokens
num_speech_tokens = num_downsampled_frames + 3
speech_tokens = SPEECH_CONTEXT_TOKEN * num_speech_tokens
query = query.replace('<speech>', speech_tokens, 1)
model_inputs = tokenizer(query, return_tensors='pt')
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
speech=speech,
speech_lengths=speech_lengths,
speech_chunks=None,
speech_wav=speech_wav if speech_wav is not None else speech, # Use speech_wav if provided, otherwise fallback to speech
modalities=["image"],
**generation_config
)
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep.strip())[0].strip()
history.append((question, response))
if return_history:
return response, history
else:
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
if verbose:
print(query_to_print, response)
return response
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
speech: Optional[torch.FloatTensor] = None,
speech_lengths: Optional[torch.LongTensor] = None,
speech_chunks: Optional[torch.LongTensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
modalities: Optional[List[str]] = ["image"],
**generate_kwargs,
) -> torch.LongTensor:
# Use Ola-style multimodal processing if speech or complex multimodal input is provided
if speech is not None or (pixel_values is not None and speech_chunks is not None):
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_speech_vision_text(
input_ids,
None, # position_ids
attention_mask,
None, # past_key_values
None, # labels
speech,
speech_lengths,
speech_chunks,
speech_wav,
pixel_values,
modalities,
image_sizes=None,
image_flags=None
)
if inputs_embeds is not None:
input_embeds = inputs_embeds
else:
# Fallback to simple processing
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids_flat = input_ids.reshape(B * N)
# Process speech input if provided
if speech is not None and hasattr(self, 'speech_context_token_id') and self.speech_context_token_id is not None:
speech_features = self.encode_speech(speech, speech_lengths, speech_wav)
if speech_features is not None:
speech_selected = (input_ids_flat == self.speech_context_token_id)
if speech_selected.sum() > 0:
input_embeds[speech_selected] = speech_features.reshape(-1, C)[:speech_selected.sum()].to(input_embeds.device)
# Process vision input if provided
if pixel_values is not None:
assert self.img_context_token_id is not None
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
selected = (input_ids_flat == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
# Original simple processing for vision-only inputs
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids_flat = input_ids.reshape(B * N)
# Process vision input if provided
if pixel_values is not None:
assert self.img_context_token_id is not None
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
selected = (input_ids_flat == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**generate_kwargs,
)
return outputs
@property
def lm_head(self):
return self.language_model.get_output_embeddings()
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
return self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, value):
return self.language_model.set_output_embeddings(value)