Spaces:
Sleeping
Sleeping
| from typing import List | |
| from PIL.Image import Image | |
| import torch | |
| from transformers import AutoModel, AutoProcessor | |
| from .utils import normalize_vectors | |
| MODEL_NAME = "Marqo/marqo-fashionCLIP" | |
| class FashionCLIPEncoder: | |
| def __init__(self, normalize: bool = False): | |
| self.normalize = normalize | |
| self.device = torch.device("cpu") | |
| self.processor = AutoProcessor.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def encode_text(self, texts: List[str]) -> List[List[float]]: | |
| kwargs = { | |
| "padding": "max_length", | |
| "return_tensors": "pt", | |
| "truncation": True, | |
| } | |
| inputs = self.processor(text=texts, **kwargs) | |
| with torch.no_grad(): | |
| batch = {k: v.to(self.device) for k, v in inputs.items()} | |
| vectors = self.model.get_text_features(**batch) | |
| return self._postprocess_vectors(vectors) | |
| def encode_images(self, images: List[Image]) -> List[List[float]]: | |
| inputs = self.processor(images=images, return_tensors="pt") | |
| with torch.no_grad(): | |
| batch = {k: v.to(self.device) for k, v in inputs.items()} | |
| vectors = self.model.get_image_features(**batch) | |
| return self._postprocess_vectors(vectors) | |
| def _postprocess_vectors(self, vectors: torch.Tensor) -> List[List[float]]: | |
| if self.normalize: | |
| vectors = normalize_vectors(vectors) | |
| return vectors.detach().cpu().numpy().tolist() | |