mistral-text-encoder / mistral_text_encoding_core.py
multimodalart's picture
Upload mistral_text_encoding_core.py
2e043a8 verified
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
from typing import Union, List, Optional
import torch
def format_text_input(prompts: List[str], system_message: str = None):
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
def get_mistral_3_small_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
max_sequence_length: int = 512,
system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation.""",
hidden_states_layers: List[int] = (10, 20, 30),
):
prompt = [prompt] if isinstance(prompt, str) else prompt
# Format input messages
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
# Process all messages at once
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
# Move to device
input_ids = inputs["input_ids"].to(text_encoder.device)
attention_mask = inputs["attention_mask"].to(text_encoder.device)
# Forward pass through the model
with torch.inference_mode():
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=text_encoder.dtype, device=text_encoder.device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(
batch_size, seq_len, num_channels * hidden_dim
)
return prompt_embeds
def prepare_text_ids(
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = get_mistral_3_small_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
max_sequence_length=max_sequence_length,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(text_encoder.device)
return prompt_embeds, text_ids