akhaliq HF Staff commited on
Commit
ff645cc
Β·
verified Β·
1 Parent(s): b09e323

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +198 -0
  2. requirements.txt +17 -0
  3. utils.py +2 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib
6
+ from transformers import Sam3Processor, Sam3Model
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ # Global model and processor
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
13
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
14
+
15
+ def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
16
+ """
17
+ Overlay segmentation masks on the input image using rainbow colormap.
18
+ """
19
+ image = image.convert("RGBA")
20
+ masks = 255 * masks.cpu().numpy().astype(np.uint8)
21
+
22
+ n_masks = masks.shape[0]
23
+ if n_masks == 0:
24
+ return image.convert("RGB")
25
+
26
+ cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
27
+ colors = [
28
+ tuple(int(c * 255) for c in cmap(i)[:3])
29
+ for i in range(n_masks)
30
+ ]
31
+
32
+ for mask, color in zip(masks, colors):
33
+ mask_img = Image.fromarray(mask)
34
+ overlay = Image.new("RGBA", image.size, color + (0,))
35
+ alpha = mask_img.point(lambda v: int(v * 0.5))
36
+ overlay.putalpha(alpha)
37
+ image = Image.alpha_composite(image, overlay)
38
+ return image
39
+
40
+ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
41
+ """
42
+ Perform promptable concept segmentation using SAM3.
43
+ """
44
+ if image is None:
45
+ return None, "❌ Please upload an image."
46
+
47
+ try:
48
+ inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
49
+
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ results = processor.post_process_instance_segmentation(
54
+ outputs,
55
+ threshold=threshold,
56
+ mask_threshold=mask_threshold,
57
+ target_sizes=inputs.get("original_sizes").tolist()
58
+ )[0]
59
+
60
+ n_masks = len(results['masks'])
61
+ if n_masks == 0:
62
+ return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)."
63
+
64
+ overlaid_image = overlay_masks(image, results["masks"])
65
+
66
+ scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores
67
+ info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
68
+
69
+ return overlaid_image, info
70
+
71
+ except Exception as e:
72
+ return image, f"❌ Error during segmentation: {str(e)}"
73
+
74
+ # Gradio Interface
75
+ with gr.Blocks(
76
+ theme=gr.themes.Soft(),
77
+ title="SAM3 - Promptable Concept Segmentation",
78
+ css="""
79
+ .gradio-container {max-width: 1400px !important;}
80
+ """
81
+ ) as demo:
82
+ gr.Markdown(
83
+ """
84
+ # SAM3 - Promptable Concept Segmentation (PCS)
85
+
86
+ **SAM3** performs zero-shot instance segmentation using natural language prompts on images.
87
+ Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects.
88
+
89
+ Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
90
+ """
91
+ )
92
+
93
+ gr.Markdown("### Inputs")
94
+ with gr.Row(variant="panel"):
95
+ image_input = gr.Image(
96
+ label="Input Image",
97
+ type="pil",
98
+ height=400,
99
+ sources=["upload", "url"],
100
+ info="Upload or paste image URL"
101
+ )
102
+ image_output = gr.Image(
103
+ label="Output (Segmented Image)",
104
+ height=400,
105
+ interactive=False
106
+ )
107
+
108
+ with gr.Row():
109
+ text_input = gr.Textbox(
110
+ label="Text Prompt",
111
+ placeholder="e.g., a person, ear, cat, bicycle...",
112
+ scale=3
113
+ )
114
+ gr.Button("πŸ” Clear", size="sm", variant="secondary").click(
115
+ fn=lambda: (None, "", None, 0.5, 0.5), outputs=[image_output, text_input, image_input, thresh_slider, mask_thresh_slider]
116
+ )
117
+
118
+ with gr.Row():
119
+ thresh_slider = gr.Slider(
120
+ minimum=0.0,
121
+ maximum=1.0,
122
+ value=0.5,
123
+ step=0.01,
124
+ label="Detection Threshold",
125
+ info="Higher values = fewer detections (objectness confidence)"
126
+ )
127
+ mask_thresh_slider = gr.Slider(
128
+ minimum=0.0,
129
+ maximum=1.0,
130
+ value=0.5,
131
+ step=0.01,
132
+ label="Mask Threshold",
133
+ info="Higher values = sharper masks"
134
+ )
135
+
136
+ info_output = gr.Markdown(
137
+ value="πŸ“ Enter a prompt and click **Segment** to start.",
138
+ label="Info / Results"
139
+ )
140
+
141
+ segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
142
+
143
+ # Event
144
+ segment_btn.click(
145
+ fn=segment,
146
+ inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
147
+ outputs=[image_output, info_output]
148
+ ).then(
149
+ fn=lambda: gr.Info("Segmentation complete!"),
150
+ _js="() => {}"
151
+ )
152
+
153
+ # Examples
154
+ gr.Markdown("### Examples")
155
+ examples = [
156
+ [
157
+ "http://images.cocodataset.org/val2017/000000077595.jpg",
158
+ "ear"
159
+ ],
160
+ [
161
+ "http://images.cocodataset.org/val2017/000000039769.jpg",
162
+ "cat"
163
+ ],
164
+ [
165
+ "http://images.cocodataset.org/val2017/000000001247.jpg",
166
+ "person"
167
+ ],
168
+ [
169
+ "http://images.cocodataset.org/val2017/000000521315.jpg",
170
+ "bicycle"
171
+ ],
172
+ [
173
+ "http://images.cocodataset.org/val2017/000000029369.jpg",
174
+ "dog"
175
+ ]
176
+ ]
177
+ gr.Examples(
178
+ examples=examples,
179
+ inputs=[image_input, text_input],
180
+ fn=segment,
181
+ outputs=[image_output, info_output],
182
+ cache_examples=True,
183
+ examples_per_page=10,
184
+ label="Try these COCO examples (URLs auto-load)"
185
+ )
186
+
187
+ gr.Markdown(
188
+ """
189
+ ### Notes
190
+ - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
191
+ - Supports natural language prompts like "a red car" or simple nouns.
192
+ - GPU recommended for faster inference.
193
+ - Thresholds control detection sensitivity and mask quality.
194
+ """
195
+ )
196
+
197
+ if __name__ == "__main__":
198
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ gradio
6
+ matplotlib
7
+ numpy
8
+ Pillow
9
+ accelerate
10
+ tokenizers
11
+ datasets
12
+ requests
13
+ opencv-python
14
+ scipy
15
+ pillow
16
+ imageio
17
+ scikit-image
utils.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # No additional utility functions needed beyond what's in app.py
2
+ # All helpers (overlay_masks, segment) are defined in app.py for simplicity and global access.