pediot commited on
Commit
5cc3519
·
1 Parent(s): fdf1598

normalize vectors

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. src/encoder.py +17 -11
  3. src/utils.py +9 -0
app.py CHANGED
@@ -4,7 +4,7 @@ from src.encoder import FashionCLIPEncoder
4
  from src.models import TextRequest, ImageRequest, Response
5
 
6
 
7
- encoder = FashionCLIPEncoder()
8
  app = FastAPI()
9
 
10
 
 
4
  from src.models import TextRequest, ImageRequest, Response
5
 
6
 
7
+ encoder = FashionCLIPEncoder(normalize=True)
8
  app = FastAPI()
9
 
10
 
src/encoder.py CHANGED
@@ -4,12 +4,16 @@ from PIL.Image import Image
4
  import torch
5
  from transformers import AutoModel, AutoProcessor
6
 
 
 
7
 
8
  MODEL_NAME = "Marqo/marqo-fashionCLIP"
9
 
10
 
11
  class FashionCLIPEncoder:
12
- def __init__(self):
 
 
13
  self.device = torch.device("cpu")
14
 
15
  self.processor = AutoProcessor.from_pretrained(
@@ -31,24 +35,26 @@ class FashionCLIPEncoder:
31
  "return_tensors": "pt",
32
  "truncation": True,
33
  }
 
34
  inputs = self.processor(text=texts, **kwargs)
35
 
36
  with torch.no_grad():
37
  batch = {k: v.to(self.device) for k, v in inputs.items()}
38
- return self._encode_text(batch)
 
 
39
 
40
  def encode_images(self, images: List[Image]) -> List[List[float]]:
41
- kwargs = {
42
- "return_tensors": "pt",
43
- }
44
- inputs = self.processor(images=images, **kwargs)
45
 
46
  with torch.no_grad():
47
  batch = {k: v.to(self.device) for k, v in inputs.items()}
48
- return self._encode_images(batch)
 
 
49
 
50
- def _encode_text(self, batch: Dict) -> List[List[float]]:
51
- return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()
 
52
 
53
- def _encode_images(self, batch: Dict) -> List[List[float]]:
54
- return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
 
4
  import torch
5
  from transformers import AutoModel, AutoProcessor
6
 
7
+ from .utils import normalize_vectors
8
+
9
 
10
  MODEL_NAME = "Marqo/marqo-fashionCLIP"
11
 
12
 
13
  class FashionCLIPEncoder:
14
+ def __init__(self, normalize: bool = False):
15
+ self.normalize = normalize
16
+
17
  self.device = torch.device("cpu")
18
 
19
  self.processor = AutoProcessor.from_pretrained(
 
35
  "return_tensors": "pt",
36
  "truncation": True,
37
  }
38
+
39
  inputs = self.processor(text=texts, **kwargs)
40
 
41
  with torch.no_grad():
42
  batch = {k: v.to(self.device) for k, v in inputs.items()}
43
+ vectors = self.model.get_text_features(**batch)
44
+
45
+ return self._postprocess_vectors(vectors)
46
 
47
  def encode_images(self, images: List[Image]) -> List[List[float]]:
48
+ inputs = self.processor(images=images, return_tensors="pt")
 
 
 
49
 
50
  with torch.no_grad():
51
  batch = {k: v.to(self.device) for k, v in inputs.items()}
52
+ vectors = self.model.get_image_features(**batch)
53
+
54
+ return self._postprocess_vectors(vectors)
55
 
56
+ def _postprocess_vectors(self, vectors: torch.Tensor) -> List[List[float]]:
57
+ if self.normalize:
58
+ vectors = normalize_vectors(vectors)
59
 
60
+ return vectors.detach().cpu().numpy().tolist()
 
src/utils.py CHANGED
@@ -20,6 +20,15 @@ def download_image_as_pil(url: str, timeout: int = 10) -> Image.Image:
20
 
21
  except Exception as e:
22
  return
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def analyze_model_parameters(model: torch.nn.Module) -> Dict:
 
20
 
21
  except Exception as e:
22
  return
23
+
24
+
25
+ def normalize_vectors(vectors: torch.Tensor) -> torch.Tensor:
26
+ norms = torch.norm(vectors, p=2, dim=1, keepdim=True)
27
+ norms = torch.norm(vectors, p=2, dim=1, keepdim=True)
28
+ norms = torch.where(norms > 1e-8, norms, torch.ones_like(norms))
29
+ normalized_vectors = vectors / norms
30
+
31
+ return normalized_vectors
32
 
33
 
34
  def analyze_model_parameters(model: torch.nn.Module) -> Dict: