from torch import nn from transformers import AutoProcessor, CLIPVisionModel, AutoModelForCausalLM, AutoTokenizer from PIL import Image import torch import numpy as np import torch.nn.functional as F class ToyModel(): """ CLIP + GPT2 """ def __init__(self, vision_model_path, language_model_path): # load vision encoder self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_path) self.processor = AutoProcessor.from_pretrained(vision_model_path) # load language encoder self.language_model = AutoModelForCausalLM.from_pretrained(language_model_path) self.tokenizer = AutoTokenizer.from_pretrained(language_model_path) # MLP connector self.mlp = nn.Sequential( nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768), nn.ReLU() ) def encode_image(self, image): image = self.processor(images=image, return_tensors="pt") return self.vision_encoder(**image) def encode_text(self, prompt): input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids prompt_embeddings = self.language_model.get_input_embeddings()(input_ids) return prompt_embeddings def chat(self, image, text): # encode image outputs = self.encode_image(image) image_embeddings = outputs.last_hidden_state # encode text text_embeddings = self.encode_text(text) # chat with image and text # embedding fusion image_embeddings = self.mlp(image_embeddings) embedding = torch.cat((image_embeddings, text_embeddings), dim=1) outputs = self.language_model(inputs_embeds=embedding) # decode logits to text logits = outputs.logits preds = F.softmax(logits, dim=-1).argmax(dim=-1) text_output = self.tokenizer.batch_decode(sequences=preds, skip_special_tokens=True) return text_output if __name__ == '__main__': model = ToyModel('/home/yuan/huggingface/model/clip-vit-base-patch32', '/home/yuan/huggingface/model/gpt2') image = Image.open('/home/yuan/RS-VL-Perception/examples_v2/thief.png') text = 'I am Iron Man' print(model.chat(image, text)) # [",....\n.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n 37\n\n\n 40 40 40 40\n'm a Man,"]