|
|
from typing import Any, Callable |
|
|
from typing import cast as type_cast |
|
|
|
|
|
import torch |
|
|
from transformers.cache_utils import DynamicCache |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VisionTransformerPretrainedModel, |
|
|
) |
|
|
|
|
|
from .image_encoder import Qwen25VLEncoder |
|
|
from .configuration_helium1_casa import Helium1CASAConfig |
|
|
from .language_helium1_casa import ( |
|
|
CausalHeliumOutput, |
|
|
Helium1CASAAttention, |
|
|
Helium1ForCausalLM, |
|
|
Helium1RMSNorm, |
|
|
) |
|
|
|
|
|
|
|
|
def meta_project( |
|
|
logits: torch.Tensor | list[torch.Tensor], |
|
|
projector: torch.nn.Module, |
|
|
norm: torch.nn.Module | None = None, |
|
|
) -> torch.Tensor | list[torch.Tensor]: |
|
|
"""Projection operation that handles both tensors and list of tensors |
|
|
|
|
|
Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where |
|
|
S can be a different sequence length per image) |
|
|
""" |
|
|
split_sizes: list[int] | None = None |
|
|
if not isinstance(logits, torch.Tensor): |
|
|
split_sizes = [_x.shape[0] for _x in logits] |
|
|
logits = torch.cat(logits, dim=0)[None, :, :] |
|
|
logits = type_cast(torch.Tensor, logits) |
|
|
logits = projector(logits) |
|
|
|
|
|
assert isinstance(logits, torch.Tensor) |
|
|
if norm is not None: |
|
|
logits = norm(logits) |
|
|
if split_sizes is not None: |
|
|
return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0)) |
|
|
return logits |
|
|
|
|
|
|
|
|
class ImageProjection(torch.nn.Module): |
|
|
"""Takes in a batch or sequence of images and returns embeddings |
|
|
which are then fed to the LM. |
|
|
|
|
|
:param config: KyuteyeConfig object |
|
|
:param lm_model_dim: Output dimension (number of channels) for this module |
|
|
""" |
|
|
|
|
|
def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.out_dim = lm_model_dim |
|
|
visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) |
|
|
|
|
|
self.enc = Qwen25VLEncoder(visual=visual) |
|
|
|
|
|
self.proj_extra = self.init_proj_module() |
|
|
|
|
|
self.norm_extra = Helium1RMSNorm(self.out_dim) |
|
|
|
|
|
def init_proj_module(self) -> torch.nn.Module: |
|
|
"""Init the project module for the inserted and/or cross-attended image tokens""" |
|
|
if self.config.vision_config.out_dim == self.out_dim: |
|
|
return torch.nn.Identity() |
|
|
return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim) |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor | list[torch.Tensor] |
|
|
) -> dict[ |
|
|
str, |
|
|
torch.Tensor | list[torch.Tensor], |
|
|
]: |
|
|
"""Image embedding mapping |
|
|
|
|
|
:param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors |
|
|
with shape (C, H, W) (or (H, W, C) in the case of Qwen) |
|
|
|
|
|
:return: Either a tensor with shape (num_total_image, S, D) or, if images |
|
|
can have different seq length, a list of `num_total_images` Tensors with shape |
|
|
(S, D) |
|
|
""" |
|
|
|
|
|
|
|
|
og_dtype = x[0].dtype |
|
|
encoded = self.enc(x)["image_embeds"] |
|
|
encoded = [_x.to(og_dtype) for _x in encoded] |
|
|
if all(x.shape[0] == encoded[0].shape[0] for x in encoded): |
|
|
encoded = torch.stack(encoded, dim=0) |
|
|
|
|
|
|
|
|
image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra) |
|
|
|
|
|
|
|
|
return {"image_embeds": image_embeds} |
|
|
|
|
|
|
|
|
class V2Helium1(Helium1ForCausalLM): |
|
|
config_class = Helium1CASAConfig |
|
|
|
|
|
def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None: |
|
|
del kwargs |
|
|
super().__init__(config) |
|
|
self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim) |
|
|
|
|
|
def get_device(self) -> str: |
|
|
"""Return the device type of the model""" |
|
|
return next(self.parameters()).device.type |
|
|
|
|
|
@property |
|
|
def token_dim(self) -> int: |
|
|
"""Returns the number of dimensions for the token representation""" |
|
|
return self.config.hidden_size |
|
|
|
|
|
@property |
|
|
def rotary_embed(self) -> Callable: |
|
|
"""Returns the rotary embedding function of the underlying model""" |
|
|
return self.model.rotary_emb |
|
|
|
|
|
def _update_model_kwargs_for_generation( |
|
|
self, |
|
|
outputs: Any, |
|
|
model_kwargs: dict[str, Any], |
|
|
is_encoder_decoder: bool = False, |
|
|
num_new_tokens: int = 1, |
|
|
): |
|
|
"""This is required to handle multiple gen calls for subtitles""" |
|
|
|
|
|
model_kwargs = super()._update_model_kwargs_for_generation( |
|
|
outputs, model_kwargs, is_encoder_decoder, num_new_tokens |
|
|
) |
|
|
|
|
|
model_kwargs["__is_first_gen_call__"] = False |
|
|
return model_kwargs |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
past_key_values: DynamicCache | None = None, |
|
|
**kwargs: Any, |
|
|
): |
|
|
__is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True) |
|
|
if past_key_values is not None and ( |
|
|
kwargs.get("cache_position") is None |
|
|
or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0 |
|
|
): |
|
|
|
|
|
past_length = past_key_values._seen_tokens |
|
|
kwargs["cache_position"] = torch.arange( |
|
|
past_length, |
|
|
past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1), |
|
|
dtype=torch.long, |
|
|
device=input_ids.device, |
|
|
) |
|
|
|
|
|
return super().prepare_inputs_for_generation( |
|
|
type_cast(torch.LongTensor, input_ids), |
|
|
past_key_values=past_key_values, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def prepare_multimodal_inputs( |
|
|
self, |
|
|
|
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
|
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
pre_image_tokens: list[int] | None = None, |
|
|
post_image_tokens: list[int] | None = None, |
|
|
**_kwargs: Any, |
|
|
) -> dict: |
|
|
"""Get a batch data mixing text and image data""" |
|
|
del _kwargs |
|
|
|
|
|
processed_inputs = { |
|
|
"input_ids": input_ids, |
|
|
"inputs_embeds": inputs_embeds, |
|
|
"labels": labels, |
|
|
"attention_mask": attention_mask, |
|
|
"image_embeds_insertion_points": image_embeds_insertion_points, |
|
|
} |
|
|
if pixel_values is not None: |
|
|
processed_inputs.update(self.image_prefix(pixel_values)) |
|
|
assert "image_embeds" in processed_inputs |
|
|
assert ( |
|
|
isinstance(processed_inputs["image_embeds"], torch.Tensor) |
|
|
and processed_inputs["image_embeds"].ndim == 3 |
|
|
) or ( |
|
|
isinstance(processed_inputs["image_embeds"], list) |
|
|
and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"]) |
|
|
) |
|
|
|
|
|
|
|
|
processed_inputs["casa_windows_info"] = { |
|
|
"num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens), |
|
|
"num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens), |
|
|
} |
|
|
|
|
|
return processed_inputs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
return_loss: bool = True, |
|
|
labels: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
pre_image_tokens: list[int] | None = None, |
|
|
post_image_tokens: list[int] | None = None, |
|
|
**kwargs: Any, |
|
|
) -> CausalHeliumOutput: |
|
|
"""Multi modal forward pass""" |
|
|
assert input_ids is not None or inputs_embeds is not None |
|
|
|
|
|
if self.training: |
|
|
assert return_loss is True, ( |
|
|
"Helium models always compute its own labels/losses in train mode" |
|
|
) |
|
|
|
|
|
|
|
|
if kwargs.get("__is_first_gen_call__", True): |
|
|
processed_inputs = self.prepare_multimodal_inputs( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
pixel_values=pixel_values, |
|
|
labels=labels, |
|
|
pre_image_tokens=pre_image_tokens, |
|
|
post_image_tokens=post_image_tokens, |
|
|
) |
|
|
processed_inputs.pop("inputs_embeds", None) |
|
|
else: |
|
|
processed_inputs = { |
|
|
"inputs_embeds": self.model.embed_tokens(input_ids), |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
not self.config.casa_attention |
|
|
and (cp := kwargs.get("cache_position", None)) is not None |
|
|
and pixel_values is not None |
|
|
): |
|
|
start = kwargs["cache_position"][0].item() |
|
|
num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4 |
|
|
num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
kwargs["cache_position"] = torch.arange( |
|
|
start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens), |
|
|
start + num_tokens + num_image_tokens, |
|
|
dtype=cp.dtype, |
|
|
device=cp.device, |
|
|
) |
|
|
|
|
|
kwargs.pop("__is_first_gen_call__", True) |
|
|
out = super().forward( |
|
|
**processed_inputs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_from_image( |
|
|
self, |
|
|
input_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None = None, |
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
|
|
reset_streaming: bool = True, |
|
|
**kwargs: Any, |
|
|
) -> "GenerateOutput | torch.LongTensor": |
|
|
assert input_ids is not None and inputs_embeds is None, ( |
|
|
"Input IDs must be provided for generation" |
|
|
) |
|
|
|
|
|
|
|
|
if kwargs.get("past_key_values", None) is None: |
|
|
kwargs["past_key_values"] = DynamicCache() |
|
|
|
|
|
|
|
|
if kwargs.get("pad_token_id", None) is None: |
|
|
kwargs["pad_token_id"] = kwargs.get("eos_token_id", None) |
|
|
if isinstance(kwargs["pad_token_id"], (list, tuple)): |
|
|
kwargs["pad_token_id"] = kwargs["pad_token_id"][0] |
|
|
|
|
|
self.start_casa_streaming_states() |
|
|
outputs = self.generate( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=pixel_values, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
use_cache=True, |
|
|
**kwargs, |
|
|
) |
|
|
if reset_streaming: |
|
|
self.reset_casa_streaming_states() |
|
|
return outputs |
|
|
|
|
|
def reset_casa_streaming_states(self, clean_cache: bool = True) -> None: |
|
|
def __reset__(m: torch.nn.Module): |
|
|
if isinstance(m, Helium1CASAAttention): |
|
|
m._set_streaming(False, ()) |
|
|
m.reset_streaming() |
|
|
if clean_cache: |
|
|
del m.streaming_state.k |
|
|
del m.streaming_state.v |
|
|
del m.streaming_state.casa_handler |
|
|
|
|
|
self.apply(__reset__) |
|
|
|
|
|
def start_casa_streaming_states(self) -> None: |
|
|
def __start__(m: torch.nn.Module): |
|
|
if isinstance(m, Helium1CASAAttention): |
|
|
m._set_streaming(True, ()) |
|
|
|
|
|
self.apply(__start__) |
|
|
|