Update handler to return PIL Images for Inference API compatibility
Browse files- __pycache__/handler.cpython-312.pyc +0 -0
- config.json +1 -1
- handler.py +25 -5
__pycache__/handler.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
|
|
|
config.json
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
"model_type": "diffsketcher",
|
| 4 |
"task": "text-to-svg",
|
| 5 |
"framework": "pytorch",
|
| 6 |
-
"pipeline_tag": "text-
|
| 7 |
"library_name": "diffusers",
|
| 8 |
"inference": {
|
| 9 |
"parameters": {
|
|
|
|
| 3 |
"model_type": "diffsketcher",
|
| 4 |
"task": "text-to-svg",
|
| 5 |
"framework": "pytorch",
|
| 6 |
+
"pipeline_tag": "text-to-image",
|
| 7 |
"library_name": "diffusers",
|
| 8 |
"inference": {
|
| 9 |
"parameters": {
|
handler.py
CHANGED
|
@@ -5,6 +5,9 @@ import torch
|
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Any, List
|
| 7 |
import math
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class EndpointHandler:
|
| 10 |
def __init__(self, path=""):
|
|
@@ -90,12 +93,20 @@ class EndpointHandler:
|
|
| 90 |
# Generate SVG content based on prompt
|
| 91 |
svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
except Exception as e:
|
| 97 |
-
# Return error
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
|
| 101 |
"""
|
|
@@ -147,7 +158,16 @@ class EndpointHandler:
|
|
| 147 |
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
|
| 148 |
|
| 149 |
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def _add_circular_elements(self, paths, width, height, colors, count):
|
| 153 |
"""Add circular elements to the SVG"""
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Any, List
|
| 7 |
import math
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cairosvg
|
| 10 |
+
import io
|
| 11 |
|
| 12 |
class EndpointHandler:
|
| 13 |
def __init__(self, path=""):
|
|
|
|
| 93 |
# Generate SVG content based on prompt
|
| 94 |
svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
|
| 95 |
|
| 96 |
+
# Convert SVG to PIL Image
|
| 97 |
+
try:
|
| 98 |
+
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
| 99 |
+
image = Image.open(io.BytesIO(png_data))
|
| 100 |
+
return image
|
| 101 |
+
except Exception as svg_error:
|
| 102 |
+
# Fallback: create a simple error image
|
| 103 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
| 104 |
+
return error_image
|
| 105 |
|
| 106 |
except Exception as e:
|
| 107 |
+
# Return error image
|
| 108 |
+
error_image = Image.new('RGB', (224, 224), color='white')
|
| 109 |
+
return error_image
|
| 110 |
|
| 111 |
def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
|
| 112 |
"""
|
|
|
|
| 158 |
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
|
| 159 |
|
| 160 |
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
| 161 |
+
|
| 162 |
+
# Convert SVG to PIL Image
|
| 163 |
+
try:
|
| 164 |
+
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
| 165 |
+
image = Image.open(io.BytesIO(png_data))
|
| 166 |
+
return image
|
| 167 |
+
except Exception as e:
|
| 168 |
+
# Fallback: create a simple error image
|
| 169 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
| 170 |
+
return error_image
|
| 171 |
|
| 172 |
def _add_circular_elements(self, paths, width, height, colors, count):
|
| 173 |
"""Add circular elements to the SVG"""
|