| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from typing import List, Optional, Tuple, Union |
| |
|
| | from transformers import (AutoConfig, AutoModelForCausalLM, |
| | OlmoConfig, OlmoModel, OlmoForCausalLM) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.generation.utils import GenerateOutput |
| | from abc import ABC, abstractmethod |
| |
|
| | import re |
| | import os |
| | import math |
| | import random |
| | import shutil |
| | from .mm_utils import get_anyres_image_grid_shape, rank0_print |
| |
|
| | from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
| |
|
| | import torch |
| | from einops import rearrange, repeat |
| |
|
| | try: |
| | from einops_exts import rearrange_many |
| | except: |
| | pass |
| |
|
| |
|
| |
|
| | from torch import einsum |
| |
|
| | from torch import Tensor, device |
| | import torch.utils.checkpoint |
| | from torch.nn import CrossEntropyLoss |
| |
|
| | from transformers.activations import ACT2FN |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPastAndCrossAttentions, |
| | BaseModelOutputWithPoolingAndCrossAttentions, |
| | CausalLMOutputWithCrossAttentions, |
| | MaskedLMOutput, |
| | ) |
| | from transformers.modeling_utils import ( |
| | PreTrainedModel, |
| | apply_chunking_to_forward, |
| | find_pruneable_heads_and_indices, |
| | prune_linear_layer, |
| | ) |
| | from transformers.utils import logging |
| | from transformers.models.bert.configuration_bert import BertConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | |
| | class PoolerProjector(nn.Module): |
| | def __init__(self, config, vision_cfg): |
| | super().__init__() |
| | self._config = config |
| | self.hw = vision_cfg.image_size // vision_cfg.patch_size |
| |
|
| | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) |
| |
|
| | self.proj = nn.Sequential( |
| | nn.GELU(), |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | ) |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | height = width = self.hw |
| | assert height * width == x.shape[1] |
| | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) |
| | x = self.conv_pool(x) |
| | x = x.flatten(2).transpose(1, 2) |
| | x = self.proj(x) |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_projector_type": "pooler"} |
| |
|
| | class IdentityMap(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_projector_type": "identity"} |
| |
|
| |
|
| | class SimpleResBlock(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.pre_norm = nn.LayerNorm(channels) |
| |
|
| | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) |
| |
|
| | def forward(self, x): |
| | x = self.pre_norm(x) |
| | return x + self.proj(x) |
| |
|
| |
|
| | def build_vision_projector(config, delay_load=False, **kwargs): |
| | projector_type = getattr(config, "mm_projector_type", "linear") |
| |
|
| | if projector_type == "linear": |
| | return nn.Linear(config.mm_hidden_size, config.hidden_size) |
| |
|
| | if projector_type == "pooler": |
| | return PoolerProjector(config, kwargs["vision_cfg"]) |
| |
|
| | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
| | if mlp_gelu_match: |
| | mlp_depth = int(mlp_gelu_match.group(1)) |
| | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
| | for _ in range(1, mlp_depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
| | return nn.Sequential(*modules) |
| |
|
| | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) |
| | if mlp_gelu_resnet_match: |
| | mlp_depth = int(mlp_gelu_resnet_match.group(1)) |
| | res_depth = int(mlp_gelu_resnet_match.group(2)) |
| | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
| | for _ in range(1, mlp_depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
| | for _ in range(res_depth): |
| | modules.append(SimpleResBlock(config.hidden_size)) |
| | return nn.Sequential(*modules) |
| |
|
| | if projector_type == "identity": |
| | return IdentityMap() |
| |
|
| | raise ValueError(f"Unknown projector type: {projector_type}") |
| |
|
| | |
| | class SpatialPool(nn.Module): |
| | def __init__(self, model_args, vision_tower): |
| | super().__init__() |
| |
|
| | self.mode = model_args.mm_spatial_pool_mode |
| | self.stride = model_args.mm_spatial_pool_stride |
| | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) |
| |
|
| | if self.mode == "average": |
| | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) |
| | elif self.mode == "max": |
| | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) |
| | elif self.mode == "conv": |
| | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) |
| | else: |
| | raise ValueError(f"Unknown pooling mode: {self.pool}.") |
| |
|
| | def forward(self, image_features, images, *args, **kwargs): |
| | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) |
| | ori_H = int(ori_W * images.shape[2] // images.shape[3]) |
| |
|
| | B, _, F = image_features.shape |
| |
|
| | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) |
| | image_features_spatial_pool = self.pool(image_features_spatial) |
| |
|
| | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() |
| |
|
| | @property |
| | def config(self): |
| | return { |
| | "mm_resampler_type": "spatial_pool", |
| | "mm_spatial_pool_stride": self.stride, |
| | "mm_spatial_pool_mode": self.mode, |
| | "mm_spatial_pool_out_channels": self.out_channels, |
| | } |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.out_channels |
| |
|
| | def disabled_train(self, mode=True): |
| | """Overwrite model.train with this function to make sure train/eval mode |
| | does not change anymore.""" |
| | return self |
| |
|
| | |
| | class BertEmbeddings(nn.Module): |
| | """Construct the embeddings from word and position embeddings.""" |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| |
|
| | |
| | |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | |
| | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
| | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| |
|
| | self.config = config |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | position_ids=None, |
| | query_embeds=None, |
| | past_key_values_length=0, |
| | ): |
| | if input_ids is not None: |
| | seq_length = input_ids.size()[1] |
| | else: |
| | seq_length = 0 |
| |
|
| | if position_ids is None: |
| | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() |
| |
|
| | if input_ids is not None: |
| | embeddings = self.word_embeddings(input_ids) |
| | if self.position_embedding_type == "absolute": |
| | position_embeddings = self.position_embeddings(position_ids) |
| | embeddings = embeddings + position_embeddings |
| |
|
| | if query_embeds is not None: |
| | embeddings = torch.cat((query_embeds, embeddings), dim=1) |
| | else: |
| | embeddings = query_embeds |
| |
|
| | embeddings = self.LayerNorm(embeddings) |
| | embeddings = self.dropout(embeddings) |
| | return embeddings |
| |
|
| |
|
| | class BertSelfAttention(nn.Module): |
| | def __init__(self, config, is_cross_attention): |
| | super().__init__() |
| | self.config = config |
| | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
| | raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) |
| |
|
| | self.num_attention_heads = config.num_attention_heads |
| | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| | self.all_head_size = self.num_attention_heads * self.attention_head_size |
| |
|
| | self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| | if is_cross_attention: |
| | self.key = nn.Linear(config.encoder_width, self.all_head_size) |
| | self.value = nn.Linear(config.encoder_width, self.all_head_size) |
| | else: |
| | self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| |
|
| | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| | self.max_position_embeddings = config.max_position_embeddings |
| | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
| | self.save_attention = False |
| |
|
| | def save_attn_gradients(self, attn_gradients): |
| | self.attn_gradients = attn_gradients |
| |
|
| | def get_attn_gradients(self): |
| | return self.attn_gradients |
| |
|
| | def save_attention_map(self, attention_map): |
| | self.attention_map = attention_map |
| |
|
| | def get_attention_map(self): |
| | return self.attention_map |
| |
|
| | def transpose_for_scores(self, x): |
| | new_x_shape = x.size()[:-1] + ( |
| | self.num_attention_heads, |
| | self.attention_head_size, |
| | ) |
| | x = x.view(*new_x_shape) |
| | return x.permute(0, 2, 1, 3) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | attention_mask=None, |
| | head_mask=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_value=None, |
| | output_attentions=False, |
| | ): |
| |
|
| | |
| | |
| | |
| | is_cross_attention = encoder_hidden_states is not None |
| |
|
| | if is_cross_attention: |
| | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| | attention_mask = encoder_attention_mask |
| | elif past_key_value is not None: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| | else: |
| | key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| | value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| |
|
| | mixed_query_layer = self.query(hidden_states) |
| |
|
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| |
|
| | past_key_value = (key_layer, value_layer) |
| |
|
| | |
| | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| |
|
| | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| | seq_length = hidden_states.size()[1] |
| | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| | distance = position_ids_l - position_ids_r |
| | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
| |
|
| | if self.position_embedding_type == "relative_key": |
| | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| | attention_scores = attention_scores + relative_position_scores |
| | elif self.position_embedding_type == "relative_key_query": |
| | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
| | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
| |
|
| | attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| | if attention_mask is not None: |
| | |
| | attention_scores = attention_scores + attention_mask |
| |
|
| | |
| | attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| |
|
| | if is_cross_attention and self.save_attention: |
| | self.save_attention_map(attention_probs) |
| | attention_probs.register_hook(self.save_attn_gradients) |
| |
|
| | |
| | |
| | attention_probs_dropped = self.dropout(attention_probs) |
| |
|
| | |
| | if head_mask is not None: |
| | attention_probs_dropped = attention_probs_dropped * head_mask |
| |
|
| | context_layer = torch.matmul(attention_probs_dropped, value_layer) |
| |
|
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | context_layer = context_layer.view(*new_context_layer_shape) |
| |
|
| | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
| |
|
| | outputs = outputs + (past_key_value,) |
| | return outputs |
| |
|
| |
|
| | class BertSelfOutput(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|
| |
|
| | class BertAttention(nn.Module): |
| | def __init__(self, config, is_cross_attention=False): |
| | super().__init__() |
| | self.self = BertSelfAttention(config, is_cross_attention) |
| | self.output = BertSelfOutput(config) |
| | self.pruned_heads = set() |
| |
|
| | def prune_heads(self, heads): |
| | if len(heads) == 0: |
| | return |
| | heads, index = find_pruneable_heads_and_indices( |
| | heads, |
| | self.self.num_attention_heads, |
| | self.self.attention_head_size, |
| | self.pruned_heads, |
| | ) |
| |
|
| | |
| | self.self.query = prune_linear_layer(self.self.query, index) |
| | self.self.key = prune_linear_layer(self.self.key, index) |
| | self.self.value = prune_linear_layer(self.self.value, index) |
| | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
| |
|
| | |
| | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
| | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
| | self.pruned_heads = self.pruned_heads.union(heads) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | attention_mask=None, |
| | head_mask=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_value=None, |
| | output_attentions=False, |
| | ): |
| | self_outputs = self.self( |
| | hidden_states, |
| | attention_mask, |
| | head_mask, |
| | encoder_hidden_states, |
| | encoder_attention_mask, |
| | past_key_value, |
| | output_attentions, |
| | ) |
| | attention_output = self.output(self_outputs[0], hidden_states) |
| |
|
| | outputs = (attention_output,) + self_outputs[1:] |
| | return outputs |
| |
|
| |
|
| | class BertIntermediate(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| | if isinstance(config.hidden_act, str): |
| | self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| | else: |
| | self.intermediate_act_fn = config.hidden_act |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.intermediate_act_fn(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertOutput(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|
| |
|
| | class BertLayer(nn.Module): |
| | def __init__(self, config, layer_num): |
| | super().__init__() |
| | self.config = config |
| | self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| | self.seq_len_dim = 1 |
| | self.attention = BertAttention(config) |
| | self.layer_num = layer_num |
| | if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0: |
| | self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) |
| | self.has_cross_attention = True |
| | else: |
| | self.has_cross_attention = False |
| | self.intermediate = BertIntermediate(config) |
| | self.output = BertOutput(config) |
| |
|
| | self.intermediate_query = BertIntermediate(config) |
| | self.output_query = BertOutput(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | attention_mask=None, |
| | head_mask=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_value=None, |
| | output_attentions=False, |
| | query_length=0, |
| | ): |
| | |
| | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| | self_attention_outputs = self.attention( |
| | hidden_states, |
| | attention_mask, |
| | head_mask, |
| | output_attentions=output_attentions, |
| | past_key_value=self_attn_past_key_value, |
| | ) |
| | attention_output = self_attention_outputs[0] |
| | outputs = self_attention_outputs[1:-1] |
| |
|
| | present_key_value = self_attention_outputs[-1] |
| |
|
| | if query_length > 0: |
| | query_attention_output = attention_output[:, :query_length, :] |
| |
|
| | if self.has_cross_attention: |
| | assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" |
| | cross_attention_outputs = self.crossattention( |
| | query_attention_output, |
| | attention_mask, |
| | head_mask, |
| | encoder_hidden_states, |
| | encoder_attention_mask, |
| | output_attentions=output_attentions, |
| | ) |
| | query_attention_output = cross_attention_outputs[0] |
| | outputs = outputs + cross_attention_outputs[1:-1] |
| |
|
| | layer_output = apply_chunking_to_forward( |
| | self.feed_forward_chunk_query, |
| | self.chunk_size_feed_forward, |
| | self.seq_len_dim, |
| | query_attention_output, |
| | ) |
| | if attention_output.shape[1] > query_length: |
| | layer_output_text = apply_chunking_to_forward( |
| | self.feed_forward_chunk, |
| | self.chunk_size_feed_forward, |
| | self.seq_len_dim, |
| | attention_output[:, query_length:, :], |
| | ) |
| | layer_output = torch.cat([layer_output, layer_output_text], dim=1) |
| | else: |
| | layer_output = apply_chunking_to_forward( |
| | self.feed_forward_chunk, |
| | self.chunk_size_feed_forward, |
| | self.seq_len_dim, |
| | attention_output, |
| | ) |
| | outputs = (layer_output,) + outputs |
| |
|
| | outputs = outputs + (present_key_value,) |
| |
|
| | return outputs |
| |
|
| | def feed_forward_chunk(self, attention_output): |
| | intermediate_output = self.intermediate(attention_output) |
| | layer_output = self.output(intermediate_output, attention_output) |
| | return layer_output |
| |
|
| | def feed_forward_chunk_query(self, attention_output): |
| | intermediate_output = self.intermediate_query(attention_output) |
| | layer_output = self.output_query(intermediate_output, attention_output) |
| | return layer_output |
| |
|
| |
|
| | class BertEncoder(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | attention_mask=None, |
| | head_mask=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_values=None, |
| | use_cache=None, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | query_length=0, |
| | ): |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attentions = () if output_attentions else None |
| | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
| |
|
| | next_decoder_cache = () if use_cache else None |
| |
|
| | for i in range(self.config.num_hidden_layers): |
| | layer_module = self.layer[i] |
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | layer_head_mask = head_mask[i] if head_mask is not None else None |
| | past_key_value = past_key_values[i] if past_key_values is not None else None |
| |
|
| | if getattr(self.config, "gradient_checkpointing", False) and self.training: |
| |
|
| | if use_cache: |
| | logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") |
| | use_cache = False |
| |
|
| | def create_custom_forward(module): |
| | def custom_forward(*inputs): |
| | return module(*inputs, past_key_value, output_attentions, query_length) |
| |
|
| | return custom_forward |
| |
|
| | layer_outputs = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(layer_module), |
| | hidden_states, |
| | attention_mask, |
| | layer_head_mask, |
| | encoder_hidden_states, |
| | encoder_attention_mask, |
| | ) |
| | else: |
| | layer_outputs = layer_module( |
| | hidden_states, |
| | attention_mask, |
| | layer_head_mask, |
| | encoder_hidden_states, |
| | encoder_attention_mask, |
| | past_key_value, |
| | output_attentions, |
| | query_length, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| | if use_cache: |
| | next_decoder_cache += (layer_outputs[-1],) |
| | if output_attentions: |
| | all_self_attentions = all_self_attentions + (layer_outputs[1],) |
| | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) |
| |
|
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | if not return_dict: |
| | return tuple( |
| | v |
| | for v in [ |
| | hidden_states, |
| | next_decoder_cache, |
| | all_hidden_states, |
| | all_self_attentions, |
| | all_cross_attentions, |
| | ] |
| | if v is not None |
| | ) |
| | return BaseModelOutputWithPastAndCrossAttentions( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_decoder_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attentions, |
| | cross_attentions=all_cross_attentions, |
| | ) |
| |
|
| |
|
| | class BertPooler(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.activation = nn.Tanh() |
| |
|
| | def forward(self, hidden_states): |
| | |
| | |
| | first_token_tensor = hidden_states[:, 0] |
| | pooled_output = self.dense(first_token_tensor) |
| | pooled_output = self.activation(pooled_output) |
| | return pooled_output |
| |
|
| |
|
| | class BertPredictionHeadTransform(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | if isinstance(config.hidden_act, str): |
| | self.transform_act_fn = ACT2FN[config.hidden_act] |
| | else: |
| | self.transform_act_fn = config.hidden_act |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.transform_act_fn(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertLMPredictionHead(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.transform = BertPredictionHeadTransform(config) |
| |
|
| | |
| | |
| | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
| |
|
| | |
| | self.decoder.bias = self.bias |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.transform(hidden_states) |
| | hidden_states = self.decoder(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class BertOnlyMLMHead(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.predictions = BertLMPredictionHead(config) |
| |
|
| | def forward(self, sequence_output): |
| | prediction_scores = self.predictions(sequence_output) |
| | return prediction_scores |
| |
|
| |
|
| | class BertPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = BertConfig |
| | base_model_prefix = "bert" |
| | _keys_to_ignore_on_load_missing = [r"position_ids"] |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| |
|
| | class BertModel(BertPreTrainedModel): |
| | """ |
| | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
| | cross-attention is added between the self-attention layers, following the architecture described in `Attention is |
| | all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
| | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
| | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an |
| | input to the forward pass. |
| | """ |
| |
|
| | def __init__(self, config, add_pooling_layer=False): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.embeddings = BertEmbeddings(config) |
| |
|
| | self.encoder = BertEncoder(config) |
| |
|
| | self.pooler = BertPooler(config) if add_pooling_layer else None |
| |
|
| | self.init_weights() |
| |
|
| | def get_input_embeddings(self): |
| | return self.embeddings.word_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embeddings.word_embeddings = value |
| |
|
| | def _prune_heads(self, heads_to_prune): |
| | """ |
| | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| | class PreTrainedModel |
| | """ |
| | for layer, heads in heads_to_prune.items(): |
| | self.encoder.layer[layer].attention.prune_heads(heads) |
| |
|
| | def get_extended_attention_mask( |
| | self, |
| | attention_mask: Tensor, |
| | input_shape: Tuple[int], |
| | device: device, |
| | is_decoder: bool, |
| | has_query: bool = False, |
| | ) -> Tensor: |
| | """ |
| | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
| | |
| | Arguments: |
| | attention_mask (:obj:`torch.Tensor`): |
| | Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
| | input_shape (:obj:`Tuple[int]`): |
| | The shape of the input to the model. |
| | device: (:obj:`torch.device`): |
| | The device of the input to the model. |
| | |
| | Returns: |
| | :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. |
| | """ |
| | |
| | |
| | if attention_mask.dim() == 3: |
| | extended_attention_mask = attention_mask[:, None, :, :] |
| | elif attention_mask.dim() == 2: |
| | |
| | |
| | |
| | if is_decoder: |
| | batch_size, seq_length = input_shape |
| |
|
| | seq_ids = torch.arange(seq_length, device=device) |
| | causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] |
| |
|
| | |
| | |
| | causal_mask = causal_mask.to(attention_mask.dtype) |
| |
|
| | if causal_mask.shape[1] < attention_mask.shape[1]: |
| | prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] |
| | if has_query: |
| | causal_mask = torch.cat( |
| | [ |
| | torch.zeros( |
| | (batch_size, prefix_seq_len, seq_length), |
| | device=device, |
| | dtype=causal_mask.dtype, |
| | ), |
| | causal_mask, |
| | ], |
| | axis=1, |
| | ) |
| | causal_mask = torch.cat( |
| | [ |
| | torch.ones( |
| | (batch_size, causal_mask.shape[1], prefix_seq_len), |
| | device=device, |
| | dtype=causal_mask.dtype, |
| | ), |
| | causal_mask, |
| | ], |
| | axis=-1, |
| | ) |
| | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] |
| | else: |
| | extended_attention_mask = attention_mask[:, None, None, :] |
| | else: |
| | raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) |
| | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| | return extended_attention_mask |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | position_ids=None, |
| | head_mask=None, |
| | query_embeds=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_values=None, |
| | use_cache=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | is_decoder=False, |
| | ): |
| | r""" |
| | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
| | the model is configured as a decoder. |
| | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
| | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
| | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
| | use_cache (:obj:`bool`, `optional`): |
| | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
| | decoding (see :obj:`past_key_values`). |
| | """ |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| |
|
| | if input_ids is None: |
| | assert query_embeds is not None, "You have to specify query_embeds when input_ids is None" |
| |
|
| | |
| | past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 |
| |
|
| | query_length = query_embeds.shape[1] if query_embeds is not None else 0 |
| |
|
| | embedding_output = self.embeddings( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | query_embeds=query_embeds, |
| | past_key_values_length=past_key_values_length, |
| | ) |
| |
|
| | input_shape = embedding_output.size()[:-1] |
| | batch_size, seq_length = input_shape |
| | device = embedding_output.device |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| |
|
| | |
| | |
| | if is_decoder: |
| | extended_attention_mask = self.get_extended_attention_mask( |
| | attention_mask, |
| | input_ids.shape, |
| | device, |
| | is_decoder, |
| | has_query=(query_embeds is not None), |
| | ) |
| | else: |
| | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder) |
| |
|
| | |
| | |
| | if encoder_hidden_states is not None: |
| | if type(encoder_hidden_states) == list: |
| | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() |
| | else: |
| | ( |
| | encoder_batch_size, |
| | encoder_sequence_length, |
| | _, |
| | ) = encoder_hidden_states.size() |
| | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| |
|
| | if type(encoder_attention_mask) == list: |
| | encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] |
| | elif encoder_attention_mask is None: |
| | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| | else: |
| | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| | else: |
| | encoder_extended_attention_mask = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
| |
|
| | encoder_outputs = self.encoder( |
| | embedding_output, |
| | attention_mask=extended_attention_mask, |
| | head_mask=head_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_extended_attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | query_length=query_length, |
| | ) |
| | sequence_output = encoder_outputs[0] |
| | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
| |
|
| | if not return_dict: |
| | return (sequence_output, pooled_output) + encoder_outputs[1:] |
| |
|
| | return BaseModelOutputWithPoolingAndCrossAttentions( |
| | last_hidden_state=sequence_output, |
| | pooler_output=pooled_output, |
| | past_key_values=encoder_outputs.past_key_values, |
| | hidden_states=encoder_outputs.hidden_states, |
| | attentions=encoder_outputs.attentions, |
| | cross_attentions=encoder_outputs.cross_attentions, |
| | ) |
| |
|
| |
|
| | class BertLMHeadModel(BertPreTrainedModel): |
| |
|
| | _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.cls = BertOnlyMLMHead(config) |
| |
|
| | self.init_weights() |
| |
|
| | def get_output_embeddings(self): |
| | return self.cls.predictions.decoder |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.cls.predictions.decoder = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | position_ids=None, |
| | head_mask=None, |
| | query_embeds=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | labels=None, |
| | past_key_values=None, |
| | use_cache=True, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | return_logits=False, |
| | is_decoder=True, |
| | reduction="mean", |
| | ): |
| | r""" |
| | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
| | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
| | the model is configured as a decoder. |
| | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
| | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are |
| | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` |
| | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
| | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
| | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
| | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
| | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
| | use_cache (:obj:`bool`, `optional`): |
| | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
| | decoding (see :obj:`past_key_values`). |
| | Returns: |
| | Example:: |
| | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig |
| | >>> import torch |
| | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
| | >>> config = BertConfig.from_pretrained("bert-base-cased") |
| | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) |
| | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
| | >>> outputs = model(**inputs) |
| | >>> prediction_logits = outputs.logits |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | if labels is not None: |
| | use_cache = False |
| | if past_key_values is not None: |
| | query_embeds = None |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | query_embeds=query_embeds, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | is_decoder=is_decoder, |
| | ) |
| |
|
| | sequence_output = outputs[0] |
| | if query_embeds is not None: |
| | sequence_output = outputs[0][:, query_embeds.shape[1] :, :] |
| |
|
| | prediction_scores = self.cls(sequence_output) |
| |
|
| | if return_logits: |
| | return prediction_scores[:, :-1, :].contiguous() |
| |
|
| | lm_loss = None |
| | if labels is not None: |
| | |
| | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() |
| | labels = labels[:, 1:].contiguous() |
| | loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) |
| | lm_loss = loss_fct( |
| | shifted_prediction_scores.view(-1, self.config.vocab_size), |
| | labels.view(-1), |
| | ) |
| | if reduction == "none": |
| | lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) |
| |
|
| | if not return_dict: |
| | output = (prediction_scores,) + outputs[2:] |
| | return ((lm_loss,) + output) if lm_loss is not None else output |
| |
|
| | return CausalLMOutputWithCrossAttentions( |
| | loss=lm_loss, |
| | logits=prediction_scores, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | cross_attentions=outputs.cross_attentions, |
| | ) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): |
| | |
| | if attention_mask is None: |
| | attention_mask = input_ids.new_ones(input_ids.shape) |
| | query_mask = input_ids.new_ones(query_embeds.shape[:-1]) |
| | attention_mask = torch.cat([query_mask, attention_mask], dim=-1) |
| |
|
| | |
| | if past is not None: |
| | input_ids = input_ids[:, -1:] |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "query_embeds": query_embeds, |
| | "attention_mask": attention_mask, |
| | "past_key_values": past, |
| | "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), |
| | "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), |
| | "is_decoder": True, |
| | } |
| |
|
| | def _reorder_cache(self, past, beam_idx): |
| | reordered_past = () |
| | for layer_past in past: |
| | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
| | return reordered_past |
| |
|
| |
|
| | class BertForMaskedLM(BertPreTrainedModel): |
| |
|
| | _keys_to_ignore_on_load_unexpected = [r"pooler"] |
| | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.cls = BertOnlyMLMHead(config) |
| |
|
| | self.init_weights() |
| |
|
| | def get_output_embeddings(self): |
| | return self.cls.predictions.decoder |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.cls.predictions.decoder = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | position_ids=None, |
| | head_mask=None, |
| | query_embeds=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | return_logits=False, |
| | is_decoder=False, |
| | ): |
| | r""" |
| | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., |
| | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored |
| | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` |
| | """ |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | query_embeds=query_embeds, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | is_decoder=is_decoder, |
| | ) |
| |
|
| | if query_embeds is not None: |
| | sequence_output = outputs[0][:, query_embeds.shape[1] :, :] |
| | prediction_scores = self.cls(sequence_output) |
| |
|
| | if return_logits: |
| | return prediction_scores |
| |
|
| | masked_lm_loss = None |
| | if labels is not None: |
| | loss_fct = CrossEntropyLoss() |
| | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
|
| | if not return_dict: |
| | output = (prediction_scores,) + outputs[2:] |
| | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
| |
|
| | return MaskedLMOutput( |
| | loss=masked_lm_loss, |
| | logits=prediction_scores, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | class Qformer(nn.Module): |
| | def __init__(self, model_args, vision_tower): |
| | super().__init__() |
| |
|
| | self.depth = model_args.mm_qformer_depth |
| | self.num_latents = model_args.mm_qformer_latents |
| | self.pretrained = model_args.mm_qformer_pretrained |
| |
|
| | self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents) |
| |
|
| | if self.pretrained is not None: |
| | pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] |
| | pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")} |
| | self.load_state_dict(pretrained_dict) |
| |
|
| | def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): |
| | encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
| | encoder_config.encoder_width = vision_width |
| | |
| | encoder_config.add_cross_attention = True |
| | encoder_config.cross_attention_freq = cross_attention_freq |
| | encoder_config.query_length = num_query_token |
| | Qformer = BertLMHeadModel(config=encoder_config) |
| | query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) |
| | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
| | Qformer.cls = None |
| | Qformer.bert.embeddings.word_embeddings = None |
| | Qformer.bert.embeddings.position_embeddings = None |
| | for layer in Qformer.bert.encoder.layer: |
| | layer.output = None |
| | layer.intermediate = None |
| | return Qformer, query_tokens, nn.LayerNorm(vision_width) |
| |
|
| | def forward(self, image_features, *args, **kwargs): |
| | x = self.ln_vision(image_features) |
| | image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) |
| |
|
| | query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) |
| | query_output = self.Qformer.bert( |
| | query_embeds=query_tokens, |
| | encoder_hidden_states=x, |
| | encoder_attention_mask=image_atts, |
| | return_dict=True, |
| | ) |
| |
|
| | return query_output.last_hidden_state |
| |
|
| | @property |
| | def hidden_size(self): |
| | return 768 |
| |
|
| | @property |
| | def config(self): |
| | return { |
| | "mm_resampler_type": "qformer", |
| | "mm_qformer_depth": self.depth, |
| | "mm_qformer_latents": self.num_latents, |
| | "mm_qformer_pretrained": self.pretrained, |
| | } |
| |
|
| |
|
| | |
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def FeedForward(dim, mult=4): |
| | inner_dim = int(dim * mult) |
| | return nn.Sequential( |
| | nn.LayerNorm(dim), |
| | nn.Linear(dim, inner_dim, bias=False), |
| | nn.GELU(), |
| | nn.Linear(inner_dim, dim, bias=False), |
| | ) |
| |
|
| |
|
| | class PerceiverAttention(nn.Module): |
| | def __init__(self, *, dim, dim_head=64, heads=8): |
| | super().__init__() |
| | self.scale = dim_head**-0.5 |
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm_media = nn.LayerNorm(dim) |
| | self.norm_latents = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, dim, bias=False) |
| |
|
| | def forward(self, x, latents): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, T, n1, D) |
| | latent (torch.Tensor): latent features |
| | shape (b, T, n2, D) |
| | """ |
| | x = self.norm_media(x) |
| | latents = self.norm_latents(latents) |
| |
|
| | h = self.heads |
| |
|
| | q = self.to_q(latents) |
| | kv_input = torch.cat((x, latents), dim=-2) |
| | k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
| | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) |
| | q = q * self.scale |
| |
|
| | |
| | sim = einsum("... i d, ... j d -> ... i j", q, k) |
| | sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| | attn = sim.softmax(dim=-1) |
| |
|
| | out = einsum("... i j, ... j d -> ... i d", attn, v) |
| | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) |
| | return self.to_out(out) |
| |
|
| |
|
| | class PerceiverResamplerModule(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | depth=6, |
| | dim_head=64, |
| | heads=8, |
| | num_latents=64, |
| | max_num_media=None, |
| | max_num_frames=None, |
| | ff_mult=4, |
| | ): |
| | super().__init__() |
| | self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
| | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None |
| | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(depth): |
| | self.layers.append( |
| | nn.ModuleList( |
| | [ |
| | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
| | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), |
| | ] |
| | ) |
| | ) |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, T, F, v, D) |
| | Returns: |
| | shape (b, T, n, D) where n is self.num_latents |
| | """ |
| | b, T, F, v = x.shape[:4] |
| |
|
| | |
| | if exists(self.frame_embs): |
| | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) |
| | x = x + frame_embs |
| | x = rearrange(x, "b T F v d -> b T (F v) d") |
| | if exists(self.media_time_embs): |
| | x = x + self.media_time_embs[:T] |
| |
|
| | |
| | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) |
| | for attn, ff in self.layers: |
| | latents = attn(x, latents) + latents |
| | latents = ff(latents) + latents |
| | return self.norm(latents) |
| |
|
| |
|
| | class PerceiverResampler(nn.Module): |
| | def __init__(self, model_args, vision_tower): |
| | super().__init__() |
| |
|
| | self.depth = model_args.mm_perceiver_depth |
| | self.num_latents = model_args.mm_perceiver_latents |
| | self.ff_mult = model_args.mm_perceiver_ff_mult |
| | self.pretrained = model_args.mm_perceiver_pretrained |
| |
|
| | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) |
| |
|
| | if self.pretrained is not None: |
| | self.load_state_dict(torch.load(self.pretrained)) |
| |
|
| | def forward(self, image_features, *args, **kwargs): |
| | return self.perceiver(image_features[:, None, None]).squeeze(1) |
| |
|
| | @property |
| | def config(self): |
| | return { |
| | "mm_resampler_type": "perceiver", |
| | "mm_perceiver_depth": self.depth, |
| | "mm_perceiver_latents": self.num_latents, |
| | "mm_perceiver_ff_mult": self.ff_mult, |
| | "mm_perceiver_pretrained": self.pretrained, |
| | } |
| |
|
| | |
| | class MaskedDrop(nn.Module): |
| | def __init__(self, model_args): |
| | super().__init__() |
| |
|
| | self.mode = model_args.mm_mask_drop_mode |
| | self.skip_percentage = model_args.mm_mask_drop_skip_percentage |
| | self.ratio = model_args.mm_mask_drop_ratio |
| | self.ratio_upper = model_args.mm_mask_drop_ratio_upper |
| | self.ratio_lower = model_args.mm_mask_drop_ratio_lower |
| |
|
| | def forward(self, image_features, *args, **kwargs): |
| |
|
| | if not self.training: |
| | return image_features |
| |
|
| | if self.skip_percentage > random.random(): |
| | return image_features |
| |
|
| | masked_features = [] |
| |
|
| | for image_feature in image_features: |
| | num_tokens = image_feature.shape[0] |
| | if self.mode == "fixed": |
| | num_keep = int(num_tokens * self.ratio) |
| | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) |
| | elif self.mode == "range": |
| | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) |
| | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) |
| | elif self.mode == "cls_only": |
| | masked_features.append(image_feature[0:1]) |
| | else: |
| | raise ValueError(f"Unexpected masked drop mode: {self.mode}") |
| |
|
| | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): |
| | masked_features = torch.stack(masked_features, dim=0) |
| |
|
| | return masked_features |
| |
|
| | @property |
| | def config(self): |
| | return { |
| | "mm_resampler_type": "masked_drop", |
| | "mm_mask_drop_mode": self.mode, |
| | "mm_mask_drop_skip_percentage": self.skip_percentage, |
| | "mm_mask_drop_ratio": self.ratio, |
| | "mm_mask_drop_ratio_upper": self.ratio_upper, |
| | "mm_mask_drop_ratio_lower": self.ratio_lower, |
| | } |
| |
|
| | def random_masking(self, x, len_keep): |
| | """ |
| | Perform per-sample random masking by per-sample shuffling. |
| | Per-sample shuffling is done by argsort random noise. |
| | x: [N, L, D], sequence |
| | """ |
| | N, L, D = x.shape |
| |
|
| | noise = torch.rand(N, L, device=x.device) |
| |
|
| | |
| | ids_shuffle = torch.argsort(noise, dim=1) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| |
|
| | |
| | ids_keep = ids_shuffle[:, :len_keep] |
| | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | |
| | mask = torch.gather(mask, dim=1, index=ids_restore) |
| |
|
| | return x_masked, mask, ids_restore |
| |
|
| | class IdentityMap(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_resampler_type": None} |
| |
|
| | |
| | def build_vision_resampler(model_args, delay_load=False, **kwargs): |
| | resampler_type = getattr(model_args, "mm_resampler_type", None) |
| | if resampler_type == "masked_drop": |
| | return MaskedDrop(model_args) |
| | elif resampler_type == "spatial_pool": |
| | return SpatialPool(model_args, **kwargs) |
| | elif resampler_type == "perceiver": |
| | return PerceiverResampler(model_args, **kwargs) |
| | elif resampler_type == "qformer": |
| | return Qformer(model_args, **kwargs) |
| | elif resampler_type is None: |
| | return IdentityMap() |
| |
|
| | raise ValueError(f"Unknown resampler type: {resampler_type}") |
| |
|
| | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig |
| |
|
| | |
| | class CLIPVisionTower(nn.Module): |
| | r""" |
| | A class to represent the CLIP Vision Tower model. |
| | |
| | Attributes : |
| | ------------ |
| | - is_loaded (bool): A flag indicating whether the model is loaded. |
| | - vision_tower_name (str): The name of the vision tower model. |
| | - select_layer (int): The layer to select features from. |
| | - select_feature (str): The type of feature to select. |
| | |
| | Methods : |
| | ------------ |
| | - `__init__(vision_tower: str, args: Namespace, delay_load: bool = False)`: Initializes the CLIPVisionTower with the given vision tower name and arguments. |
| | - `load_model(device_map: Optional[dict] = None)`: Loads the vision tower model and image processor. |
| | - `feature_select(image_forward_outs: Any) -> torch.Tensor`: Selects features from the image forward outputs based on the specified feature type. |
| | - `forward(images: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor`: Forward pass for the vision tower model. |
| | - `dummy_feature() -> torch.Tensor`: Returns a dummy feature tensor. |
| | - `dtype() -> torch.dtype`: Returns the data type of the vision tower model. |
| | - `device() -> torch.device`: Returns the device of the vision tower model. |
| | - `config() -> Any`: Returns the configuration of the vision tower model. |
| | - `hidden_size() -> int`: Returns the hidden size of the vision tower model. |
| | - `num_patches_per_side() -> int`: Returns the number of patches per side of the image. |
| | - `num_patches() -> int`: Returns the total number of patches in the image. |
| | - `image_size() -> int`: Returns the size of the image. |
| | """ |
| |
|
| | def __init__(self, vision_tower, args, delay_load=False): |
| | super().__init__() |
| |
|
| | self.is_loaded = False |
| |
|
| | self.vision_tower_name = vision_tower |
| | self.select_layer = args.mm_vision_select_layer |
| | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| |
|
| | if not delay_load: |
| | rank0_print(f"Loading vision tower: {vision_tower}") |
| | self.load_model() |
| | elif getattr(args, "unfreeze_mm_vision_tower", False): |
| | |
| | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") |
| | self.load_model() |
| | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: |
| | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") |
| | self.load_model() |
| | else: |
| | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) |
| |
|
| | def load_model(self, device_map=None): |
| | if self.is_loaded: |
| | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) |
| | return |
| |
|
| | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) |
| | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) |
| | self.vision_tower.requires_grad_(False) |
| |
|
| | self.is_loaded = True |
| |
|
| | def feature_select(self, image_forward_outs): |
| | select_feature_type = self.select_feature |
| |
|
| | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: |
| | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 |
| | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) |
| | select_feature_type = select_feature_type.replace("slicefour_", "") |
| | elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: |
| | select_layers = [-2, -5, -8, -11, 6] |
| | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) |
| | select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") |
| | else: |
| | image_features = image_forward_outs.hidden_states[self.select_layer] |
| |
|
| | if select_feature_type == "patch": |
| | image_features = image_features[:, 1:] |
| | elif select_feature_type == "cls_patch": |
| | image_features = image_features |
| | else: |
| | raise ValueError(f"Unexpected select feature: {select_feature_type}") |
| | return image_features |
| |
|
| | def forward(self, images): |
| | if type(images) is list: |
| | image_features = [] |
| | for image in images: |
| | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) |
| | image_feature = self.feature_select(image_forward_out).to(image.dtype) |
| | image_features.append(image_feature) |
| | else: |
| | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
| | image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| |
|
| | return image_features |
| |
|
| | @property |
| | def dummy_feature(self): |
| | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
| |
|
| | @property |
| | def dtype(self): |
| | return self.vision_tower.dtype |
| |
|
| | @property |
| | def device(self): |
| | return self.vision_tower.device |
| |
|
| | @property |
| | def config(self): |
| | if self.is_loaded: |
| | return self.vision_tower.config |
| | else: |
| | return self.cfg_only |
| |
|
| | @property |
| | def hidden_size(self): |
| | _hidden_size = self.config.hidden_size |
| | if "slicefour" in self.select_feature: |
| | _hidden_size *= 4 |
| | if "slice_m25811_f6" in self.select_feature: |
| | _hidden_size *= 5 |
| | return _hidden_size |
| |
|
| | @property |
| | def num_patches_per_side(self): |
| | return self.config.image_size // self.config.patch_size |
| |
|
| | @property |
| | def num_patches(self): |
| | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 |
| | if "cls_patch" in self.select_feature: |
| | _num_patches += 1 |
| | return _num_patches |
| |
|
| | @property |
| | def image_size(self): |
| | return self.config.image_size |
| |
|
| | def build_vision_tower(vision_tower_cfg, **kwargs): |
| | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
| | is_absolute_path_exists = os.path.exists(vision_tower) |
| | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: |
| | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
| |
|
| | raise ValueError(f"Unknown vision tower: {vision_tower}") |
| |
|
| | class InstellaVLMetaModel: |
| |
|
| | def __init__(self, config): |
| | super(InstellaVLMetaModel, self).__init__(config) |
| |
|
| | if hasattr(config, "mm_vision_tower"): |
| | delay_load = getattr(config, "delay_load", False) |
| | self.vision_tower = build_vision_tower(config, delay_load=delay_load) |
| | self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) |
| | self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) |
| |
|
| | if "unpad" in getattr(config, "mm_patch_merge_type", ""): |
| | self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) |
| |
|
| | def get_vision_tower(self): |
| | vision_tower = getattr(self, "vision_tower", None) |
| | if type(vision_tower) is list: |
| | vision_tower = vision_tower[0] |
| | return vision_tower |
| |
|
| | def initialize_vision_modules(self, model_args, fsdp=None): |
| | vision_tower = model_args.vision_tower |
| | mm_vision_select_layer = model_args.mm_vision_select_layer |
| | mm_vision_select_feature = model_args.mm_vision_select_feature |
| | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter |
| | mm_patch_merge_type = model_args.mm_patch_merge_type |
| |
|
| | self.config.mm_vision_tower = vision_tower |
| | self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") |
| |
|
| | if self.get_vision_tower() is None: |
| | vision_tower = build_vision_tower(model_args) |
| | vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) |
| | for k, v in vision_resampler.config.items(): |
| | setattr(self.config, k, v) |
| |
|
| | if fsdp is not None and len(fsdp) > 0: |
| | self.vision_tower = [vision_tower] |
| | self.vision_resampler = [vision_resampler] |
| | else: |
| | self.vision_tower = vision_tower |
| | self.vision_resampler = vision_resampler |
| | else: |
| | if fsdp is not None and len(fsdp) > 0: |
| | vision_resampler = self.vision_resampler[0] |
| | vision_tower = self.vision_tower[0] |
| | else: |
| | vision_resampler = self.vision_resampler |
| | vision_tower = self.vision_tower |
| | vision_tower.load_model() |
| |
|
| | |
| | for p in self.vision_resampler.parameters(): |
| | p.requires_grad = True |
| |
|
| | self.config.use_mm_proj = True |
| | self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") |
| | self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) |
| | self.config.mm_vision_select_layer = mm_vision_select_layer |
| | self.config.mm_vision_select_feature = mm_vision_select_feature |
| | self.config.mm_patch_merge_type = mm_patch_merge_type |
| | self.config.online_training = model_args.online_training |
| |
|
| | if getattr(self, "mm_projector", None) is None: |
| | self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) |
| |
|
| | if "unpad" in mm_patch_merge_type: |
| | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) |
| | self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) |
| | else: |
| | |
| | for p in self.mm_projector.parameters(): |
| | p.requires_grad = True |
| |
|
| | if pretrain_mm_mlp_adapter is not None: |
| | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") |
| |
|
| | def get_w(weights, keyword): |
| | return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} |
| |
|
| | incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) |
| | rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") |
| | incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) |
| | rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") |
| | |
| | if 'tmp-' in pretrain_mm_mlp_adapter: |
| | pretrain_mm_mlp_adapter_folder = os.path.dirname(pretrain_mm_mlp_adapter) |
| | shutil.rmtree(pretrain_mm_mlp_adapter_folder, ignore_errors=True) |
| | |
| |
|
| |
|
| | def unpad_image(tensor, original_size): |
| | """ |
| | Unpads a PyTorch tensor of a padded and resized image. |
| | |
| | Args: |
| | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
| | original_size (tuple): The original size of the image (height, width). |
| | |
| | Returns: |
| | torch.Tensor: The unpadded image tensor. |
| | """ |
| | original_width, original_height = original_size |
| | current_height, current_width = tensor.shape[1:] |
| |
|
| | |
| | original_aspect_ratio = original_width / original_height |
| | current_aspect_ratio = current_width / current_height |
| |
|
| | |
| | if original_aspect_ratio > current_aspect_ratio: |
| | |
| | scale_factor = current_width / original_width |
| | new_height = int(original_height * scale_factor) |
| | padding = (current_height - new_height) // 2 |
| | unpadded_tensor = tensor[:, padding : current_height - padding, :] |
| | else: |
| | |
| | scale_factor = current_height / original_height |
| | new_width = int(original_width * scale_factor) |
| | padding = (current_width - new_width) // 2 |
| | unpadded_tensor = tensor[:, :, padding : current_width - padding] |
| |
|
| | return unpadded_tensor |
| |
|
| |
|
| | class InstellaVLMetaForCausalLM(ABC): |
| |
|
| | @abstractmethod |
| | def get_model(self): |
| | pass |
| |
|
| | def get_vision_tower(self): |
| | return self.get_model().get_vision_tower() |
| |
|
| | def get_2dPool(self, image_feature): |
| | height = width = self.get_vision_tower().num_patches_per_side |
| | num_frames, num_tokens, num_dim = image_feature.shape |
| | image_feature = image_feature.view(num_frames, height, width, -1) |
| | image_feature = image_feature.permute(0, 3, 1, 2).contiguous() |
| | |
| | if self.config.mm_spatial_pool_mode == "average": |
| | image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride) |
| | elif self.config.mm_spatial_pool_mode == "max": |
| | image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) |
| | elif self.config.mm_spatial_pool_mode == "bilinear": |
| | height, weight = image_feature.shape[2:] |
| | scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)] |
| | image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') |
| |
|
| | else: |
| | raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") |
| | image_feature = image_feature.permute(0, 2, 3, 1) |
| | image_feature = image_feature.view(num_frames, -1, num_dim) |
| | return image_feature |
| |
|
| | def encode_images(self, images): |
| | image_features = self.get_model().get_vision_tower()(images) |
| | |
| | image_features = self.get_model().mm_projector(image_features) |
| | return image_features |
| | |
| | def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): |
| | videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) |
| | per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) |
| | all_videos_or_images_features = [] |
| |
|
| | for idx, feat in enumerate(per_videos_or_images_features): |
| | feat = self.get_model().mm_projector(feat) |
| | if idx in video_idx_in_batch: |
| | feat = self.get_2dPool(feat) |
| | all_videos_or_images_features.append(feat) |
| | return all_videos_or_images_features |
| |
|
| | def add_token_per_grid(self, image_feature): |
| | resize_h = int(math.sqrt(image_feature.shape[1])) |
| | num_frames = image_feature.shape[0] |
| | image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) |
| | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| | image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| | image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
| | image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| | return image_feature |
| |
|
| | def add_token_per_frame(self, image_feature): |
| | image_feature = image_feature.permute(2, 0, 1).contiguous() |
| | image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
| | image_feature = image_feature.permute(1, 2, 0).contiguous() |
| | return image_feature |
| |
|
| | def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): |
| | vision_tower = self.get_vision_tower() |
| | |
| | if vision_tower is None or images is None or input_ids.shape[1] == 1: |
| | return input_ids, position_ids, attention_mask, past_key_values, None, labels |
| |
|
| | if isinstance(modalities, str): |
| | modalities = [modalities] |
| |
|
| | if type(images) is list or images.ndim == 5: |
| | if type(images) is list: |
| | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] |
| |
|
| | video_idx_in_batch = [] |
| | for _ in range(len(modalities)): |
| | if modalities[_] == "video": |
| | video_idx_in_batch.append(_) |
| |
|
| | |
| |
|
| | images_list = [] |
| | for image in images: |
| | if image.ndim == 4: |
| | images_list.append(image) |
| | else: |
| | images_list.append(image.unsqueeze(0)) |
| |
|
| | |
| | concat_images = torch.cat([image for image in images_list], dim=0) |
| | split_sizes = [image.shape[0] for image in images_list] |
| | encoded_image_features = self.encode_images(concat_images) |
| | |
| | |
| |
|
| | |
| | |
| | encoded_image_features = torch.split(encoded_image_features, split_sizes) |
| | image_features = [] |
| | for idx, image_feat in enumerate(encoded_image_features): |
| | if idx in video_idx_in_batch: |
| | image_features.append(self.get_2dPool(image_feat)) |
| | else: |
| | image_features.append(image_feat) |
| | |
| | |
| | |
| | mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") |
| | image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") |
| |
|
| | if mm_patch_merge_type == "flat": |
| | image_features = [x.flatten(0, 1) for x in image_features] |
| |
|
| | elif mm_patch_merge_type.startswith("spatial"): |
| | new_image_features = [] |
| | for image_idx, image_feature in enumerate(image_features): |
| | |
| | |
| | |
| | |
| | |
| | if image_idx in video_idx_in_batch: |
| | |
| | if self.config.mm_newline_position == "grid": |
| | |
| | image_feature = self.add_token_per_grid(image_feature) |
| | |
| | new_image_features.append(image_feature) |
| | elif self.config.mm_newline_position == "frame": |
| | |
| | image_feature = self.add_token_per_frame(image_feature) |
| |
|
| | new_image_features.append(image_feature.flatten(0, 1)) |
| | |
| | elif self.config.mm_newline_position == "one_token": |
| | |
| | image_feature = image_feature.flatten(0, 1) |
| | if 'unpad' in mm_patch_merge_type: |
| | image_feature = torch.cat(( |
| | image_feature, |
| | self.model.image_newline[None].to(image_feature.device) |
| | ), dim=0) |
| | new_image_features.append(image_feature) |
| | elif self.config.mm_newline_position == "no_token": |
| | new_image_features.append(image_feature.flatten(0, 1)) |
| | else: |
| | raise ValueError(f"Unexpected mm_newline_position: {self.config.mm_newline_position}") |
| |
|
| |
|
| | elif image_feature.shape[0] > 1: |
| | base_image_feature = image_feature[0] |
| | image_feature = image_feature[1:] |
| | height = width = self.get_vision_tower().num_patches_per_side |
| |
|
| | assert height * width == base_image_feature.shape[0] |
| |
|
| | if "anyres_max" in image_aspect_ratio: |
| | matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) |
| | if matched_anyres_max_num_patches: |
| | max_num_patches = int(matched_anyres_max_num_patches.group(1)) |
| |
|
| | if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: |
| | if hasattr(self.get_vision_tower(), "image_size"): |
| | vision_tower_image_size = self.get_vision_tower().image_size |
| | else: |
| | raise ValueError("vision_tower_image_size is not found in the vision tower.") |
| | try: |
| | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) |
| | except Exception as e: |
| | rank0_print(f"Error: {e}") |
| | num_patch_width, num_patch_height = 2, 2 |
| | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
| | else: |
| | image_feature = image_feature.view(2, 2, height, width, -1) |
| |
|
| | if "maxpool2x2" in mm_patch_merge_type: |
| | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| | image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| | image_feature = nn.functional.max_pool2d(image_feature, 2) |
| | image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| | elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: |
| | unit = image_feature.shape[2] |
| | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| | image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| | image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
| | c, h, w = image_feature.shape |
| | times = math.sqrt(h * w / (max_num_patches * unit**2)) |
| | if times > 1.1: |
| | image_feature = image_feature[None] |
| | image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] |
| | image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
| | image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| | elif "unpad" in mm_patch_merge_type: |
| | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| | image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| | image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
| | image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
| | image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| | else: |
| | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() |
| | image_feature = image_feature.flatten(0, 3) |
| | if "nobase" in mm_patch_merge_type: |
| | pass |
| | else: |
| | image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
| | else: |
| | image_feature = image_feature[0] |
| | if "unpad" in mm_patch_merge_type: |
| | image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) |
| |
|
| | new_image_features.append(image_feature) |
| | image_features = new_image_features |
| | else: |
| | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") |
| | else: |
| | image_features = self.encode_images(images) |
| |
|
| | |
| | if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): |
| | raise NotImplementedError |
| | |
| |
|
| | |
| | |
| | |
| | |
| | _labels = labels |
| | _position_ids = position_ids |
| | _attention_mask = attention_mask |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| | else: |
| | attention_mask = attention_mask.bool() |
| | if position_ids is None: |
| | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| | if labels is None: |
| | labels = torch.full_like(input_ids, IGNORE_INDEX) |
| |
|
| | |
| | _input_ids = input_ids |
| | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] |
| | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
| |
|
| | new_input_embeds = [] |
| | new_labels = [] |
| | cur_image_idx = 0 |
| | |
| | for batch_idx, cur_input_ids in enumerate(input_ids): |
| | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| | |
| | if num_images == 0: |
| | try: |
| | cur_image_features = image_features[cur_image_idx] |
| | except IndexError: |
| | try: |
| | cur_image_features = image_features[cur_image_idx - 1] |
| | except IndexError: |
| | pass |
| | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) |
| | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) |
| | new_input_embeds.append(cur_input_embeds) |
| | new_labels.append(labels[batch_idx]) |
| | cur_image_idx += 1 |
| | continue |
| |
|
| | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| | cur_input_ids_noim = [] |
| | cur_labels = labels[batch_idx] |
| | cur_labels_noim = [] |
| | for i in range(len(image_token_indices) - 1): |
| | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
| | cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
| | split_sizes = [x.shape[0] for x in cur_labels_noim] |
| | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) |
| | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
| | cur_new_input_embeds = [] |
| | cur_new_labels = [] |
| |
|
| | for i in range(num_images + 1): |
| | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| | cur_new_labels.append(cur_labels_noim[i]) |
| | if i < num_images: |
| | try: |
| | cur_image_features = image_features[cur_image_idx] |
| | except IndexError: |
| | cur_image_features = image_features[cur_image_idx - 1] |
| | cur_image_idx += 1 |
| | cur_new_input_embeds.append(cur_image_features) |
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) |
| |
|
| | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
| |
|
| | |
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
| | cur_new_labels = torch.cat(cur_new_labels) |
| |
|
| | new_input_embeds.append(cur_new_input_embeds) |
| | new_labels.append(cur_new_labels) |
| |
|
| | |
| | tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) |
| | |
| |
|
| | new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] |
| | new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] |
| | |
| | |
| | |
| | |
| |
|
| | |
| | max_len = max(x.shape[0] for x in new_input_embeds) |
| | batch_size = len(new_input_embeds) |
| |
|
| | new_input_embeds_padded = [] |
| | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) |
| | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
| | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
| | |
| |
|
| | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
| | cur_len = cur_new_embed.shape[0] |
| | if getattr(self.config, "tokenizer_padding_side", "right") == "left": |
| | new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) |
| | if cur_len > 0: |
| | new_labels_padded[i, -cur_len:] = cur_new_labels |
| | attention_mask[i, -cur_len:] = True |
| | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
| | else: |
| | new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) |
| | if cur_len > 0: |
| | new_labels_padded[i, :cur_len] = cur_new_labels |
| | attention_mask[i, :cur_len] = True |
| | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
| |
|
| | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
| | |
| |
|
| | if _labels is None: |
| | new_labels = None |
| | else: |
| | new_labels = new_labels_padded |
| |
|
| | if _attention_mask is None: |
| | attention_mask = None |
| | else: |
| | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
| |
|
| | if _position_ids is None: |
| | position_ids = None |
| | if getattr(self.config, "use_pos_skipping", False) and self.training: |
| | position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) |
| | split_position = random.randint(0, new_input_embeds.size(1)) |
| | left_add = random.randint(0, self.config.pos_skipping_range) |
| | right_add = random.randint(left_add, self.config.pos_skipping_range) |
| | position_ids[:, :split_position] += left_add |
| | position_ids[:, split_position:] += right_add |
| | |
| | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels |
| |
|
| | def initialize_vision_tokenizer(self, model_args, tokenizer): |
| | if model_args.mm_use_im_patch_token: |
| | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
| | self.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if model_args.mm_use_im_start_end: |
| | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
| | self.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if num_new_tokens > 0: |
| | input_embeddings = self.get_input_embeddings().weight.data |
| | output_embeddings = self.get_output_embeddings().weight.data |
| |
|
| | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| |
|
| | input_embeddings[-num_new_tokens:] = input_embeddings_avg |
| | output_embeddings[-num_new_tokens:] = output_embeddings_avg |
| |
|
| | if model_args.tune_mm_mlp_adapter: |
| | for p in self.get_input_embeddings().parameters(): |
| | p.requires_grad = True |
| | for p in self.get_output_embeddings().parameters(): |
| | p.requires_grad = False |
| |
|
| | if model_args.pretrain_mm_mlp_adapter: |
| | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") |
| | embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
| | assert num_new_tokens == 2 |
| | if input_embeddings.shape == embed_tokens_weight.shape: |
| | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] |
| | elif embed_tokens_weight.shape[0] == num_new_tokens: |
| | input_embeddings[-num_new_tokens:] = embed_tokens_weight |
| | else: |
| | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") |
| | elif model_args.mm_use_im_patch_token: |
| | if model_args.tune_mm_mlp_adapter: |
| | for p in self.get_input_embeddings().parameters(): |
| | p.requires_grad = False |
| | for p in self.get_output_embeddings().parameters(): |
| | p.requires_grad = False |
| |
|
| | class InstellaVLConfig(OlmoConfig): |
| | """ |
| | Configuration class for the InstellaVL model. |
| | Attributes: |
| | model_type (str): The type of the model, set to "instellavl". |
| | """ |
| |
|
| | model_type = "instellavl" |
| |
|
| |
|
| | def disable_torch_init(): |
| | r""" |
| | Disable the redundant torch default initialization to accelerate model creation. |
| | """ |
| | import torch |
| |
|
| | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| |
|
| |
|
| | class InstellaVLModel(InstellaVLMetaModel, OlmoModel): |
| | config_class = InstellaVLConfig |
| |
|
| | def __init__(self, config: OlmoConfig): |
| | super(InstellaVLModel, self).__init__(config) |
| |
|
| |
|
| | class InstellaVLForCausalLM(OlmoForCausalLM, InstellaVLMetaForCausalLM): |
| | r""" |
| | InstellaVLForCausalLM is a class that extends OlmoForCausalLM and InstellaVLMetaForCausalLM to provide |
| | a language model with multimodal capabilities, specifically for handling images along with text. |
| | |
| | 1. Attributes: |
| | - config_class (type): The configuration class to use for this model. |
| | - model (InstellaVLModel): The underlying model. |
| | - lm_head (nn.Linear): The linear layer for language modeling head. |
| | |
| | 2. Methods: |
| | |
| | 1. `__init__(config: InstellaVLConfig)`: |
| | Initializes the InstellaVLForCausalLM model with the given configuration. |
| | |
| | 2. `get_model() -> InstellaVLModel`: |
| | Returns the underlying model. |
| | |
| | 3. `forward() -> Union[Tuple, CausalLMOutputWithPast]`: |
| | Performs a forward pass through the model. |
| | |
| | 4. `generate() -> Union[GenerateOutput, torch.LongTensor]`: |
| | Generates text based on the input. |
| | |
| | 5. `prepare_inputs_for_generation(input_ids: torch.LongTensor,) -> dict`: |
| | Prepares inputs for text generation. |
| | |
| | """ |
| |
|
| | config_class = InstellaVLConfig |
| |
|
| | def __init__(self, config: OlmoConfig): |
| | r""" |
| | Initializes the InstellaVLForCausalLM model. |
| | |
| | Args: |
| | - config (OlmoConfig): Configuration object for the model. |
| | |
| | Attributes: |
| | - model (InstellaVLModel): The main model instance. |
| | - lm_head (torch.nn.Linear): Linear layer that maps hidden states to vocabulary size. |
| | """ |
| | super(OlmoForCausalLM, self).__init__(config) |
| | disable_torch_init() |
| | config.model_type = "instellavl" |
| | self.model = InstellaVLModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | def get_model(self): |
| | return self.model |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | image_sizes: Optional[List[List[int]]] = None, |
| | return_dict: Optional[bool] = None, |
| | modalities: Optional[List[str]] = ["image"], |
| | cache_position=None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | r""" |
| | Args: |
| | - input_ids (torch.LongTensor, optional): Input token IDs. |
| | - attention_mask (torch.Tensor, optional): Attention mask. |
| | - position_ids (torch.LongTensor, optional): Position IDs. |
| | - past_key_values (List[torch.FloatTensor], optional): Past key values for caching. |
| | - inputs_embeds (torch.FloatTensor, optional): Input embeddings. |
| | - labels (torch.LongTensor, optional): Labels for language modeling. |
| | - use_cache (bool, optional): Whether to use cache. |
| | - output_attentions (bool, optional): Whether to output attentions. |
| | - output_hidden_states (bool, optional): Whether to output hidden states. |
| | - images (torch.FloatTensor, optional): Input images. |
| | - image_sizes (List[List[int]], optional): Sizes of input images. |
| | - return_dict (bool, optional): Whether to return a dictionary. |
| | - modalities (List[str], optional): List of modalities. |
| | - cache_position (optional): Cache position. |
| | |
| | Returns: |
| | Union[Tuple, CausalLMOutputWithPast]: The output of the forward pass. |
| | """ |
| | if inputs_embeds is None: |
| | ( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | inputs_embeds, |
| | labels |
| | ) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, |
| | position_ids, |
| | attention_mask, |
| | past_key_values, |
| | labels, |
| | images, |
| | modalities, |
| | image_sizes |
| | ) |
| |
|
| | return super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | inputs: Optional[torch.Tensor] = None, |
| | images: Optional[torch.Tensor] = None, |
| | image_sizes: Optional[torch.Tensor] = None, |
| | modalities: Optional[List[str]] = ["image"], |
| | **kwargs, |
| | ) -> Union[GenerateOutput, torch.LongTensor]: |
| | r""" |
| | Args: |
| | - inputs (torch.Tensor, optional): Input tensor. |
| | - images (torch.Tensor, optional): Input images. |
| | - image_sizes (torch.Tensor, optional): Sizes of input images. |
| | - modalities (List[str], optional): List of modalities. |
| | - **kwargs: Additional arguments. |
| | |
| | Returns: |
| | Union[GenerateOutput, torch.LongTensor]: The generated text. |
| | """ |
| | modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities |
| | position_ids = kwargs.pop("position_ids", None) |
| | attention_mask = kwargs.pop("attention_mask", None) |
| | if "inputs_embeds" in kwargs: |
| | raise NotImplementedError("`inputs_embeds` is not supported") |
| |
|
| | if images is not None: |
| | ( |
| | inputs, |
| | position_ids, |
| | attention_mask, |
| | _, |
| | inputs_embeds, |
| | _ |
| | ) = self.prepare_inputs_labels_for_multimodal( |
| | inputs, |
| | position_ids, |
| | attention_mask, |
| | None, |
| | None, |
| | images, |
| | image_sizes=image_sizes |
| | ) |
| | else: |
| | inputs_embeds = self.get_model().embed_tokens(inputs) |
| | return super().generate( |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, |
| | **kwargs |
| | ) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
| | inputs_embeds=None, **kwargs): |
| | r""" |
| | Args: |
| | - input_ids (torch.LongTensor): Input token IDs. |
| | - past_key_values (List[torch.FloatTensor], optional): Past key values for caching. |
| | - inputs_embeds (torch.FloatTensor, optional): Input embeddings. |
| | - **kwargs: Additional arguments. |
| | |
| | Returns: |
| | dict: Prepared inputs for generation. |
| | """ |
| | images = kwargs.pop("images", None) |
| | image_sizes = kwargs.pop("image_sizes", None) |
| | inputs = super().prepare_inputs_for_generation( |
| | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| | ) |
| | if images is not None: |
| | inputs['images'] = images |
| | if image_sizes is not None: |
| | inputs['image_sizes'] = image_sizes |
| | return inputs |
| |
|
| | AutoConfig.register("instellavl", InstellaVLConfig) |
| | AutoModelForCausalLM.register(InstellaVLConfig, InstellaVLForCausalLM) |
| |
|