File size: 5,990 Bytes
75921b2
ff645cc
 
 
 
 
bc62515
ff645cc
 
 
 
 
 
 
 
e3ba680
ff645cc
 
 
bc62515
ff645cc
 
 
 
bc62515
 
 
ff645cc
 
 
30a638e
 
 
 
ff645cc
 
 
 
 
 
 
 
 
 
 
 
bc62515
ff645cc
bc62515
 
 
 
 
 
 
 
ff645cc
bc62515
ff645cc
 
bc62515
 
ff645cc
 
bc62515
ff645cc
ed119eb
 
bc62515
ed119eb
 
 
bc62515
 
 
 
ed119eb
 
ff645cc
 
 
 
bc62515
ff645cc
 
 
 
 
bc62515
 
ff645cc
 
 
 
 
 
 
 
 
 
 
 
bc62515
c9de2d4
ff645cc
 
bc62515
ff645cc
 
 
 
 
bc62515
ff645cc
 
ed119eb
ff645cc
 
 
 
 
 
 
 
bc62515
ff645cc
 
 
 
 
 
 
bc62515
ff645cc
 
 
 
 
 
 
 
 
30a638e
 
1993b7d
30a638e
 
 
 
4c0f830
30a638e
 
ed119eb
 
bc62515
ed119eb
 
ff645cc
 
 
 
 
 
 
 
 
 
bc62515
 
ff645cc
 
 
 
fd4d970
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import Sam3Processor, Sam3Model
import requests
import warnings
warnings.filterwarnings("ignore")

# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")

@spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
    """
    Perform promptable concept segmentation using SAM3.
    Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...])
    """
    if image is None:
        return None, "❌ Please upload an image."
    
    if not text.strip():
        return (image, []), "❌ Please enter a text prompt."
    
    try:
        inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
        
        for key in inputs:
            if inputs[key].dtype == torch.float32:
                inputs[key] = inputs[key].to(model.dtype)
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=inputs.get("original_sizes").tolist()
        )[0]
        
        n_masks = len(results['masks'])
        if n_masks == 0:
            return (image, []), f"❌ No objects found matching '{text}' (try adjusting thresholds)."
        
        # Format for AnnotatedImage: list of (mask, label) tuples
        # mask should be numpy array with values 0-1 (float) matching image dimensions
        annotations = []
        for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
            # Convert binary mask to float numpy array (0-1 range)
            mask_np = mask.cpu().numpy().astype(np.float32)
            label = f"{text} #{i+1} ({score:.2f})"
            annotations.append((mask_np, label))
        
        scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]])
        info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
        
        # Return tuple: (base_image, list_of_annotations)
        return (image, annotations), info
        
    except Exception as e:
        return (image, []), f"❌ Error during segmentation: {str(e)}"

def clear_all():
    """Clear all inputs and outputs"""
    return None, "", None, 0.5, 0.5, "πŸ“ Enter a prompt and click **Segment** to start."

def segment_example(image_path: str, prompt: str):
    """Handle example clicks"""
    if image_path.startswith("http"):
        image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_path).convert("RGB")
    return segment(image, prompt, 0.5, 0.5)

# Gradio Interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="SAM3 - Promptable Concept Segmentation",
    css=".gradio-container {max-width: 1400px !important;}"
) as demo:
    gr.Markdown(
        """
        # SAM3 - Promptable Concept Segmentation (PCS)
        
        **SAM3** performs zero-shot instance segmentation using natural language prompts.
        Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks.
        
        Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
        """
    )
    
    gr.Markdown("### Inputs")
    with gr.Row(variant="panel"):
        image_input = gr.Image(
            label="Input Image",
            type="pil",
            height=400,
        )
        # AnnotatedImage expects: (base_image, [(mask, label), ...])
        image_output = gr.AnnotatedImage(
            label="Output (Segmented Image)",
            height=400,
            show_legend=True,
        )
    
    with gr.Row():
        text_input = gr.Textbox(
            label="Text Prompt",
            placeholder="e.g., person, ear, cat, bicycle...",
            scale=3
        )
        clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
    
    with gr.Row():
        thresh_slider = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.01,
            label="Detection Threshold",
            info="Higher = fewer detections"
        )
        mask_thresh_slider = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.01,
            label="Mask Threshold",
            info="Higher = sharper masks"
        )
    
    info_output = gr.Markdown(
        value="πŸ“ Enter a prompt and click **Segment** to start.",
        label="Info / Results"
    )
    
    segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
    
    gr.Examples(
        examples=[
            ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
        ],
        inputs=[image_input, text_input],
        outputs=[image_output, info_output],
        fn=segment_example,
        cache_examples=False,
    )
    
    clear_btn.click(
        fn=clear_all,
        outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output]
    )
    
    segment_btn.click(
        fn=segment,
        inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
        outputs=[image_output, info_output]
    )
    
    gr.Markdown(
        """
        ### Notes
        - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
        - Click on segments in the output to see labels
        - GPU recommended for faster inference
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)