import os import io import base64 import json import torch import numpy as np from PIL import Image # Safely import cairosvg with fallback try: import cairosvg except ImportError: print("Warning: cairosvg not found. Installing...") import subprocess subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"]) import cairosvg # Safely import clip with fallback try: import clip except ImportError: print("Warning: clip not found. Installing...") import subprocess subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"]) import clip # Import the simplified DiffSketcher try: from simplified_diffsketcher import SimplifiedDiffSketcher except ImportError: print("Warning: simplified_diffsketcher not found. Using placeholder.") SimplifiedDiffSketcher = None class EndpointHandler: def __init__(self, model_dir): """Initialize the handler with model directory""" self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing model on device: {self.device}") # Initialize the simplified DiffSketcher if available if SimplifiedDiffSketcher is not None: try: self.model = SimplifiedDiffSketcher(model_dir) self.use_model = True print("Simplified DiffSketcher initialized successfully") except Exception as e: print(f"Error initializing simplified DiffSketcher: {e}") self.use_model = False else: self.use_model = False print("Using placeholder SVG generator") def generate_placeholder_svg(self, prompt, width=512, height=512): """Generate a placeholder SVG""" svg_content = f""" {prompt} """ return svg_content def __call__(self, data): """Handle a request to the model""" try: # Extract the prompt if isinstance(data, dict) and "inputs" in data: if isinstance(data["inputs"], str): prompt = data["inputs"] elif isinstance(data["inputs"], dict) and "text" in data["inputs"]: prompt = data["inputs"]["text"] else: prompt = "No prompt provided" else: prompt = "No prompt provided" # Generate SVG using the model or placeholder if self.use_model: try: # Use the simplified DiffSketcher result = self.model(prompt) image = result["image"] except Exception as e: print(f"Error using simplified DiffSketcher: {e}") # Fall back to placeholder svg_content = self.generate_placeholder_svg(prompt) png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) image = Image.open(io.BytesIO(png_data)) else: # Use the placeholder SVG generator svg_content = self.generate_placeholder_svg(prompt) try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) image = Image.open(io.BytesIO(png_data)) except Exception as e: print(f"Error converting SVG to PNG: {e}") # Create a simple placeholder image image = Image.new("RGB", (512, 512), color="#f0f0f0") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), prompt, fill="black", anchor="mm") # Return the PIL Image directly return image except Exception as e: print(f"Error in handler: {e}") # Return a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") return image