#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Full implementation of DiffSketcher handler. """ import os import sys import torch import numpy as np from PIL import Image import random import io import base64 import cairosvg import math import time # Add the DiffSketcher repository to the path sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "DiffSketcher")) # Add the mock diffvg to the path sys.path.append(os.path.dirname(os.path.abspath(__file__))) import mock_diffvg as diffvg # Try to import the real DiffSketcher modules try: from models.clip_model import ClipModel from models.sd_model import StableDiffusion from models.loss import Loss from models.painter_params import Painter, PainterOptimizer from utils.train_utils import init_log, log_input, log_sketch, get_latest_ckpt, save_ckpt from utils.vector_utils import ( svg_to_png, create_dir, init_svg, read_svg, get_svg_size, get_svg_path_d, get_svg_path_width, get_svg_color, set_svg_path_d, set_svg_path_width, set_svg_color, get_svg_meta, set_svg_meta, get_svg_path_bbox, get_svg_bbox, get_png_size, get_svg_path_group, get_svg_group_opacity, set_svg_group_opacity, get_svg_group_path_indices, get_svg_group_path_opacity, set_svg_group_path_opacity, get_svg_group_path_fill, set_svg_group_path_fill, get_svg_group_path_stroke, set_svg_group_path_stroke, get_svg_group_path_stroke_width, set_svg_group_path_stroke_width, get_svg_group_path_stroke_opacity, set_svg_group_path_stroke_opacity, get_svg_group_path_fill_opacity, set_svg_group_path_fill_opacity, get_svg_group_path_stroke_linecap, set_svg_group_path_stroke_linecap, get_svg_group_path_stroke_linejoin, set_svg_group_path_stroke_linejoin, get_svg_group_path_stroke_miterlimit, set_svg_group_path_stroke_miterlimit, get_svg_group_path_stroke_dasharray, set_svg_group_path_stroke_dasharray, get_svg_group_path_stroke_dashoffset, set_svg_group_path_stroke_dashoffset, get_svg_group_path_transform, set_svg_group_path_transform, get_svg_group_transform, set_svg_group_transform, get_svg_path_transform, set_svg_path_transform, get_svg_path_fill, set_svg_path_fill, get_svg_path_stroke, set_svg_path_stroke, get_svg_path_stroke_width, set_svg_path_stroke_width, get_svg_path_stroke_opacity, set_svg_path_stroke_opacity, get_svg_path_fill_opacity, set_svg_path_fill_opacity, get_svg_path_stroke_linecap, set_svg_path_stroke_linecap, get_svg_path_stroke_linejoin, set_svg_path_stroke_linejoin, get_svg_path_stroke_miterlimit, set_svg_path_stroke_miterlimit, get_svg_path_stroke_dasharray, set_svg_path_stroke_dasharray, get_svg_path_stroke_dashoffset, set_svg_path_stroke_dashoffset, ) REAL_DIFFSKETCHER_AVAILABLE = True except ImportError: print("Warning: Could not import DiffSketcher modules. Using mock implementation instead.") REAL_DIFFSKETCHER_AVAILABLE = False class EndpointHandler: def __init__(self, path=""): """ Initialize the DiffSketcher model. Args: path (str): Path to the model directory """ self.path = path self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing DiffSketcher handler on {self.device}") # Check if the real DiffSketcher is available self.use_real_diffsketcher = REAL_DIFFSKETCHER_AVAILABLE if self.use_real_diffsketcher: try: # Initialize the real DiffSketcher model self._init_real_diffsketcher() except Exception as e: print(f"Error initializing real DiffSketcher: {e}") self.use_real_diffsketcher = False if not self.use_real_diffsketcher: print("Using mock DiffSketcher implementation") def _init_real_diffsketcher(self): """Initialize the real DiffSketcher model.""" # Load model weights model_dir = os.path.join(self.path, "models", "diffsketcher") if not os.path.exists(model_dir): model_dir = "/workspace/vector_models/models/diffsketcher" # Initialize CLIP model self.clip_model = ClipModel(device=self.device) # Initialize Stable Diffusion model self.sd_model = StableDiffusion(device=self.device) # Initialize loss function self.loss_fn = Loss(device=self.device) # Initialize painter parameters self.painter = Painter( num_paths=48, num_segments=4, canvas_size=512, device=self.device ) # Initialize painter optimizer self.painter_optimizer = PainterOptimizer( self.painter, lr=1e-2, device=self.device ) def svg_to_png(self, svg_string, width=512, height=512): """ Convert SVG string to PNG image. Args: svg_string (str): SVG string width (int): Width of the output image height (int): Height of the output image Returns: PIL.Image.Image: PNG image """ try: # Use cairosvg to convert SVG to PNG png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'), output_width=width, output_height=height) return Image.open(io.BytesIO(png_data)) except Exception as e: print(f"Error converting SVG to PNG: {e}") # Return a blank image if conversion fails return Image.new('RGB', (width, height), color=(240, 240, 240)) def generate_svg(self, prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=None): """ Generate SVG using DiffSketcher. Args: prompt (str): Text prompt negative_prompt (str): Negative text prompt num_paths (int): Number of paths guidance_scale (float): Guidance scale seed (int): Random seed Returns: tuple: (svg_string, png_image) """ # Set random seed for reproducibility if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) else: seed = random.randint(0, 100000) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) if self.use_real_diffsketcher: try: # Generate SVG using the real DiffSketcher return self._generate_svg_real(prompt, negative_prompt, num_paths, guidance_scale) except Exception as e: print(f"Error generating SVG with real DiffSketcher: {e}") # Fall back to mock implementation return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale) else: # Generate SVG using the mock implementation return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale) def _generate_svg_real(self, prompt, negative_prompt, num_paths, guidance_scale): """ Generate SVG using the real DiffSketcher. Args: prompt (str): Text prompt negative_prompt (str): Negative text prompt num_paths (int): Number of paths guidance_scale (float): Guidance scale Returns: tuple: (svg_string, png_image) """ # Initialize painter with the specified number of paths self.painter.num_paths = num_paths # Get CLIP embeddings for the prompt text_embeddings = self.clip_model.get_text_embeddings(prompt, negative_prompt) # Initialize SVG svg_string = init_svg(self.painter.canvas_size, self.painter.canvas_size) # Optimize the SVG for i in range(1000): # Number of optimization steps # Forward pass svg_tensor = self.painter.get_image() # Calculate loss loss = self.loss_fn(svg_tensor, text_embeddings, guidance_scale) # Backward pass loss.backward() # Update parameters self.painter_optimizer.step() self.painter_optimizer.zero_grad() # Log progress if i % 100 == 0: print(f"Step {i}, Loss: {loss.item()}") # Get the final SVG svg_string = self.painter.get_svg() # Convert SVG to PNG png_image = self.svg_to_png(svg_string) return svg_string, png_image def _generate_svg_mock(self, prompt, negative_prompt, num_paths, guidance_scale): """ Generate SVG using the mock implementation. Args: prompt (str): Text prompt negative_prompt (str): Negative text prompt num_paths (int): Number of paths guidance_scale (float): Guidance scale Returns: tuple: (svg_string, png_image) """ # Create a color palette based on the prompt word_sum = sum(ord(c) for c in prompt) palette_seed = word_sum % 5 if palette_seed == 0: # Warm colors color_ranges = [(200, 255), (100, 180), (50, 150)] # R, G, B ranges elif palette_seed == 1: # Cool colors color_ranges = [(50, 150), (100, 180), (200, 255)] # R, G, B ranges elif palette_seed == 2: # Earthy tones color_ranges = [(150, 200), (100, 150), (50, 100)] # R, G, B ranges elif palette_seed == 3: # Vibrant colors color_ranges = [(200, 255), (50, 255), (50, 255)] # R, G, B ranges else: # Grayscale with accent color_ranges = [(100, 200), (100, 200), (100, 200)] # R, G, B ranges # Create a simple SVG with some paths - DiffSketcher style (sketch-like with bold strokes) svg_string = f""" DiffSketcher: {prompt} """ # Add a grid pattern (characteristic of DiffSketcher) svg_string += """ """ # Add some sketch-like paths (DiffSketcher specializes in sketch-like vector graphics) svg_string += '' # Generate a more complex scene based on the prompt if "car" in prompt.lower(): # Generate a car svg_string += self._generate_car_svg(color_ranges) elif "face" in prompt.lower() or "portrait" in prompt.lower(): # Generate a face svg_string += self._generate_face_svg(color_ranges) elif "landscape" in prompt.lower() or "mountain" in prompt.lower(): # Generate a landscape svg_string += self._generate_landscape_svg(color_ranges) elif "flower" in prompt.lower() or "plant" in prompt.lower(): # Generate a flower svg_string += self._generate_flower_svg(color_ranges) elif "animal" in prompt.lower() or "dog" in prompt.lower() or "cat" in prompt.lower(): # Generate an animal svg_string += self._generate_animal_svg(color_ranges) else: # Generate abstract art svg_string += self._generate_abstract_svg(color_ranges, num_paths) svg_string += '' # Convert SVG to PNG png_image = self.svg_to_png(svg_string) return svg_string, png_image def _generate_car_svg(self, color_ranges): """Generate a car SVG.""" car_svg = "" # Car body r = random.randint(color_ranges[0][0], color_ranges[0][1]) g = random.randint(color_ranges[1][0], color_ranges[1][1]) b = random.randint(color_ranges[2][0], color_ranges[2][1]) car_svg += f'' # Windows car_svg += '' # Wheels car_svg += '' car_svg += '' car_svg += '' car_svg += '' # Headlights car_svg += '' car_svg += '' return car_svg def _generate_face_svg(self, color_ranges): """Generate a face SVG.""" face_svg = "" # Face shape r = random.randint(color_ranges[0][0], color_ranges[0][1]) g = random.randint(color_ranges[1][0], color_ranges[1][1]) b = random.randint(color_ranges[2][0], color_ranges[2][1]) face_svg += f'' # Eyes face_svg += '' face_svg += '' face_svg += '' face_svg += '' # Eyebrows face_svg += '' face_svg += '' # Nose face_svg += '' # Mouth if random.random() < 0.7: # Smile face_svg += '' else: # Neutral face_svg += '' # Hair hair_r = random.randint(0, 100) hair_g = random.randint(0, 100) hair_b = random.randint(0, 100) face_svg += f'' return face_svg def _generate_landscape_svg(self, color_ranges): """Generate a landscape SVG.""" landscape_svg = "" # Sky sky_r = random.randint(100, 200) sky_g = random.randint(150, 255) sky_b = random.randint(200, 255) landscape_svg += f'' # Sun sun_x = random.randint(50, 462) sun_y = random.randint(50, 150) landscape_svg += f'' # Mountains for i in range(5): mountain_x = random.randint(-100, 512) mountain_width = random.randint(200, 400) mountain_height = random.randint(100, 200) r = random.randint(50, 150) g = random.randint(50, 150) b = random.randint(50, 150) landscape_svg += f'' # Snow cap landscape_svg += f'' # Ground ground_r = random.randint(50, 150) ground_g = random.randint(100, 200) ground_b = random.randint(50, 100) landscape_svg += f'' # Trees for i in range(10): tree_x = random.randint(20, 492) tree_y = random.randint(320, 450) tree_height = random.randint(50, 100) # Trunk landscape_svg += f'' # Foliage foliage_r = random.randint(0, 100) foliage_g = random.randint(100, 200) foliage_b = random.randint(0, 100) landscape_svg += f'' return landscape_svg def _generate_flower_svg(self, color_ranges): """Generate a flower SVG.""" flower_svg = "" # Stem stem_height = random.randint(150, 300) flower_svg += f'' # Leaves leaf_y1 = random.randint(350, 420) leaf_y2 = random.randint(280, 349) flower_svg += f'' flower_svg += f'' # Flower center center_y = 450 - stem_height flower_svg += f'' # Petals r = random.randint(color_ranges[0][0], color_ranges[0][1]) g = random.randint(color_ranges[1][0], color_ranges[1][1]) b = random.randint(color_ranges[2][0], color_ranges[2][1]) num_petals = random.randint(5, 12) petal_length = random.randint(40, 70) for i in range(num_petals): angle = 2 * math.pi * i / num_petals petal_x = 256 + petal_length * math.cos(angle) petal_y = center_y + petal_length * math.sin(angle) control_x1 = 256 + petal_length * 0.5 * math.cos(angle - 0.3) control_y1 = center_y + petal_length * 0.5 * math.sin(angle - 0.3) control_x2 = 256 + petal_length * 0.5 * math.cos(angle + 0.3) control_y2 = center_y + petal_length * 0.5 * math.sin(angle + 0.3) flower_svg += f'' return flower_svg def _generate_animal_svg(self, color_ranges): """Generate an animal SVG.""" animal_svg = "" # Body r = random.randint(color_ranges[0][0], color_ranges[0][1]) g = random.randint(color_ranges[1][0], color_ranges[1][1]) b = random.randint(color_ranges[2][0], color_ranges[2][1]) animal_svg += f'' # Head animal_svg += f'' # Eyes animal_svg += '' animal_svg += '' animal_svg += '' animal_svg += '' # Nose animal_svg += '' # Ears animal_svg += f'' animal_svg += f'' # Legs animal_svg += '' animal_svg += '' animal_svg += '' animal_svg += '' # Tail animal_svg += f'' return animal_svg def _generate_abstract_svg(self, color_ranges, num_paths): """Generate abstract art SVG.""" abstract_svg = "" # Generate random paths for i in range(num_paths): # Random color r = random.randint(color_ranges[0][0], color_ranges[0][1]) g = random.randint(color_ranges[1][0], color_ranges[1][1]) b = random.randint(color_ranges[2][0], color_ranges[2][1]) # Random stroke width stroke_width = random.uniform(1, 5) # Random path path_data = "M" x, y = random.uniform(0, 512), random.uniform(0, 512) path_data += f"{x},{y} " # Random number of segments num_segments = random.randint(2, 5) for j in range(num_segments): # Random curve or line if random.random() > 0.5: # Curve cx1, cy1 = random.uniform(0, 512), random.uniform(0, 512) cx2, cy2 = random.uniform(0, 512), random.uniform(0, 512) x, y = random.uniform(0, 512), random.uniform(0, 512) path_data += f"C{cx1},{cy1} {cx2},{cy2} {x},{y} " else: # Line x, y = random.uniform(0, 512), random.uniform(0, 512) path_data += f"L{x},{y} " # Add path to SVG abstract_svg += f'' return abstract_svg def __call__(self, data): """ Process the input data and generate SVG output. Args: data (dict): Input data containing the prompt and other parameters Returns: PIL.Image.Image: Output image """ # Extract parameters from the input data prompt = data.get("prompt", "") if not prompt and "inputs" in data: prompt = data.get("inputs", "") if not prompt: # Create a default error image error_img = Image.new('RGB', (512, 512), color=(240, 240, 240)) return error_img negative_prompt = data.get("negative_prompt", "") num_paths = int(data.get("num_paths", 96)) guidance_scale = float(data.get("guidance_scale", 7.5)) seed = data.get("seed") if seed is not None: seed = int(seed) # Generate SVG svg_string, png_image = self.generate_svg( prompt=prompt, negative_prompt=negative_prompt, num_paths=num_paths, guidance_scale=guidance_scale, seed=seed ) # Return the image directly (not as a dictionary) return png_image