| from typing import Dict, Any | |
| import torch | |
| import base64 | |
| import io | |
| import os | |
| import json | |
| from PIL import Image | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| # Initialize device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Initializing diffsketcher handler on {self.device}") | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: | |
| # Extract prompt from the input data | |
| prompt = data.get("prompt", "") | |
| if not prompt and "prompts" in data: | |
| prompts = data.get("prompts", [""]) | |
| prompt = prompts[0] if prompts else "" | |
| # Generate a placeholder SVG | |
| svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>' | |
| # Create a placeholder image | |
| image = Image.new('RGB', (512, 512), color = (100, 100, 100)) | |
| # Convert the image to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Return the results | |
| return { | |
| "svg": svg, | |
| "image": img_str | |
| } | |