|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import cairosvg |
|
|
import numpy as np |
|
|
|
|
|
class DiffSketcherHandler: |
|
|
def __init__(self): |
|
|
self.initialized = False |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model = None |
|
|
|
|
|
def initialize(self, context): |
|
|
"""Initialize the handler.""" |
|
|
self.initialized = True |
|
|
|
|
|
|
|
|
try: |
|
|
import pydiffvg |
|
|
self.diffvg = pydiffvg |
|
|
print("Successfully imported pydiffvg") |
|
|
except ImportError as e: |
|
|
print(f"Warning: Could not import pydiffvg: {e}") |
|
|
print("Will use placeholder SVG generation") |
|
|
self.diffvg = None |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Initialize the actual model when needed.""" |
|
|
if self.model is not None: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
|
|
|
self.model = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
|
|
|
print("Successfully initialized the model") |
|
|
except Exception as e: |
|
|
print(f"Error initializing model: {e}") |
|
|
print("Will use placeholder generation") |
|
|
self.model = None |
|
|
|
|
|
def preprocess(self, data): |
|
|
"""Preprocess the input data.""" |
|
|
inputs = data.get("inputs", "") |
|
|
if not inputs: |
|
|
inputs = "a beautiful landscape" |
|
|
|
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
num_paths = parameters.get("num_paths", 96) |
|
|
token_ind = parameters.get("token_ind", 4) |
|
|
num_iter = parameters.get("num_iter", 800) |
|
|
|
|
|
return { |
|
|
"prompt": inputs, |
|
|
"num_paths": num_paths, |
|
|
"token_ind": token_ind, |
|
|
"num_iter": num_iter |
|
|
} |
|
|
|
|
|
def _generate_placeholder_svg(self, prompt): |
|
|
"""Generate a placeholder SVG when the actual model is not available.""" |
|
|
import svgwrite |
|
|
|
|
|
|
|
|
dwg = svgwrite.Drawing(size=(512, 512)) |
|
|
|
|
|
dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0')) |
|
|
|
|
|
dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db')) |
|
|
|
|
|
dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black')) |
|
|
|
|
|
dwg.add(dwg.text("Placeholder SVG - Model not available", |
|
|
insert=(50, 480), font_size=16, fill='red')) |
|
|
|
|
|
svg_string = dwg.tostring() |
|
|
|
|
|
|
|
|
png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8')) |
|
|
image = Image.open(BytesIO(png_data)) |
|
|
|
|
|
return svg_string, image |
|
|
|
|
|
def inference(self, inputs): |
|
|
"""Run inference with the preprocessed inputs.""" |
|
|
prompt = inputs["prompt"] |
|
|
|
|
|
|
|
|
if self.model is None and self.diffvg is not None: |
|
|
try: |
|
|
self._initialize_model() |
|
|
except Exception as e: |
|
|
print(f"Error initializing model during inference: {e}") |
|
|
|
|
|
|
|
|
if self.model is not None and self.diffvg is not None: |
|
|
try: |
|
|
|
|
|
|
|
|
svg_string, image = self._generate_placeholder_svg(prompt) |
|
|
except Exception as e: |
|
|
print(f"Error during model inference: {e}") |
|
|
svg_string, image = self._generate_placeholder_svg(prompt) |
|
|
else: |
|
|
|
|
|
svg_string, image = self._generate_placeholder_svg(prompt) |
|
|
|
|
|
return { |
|
|
"svg": svg_string, |
|
|
"image": image |
|
|
} |
|
|
|
|
|
def postprocess(self, inference_output): |
|
|
"""Post-process the model output.""" |
|
|
svg_string = inference_output["svg"] |
|
|
image = inference_output["image"] |
|
|
|
|
|
|
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
img_base64 = f"data:image/png;base64,{img_str}" |
|
|
|
|
|
return { |
|
|
"svg": svg_string, |
|
|
"image": img_base64 |
|
|
} |
|
|
|
|
|
def handle(self, data, context): |
|
|
"""Handle the request.""" |
|
|
if not self.initialized: |
|
|
self.initialize(context) |
|
|
|
|
|
preprocessed_data = self.preprocess(data) |
|
|
inference_output = self.inference(preprocessed_data) |
|
|
return self.postprocess(inference_output) |