import spaces import gradio as gr import torch import numpy as np from PIL import Image import matplotlib from transformers import Sam3Processor, Sam3Model import warnings warnings.filterwarnings("ignore") # Global model and processor device = "cuda" if torch.cuda.is_available() else "cpu" model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) processor = Sam3Processor.from_pretrained("facebook/sam3") def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image: """ Overlay segmentation masks on the input image using rainbow colormap. """ image = image.convert("RGBA") masks = 255 * masks.cpu().numpy().astype(np.uint8) n_masks = masks.shape[0] if n_masks == 0: return image.convert("RGB") cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks) colors = [ tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks) ] for mask, color in zip(masks, colors): mask_img = Image.fromarray(mask) overlay = Image.new("RGBA", image.size, color + (0,)) alpha = mask_img.point(lambda v: int(v * 0.5)) overlay.putalpha(alpha) image = Image.alpha_composite(image, overlay) return image spaces.GPU() def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): """ Perform promptable concept segmentation using SAM3. """ if image is None: return None, "❌ Please upload an image." try: inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=inputs.get("original_sizes").tolist() )[0] n_masks = len(results['masks']) if n_masks == 0: return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)." overlaid_image = overlay_masks(image, results["masks"]) scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores info = f"✅ Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}" return overlaid_image, info except Exception as e: return image, f"❌ Error during segmentation: {str(e)}" def clear_all(): """Clear all inputs and outputs""" return None, "", None, 0.5, 0.5 def segment_example(image_path: str, prompt: str): """Handle example clicks""" image = Image.open(image_path) if image_path else None return segment(image, prompt, 0.5, 0.5) # Gradio Interface with gr.Blocks( theme=gr.themes.Soft(), title="SAM3 - Promptable Concept Segmentation", css=""" .gradio-container {max-width: 1400px !important;} """ ) as demo: gr.Markdown( """ # SAM3 - Promptable Concept Segmentation (PCS) **SAM3** performs zero-shot instance segmentation using natural language prompts on images. Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects. Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) """ ) gr.Markdown("### Inputs") with gr.Row(variant="panel"): image_input = gr.Image( label="Input Image", type="pil", height=400, ) image_output = gr.Image( label="Output (Segmented Image)", height=400, interactive=False ) with gr.Row(): text_input = gr.Textbox( label="Text Prompt", placeholder="e.g., a person, ear, cat, bicycle...", scale=3 ) clear_btn = gr.Button("🔍 Clear", size="sm", variant="secondary") with gr.Row(): thresh_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Detection Threshold", info="Higher values = fewer detections (objectness confidence)" ) mask_thresh_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Mask Threshold", info="Higher values = sharper masks" ) info_output = gr.Markdown( value="📝 Enter a prompt and click **Segment** to start.", label="Info / Results" ) segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg") # Clear button handler clear_btn.click( fn=clear_all, outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider] ) # Segment button handler segment_btn.click( fn=segment, inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], outputs=[image_output, info_output] ).then( fn=lambda: None, ) # Examples gr.Markdown("### Examples") examples = [ [ "http://images.cocodataset.org/val2017/000000077595.jpg", "ear" ], [ "http://images.cocodataset.org/val2017/000000039769.jpg", "cat" ], [ "http://images.cocodataset.org/val2017/000000001247.jpg", "person" ], [ "http://images.cocodataset.org/val2017/000000521315.jpg", "bicycle" ], [ "http://images.cocodataset.org/val2017/000000029369.jpg", "dog" ] ] gr.Examples( examples=examples, inputs=[image_input, text_input], fn=segment_example, outputs=[image_output, info_output], cache_examples=True, examples_per_page=10, label="Try these COCO examples (URLs auto-load)" ) gr.Markdown( """ ### Notes - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) - Supports natural language prompts like "a red car" or simple nouns. - GPU recommended for faster inference. - Thresholds control detection sensitivity and mask quality. """ ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) ``` === utils.py === ```python import torch import numpy as np from PIL import Image import matplotlib import requests from io import BytesIO def load_image_from_url(url: str) -> Image.Image: """ Load an image from a URL. Args: url: Image URL Returns: PIL Image object """ try: response = requests.get(url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)) return image.convert("RGB") except Exception as e: raise ValueError(f"Could not load image from URL: {str(e)}") def validate_image(image: Image.Image) -> bool: """ Validate if the image is suitable for processing. Args: image: PIL Image object Returns: True if valid, False otherwise """ if image is None: return False if image.size[0] <= 0 or image.size[1] <= 0: return False return True def resize_for_processing(image: Image.Image, max_size: int = 1024) -> Image.Image: """ Resize image for processing while maintaining aspect ratio. Args: image: Input PIL Image max_size: Maximum size for the longer dimension Returns: Resized PIL Image """ width, height = image.size if max(width, height) <= max_size: return image if width > height: new_width = max_size new_height = int(height * max_size / width) else: new_height = max_size new_width = int(width * max_size / height) return image.resize((new_width, new_height), Image.Resampling.LANCZOS) def overlay_masks_advanced(image: Image.Image, masks: torch.Tensor, alpha: float = 0.5) -> Image.Image: """ Advanced overlay function with customizable alpha. Args: image: Input PIL Image masks: Segmentation masks tensor alpha: Overlay transparency (0-1) Returns: Overlaid PIL Image """ image = image.convert("RGBA") masks = 255 * masks.cpu().numpy().astype(np.uint8) n_masks = masks.shape[0] if n_masks == 0: return image.convert("RGB") # Use a good colormap cmap = matplotlib.colormaps.get_cmap("tab10").resampled(n_masks) colors = [ tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks) ] for mask, color in zip(masks, colors): mask_img = Image.fromarray(mask) overlay = Image.new("RGBA", image.size, color + (0,)) alpha_map = mask_img.point(lambda v: int(v * alpha * 255)) overlay.putalpha(alpha_map) image = Image.alpha_composite(image, overlay) return image