import os import json import torch import base64 from io import BytesIO from PIL import Image import cairosvg import numpy as np class DiffSketcherHandler: def __init__(self): self.initialized = False self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None def initialize(self, context): """Initialize the handler.""" self.initialized = True # Import dependencies here to avoid issues during startup try: import pydiffvg self.diffvg = pydiffvg print("Successfully imported pydiffvg") except ImportError as e: print(f"Warning: Could not import pydiffvg: {e}") print("Will use placeholder SVG generation") self.diffvg = None # We'll initialize the actual model only when needed return None def _initialize_model(self): """Initialize the actual model when needed.""" if self.model is not None: return try: # Try to import and initialize the actual model from diffusers import StableDiffusionPipeline # Load a small model for testing self.model = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) print("Successfully initialized the model") except Exception as e: print(f"Error initializing model: {e}") print("Will use placeholder generation") self.model = None def preprocess(self, data): """Preprocess the input data.""" inputs = data.get("inputs", "") if not inputs: inputs = "a beautiful landscape" # Get parameters parameters = data.get("parameters", {}) num_paths = parameters.get("num_paths", 96) token_ind = parameters.get("token_ind", 4) num_iter = parameters.get("num_iter", 800) return { "prompt": inputs, "num_paths": num_paths, "token_ind": token_ind, "num_iter": num_iter } def _generate_placeholder_svg(self, prompt): """Generate a placeholder SVG when the actual model is not available.""" import svgwrite # Create a simple SVG dwg = svgwrite.Drawing(size=(512, 512)) # Add a background rectangle dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0')) # Add a circle dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db')) # Add the prompt as text dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black')) # Add a note that this is a placeholder dwg.add(dwg.text("Placeholder SVG - Model not available", insert=(50, 480), font_size=16, fill='red')) svg_string = dwg.tostring() # Convert SVG to PNG for preview png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8')) image = Image.open(BytesIO(png_data)) return svg_string, image def inference(self, inputs): """Run inference with the preprocessed inputs.""" prompt = inputs["prompt"] # Try to initialize the model if not already done if self.model is None and self.diffvg is not None: try: self._initialize_model() except Exception as e: print(f"Error initializing model during inference: {e}") # If we have a working model, use it if self.model is not None and self.diffvg is not None: try: # This would be the actual DiffSketcher implementation # For now, we'll just generate a placeholder svg_string, image = self._generate_placeholder_svg(prompt) except Exception as e: print(f"Error during model inference: {e}") svg_string, image = self._generate_placeholder_svg(prompt) else: # Use placeholder if model is not available svg_string, image = self._generate_placeholder_svg(prompt) return { "svg": svg_string, "image": image } def postprocess(self, inference_output): """Post-process the model output.""" svg_string = inference_output["svg"] image = inference_output["image"] # Convert image to base64 for JSON response buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() img_base64 = f"data:image/png;base64,{img_str}" return { "svg": svg_string, "image": img_base64 } def handle(self, data, context): """Handle the request.""" if not self.initialized: self.initialize(context) preprocessed_data = self.preprocess(data) inference_output = self.inference(preprocessed_data) return self.postprocess(inference_output)