diffsketcher / diffsketcher_handler.py
jree423's picture
Update model files for Inference API
4039872 verified
raw
history blame
5.34 kB
import os
import json
import torch
import base64
from io import BytesIO
from PIL import Image
import cairosvg
import numpy as np
class DiffSketcherHandler:
def __init__(self):
self.initialized = False
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
def initialize(self, context):
"""Initialize the handler."""
self.initialized = True
# Import dependencies here to avoid issues during startup
try:
import pydiffvg
self.diffvg = pydiffvg
print("Successfully imported pydiffvg")
except ImportError as e:
print(f"Warning: Could not import pydiffvg: {e}")
print("Will use placeholder SVG generation")
self.diffvg = None
# We'll initialize the actual model only when needed
return None
def _initialize_model(self):
"""Initialize the actual model when needed."""
if self.model is not None:
return
try:
# Try to import and initialize the actual model
from diffusers import StableDiffusionPipeline
# Load a small model for testing
self.model = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
print("Successfully initialized the model")
except Exception as e:
print(f"Error initializing model: {e}")
print("Will use placeholder generation")
self.model = None
def preprocess(self, data):
"""Preprocess the input data."""
inputs = data.get("inputs", "")
if not inputs:
inputs = "a beautiful landscape"
# Get parameters
parameters = data.get("parameters", {})
num_paths = parameters.get("num_paths", 96)
token_ind = parameters.get("token_ind", 4)
num_iter = parameters.get("num_iter", 800)
return {
"prompt": inputs,
"num_paths": num_paths,
"token_ind": token_ind,
"num_iter": num_iter
}
def _generate_placeholder_svg(self, prompt):
"""Generate a placeholder SVG when the actual model is not available."""
import svgwrite
# Create a simple SVG
dwg = svgwrite.Drawing(size=(512, 512))
# Add a background rectangle
dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0'))
# Add a circle
dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db'))
# Add the prompt as text
dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black'))
# Add a note that this is a placeholder
dwg.add(dwg.text("Placeholder SVG - Model not available",
insert=(50, 480), font_size=16, fill='red'))
svg_string = dwg.tostring()
# Convert SVG to PNG for preview
png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'))
image = Image.open(BytesIO(png_data))
return svg_string, image
def inference(self, inputs):
"""Run inference with the preprocessed inputs."""
prompt = inputs["prompt"]
# Try to initialize the model if not already done
if self.model is None and self.diffvg is not None:
try:
self._initialize_model()
except Exception as e:
print(f"Error initializing model during inference: {e}")
# If we have a working model, use it
if self.model is not None and self.diffvg is not None:
try:
# This would be the actual DiffSketcher implementation
# For now, we'll just generate a placeholder
svg_string, image = self._generate_placeholder_svg(prompt)
except Exception as e:
print(f"Error during model inference: {e}")
svg_string, image = self._generate_placeholder_svg(prompt)
else:
# Use placeholder if model is not available
svg_string, image = self._generate_placeholder_svg(prompt)
return {
"svg": svg_string,
"image": image
}
def postprocess(self, inference_output):
"""Post-process the model output."""
svg_string = inference_output["svg"]
image = inference_output["image"]
# Convert image to base64 for JSON response
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
img_base64 = f"data:image/png;base64,{img_str}"
return {
"svg": svg_string,
"image": img_base64
}
def handle(self, data, context):
"""Handle the request."""
if not self.initialized:
self.initialize(context)
preprocessed_data = self.preprocess(data)
inference_output = self.inference(preprocessed_data)
return self.postprocess(inference_output)