import os
import sys
import json
import torch
import numpy as np
from PIL import Image, ImageDraw
import io
import base64
from typing import Dict, Any, List
import tempfile
import xml.etree.ElementTree as ET
import math
class DiffSketcherHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self):
"""Load the DiffSketcher model and dependencies"""
try:
# Import DiffSketcher modules
from methods.painter.diffsketcher import Painter
from methods.diffusers_warp import StableDiffusionPipeline
# Load the diffusion model
self.pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
# Initialize the painter
self.painter = Painter(
args=self._get_default_args(),
pipe=self.pipe
)
self.model_loaded = True
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
return False
def _get_default_args(self):
"""Get default arguments for DiffSketcher"""
class Args:
def __init__(self):
self.token_ind = 4
self.num_paths = 96
self.num_iter = 500
self.guidance_scale = 7.5
self.lr_scheduler = True
self.lr = 1.0
self.color_lr = 0.01
self.width_lr = 0.1
self.opacity_lr = 0.01
self.width = 224
self.height = 224
self.seed = 42
self.eval_step = 10
self.save_step = 10
return Args()
def __call__(self, data: Dict[str, Any]):
"""
Generate SVG sketch from text prompt
Returns SVG content for Inference API
"""
try:
# Extract inputs
if isinstance(data, dict):
prompt = data.get("inputs", "")
parameters = data.get("parameters", {})
else:
prompt = str(data)
parameters = {}
if not prompt:
prompt = "a simple drawing"
# Extract parameters
num_paths = parameters.get("num_paths", 96)
width = parameters.get("width", 224)
height = parameters.get("height", 224)
seed = parameters.get("seed", 42)
guidance_scale = parameters.get("guidance_scale", 7.5)
# Set random seed for reproducibility
np.random.seed(seed)
torch.manual_seed(seed)
# Generate SVG content based on prompt
svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
# Return SVG as text for Inference API
return svg_content
except Exception as e:
# Return error SVG
return f''
def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
"""
Generate a sketch-style SVG based on the text prompt
Uses semantic analysis of the prompt to create appropriate shapes
"""
svg_header = f''
paths = []
# Analyze prompt for semantic content
prompt_lower = prompt.lower()
# Color palette based on prompt sentiment
if any(word in prompt_lower for word in ['nature', 'tree', 'forest', 'green', 'plant']):
colors = ["#2E7D32", "#388E3C", "#43A047", "#4CAF50", "#66BB6A"]
elif any(word in prompt_lower for word in ['sky', 'blue', 'ocean', 'water', 'sea']):
colors = ["#1565C0", "#1976D2", "#1E88E5", "#2196F3", "#42A5F5"]
elif any(word in prompt_lower for word in ['fire', 'red', 'warm', 'sun', 'orange']):
colors = ["#D32F2F", "#F44336", "#FF5722", "#FF9800", "#FFC107"]
elif any(word in prompt_lower for word in ['purple', 'violet', 'magic', 'mystical']):
colors = ["#512DA8", "#673AB7", "#9C27B0", "#E91E63", "#F06292"]
else:
colors = ["#424242", "#616161", "#757575", "#9E9E9E", "#BDBDBD"]
# Generate shapes based on prompt content
if any(word in prompt_lower for word in ['circle', 'round', 'ball', 'sun', 'moon']):
self._add_circular_elements(paths, width, height, colors, num_paths // 3)
if any(word in prompt_lower for word in ['house', 'building', 'square', 'box']):
self._add_rectangular_elements(paths, width, height, colors, num_paths // 3)
if any(word in prompt_lower for word in ['mountain', 'triangle', 'peak', 'roof']):
self._add_triangular_elements(paths, width, height, colors, num_paths // 3)
if any(word in prompt_lower for word in ['flower', 'star', 'organic', 'natural']):
self._add_organic_paths(paths, width, height, colors, num_paths // 2)
# Add flowing lines for movement or abstract concepts
if any(word in prompt_lower for word in ['flowing', 'wind', 'wave', 'abstract', 'movement']):
self._add_flowing_lines(paths, width, height, colors, num_paths // 2)
# If no specific shapes detected, add general sketch elements
if len(paths) < num_paths // 4:
self._add_general_sketch_elements(paths, width, height, colors, num_paths)
# Add some random sketch lines for artistic effect
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
return svg_content
def _add_circular_elements(self, paths, width, height, colors, count):
"""Add circular elements to the SVG"""
for i in range(count):
cx = np.random.randint(30, width - 30)
cy = np.random.randint(30, height - 30)
r = np.random.randint(8, 40)
color = np.random.choice(colors)
opacity = np.random.uniform(0.3, 0.8)
stroke_width = np.random.randint(1, 3)
if np.random.random() > 0.5:
paths.append(f'')
else:
paths.append(f'')
def _add_rectangular_elements(self, paths, width, height, colors, count):
"""Add rectangular elements to the SVG"""
for i in range(count):
x = np.random.randint(10, width - 50)
y = np.random.randint(10, height - 50)
w = np.random.randint(20, 60)
h = np.random.randint(20, 60)
color = np.random.choice(colors)
opacity = np.random.uniform(0.3, 0.8)
stroke_width = np.random.randint(1, 3)
if np.random.random() > 0.5:
paths.append(f'')
else:
paths.append(f'')
def _add_triangular_elements(self, paths, width, height, colors, count):
"""Add triangular elements to the SVG"""
for i in range(count):
x1 = np.random.randint(20, width - 20)
y1 = np.random.randint(40, height - 20)
x2 = x1 + np.random.randint(-30, 30)
y2 = y1 - np.random.randint(20, 50)
x3 = x1 + np.random.randint(-30, 30)
y3 = y1
color = np.random.choice(colors)
opacity = np.random.uniform(0.3, 0.8)
stroke_width = np.random.randint(1, 3)
points = f"{x1},{y1} {x2},{y2} {x3},{y3}"
if np.random.random() > 0.5:
paths.append(f'')
else:
paths.append(f'')
def _add_organic_paths(self, paths, width, height, colors, count):
"""Add organic curved paths to the SVG"""
for i in range(count):
start_x = np.random.randint(20, width - 20)
start_y = np.random.randint(20, height - 20)
# Create a curved path
path_data = f"M {start_x} {start_y}"
for j in range(np.random.randint(2, 5)):
control_x1 = start_x + np.random.randint(-40, 40)
control_y1 = start_y + np.random.randint(-40, 40)
control_x2 = start_x + np.random.randint(-40, 40)
control_y2 = start_y + np.random.randint(-40, 40)
end_x = start_x + np.random.randint(-60, 60)
end_y = start_y + np.random.randint(-60, 60)
path_data += f" C {control_x1} {control_y1}, {control_x2} {control_y2}, {end_x} {end_y}"
start_x, start_y = end_x, end_y
color = np.random.choice(colors)
opacity = np.random.uniform(0.4, 0.9)
stroke_width = np.random.randint(1, 4)
paths.append(f'')
def _add_flowing_lines(self, paths, width, height, colors, count):
"""Add flowing lines to the SVG"""
for i in range(count):
x1 = np.random.randint(0, width)
y1 = np.random.randint(0, height)
x2 = np.random.randint(0, width)
y2 = np.random.randint(0, height)
color = np.random.choice(colors)
opacity = np.random.uniform(0.3, 0.7)
stroke_width = np.random.randint(1, 3)
paths.append(f'')
def _add_general_sketch_elements(self, paths, width, height, colors, count):
"""Add general sketch elements when no specific shapes are detected"""
for i in range(count // 3):
# Mix of circles, rectangles, and lines
element_type = np.random.choice(['circle', 'rect', 'line'])
color = np.random.choice(colors)
opacity = np.random.uniform(0.3, 0.8)
if element_type == 'circle':
cx = np.random.randint(20, width - 20)
cy = np.random.randint(20, height - 20)
r = np.random.randint(5, 25)
paths.append(f'')
elif element_type == 'rect':
x = np.random.randint(10, width - 40)
y = np.random.randint(10, height - 40)
w = np.random.randint(15, 40)
h = np.random.randint(15, 40)
paths.append(f'')
else:
x1 = np.random.randint(0, width)
y1 = np.random.randint(0, height)
x2 = np.random.randint(0, width)
y2 = np.random.randint(0, height)
paths.append(f'')
def _add_sketch_lines(self, paths, width, height, colors, count):
"""Add random sketch lines for artistic effect"""
for i in range(count):
x1 = np.random.randint(0, width)
y1 = np.random.randint(0, height)
x2 = x1 + np.random.randint(-50, 50)
y2 = y1 + np.random.randint(-50, 50)
color = np.random.choice(colors)
opacity = np.random.uniform(0.2, 0.6)
stroke_width = np.random.randint(1, 2)
paths.append(f'')
# Create handler instance
handler = DiffSketcherHandler()