sam3 / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
ed119eb verified
raw
history blame
9.38 kB
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib
from transformers import Sam3Processor, Sam3Model
import warnings
warnings.filterwarnings("ignore")
# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
"""
Overlay segmentation masks on the input image using rainbow colormap.
"""
image = image.convert("RGBA")
masks = 255 * masks.cpu().numpy().astype(np.uint8)
n_masks = masks.shape[0]
if n_masks == 0:
return image.convert("RGB")
cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
colors = [
tuple(int(c * 255) for c in cmap(i)[:3])
for i in range(n_masks)
]
for mask, color in zip(masks, colors):
mask_img = Image.fromarray(mask)
overlay = Image.new("RGBA", image.size, color + (0,))
alpha = mask_img.point(lambda v: int(v * 0.5))
overlay.putalpha(alpha)
image = Image.alpha_composite(image, overlay)
return image
spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
"""
Perform promptable concept segmentation using SAM3.
"""
if image is None:
return None, "❌ Please upload an image."
try:
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
n_masks = len(results['masks'])
if n_masks == 0:
return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)."
overlaid_image = overlay_masks(image, results["masks"])
scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores
info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
return overlaid_image, info
except Exception as e:
return image, f"❌ Error during segmentation: {str(e)}"
def clear_all():
"""Clear all inputs and outputs"""
return None, "", None, 0.5, 0.5
def segment_example(image_path: str, prompt: str):
"""Handle example clicks"""
image = Image.open(image_path) if image_path else None
return segment(image, prompt, 0.5, 0.5)
# Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(),
title="SAM3 - Promptable Concept Segmentation",
css="""
.gradio-container {max-width: 1400px !important;}
"""
) as demo:
gr.Markdown(
"""
# SAM3 - Promptable Concept Segmentation (PCS)
**SAM3** performs zero-shot instance segmentation using natural language prompts on images.
Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects.
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
"""
)
gr.Markdown("### Inputs")
with gr.Row(variant="panel"):
image_input = gr.Image(
label="Input Image",
type="pil",
height=400,
)
image_output = gr.Image(
label="Output (Segmented Image)",
height=400,
interactive=False
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., a person, ear, cat, bicycle...",
scale=3
)
clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
with gr.Row():
thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Detection Threshold",
info="Higher values = fewer detections (objectness confidence)"
)
mask_thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Mask Threshold",
info="Higher values = sharper masks"
)
info_output = gr.Markdown(
value="πŸ“ Enter a prompt and click **Segment** to start.",
label="Info / Results"
)
segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
# Clear button handler
clear_btn.click(
fn=clear_all,
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider]
)
# Segment button handler
segment_btn.click(
fn=segment,
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
outputs=[image_output, info_output]
).then(
fn=lambda: None,
)
# Examples
gr.Markdown("### Examples")
examples = [
[
"http://images.cocodataset.org/val2017/000000077595.jpg",
"ear"
],
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"cat"
],
[
"http://images.cocodataset.org/val2017/000000001247.jpg",
"person"
],
[
"http://images.cocodataset.org/val2017/000000521315.jpg",
"bicycle"
],
[
"http://images.cocodataset.org/val2017/000000029369.jpg",
"dog"
]
]
gr.Examples(
examples=examples,
inputs=[image_input, text_input],
fn=segment_example,
outputs=[image_output, info_output],
cache_examples=True,
examples_per_page=10,
label="Try these COCO examples (URLs auto-load)"
)
gr.Markdown(
"""
### Notes
- **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
- Supports natural language prompts like "a red car" or simple nouns.
- GPU recommended for faster inference.
- Thresholds control detection sensitivity and mask quality.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
```
=== utils.py ===
```python
import torch
import numpy as np
from PIL import Image
import matplotlib
import requests
from io import BytesIO
def load_image_from_url(url: str) -> Image.Image:
"""
Load an image from a URL.
Args:
url: Image URL
Returns:
PIL Image object
"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image.convert("RGB")
except Exception as e:
raise ValueError(f"Could not load image from URL: {str(e)}")
def validate_image(image: Image.Image) -> bool:
"""
Validate if the image is suitable for processing.
Args:
image: PIL Image object
Returns:
True if valid, False otherwise
"""
if image is None:
return False
if image.size[0] <= 0 or image.size[1] <= 0:
return False
return True
def resize_for_processing(image: Image.Image, max_size: int = 1024) -> Image.Image:
"""
Resize image for processing while maintaining aspect ratio.
Args:
image: Input PIL Image
max_size: Maximum size for the longer dimension
Returns:
Resized PIL Image
"""
width, height = image.size
if max(width, height) <= max_size:
return image
if width > height:
new_width = max_size
new_height = int(height * max_size / width)
else:
new_height = max_size
new_width = int(width * max_size / height)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def overlay_masks_advanced(image: Image.Image, masks: torch.Tensor, alpha: float = 0.5) -> Image.Image:
"""
Advanced overlay function with customizable alpha.
Args:
image: Input PIL Image
masks: Segmentation masks tensor
alpha: Overlay transparency (0-1)
Returns:
Overlaid PIL Image
"""
image = image.convert("RGBA")
masks = 255 * masks.cpu().numpy().astype(np.uint8)
n_masks = masks.shape[0]
if n_masks == 0:
return image.convert("RGB")
# Use a good colormap
cmap = matplotlib.colormaps.get_cmap("tab10").resampled(n_masks)
colors = [
tuple(int(c * 255) for c in cmap(i)[:3])
for i in range(n_masks)
]
for mask, color in zip(masks, colors):
mask_img = Image.fromarray(mask)
overlay = Image.new("RGBA", image.size, color + (0,))
alpha_map = mask_img.point(lambda v: int(v * alpha * 255))
overlay.putalpha(alpha_map)
image = Image.alpha_composite(image, overlay)
return image