| import torch | |
| from open_clip import create_model | |
| from transformers import PretrainedConfig, PreTrainedModel, CLIPProcessor | |
| from transformers.models.clip.modeling_clip import CLIPOutput | |
| from typing import Optional, Tuple, Union | |
| class MarqoFashionCLIPConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| open_clip_model_name: str = "", | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.open_clip_model_name = open_clip_model_name | |
| class MarqoFashionCLIP(PreTrainedModel): | |
| config_class = MarqoFashionCLIPConfig | |
| def __init__(self, config: MarqoFashionCLIPConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = create_model(config.open_clip_model_name, output_dict=True) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| normalize: bool = False, | |
| **kwargs | |
| ) -> torch.FloatTensor: | |
| with torch.inference_mode(): | |
| image_features = self.model.encode_image(pixel_values, normalize=normalize) | |
| return image_features | |
| def get_text_features( | |
| self, | |
| input_ids: torch.Tensor, | |
| normalize: bool = False, | |
| **kwargs | |
| ) -> torch.FloatTensor: | |
| with torch.inference_mode(): | |
| text_features = self.model.encode_text(input_ids, normalize=normalize) | |
| return text_features | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CLIPOutput]: | |
| vision_outputs = self.get_image_features(pixel_values=pixel_values, normalize=True) | |
| text_outputs = self.get_text_features(input_ids=input_ids, normalize=True) | |
| logits_per_text = text_outputs @ vision_outputs.T | |
| logits_per_image = logits_per_text.T | |
| if not return_dict: | |
| return logits_per_image, logits_per_text, text_outputs, vision_outputs | |
| return CLIPOutput( | |
| logits_per_image=logits_per_image, | |
| logits_per_text=logits_per_text, | |
| text_embeds=text_outputs, | |
| image_embeds=vision_outputs | |
| ) | |