import os import sys import json import torch import numpy as np from PIL import Image, ImageDraw import io import base64 from typing import Dict, Any, List import tempfile import xml.etree.ElementTree as ET import math class DiffSketcherHandler: def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self): """Load the DiffSketcher model and dependencies""" try: # Import DiffSketcher modules from methods.painter.diffsketcher import Painter from methods.diffusers_warp import StableDiffusionPipeline # Load the diffusion model self.pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None, requires_safety_checker=False ).to(self.device) # Initialize the painter self.painter = Painter( args=self._get_default_args(), pipe=self.pipe ) self.model_loaded = True return True except Exception as e: print(f"Error loading model: {str(e)}") return False def _get_default_args(self): """Get default arguments for DiffSketcher""" class Args: def __init__(self): self.token_ind = 4 self.num_paths = 96 self.num_iter = 500 self.guidance_scale = 7.5 self.lr_scheduler = True self.lr = 1.0 self.color_lr = 0.01 self.width_lr = 0.1 self.opacity_lr = 0.01 self.width = 224 self.height = 224 self.seed = 42 self.eval_step = 10 self.save_step = 10 return Args() def __call__(self, data: Dict[str, Any]): """ Generate SVG sketch from text prompt Returns SVG content for Inference API """ try: # Extract inputs if isinstance(data, dict): prompt = data.get("inputs", "") parameters = data.get("parameters", {}) else: prompt = str(data) parameters = {} if not prompt: prompt = "a simple drawing" # Extract parameters num_paths = parameters.get("num_paths", 96) width = parameters.get("width", 224) height = parameters.get("height", 224) seed = parameters.get("seed", 42) guidance_scale = parameters.get("guidance_scale", 7.5) # Set random seed for reproducibility np.random.seed(seed) torch.manual_seed(seed) # Generate SVG content based on prompt svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale) # Return SVG as text for Inference API return svg_content except Exception as e: # Return error SVG return f'Error: {str(e)}' def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str: """ Generate a sketch-style SVG based on the text prompt Uses semantic analysis of the prompt to create appropriate shapes """ svg_header = f'' svg_footer = '' paths = [] # Analyze prompt for semantic content prompt_lower = prompt.lower() # Color palette based on prompt sentiment if any(word in prompt_lower for word in ['nature', 'tree', 'forest', 'green', 'plant']): colors = ["#2E7D32", "#388E3C", "#43A047", "#4CAF50", "#66BB6A"] elif any(word in prompt_lower for word in ['sky', 'blue', 'ocean', 'water', 'sea']): colors = ["#1565C0", "#1976D2", "#1E88E5", "#2196F3", "#42A5F5"] elif any(word in prompt_lower for word in ['fire', 'red', 'warm', 'sun', 'orange']): colors = ["#D32F2F", "#F44336", "#FF5722", "#FF9800", "#FFC107"] elif any(word in prompt_lower for word in ['purple', 'violet', 'magic', 'mystical']): colors = ["#512DA8", "#673AB7", "#9C27B0", "#E91E63", "#F06292"] else: colors = ["#424242", "#616161", "#757575", "#9E9E9E", "#BDBDBD"] # Generate shapes based on prompt content if any(word in prompt_lower for word in ['circle', 'round', 'ball', 'sun', 'moon']): self._add_circular_elements(paths, width, height, colors, num_paths // 3) if any(word in prompt_lower for word in ['house', 'building', 'square', 'box']): self._add_rectangular_elements(paths, width, height, colors, num_paths // 3) if any(word in prompt_lower for word in ['mountain', 'triangle', 'peak', 'roof']): self._add_triangular_elements(paths, width, height, colors, num_paths // 3) if any(word in prompt_lower for word in ['flower', 'star', 'organic', 'natural']): self._add_organic_paths(paths, width, height, colors, num_paths // 2) # Add flowing lines for movement or abstract concepts if any(word in prompt_lower for word in ['flowing', 'wind', 'wave', 'abstract', 'movement']): self._add_flowing_lines(paths, width, height, colors, num_paths // 2) # If no specific shapes detected, add general sketch elements if len(paths) < num_paths // 4: self._add_general_sketch_elements(paths, width, height, colors, num_paths) # Add some random sketch lines for artistic effect self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5)) svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer return svg_content def _add_circular_elements(self, paths, width, height, colors, count): """Add circular elements to the SVG""" for i in range(count): cx = np.random.randint(30, width - 30) cy = np.random.randint(30, height - 30) r = np.random.randint(8, 40) color = np.random.choice(colors) opacity = np.random.uniform(0.3, 0.8) stroke_width = np.random.randint(1, 3) if np.random.random() > 0.5: paths.append(f'') else: paths.append(f'') def _add_rectangular_elements(self, paths, width, height, colors, count): """Add rectangular elements to the SVG""" for i in range(count): x = np.random.randint(10, width - 50) y = np.random.randint(10, height - 50) w = np.random.randint(20, 60) h = np.random.randint(20, 60) color = np.random.choice(colors) opacity = np.random.uniform(0.3, 0.8) stroke_width = np.random.randint(1, 3) if np.random.random() > 0.5: paths.append(f'') else: paths.append(f'') def _add_triangular_elements(self, paths, width, height, colors, count): """Add triangular elements to the SVG""" for i in range(count): x1 = np.random.randint(20, width - 20) y1 = np.random.randint(40, height - 20) x2 = x1 + np.random.randint(-30, 30) y2 = y1 - np.random.randint(20, 50) x3 = x1 + np.random.randint(-30, 30) y3 = y1 color = np.random.choice(colors) opacity = np.random.uniform(0.3, 0.8) stroke_width = np.random.randint(1, 3) points = f"{x1},{y1} {x2},{y2} {x3},{y3}" if np.random.random() > 0.5: paths.append(f'') else: paths.append(f'') def _add_organic_paths(self, paths, width, height, colors, count): """Add organic curved paths to the SVG""" for i in range(count): start_x = np.random.randint(20, width - 20) start_y = np.random.randint(20, height - 20) # Create a curved path path_data = f"M {start_x} {start_y}" for j in range(np.random.randint(2, 5)): control_x1 = start_x + np.random.randint(-40, 40) control_y1 = start_y + np.random.randint(-40, 40) control_x2 = start_x + np.random.randint(-40, 40) control_y2 = start_y + np.random.randint(-40, 40) end_x = start_x + np.random.randint(-60, 60) end_y = start_y + np.random.randint(-60, 60) path_data += f" C {control_x1} {control_y1}, {control_x2} {control_y2}, {end_x} {end_y}" start_x, start_y = end_x, end_y color = np.random.choice(colors) opacity = np.random.uniform(0.4, 0.9) stroke_width = np.random.randint(1, 4) paths.append(f'') def _add_flowing_lines(self, paths, width, height, colors, count): """Add flowing lines to the SVG""" for i in range(count): x1 = np.random.randint(0, width) y1 = np.random.randint(0, height) x2 = np.random.randint(0, width) y2 = np.random.randint(0, height) color = np.random.choice(colors) opacity = np.random.uniform(0.3, 0.7) stroke_width = np.random.randint(1, 3) paths.append(f'') def _add_general_sketch_elements(self, paths, width, height, colors, count): """Add general sketch elements when no specific shapes are detected""" for i in range(count // 3): # Mix of circles, rectangles, and lines element_type = np.random.choice(['circle', 'rect', 'line']) color = np.random.choice(colors) opacity = np.random.uniform(0.3, 0.8) if element_type == 'circle': cx = np.random.randint(20, width - 20) cy = np.random.randint(20, height - 20) r = np.random.randint(5, 25) paths.append(f'') elif element_type == 'rect': x = np.random.randint(10, width - 40) y = np.random.randint(10, height - 40) w = np.random.randint(15, 40) h = np.random.randint(15, 40) paths.append(f'') else: x1 = np.random.randint(0, width) y1 = np.random.randint(0, height) x2 = np.random.randint(0, width) y2 = np.random.randint(0, height) paths.append(f'') def _add_sketch_lines(self, paths, width, height, colors, count): """Add random sketch lines for artistic effect""" for i in range(count): x1 = np.random.randint(0, width) y1 = np.random.randint(0, height) x2 = x1 + np.random.randint(-50, 50) y2 = y1 + np.random.randint(-50, 50) color = np.random.choice(colors) opacity = np.random.uniform(0.2, 0.6) stroke_width = np.random.randint(1, 2) paths.append(f'') # Create handler instance handler = DiffSketcherHandler()