Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
Helium1-VL-2B / modeling_helium1_casa.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
1126ea7 verified
from typing import Any, Callable
from typing import cast as type_cast
import torch
from transformers.cache_utils import DynamicCache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation.utils import GenerateOutput
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)
from .image_encoder import Qwen25VLEncoder
from .configuration_helium1_casa import Helium1CASAConfig
from .language_helium1_casa import (
CausalHeliumOutput,
Helium1CASAAttention,
Helium1ForCausalLM,
Helium1RMSNorm,
)
def meta_project(
logits: torch.Tensor | list[torch.Tensor],
projector: torch.nn.Module,
norm: torch.nn.Module | None = None,
) -> torch.Tensor | list[torch.Tensor]:
"""Projection operation that handles both tensors and list of tensors
Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where
S can be a different sequence length per image)
"""
split_sizes: list[int] | None = None
if not isinstance(logits, torch.Tensor):
split_sizes = [_x.shape[0] for _x in logits]
logits = torch.cat(logits, dim=0)[None, :, :]
logits = type_cast(torch.Tensor, logits)
logits = projector(logits)
assert isinstance(logits, torch.Tensor)
if norm is not None:
logits = norm(logits)
if split_sizes is not None:
return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0))
return logits
class ImageProjection(torch.nn.Module):
"""Takes in a batch or sequence of images and returns embeddings
which are then fed to the LM.
:param config: KyuteyeConfig object
:param lm_model_dim: Output dimension (number of channels) for this module
"""
def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None:
super().__init__()
self.config = config
self.out_dim = lm_model_dim
visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.enc = Qwen25VLEncoder(visual=visual)
# Projection layer
self.proj_extra = self.init_proj_module()
# Output normalizations
self.norm_extra = Helium1RMSNorm(self.out_dim)
def init_proj_module(self) -> torch.nn.Module:
"""Init the project module for the inserted and/or cross-attended image tokens"""
if self.config.vision_config.out_dim == self.out_dim:
return torch.nn.Identity()
return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim)
def forward(
self, x: torch.Tensor | list[torch.Tensor]
) -> dict[
str,
torch.Tensor | list[torch.Tensor],
]:
"""Image embedding mapping
:param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors
with shape (C, H, W) (or (H, W, C) in the case of Qwen)
:return: Either a tensor with shape (num_total_image, S, D) or, if images
can have different seq length, a list of `num_total_images` Tensors with shape
(S, D)
"""
# Apply image encoder
og_dtype = x[0].dtype
encoded = self.enc(x)["image_embeds"]
encoded = [_x.to(og_dtype) for _x in encoded]
if all(x.shape[0] == encoded[0].shape[0] for x in encoded):
encoded = torch.stack(encoded, dim=0)
# Extra projection
image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra)
# Apply different projection for extra vs cross attended tokens
return {"image_embeds": image_embeds}
class V2Helium1(Helium1ForCausalLM): # pyright: ignore[reportIncompatibleMethodOverride]
config_class = Helium1CASAConfig
def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None:
del kwargs
super().__init__(config)
self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim)
def get_device(self) -> str:
"""Return the device type of the model"""
return next(self.parameters()).device.type
@property
def token_dim(self) -> int:
"""Returns the number of dimensions for the token representation"""
return self.config.hidden_size
@property
def rotary_embed(self) -> Callable:
"""Returns the rotary embedding function of the underlying model"""
return self.model.rotary_emb
def _update_model_kwargs_for_generation(
self,
outputs: Any,
model_kwargs: dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
):
"""This is required to handle multiple gen calls for subtitles"""
# Call parent to get default updates
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
)
# Used by prepare_inputs_for_generation
model_kwargs["__is_first_gen_call__"] = False
return model_kwargs
def prepare_inputs_for_generation( # pyright: ignore[reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor,
past_key_values: DynamicCache | None = None,
**kwargs: Any,
):
__is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True)
if past_key_values is not None and (
kwargs.get("cache_position") is None
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0
):
# We're continuing from a cached state
past_length = past_key_values._seen_tokens
kwargs["cache_position"] = torch.arange(
past_length,
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1),
dtype=torch.long,
device=input_ids.device,
)
return super().prepare_inputs_for_generation(
type_cast(torch.LongTensor, input_ids),
past_key_values=past_key_values,
**kwargs,
)
def prepare_multimodal_inputs(
self,
# text only training
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
labels: torch.Tensor | None = None,
# image values
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**_kwargs: Any,
) -> dict:
"""Get a batch data mixing text and image data"""
del _kwargs
processed_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"labels": labels,
"attention_mask": attention_mask,
"image_embeds_insertion_points": image_embeds_insertion_points,
}
if pixel_values is not None:
processed_inputs.update(self.image_prefix(pixel_values))
assert "image_embeds" in processed_inputs
assert (
isinstance(processed_inputs["image_embeds"], torch.Tensor)
and processed_inputs["image_embeds"].ndim == 3
) or (
isinstance(processed_inputs["image_embeds"], list)
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"])
)
# Add kwargs necessary to compute cu_seqlens windows for CASA
processed_inputs["casa_windows_info"] = {
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens),
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens),
}
return processed_inputs
def forward( # pyright: ignore[reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
return_loss: bool = True,
labels: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
pre_image_tokens: list[int] | None = None,
post_image_tokens: list[int] | None = None,
**kwargs: Any,
) -> CausalHeliumOutput:
"""Multi modal forward pass"""
assert input_ids is not None or inputs_embeds is not None
if self.training:
assert return_loss is True, (
"Helium models always compute its own labels/losses in train mode"
)
# Case 1: For first generation call we need to compute pixel values and CASA states
if kwargs.get("__is_first_gen_call__", True):
processed_inputs = self.prepare_multimodal_inputs(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
image_embeds_insertion_points=image_embeds_insertion_points,
pixel_values=pixel_values,
labels=labels,
pre_image_tokens=pre_image_tokens,
post_image_tokens=post_image_tokens,
)
processed_inputs.pop("inputs_embeds", None)
else:
processed_inputs = {
"inputs_embeds": self.model.embed_tokens(input_ids),
"attention_mask": attention_mask,
}
# For Helium prefix, we need to update the positions by the number
# of image tokens inserted in the first call
if (
not self.config.casa_attention
and (cp := kwargs.get("cache_position", None)) is not None
and pixel_values is not None
):
start = kwargs["cache_position"][0].item()
num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4
num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # type: ignore
kwargs["cache_position"] = torch.arange(
start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens),
start + num_tokens + num_image_tokens,
dtype=cp.dtype,
device=cp.device,
)
kwargs.pop("__is_first_gen_call__", True)
out = super().forward(
**processed_inputs, # type: ignore
**kwargs,
)
return out
@torch.no_grad()
def generate_from_image( # pyright: ignore[reportInconsistentOverload,reportIncompatibleMethodOverride]
self,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_embeds_insertion_points: list[torch.Tensor] | None = None,
pixel_values: torch.Tensor | list[torch.Tensor] | None = None,
reset_streaming: bool = True,
**kwargs: Any,
) -> "GenerateOutput | torch.LongTensor":
assert input_ids is not None and inputs_embeds is None, (
"Input IDs must be provided for generation"
)
# init self-attention KVCache
if kwargs.get("past_key_values", None) is None:
kwargs["past_key_values"] = DynamicCache()
# To avoid generate warning
if kwargs.get("pad_token_id", None) is None:
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None)
if isinstance(kwargs["pad_token_id"], (list, tuple)):
kwargs["pad_token_id"] = kwargs["pad_token_id"][0]
self.start_casa_streaming_states()
outputs = self.generate(
input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_embeds_insertion_points=image_embeds_insertion_points,
use_cache=True,
**kwargs,
)
if reset_streaming:
self.reset_casa_streaming_states()
return outputs
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None:
def __reset__(m: torch.nn.Module):
if isinstance(m, Helium1CASAAttention):
m._set_streaming(False, ())
m.reset_streaming()
if clean_cache:
del m.streaming_state.k
del m.streaming_state.v
del m.streaming_state.casa_handler
self.apply(__reset__)
def start_casa_streaming_states(self) -> None:
def __start__(m: torch.nn.Module):
if isinstance(m, Helium1CASAAttention):
m._set_streaming(True, ())
self.apply(__start__)