File size: 3,105 Bytes
8de5d7e a96d3ab 2f521ae bebb135 3885d88 9900cee 2f521ae 3885d88 9900cee 3885d88 2f521ae 9900cee 2f521ae 9900cee 2f521ae 9900cee 2f521ae 9900cee 2f521ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import base64
import io
from PIL import Image, ImageDraw
import json
class EndpointHandler:
def __init__(self, path=""):
self.path = path
self.initialized = False
def __call__(self, data):
"""Handle a request to the model."""
if not self.initialized:
self.initialize()
if data is None:
return None
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return self.postprocess(outputs)
def initialize(self):
"""Initialize the handler."""
self.initialized = True
def preprocess(self, request):
"""Process the input request."""
if isinstance(request, str):
# Single prompt
prompt = request
payload = {"prompt": prompt}
elif isinstance(request, dict):
# Full payload
payload = request
else:
# Try to parse as JSON
try:
payload = json.loads(request)
except:
payload = {"prompt": str(request)}
return payload
def inference(self, inputs):
"""Generate vector graphics from the inputs."""
# This is a placeholder implementation
# In a real scenario, this would call the actual model
# Create a simple SVG based on the prompt
prompt = inputs.get("prompt", "")
if not prompt:
prompts = inputs.get("prompts", [""])
prompt = prompts[0] if prompts else ""
# Generate a simple SVG
svg = f"""
<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512">
<rect width="512" height="512" fill="#f0f0f0"/>
<text x="256" y="50" font-family="Arial" font-size="20" text-anchor="middle" fill="#333">Generated from: "{prompt}"</text>
<g transform="translate(256, 256)">
<circle cx="0" cy="0" r="100" fill="#3498db" opacity="0.7"/>
<rect x="-50" y="-50" width="100" height="100" fill="#e74c3c" opacity="0.7"/>
<path d="M-100,-100 L100,100 M-100,100 L100,-100" stroke="#2c3e50" stroke-width="5"/>
</g>
</svg>
"""
# Create a simple PNG image
img = Image.new("RGB", (512, 512), color="#f0f0f0")
draw = ImageDraw.Draw(img)
draw.ellipse((156, 156, 356, 356), fill="#3498db", outline="#3498db")
draw.rectangle((206, 206, 306, 306), fill="#e74c3c", outline="#e74c3c")
draw.line((156, 156, 356, 356), fill="#2c3e50", width=5)
draw.line((156, 356, 356, 156), fill="#2c3e50", width=5)
# Convert image to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {"svg": svg, "image": img_str}
def postprocess(self, inference_output):
"""Return the output as JSON."""
return inference_output
|