diffsketcher / handler.py
jree423's picture
Upload handler.py with huggingface_hub
9900cee verified
raw
history blame
3.11 kB
import base64
import io
from PIL import Image, ImageDraw
import json
class EndpointHandler:
def __init__(self, path=""):
self.path = path
self.initialized = False
def __call__(self, data):
"""Handle a request to the model."""
if not self.initialized:
self.initialize()
if data is None:
return None
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return self.postprocess(outputs)
def initialize(self):
"""Initialize the handler."""
self.initialized = True
def preprocess(self, request):
"""Process the input request."""
if isinstance(request, str):
# Single prompt
prompt = request
payload = {"prompt": prompt}
elif isinstance(request, dict):
# Full payload
payload = request
else:
# Try to parse as JSON
try:
payload = json.loads(request)
except:
payload = {"prompt": str(request)}
return payload
def inference(self, inputs):
"""Generate vector graphics from the inputs."""
# This is a placeholder implementation
# In a real scenario, this would call the actual model
# Create a simple SVG based on the prompt
prompt = inputs.get("prompt", "")
if not prompt:
prompts = inputs.get("prompts", [""])
prompt = prompts[0] if prompts else ""
# Generate a simple SVG
svg = f"""
<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512">
<rect width="512" height="512" fill="#f0f0f0"/>
<text x="256" y="50" font-family="Arial" font-size="20" text-anchor="middle" fill="#333">Generated from: "{prompt}"</text>
<g transform="translate(256, 256)">
<circle cx="0" cy="0" r="100" fill="#3498db" opacity="0.7"/>
<rect x="-50" y="-50" width="100" height="100" fill="#e74c3c" opacity="0.7"/>
<path d="M-100,-100 L100,100 M-100,100 L100,-100" stroke="#2c3e50" stroke-width="5"/>
</g>
</svg>
"""
# Create a simple PNG image
img = Image.new("RGB", (512, 512), color="#f0f0f0")
draw = ImageDraw.Draw(img)
draw.ellipse((156, 156, 356, 356), fill="#3498db", outline="#3498db")
draw.rectangle((206, 206, 306, 306), fill="#e74c3c", outline="#e74c3c")
draw.line((156, 156, 356, 356), fill="#2c3e50", width=5)
draw.line((156, 356, 356, 156), fill="#2c3e50", width=5)
# Convert image to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {"svg": svg, "image": img_str}
def postprocess(self, inference_output):
"""Return the output as JSON."""
return inference_output