akhaliq HF Staff commited on
Commit
bc62515
Β·
verified Β·
1 Parent(s): c9de2d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -53
app.py CHANGED
@@ -3,8 +3,8 @@ import gradio as gr
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
- import matplotlib
7
  from transformers import Sam3Processor, Sam3Model
 
8
  import warnings
9
  warnings.filterwarnings("ignore")
10
 
@@ -13,44 +13,21 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
14
  processor = Sam3Processor.from_pretrained("facebook/sam3")
15
 
16
- def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
17
- """
18
- Overlay segmentation masks on the input image using rainbow colormap.
19
- """
20
- image = image.convert("RGBA")
21
- masks = 255 * masks.cpu().numpy().astype(np.uint8)
22
-
23
- n_masks = masks.shape[0]
24
- if n_masks == 0:
25
- return image.convert("RGB")
26
-
27
- cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
28
- colors = [
29
- tuple(int(c * 255) for c in cmap(i)[:3])
30
- for i in range(n_masks)
31
- ]
32
-
33
- for mask, color in zip(masks, colors):
34
- mask_img = Image.fromarray(mask)
35
- overlay = Image.new("RGBA", image.size, color + (0,))
36
- alpha = mask_img.point(lambda v: int(v * 0.5))
37
- overlay.putalpha(alpha)
38
- image = Image.alpha_composite(image, overlay)
39
- return image
40
-
41
  @spaces.GPU()
42
  def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
43
  """
44
  Perform promptable concept segmentation using SAM3.
 
45
  """
46
  if image is None:
47
  return None, "❌ Please upload an image."
48
 
 
 
 
49
  try:
50
- # Ensure inputs match model dtype
51
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
52
 
53
- # Convert inputs to match model dtype
54
  for key in inputs:
55
  if inputs[key].dtype == torch.float32:
56
  inputs[key] = inputs[key].to(model.dtype)
@@ -67,41 +44,50 @@ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: flo
67
 
68
  n_masks = len(results['masks'])
69
  if n_masks == 0:
70
- return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)."
71
 
72
- overlaid_image = overlay_masks(image, results["masks"])
 
 
 
 
 
 
 
73
 
74
- scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores
75
  info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
76
 
77
- return overlaid_image, info
 
78
 
79
  except Exception as e:
80
- return image, f"❌ Error during segmentation: {str(e)}"
81
 
82
  def clear_all():
83
  """Clear all inputs and outputs"""
84
- return None, "", None, 0.5, 0.5
85
 
86
  def segment_example(image_path: str, prompt: str):
87
  """Handle example clicks"""
88
- image = Image.open(image_path) if image_path else None
 
 
 
89
  return segment(image, prompt, 0.5, 0.5)
90
 
91
  # Gradio Interface
92
  with gr.Blocks(
93
  theme=gr.themes.Soft(),
94
  title="SAM3 - Promptable Concept Segmentation",
95
- css="""
96
- .gradio-container {max-width: 1400px !important;}
97
- """
98
  ) as demo:
99
  gr.Markdown(
100
  """
101
  # SAM3 - Promptable Concept Segmentation (PCS)
102
 
103
- **SAM3** performs zero-shot instance segmentation using natural language prompts on images.
104
- Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects.
105
 
106
  Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
107
  """
@@ -114,15 +100,17 @@ with gr.Blocks(
114
  type="pil",
115
  height=400,
116
  )
 
117
  image_output = gr.AnnotatedImage(
118
  label="Output (Segmented Image)",
119
  height=400,
 
120
  )
121
 
122
  with gr.Row():
123
  text_input = gr.Textbox(
124
  label="Text Prompt",
125
- placeholder="e.g., a person, ear, cat, bicycle...",
126
  scale=3
127
  )
128
  clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
@@ -134,7 +122,7 @@ with gr.Blocks(
134
  value=0.5,
135
  step=0.01,
136
  label="Detection Threshold",
137
- info="Higher values = fewer detections (objectness confidence)"
138
  )
139
  mask_thresh_slider = gr.Slider(
140
  minimum=0.0,
@@ -142,7 +130,7 @@ with gr.Blocks(
142
  value=0.5,
143
  step=0.01,
144
  label="Mask Threshold",
145
- info="Higher values = sharper masks"
146
  )
147
 
148
  info_output = gr.Markdown(
@@ -152,7 +140,6 @@ with gr.Blocks(
152
 
153
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
154
 
155
- # Add some example prompts
156
  gr.Examples(
157
  examples=[
158
  ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
@@ -163,29 +150,23 @@ with gr.Blocks(
163
  cache_examples=True,
164
  )
165
 
166
- # Clear button handler
167
  clear_btn.click(
168
  fn=clear_all,
169
- outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider]
170
  )
171
 
172
- # Segment button handler
173
  segment_btn.click(
174
  fn=segment,
175
  inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
176
  outputs=[image_output, info_output]
177
- ).then(
178
- fn=lambda: None,
179
  )
180
 
181
-
182
  gr.Markdown(
183
  """
184
  ### Notes
185
  - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
186
- - Supports natural language prompts like "a red car" or simple nouns.
187
- - GPU recommended for faster inference.
188
- - Thresholds control detection sensitivity and mask quality.
189
  """
190
  )
191
 
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
6
  from transformers import Sam3Processor, Sam3Model
7
+ import requests
8
  import warnings
9
  warnings.filterwarnings("ignore")
10
 
 
13
  model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
14
  processor = Sam3Processor.from_pretrained("facebook/sam3")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @spaces.GPU()
17
  def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
18
  """
19
  Perform promptable concept segmentation using SAM3.
20
+ Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...])
21
  """
22
  if image is None:
23
  return None, "❌ Please upload an image."
24
 
25
+ if not text.strip():
26
+ return (image, []), "❌ Please enter a text prompt."
27
+
28
  try:
 
29
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
30
 
 
31
  for key in inputs:
32
  if inputs[key].dtype == torch.float32:
33
  inputs[key] = inputs[key].to(model.dtype)
 
44
 
45
  n_masks = len(results['masks'])
46
  if n_masks == 0:
47
+ return (image, []), f"❌ No objects found matching '{text}' (try adjusting thresholds)."
48
 
49
+ # Format for AnnotatedImage: list of (mask, label) tuples
50
+ # mask should be numpy array with values 0-1 (float) matching image dimensions
51
+ annotations = []
52
+ for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
53
+ # Convert binary mask to float numpy array (0-1 range)
54
+ mask_np = mask.cpu().numpy().astype(np.float32)
55
+ label = f"{text} #{i+1} ({score:.2f})"
56
+ annotations.append((mask_np, label))
57
 
58
+ scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]])
59
  info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
60
 
61
+ # Return tuple: (base_image, list_of_annotations)
62
+ return (image, annotations), info
63
 
64
  except Exception as e:
65
+ return (image, []), f"❌ Error during segmentation: {str(e)}"
66
 
67
  def clear_all():
68
  """Clear all inputs and outputs"""
69
+ return None, "", None, 0.5, 0.5, "πŸ“ Enter a prompt and click **Segment** to start."
70
 
71
  def segment_example(image_path: str, prompt: str):
72
  """Handle example clicks"""
73
+ if image_path.startswith("http"):
74
+ image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
75
+ else:
76
+ image = Image.open(image_path).convert("RGB")
77
  return segment(image, prompt, 0.5, 0.5)
78
 
79
  # Gradio Interface
80
  with gr.Blocks(
81
  theme=gr.themes.Soft(),
82
  title="SAM3 - Promptable Concept Segmentation",
83
+ css=".gradio-container {max-width: 1400px !important;}"
 
 
84
  ) as demo:
85
  gr.Markdown(
86
  """
87
  # SAM3 - Promptable Concept Segmentation (PCS)
88
 
89
+ **SAM3** performs zero-shot instance segmentation using natural language prompts.
90
+ Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks.
91
 
92
  Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
93
  """
 
100
  type="pil",
101
  height=400,
102
  )
103
+ # AnnotatedImage expects: (base_image, [(mask, label), ...])
104
  image_output = gr.AnnotatedImage(
105
  label="Output (Segmented Image)",
106
  height=400,
107
+ show_legend=True,
108
  )
109
 
110
  with gr.Row():
111
  text_input = gr.Textbox(
112
  label="Text Prompt",
113
+ placeholder="e.g., person, ear, cat, bicycle...",
114
  scale=3
115
  )
116
  clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
 
122
  value=0.5,
123
  step=0.01,
124
  label="Detection Threshold",
125
+ info="Higher = fewer detections"
126
  )
127
  mask_thresh_slider = gr.Slider(
128
  minimum=0.0,
 
130
  value=0.5,
131
  step=0.01,
132
  label="Mask Threshold",
133
+ info="Higher = sharper masks"
134
  )
135
 
136
  info_output = gr.Markdown(
 
140
 
141
  segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
142
 
 
143
  gr.Examples(
144
  examples=[
145
  ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
 
150
  cache_examples=True,
151
  )
152
 
 
153
  clear_btn.click(
154
  fn=clear_all,
155
+ outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output]
156
  )
157
 
 
158
  segment_btn.click(
159
  fn=segment,
160
  inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
161
  outputs=[image_output, info_output]
 
 
162
  )
163
 
 
164
  gr.Markdown(
165
  """
166
  ### Notes
167
  - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
168
+ - Click on segments in the output to see labels
169
+ - GPU recommended for faster inference
 
170
  """
171
  )
172