| | from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM |
| | from .configuration_deepseek_v2 import DeepseekV2Config |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| | from typing import List, Optional, Tuple, Union |
| | from transformers.cache_utils import Cache |
| | import requests |
| | from PIL import Image, ImageOps, ImageDraw, ImageFont |
| | from io import BytesIO |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import CrossEntropyLoss |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import InterpolationMode |
| | import os |
| | from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector |
| | from addict import Dict |
| | from transformers import TextStreamer |
| | from .conversation import get_conv_template |
| | from abc import ABC |
| | import math |
| | import re |
| | from tqdm import tqdm |
| | import numpy as np |
| | import time |
| |
|
| |
|
| | def load_image(image_path): |
| |
|
| | try: |
| | image = Image.open(image_path) |
| | |
| | corrected_image = ImageOps.exif_transpose(image) |
| | |
| | return corrected_image |
| | |
| | except Exception as e: |
| | print(f"error: {e}") |
| | try: |
| | return Image.open(image_path) |
| | except: |
| | return None |
| |
|
| |
|
| | def re_match(text): |
| | pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' |
| | matches = re.findall(pattern, text, re.DOTALL) |
| |
|
| | |
| | |
| |
|
| | mathes_image = [] |
| | mathes_other = [] |
| | for a_match in matches: |
| | if '<|ref|>image<|/ref|>' in a_match[0]: |
| | mathes_image.append(a_match[0]) |
| | else: |
| | mathes_other.append(a_match[0]) |
| | return matches, mathes_image, mathes_other |
| |
|
| |
|
| | def extract_coordinates_and_label(ref_text, image_width, image_height): |
| |
|
| | try: |
| | label_type = ref_text[1] |
| | cor_list = eval(ref_text[2]) |
| | except Exception as e: |
| | print(e) |
| | return None |
| |
|
| | return (label_type, cor_list) |
| |
|
| |
|
| | def draw_bounding_boxes(image, refs, ouput_path): |
| |
|
| | image_width, image_height = image.size |
| | |
| | img_draw = image.copy() |
| | draw = ImageDraw.Draw(img_draw) |
| |
|
| | overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) |
| | draw2 = ImageDraw.Draw(overlay) |
| | |
| | |
| | |
| | |
| | |
| | |
| | font = ImageFont.load_default() |
| |
|
| | img_idx = 0 |
| | |
| | for i, ref in enumerate(refs): |
| | try: |
| | result = extract_coordinates_and_label(ref, image_width, image_height) |
| | if result: |
| | label_type, points_list = result |
| | |
| | color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) |
| |
|
| | color_a = color + (20, ) |
| | for points in points_list: |
| | x1, y1, x2, y2 = points |
| |
|
| | x1 = int(x1 / 999 * image_width) |
| | y1 = int(y1 / 999 * image_height) |
| |
|
| | x2 = int(x2 / 999 * image_width) |
| | y2 = int(y2 / 999 * image_height) |
| |
|
| | if label_type == 'image': |
| | try: |
| | cropped = image.crop((x1, y1, x2, y2)) |
| | cropped.save(f"{ouput_path}/images/{img_idx}.jpg") |
| | except Exception as e: |
| | print(e) |
| | pass |
| | img_idx += 1 |
| | |
| | try: |
| | if label_type == 'title': |
| | draw.rectangle([x1, y1, x2, y2], outline=color, width=4) |
| | draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| | else: |
| | draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
| | draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| | text_x = x1 |
| | text_y = max(0, y1 - 15) |
| | |
| | |
| | text_bbox = draw.textbbox((0, 0), label_type, font=font) |
| | text_width = text_bbox[2] - text_bbox[0] |
| | text_height = text_bbox[3] - text_bbox[1] |
| | draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], |
| | fill=(255, 255, 255, 30)) |
| | |
| | draw.text((text_x, text_y), label_type, font=font, fill=color) |
| | except: |
| | pass |
| | except: |
| | continue |
| | img_draw.paste(overlay, (0, 0), overlay) |
| | return img_draw |
| |
|
| |
|
| | def process_image_with_refs(image, ref_texts, output_path): |
| |
|
| | result_image = draw_bounding_boxes(image, ref_texts, output_path) |
| | |
| | return result_image |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| | best_ratio_diff = float('inf') |
| | best_ratio = (1, 1) |
| | area = width * height |
| | for ratio in target_ratios: |
| | target_aspect_ratio = ratio[0] / ratio[1] |
| | ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| | if ratio_diff < best_ratio_diff: |
| | best_ratio_diff = ratio_diff |
| | best_ratio = ratio |
| | elif ratio_diff == best_ratio_diff: |
| | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| | best_ratio = ratio |
| | |
| | return best_ratio |
| |
|
| |
|
| | def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): |
| | orig_width, orig_height = image.size |
| | aspect_ratio = orig_width / orig_height |
| |
|
| | |
| | target_ratios = set( |
| | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
| | i * j <= max_num and i * j >= min_num) |
| | |
| | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
| | |
| | target_aspect_ratio = find_closest_aspect_ratio( |
| | aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
| |
|
| | |
| | |
| | target_width = image_size * target_aspect_ratio[0] |
| | target_height = image_size * target_aspect_ratio[1] |
| | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| |
|
| | |
| | resized_img = image.resize((target_width, target_height)) |
| | processed_images = [] |
| | for i in range(blocks): |
| | box = ( |
| | (i % (target_width // image_size)) * image_size, |
| | (i // (target_width // image_size)) * image_size, |
| | ((i % (target_width // image_size)) + 1) * image_size, |
| | ((i // (target_width // image_size)) + 1) * image_size |
| | ) |
| | |
| | split_img = resized_img.crop(box) |
| | processed_images.append(split_img) |
| | assert len(processed_images) == blocks |
| | if use_thumbnail and len(processed_images) != 1: |
| | thumbnail_img = image.resize((image_size, image_size)) |
| | processed_images.append(thumbnail_img) |
| | return processed_images, target_aspect_ratio |
| |
|
| |
|
| |
|
| | def normalize_transform(mean, std): |
| | if mean is None and std is None: |
| | transform = None |
| | elif mean is None and std is not None: |
| | mean = [0.] * len(std) |
| | transform = transforms.Normalize(mean=mean, std=std) |
| | elif mean is not None and std is None: |
| | std = [1.] * len(mean) |
| | transform = transforms.Normalize(mean=mean, std=std) |
| | else: |
| | transform = transforms.Normalize(mean=mean, std=std) |
| |
|
| | return transform |
| |
|
| |
|
| |
|
| | def format_messages( |
| | conversations: List[Dict[str, str]], |
| | sft_format: str = "deepseek", |
| | system_prompt: str = "", |
| | ): |
| | """ |
| | Applies the SFT template to conversation. |
| | |
| | Args: |
| | conversations (List[Dict]): A List of messages. |
| | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". |
| | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". |
| | |
| | Returns: |
| | sft_prompt (str): The formatted text. |
| | """ |
| |
|
| | conv = get_conv_template(sft_format) |
| | conv.set_system_message(system_prompt) |
| | for message in conversations: |
| | conv.append_message(message["role"], message["content"].strip()) |
| | sft_prompt = conv.get_prompt().strip() |
| |
|
| | return sft_prompt |
| |
|
| |
|
| | def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): |
| | t = tokenizer.encode(text, add_special_tokens=False) |
| | bos_id = 0 |
| | eos_id = 1 |
| | if bos: |
| | t = [bos_id] + t |
| | if eos: |
| | t = t + [eos_id] |
| |
|
| | return t |
| |
|
| | def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: |
| | """ |
| | |
| | Args: |
| | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : |
| | [ |
| | { |
| | "role": "User", |
| | "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", |
| | "images": ["./examples/table_datasets.png"] |
| | }, |
| | {"role": "Assistant", "content": ""}, |
| | ] |
| | |
| | Returns: |
| | pil_images (List[PIL.Image.Image]): the list of PIL images. |
| | |
| | """ |
| |
|
| | pil_images = [] |
| |
|
| | for message in conversations: |
| | if "images" not in message: |
| | continue |
| |
|
| | for image_path in message["images"]: |
| | |
| | |
| | |
| | |
| | |
| | |
| | pil_img = load_image(image_path) |
| | pil_img = pil_img.convert("RGB") |
| | pil_images.append(pil_img) |
| |
|
| | return pil_images |
| |
|
| |
|
| | class BaseTransform(ABC): |
| |
|
| | def set_rng(self, *args, **kwargs): |
| | pass |
| |
|
| | def __call__(self, *args, **kwargs) -> torch.Tensor: |
| | pass |
| |
|
| | @property |
| | def default_shape(self): |
| | raise NotImplementedError |
| |
|
| |
|
| | class BasicImageTransform(BaseTransform): |
| | def __init__( |
| | self, |
| | mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| | std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| | normalize: bool = True |
| | ): |
| | self.mean = mean |
| | self.std = std |
| | |
| | transform_pipelines = [ |
| | transforms.ToTensor() |
| | ] |
| |
|
| | normalize = normalize_transform(mean, std) if normalize else nn.Identity() |
| | if normalize is not None: |
| | transform_pipelines.append(normalize) |
| |
|
| | self.transform = transforms.Compose(transform_pipelines) |
| | |
| | def __call__(self, x): |
| | x = self.transform(x) |
| | return x |
| |
|
| | class NoEOSTextStreamer(TextStreamer): |
| | def on_finalized_text(self, text: str, stream_end: bool = False): |
| |
|
| | eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) |
| | text = text.replace(eos_text, "\n") |
| | print(text, flush=True, end="") |
| |
|
| |
|
| | class DeepseekOCRConfig(DeepseekV2Config): |
| | model_type = "DeepseekOCR" |
| |
|
| | class DeepseekOCRModel(DeepseekV2Model): |
| | config_class = DeepseekOCRConfig |
| |
|
| | def __init__(self, config: DeepseekV2Config): |
| | super(DeepseekOCRModel, self).__init__(config) |
| |
|
| | self.sam_model = build_sam_vit_b() |
| | self.vision_model = build_clip_l() |
| | |
| | n_embed = 1280 |
| | self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) |
| | embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) |
| | self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) |
| | self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) |
| |
|
| |
|
| |
|
| | |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | images_seq_mask: Optional[torch.FloatTensor] = None, |
| | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| |
|
| |
|
| |
|
| |
|
| | if inputs_embeds is None: |
| | |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| |
|
| |
|
| | sam_model = getattr(self, 'sam_model', None) |
| | |
| | vision_model = getattr(self, 'vision_model', None) |
| |
|
| |
|
| |
|
| | if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: |
| |
|
| | idx = 0 |
| | |
| | |
| | |
| | |
| | for image, crop_shape in zip(images, images_spatial_crop): |
| | images_in_this_batch = [] |
| |
|
| | patches = image[0] |
| | image_ori = image[1] |
| |
|
| | with torch.no_grad(): |
| | |
| | |
| | if torch.sum(patches).item() != 0: |
| | |
| | crop_flag = 1 |
| | local_features_1 = sam_model(patches) |
| |
|
| | local_features_2 = vision_model(patches, local_features_1) |
| | |
| | local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | local_features = self.projector(local_features) |
| |
|
| |
|
| | global_features_1 = sam_model(image_ori) |
| | global_features_2 = vision_model(image_ori, global_features_1) |
| | global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | global_features = self.projector(global_features) |
| |
|
| | print('=====================') |
| | print('BASE: ', global_features.shape) |
| | print('PATCHES: ', local_features.shape) |
| | print('=====================') |
| |
|
| | _, hw, n_dim = global_features.shape |
| | h = w = int(hw ** 0.5) |
| |
|
| | _2, hw2, n_dim2 = local_features.shape |
| | h2 = w2 = int(hw2 ** 0.5) |
| |
|
| | width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] |
| |
|
| | global_features = global_features.view(h, w, n_dim) |
| |
|
| | global_features = torch.cat( |
| | [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
| | ) |
| |
|
| | global_features = global_features.view(-1, n_dim) |
| |
|
| |
|
| | local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) |
| | local_features = torch.cat( |
| | [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 |
| | ) |
| | local_features = local_features.view(-1, n_dim2) |
| |
|
| | global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | else: |
| | global_features_1 = sam_model(image_ori) |
| | global_features_2 = vision_model(image_ori, global_features_1) |
| | global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | global_features = self.projector(global_features) |
| | print('=====================') |
| | print('BASE: ', global_features.shape) |
| | print('NO PATCHES') |
| | print('=====================') |
| | _, hw, n_dim = global_features.shape |
| | h = w = int(hw ** 0.5) |
| |
|
| |
|
| | global_features = global_features.view(h, w, n_dim) |
| |
|
| | global_features = torch.cat( |
| | [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
| | ) |
| |
|
| | global_features = global_features.view(-1, n_dim) |
| |
|
| | global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) |
| |
|
| | images_in_this_batch.append(global_local_features) |
| | |
| |
|
| | |
| |
|
| | if images_in_this_batch: |
| | images_in_this_batch = torch.cat(images_in_this_batch, dim=0) |
| | |
| |
|
| | inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) |
| |
|
| | idx += 1 |
| | |
| |
|
| | return super(DeepseekOCRModel, self).forward( |
| | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, |
| | output_attentions=output_attentions, output_hidden_states=output_hidden_states, |
| | return_dict=return_dict |
| | ) |
| | |
| |
|
| | class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): |
| |
|
| | config_class = DeepseekOCRConfig |
| | |
| |
|
| | def __init__(self, config): |
| | super(DeepseekV2ForCausalLM, self).__init__(config) |
| | self.model = DeepseekOCRModel(config) |
| |
|
| | self.vocab_size = config.vocab_size |
| |
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_model(self): |
| | return self.model |
| |
|
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | images_seq_mask: Optional[torch.FloatTensor] = None, |
| | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| | return_dict: Optional[bool] = None, |
| | |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| |
|
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | past_key_values=past_key_values, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | images=images, |
| | images_seq_mask = images_seq_mask, |
| | images_spatial_crop = images_spatial_crop, |
| | return_dict=return_dict |
| | |
| | ) |
| |
|
| |
|
| | |
| | |
| |
|
| | hidden_states = outputs[0] |
| | logits = self.lm_head(hidden_states) |
| | logits = logits.float() |
| |
|
| | |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | def prepare_inputs_for_generation( |
| | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| | ): |
| | |
| | past_length = 0 |
| | if past_key_values is not None: |
| | if isinstance(past_key_values, Cache): |
| | cache_length = past_key_values.get_seq_length() |
| | past_length = past_key_values.seen_tokens |
| | max_cache_length = past_key_values.get_max_length() |
| | else: |
| | cache_length = past_length = past_key_values[0][0].shape[2] |
| | max_cache_length = None |
| |
|
| | |
| | |
| | |
| | |
| | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| | |
| | |
| | elif past_length < input_ids.shape[1]: |
| | input_ids = input_ids[:, past_length:] |
| | |
| |
|
| | |
| | if ( |
| | max_cache_length is not None |
| | and attention_mask is not None |
| | and cache_length + input_ids.shape[1] > max_cache_length |
| | ): |
| | attention_mask = attention_mask[:, -max_cache_length:] |
| |
|
| | position_ids = kwargs.get("position_ids", None) |
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | if past_key_values: |
| | position_ids = position_ids[:, -input_ids.shape[1] :] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) |
| |
|
| | |
| | if inputs_embeds is not None and past_key_values is None: |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | model_inputs.update( |
| | { |
| | "position_ids": position_ids, |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache"), |
| | "attention_mask": attention_mask, |
| | "images": kwargs.get("images", None), |
| | "images_seq_mask": kwargs.get("images_seq_mask", None), |
| | "images_spatial_crop": kwargs.get("images_spatial_crop", None), |
| | } |
| | ) |
| | return model_inputs |
| | |
| |
|
| | def disable_torch_init(self): |
| | """ |
| | Disable the redundant torch default initialization to accelerate model creation. |
| | """ |
| | import torch |
| | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| |
|
| |
|
| |
|
| | def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): |
| | self.disable_torch_init() |
| |
|
| | os.makedirs(output_path, exist_ok=True) |
| | os.makedirs(f'{output_path}/images', exist_ok=True) |
| |
|
| | if prompt and image_file: |
| | conversation = [ |
| | { |
| | "role": "<|User|>", |
| | |
| | "content": f'{prompt}', |
| | |
| | |
| | |
| | |
| | "images": [f'{image_file}'], |
| | }, |
| | {"role": "<|Assistant|>", "content": ""}, |
| | ] |
| | |
| | elif prompt: |
| | conversation = [ |
| | { |
| | "role": "<|User|>", |
| | |
| | "content": f'{prompt}', |
| | |
| | |
| | |
| | |
| | |
| | }, |
| | {"role": "<|Assistant|>", "content": ""}, |
| | ] |
| | else: |
| | assert False, f'prompt is none!' |
| | |
| | prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') |
| |
|
| | patch_size = 16 |
| | downsample_ratio = 4 |
| | images = load_pil_images(conversation) |
| |
|
| | valid_img_tokens = 0 |
| | ratio = 1 |
| |
|
| | image_draw = images[0].copy() |
| |
|
| | w,h = image_draw.size |
| | |
| | ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) |
| | |
| |
|
| | image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) |
| | images_seq_mask = [] |
| |
|
| | image_token = '<image>' |
| | image_token_id = 128815 |
| | text_splits = prompt.split(image_token) |
| |
|
| | images_list, images_crop_list, images_seq_mask = [], [], [] |
| | tokenized_str = [] |
| | images_spatial_crop = [] |
| | for text_sep, image in zip(text_splits, images): |
| |
|
| | tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) |
| | tokenized_str += tokenized_sep |
| | images_seq_mask += [False] * len(tokenized_sep) |
| |
|
| | if crop_mode: |
| |
|
| | if image.size[0] <= 640 and image.size[1] <= 640: |
| | crop_ratio = [1, 1] |
| |
|
| | else: |
| | if crop_mode: |
| | |
| | images_crop_raw, crop_ratio = dynamic_preprocess(image) |
| | else: |
| | |
| | crop_ratio = [1, 1] |
| | |
| | """process the global view""" |
| | |
| | global_view = ImageOps.pad(image, (base_size, base_size), |
| | color=tuple(int(x * 255) for x in image_transform.mean)) |
| | |
| | if base_size == 1024: |
| | valid_img_tokens += int(256 * ratio) |
| | elif base_size == 1280: |
| | valid_img_tokens += int(400 * ratio) |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | |
| | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| |
|
| | |
| |
|
| | width_crop_num, height_crop_num = crop_ratio |
| |
|
| | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| | |
| | |
| | if width_crop_num > 1 or height_crop_num > 1: |
| | """process the local views""" |
| | |
| | for i in range(len(images_crop_raw)): |
| | images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) |
| | |
| | if image_size == 640: |
| | valid_img_tokens += len(images_crop_list) * 100 |
| |
|
| | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| | num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) |
| |
|
| |
|
| |
|
| | """add image tokens""" |
| |
|
| | |
| |
|
| | tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base |
| | tokenized_image += [image_token_id] |
| | if width_crop_num > 1 or height_crop_num > 1: |
| | tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( |
| | num_queries * height_crop_num) |
| | tokenized_str += tokenized_image |
| | images_seq_mask += [True] * len(tokenized_image) |
| | |
| |
|
| | else: |
| | |
| | |
| |
|
| | """process the global view""" |
| | if image_size <= 640: |
| | print('directly resize') |
| | image = image.resize((image_size, image_size)) |
| | |
| | global_view = ImageOps.pad(image, (image_size, image_size), |
| | color=tuple(int(x * 255) for x in image_transform.mean)) |
| | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| |
|
| | if base_size == 1024: |
| | valid_img_tokens += int(256 * ratio) |
| | elif base_size == 1280: |
| | valid_img_tokens += int(400 * ratio) |
| | elif base_size == 640: |
| | valid_img_tokens += int(100 * 1) |
| | elif base_size == 512: |
| | valid_img_tokens += int(64 * 1) |
| |
|
| | width_crop_num, height_crop_num = 1, 1 |
| |
|
| | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| |
|
| |
|
| | """add image tokens""" |
| | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| |
|
| | tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries |
| | tokenized_image += [image_token_id] |
| | |
| | |
| | tokenized_str += tokenized_image |
| | images_seq_mask += [True] * len(tokenized_image) |
| | |
| | |
| |
|
| | """process the last text split""" |
| | tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) |
| | tokenized_str += tokenized_sep |
| | images_seq_mask += [False] * len(tokenized_sep) |
| |
|
| | """add the bos tokens""" |
| | bos_id = 0 |
| | tokenized_str = [bos_id] + tokenized_str |
| | images_seq_mask = [False] + images_seq_mask |
| |
|
| |
|
| |
|
| | input_ids = torch.LongTensor(tokenized_str) |
| |
|
| |
|
| | |
| |
|
| | images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) |
| |
|
| |
|
| | if len(images_list) == 0: |
| | images_ori = torch.zeros((1, 3, image_size, image_size)) |
| | images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) |
| | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| |
|
| | else: |
| | images_ori = torch.stack(images_list, dim=0) |
| | images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) |
| | if images_crop_list: |
| | images_crop = torch.stack(images_crop_list, dim=0) |
| | else: |
| | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| |
|
| |
|
| |
|
| | if not eval_mode: |
| | streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
| | with torch.autocast("cuda", dtype=torch.bfloat16): |
| | with torch.no_grad(): |
| | output_ids = self.generate( |
| | input_ids.unsqueeze(0).cuda(), |
| | images=[(images_crop.cuda(), images_ori.cuda())], |
| | images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
| | images_spatial_crop = images_spatial_crop, |
| | |
| | |
| | temperature=0.0, |
| | eos_token_id=tokenizer.eos_token_id, |
| | streamer=streamer, |
| | max_new_tokens=8192, |
| | no_repeat_ngram_size = 20, |
| | use_cache = True |
| | ) |
| |
|
| | else: |
| | with torch.autocast("cuda", dtype=torch.bfloat16): |
| | with torch.no_grad(): |
| | output_ids = self.generate( |
| | input_ids.unsqueeze(0).cuda(), |
| | images=[(images_crop.cuda(), images_ori.cuda())], |
| | images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
| | images_spatial_crop = images_spatial_crop, |
| | |
| | |
| | temperature=0.0, |
| | eos_token_id=tokenizer.eos_token_id, |
| | max_new_tokens=8192, |
| | no_repeat_ngram_size = 35, |
| | use_cache = True |
| | ) |
| | |
| |
|
| | if '<image>' in conversation[0]['content'] and eval_mode: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| | stop_str = '<|end▁of▁sentence|>' |
| | if outputs.endswith(stop_str): |
| | outputs = outputs[:-len(stop_str)] |
| | |
| | outputs = outputs.strip() |
| |
|
| | return outputs |
| | |
| | if '<image>' in conversation[0]['content'] and test_compress: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| | pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) |
| | print('='*50) |
| | print('image size: ', (w, h)) |
| | print('valid image tokens: ', int(valid_img_tokens)) |
| | print('output texts tokens (valid): ', pure_texts_outputs_token_length) |
| | print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) |
| | print('='*50) |
| |
|
| |
|
| | if '<image>' in conversation[0]['content'] and save_results: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| | stop_str = '<|end▁of▁sentence|>' |
| |
|
| | print('='*15 + 'save results:' + '='*15) |
| | |
| | |
| | if outputs.endswith(stop_str): |
| | outputs = outputs[:-len(stop_str)] |
| | outputs = outputs.strip() |
| |
|
| | matches_ref, matches_images, mathes_other = re_match(outputs) |
| | |
| | result = process_image_with_refs(image_draw, matches_ref, output_path) |
| |
|
| |
|
| | for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): |
| | outputs = outputs.replace(a_match_image, ' + '.jpg)\n') |
| | |
| | for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): |
| | outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') |
| |
|
| |
|
| | |
| | |
| | with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: |
| | afile.write(outputs) |
| |
|
| | if 'line_type' in outputs: |
| | import matplotlib.pyplot as plt |
| | lines = eval(outputs)['Line']['line'] |
| |
|
| | line_type = eval(outputs)['Line']['line_type'] |
| | |
| |
|
| | endpoints = eval(outputs)['Line']['line_endpoint'] |
| |
|
| | fig, ax = plt.subplots(figsize=(3,3), dpi=200) |
| | ax.set_xlim(-15, 15) |
| | ax.set_ylim(-15, 15) |
| |
|
| | for idx, line in enumerate(lines): |
| | try: |
| | p0 = eval(line.split(' -- ')[0]) |
| | p1 = eval(line.split(' -- ')[-1]) |
| |
|
| | if line_type[idx] == '--': |
| | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') |
| | else: |
| | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') |
| |
|
| | ax.scatter(p0[0], p0[1], s=5, color = 'k') |
| | ax.scatter(p1[0], p1[1], s=5, color = 'k') |
| | except: |
| | pass |
| |
|
| | for endpoint in endpoints: |
| |
|
| | label = endpoint.split(': ')[0] |
| | (x, y) = eval(endpoint.split(': ')[1]) |
| | ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', |
| | fontsize=5, fontweight='light') |
| | |
| |
|
| | plt.savefig(f'{output_path}/geo.jpg') |
| | plt.close() |
| |
|
| | result.save(f"{output_path}/result_with_boxes.jpg") |
| |
|