manevamarija commited on
Commit
2c7a9e7
·
verified ·
1 Parent(s): eef716f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -11
handler.py CHANGED
@@ -3,47 +3,61 @@ from PIL import Image
3
  from io import BytesIO
4
  import base64
5
  import torch
 
6
  from transformers import CLIPProcessor, CLIPModel
7
 
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
  self.model.eval()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  """
16
  Args:
17
  data: {
18
  "inputs": {
19
- "image": base64 string,
20
- "candiates": list of strings
21
  }
22
  }
23
 
24
  Returns:
25
- List of dicts with raw cosine similarity scores (not softmax probabilities).
26
  """
27
  inputs = data.get("inputs", data)
28
 
29
  # Decode and process image
30
  image = Image.open(BytesIO(base64.b64decode(inputs["image"]))).convert("RGB")
31
- categories = inputs["candiates"]
32
 
33
- # Get image and text features
34
- processed = self.processor(text=categories, images=image, return_tensors="pt", padding=True)
 
35
  with torch.no_grad():
36
  image_features = self.model.get_image_features(processed["pixel_values"])
37
  text_features = self.model.get_text_features(processed["input_ids"], attention_mask=processed["attention_mask"])
38
 
39
- # Normalize (L2) to get cosine similarity
40
  image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
41
  text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
42
 
43
- similarity = (image_features @ text_features.T).squeeze(0) # shape: (num_labels,)
 
44
 
45
- # Format output with raw cosine scores
46
- result = [{"label": label, "score": score.item()} for label, score in zip(categories, similarity)]
47
  result = sorted(result, key=lambda x: x["score"], reverse=True)
48
- return result
49
 
 
 
3
  from io import BytesIO
4
  import base64
5
  import torch
6
+ import csv
7
  from transformers import CLIPProcessor, CLIPModel
8
 
9
+
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
  self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
  self.model.eval()
15
 
16
+ # Load categories from CSV
17
+ self.categories = self.load_categories_from_csv("categories.csv")
18
+
19
+ def load_categories_from_csv(self, filepath: str) -> List[str]:
20
+ categories = []
21
+ with open(filepath, newline='', encoding='utf-8') as csvfile:
22
+ reader = csv.reader(csvfile)
23
+ for row in reader:
24
+ if row:
25
+ categories.append(row[0].strip())
26
+ return categories
27
+
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
  """
30
  Args:
31
  data: {
32
  "inputs": {
33
+ "image": base64 string
 
34
  }
35
  }
36
 
37
  Returns:
38
+ Top 20 categories with highest similarity score.
39
  """
40
  inputs = data.get("inputs", data)
41
 
42
  # Decode and process image
43
  image = Image.open(BytesIO(base64.b64decode(inputs["image"]))).convert("RGB")
 
44
 
45
+ # Process image and text
46
+ processed = self.processor(text=self.categories, images=image, return_tensors="pt", padding=True)
47
+
48
  with torch.no_grad():
49
  image_features = self.model.get_image_features(processed["pixel_values"])
50
  text_features = self.model.get_text_features(processed["input_ids"], attention_mask=processed["attention_mask"])
51
 
52
+ # Normalize features
53
  image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
54
  text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
55
 
56
+ # Compute similarity
57
+ similarity = (image_features @ text_features.T).squeeze(0)
58
 
59
+ # Prepare result
60
+ result = [{"label": label, "score": score.item()} for label, score in zip(self.categories, similarity)]
61
  result = sorted(result, key=lambda x: x["score"], reverse=True)
 
62
 
63
+ return result[:20] # Return top 20