Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| from PIL import Image, ImageDraw, ImageFont | |
| import requests # To handle image URLs if needed, but we focus on uploads | |
| # Load the model and processor | |
| # Using revision="no_timm" to potentially avoid the timm dependency if not installed, | |
| # but it's safer to include timm in requirements.txt | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101-dc5") | |
| # Define class names for filtering (check model.config.id2label for exact mapping) | |
| # Common COCO IDs: cat=16, dog=17 (0-indexed) but let's use labels | |
| # We need to get the actual labels the model uses | |
| id2label = model.config.id2label | |
| target_labels = ["cat", "dog"] | |
| target_ids = [label_id for label_id, label in id2label.items() if label in target_labels] | |
| # Colors for bounding boxes (simple example) | |
| colors = {"cat": "red", "dog": "blue"} | |
| def detect_objects(image_input): | |
| """ | |
| Detects cats and dogs in the input image using DETR. | |
| Args: | |
| image_input (PIL.Image.Image): Input image. | |
| Returns: | |
| PIL.Image.Image: Image with bounding boxes drawn around detected cats/dogs. | |
| """ | |
| if image_input is None: | |
| return None | |
| # Convert Gradio input (if numpy) to PIL Image, although type="pil" should handle this | |
| if not isinstance(image_input, Image.Image): | |
| image = Image.fromarray(image_input) | |
| else: | |
| image = image_input.copy() # Work on a copy | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Perform inference | |
| outputs = model(**inputs) | |
| # Post-process the results | |
| # Convert outputs (bounding boxes and class logits) to COCO API format | |
| target_sizes = torch.tensor([image.size[::-1]]) # (height, width) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lower threshold (e.g., 0.5) might find more objects | |
| # Draw bounding boxes for cats and dogs | |
| draw = ImageDraw.Draw(image) | |
| try: | |
| # Use a default font or specify a path to a .ttf file if available in the Space | |
| font = ImageFont.load_default() | |
| except IOError: | |
| print("Default font not found. Using basic drawing without text.") | |
| font = None | |
| detections_found = False | |
| for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| label_id = label_id.item() | |
| if label_id in target_ids: | |
| detections_found = True | |
| box = [round(i, 2) for i in box.tolist()] | |
| label = id2label[label_id] | |
| box_color = colors.get(label, "green") # Default to green if label not in colors dict | |
| print(f"Detected {label} with confidence {round(score.item(), 3)} at {box}") | |
| # Draw rectangle | |
| draw.rectangle(box, outline=box_color, width=3) | |
| # Draw label text | |
| if font: | |
| text = f"{label}: {score.item():.2f}" | |
| text_width, text_height = font.getsize(text) if hasattr(font, 'getsize') else (50, 10) # Estimate size if getsize not available | |
| text_bg_coords = [(box[0], box[1]), (box[0] + text_width + 4, box[1] + text_height + 4)] | |
| draw.rectangle(text_bg_coords, fill=box_color) | |
| draw.text((box[0] + 2, box[1] + 2), text, fill="white", font=font) | |
| if not detections_found: | |
| print("No cats or dogs detected with the current threshold.") | |
| # Optionally add text to the image saying nothing was found | |
| # draw.text((10, 10), "No cats or dogs detected", fill="black", font=font) | |
| return image | |
| # Create the Gradio interface | |
| title = "Cat & Dog Detector (using DETR ResNet-101)" | |
| description = ("Upload an image and the model will draw bounding boxes " | |
| "around detected cats and dogs. Uses the facebook/detr-resnet-101-dc5 model from Hugging Face.") | |
| iface = gr.Interface( | |
| fn=detect_objects, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Image(type="pil", label="Output Image with Detections"), | |
| title=title, | |
| description=description, | |
| examples=[ | |
| # You can add paths to example images if you upload them to your space | |
| # Or provide URLs | |
| ["http://images.cocodataset.org/val2017/000000039769.jpg"], # Example image URL with cats | |
| ["https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg"] # Example image with dog and cat | |
| ], | |
| allow_flagging="never" # You can change flagging options if needed | |
| ) | |
| # Launch the app | |
| iface.launch() |