Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import Mistral3ForConditionalGeneration, AutoProcessor | |
| from typing import Union, List, Optional | |
| import torch | |
| def format_text_input(prompts: List[str], system_message: str = None): | |
| # Remove [IMG] tokens from prompts to avoid Pixtral validation issues | |
| # when truncation is enabled. The processor counts [IMG] tokens and fails | |
| # if the count changes after truncation. | |
| cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] | |
| return [ | |
| [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_message}], | |
| }, | |
| {"role": "user", "content": [{"type": "text", "text": prompt}]}, | |
| ] | |
| for prompt in cleaned_txt | |
| ] | |
| def get_mistral_3_small_prompt_embeds( | |
| text_encoder: Mistral3ForConditionalGeneration, | |
| tokenizer: AutoProcessor, | |
| prompt: Union[str, List[str]], | |
| max_sequence_length: int = 512, | |
| system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object | |
| attribution and actions without speculation.""", | |
| hidden_states_layers: List[int] = (10, 20, 30), | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| # Format input messages | |
| messages_batch = format_text_input(prompts=prompt, system_message=system_message) | |
| # Process all messages at once | |
| inputs = tokenizer.apply_chat_template( | |
| messages_batch, | |
| add_generation_prompt=False, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_sequence_length, | |
| ) | |
| # Move to device | |
| input_ids = inputs["input_ids"].to(text_encoder.device) | |
| attention_mask = inputs["attention_mask"].to(text_encoder.device) | |
| # Forward pass through the model | |
| with torch.inference_mode(): | |
| output = text_encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| use_cache=False, | |
| ) | |
| # Only use outputs from intermediate layers and stack them | |
| out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) | |
| out = out.to(dtype=text_encoder.dtype, device=text_encoder.device) | |
| batch_size, num_channels, seq_len, hidden_dim = out.shape | |
| prompt_embeds = out.permute(0, 2, 1, 3).reshape( | |
| batch_size, seq_len, num_channels * hidden_dim | |
| ) | |
| return prompt_embeds | |
| def prepare_text_ids( | |
| x: torch.Tensor, # (B, L, D) or (L, D) | |
| t_coord: Optional[torch.Tensor] = None, | |
| ): | |
| B, L, _ = x.shape | |
| out_ids = [] | |
| for i in range(B): | |
| t = torch.arange(1) if t_coord is None else t_coord[i] | |
| h = torch.arange(1) | |
| w = torch.arange(1) | |
| l = torch.arange(L) | |
| coords = torch.cartesian_prod(t, h, w, l) | |
| out_ids.append(coords) | |
| return torch.stack(out_ids) | |
| def encode_prompt( | |
| text_encoder: Mistral3ForConditionalGeneration, | |
| tokenizer: AutoProcessor, | |
| prompt: Union[str, List[str]], | |
| num_images_per_prompt: int = 1, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| max_sequence_length: int = 512, | |
| ): | |
| if prompt is None: | |
| prompt = "" | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if prompt_embeds is None: | |
| prompt_embeds = get_mistral_3_small_prompt_embeds( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| batch_size, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| text_ids = prepare_text_ids(prompt_embeds) | |
| text_ids = text_ids.to(text_encoder.device) | |
| return prompt_embeds, text_ids | |