File size: 5,337 Bytes
4039872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import json
import torch
import base64
from io import BytesIO
from PIL import Image
import cairosvg
import numpy as np

class DiffSketcherHandler:
    def __init__(self):
        self.initialized = False
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = None
        
    def initialize(self, context):
        """Initialize the handler."""
        self.initialized = True
        
        # Import dependencies here to avoid issues during startup
        try:
            import pydiffvg
            self.diffvg = pydiffvg
            print("Successfully imported pydiffvg")
        except ImportError as e:
            print(f"Warning: Could not import pydiffvg: {e}")
            print("Will use placeholder SVG generation")
            self.diffvg = None
        
        # We'll initialize the actual model only when needed
        return None
    
    def _initialize_model(self):
        """Initialize the actual model when needed."""
        if self.model is not None:
            return
        
        try:
            # Try to import and initialize the actual model
            from diffusers import StableDiffusionPipeline
            
            # Load a small model for testing
            self.model = StableDiffusionPipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            ).to(self.device)
            
            print("Successfully initialized the model")
        except Exception as e:
            print(f"Error initializing model: {e}")
            print("Will use placeholder generation")
            self.model = None
    
    def preprocess(self, data):
        """Preprocess the input data."""
        inputs = data.get("inputs", "")
        if not inputs:
            inputs = "a beautiful landscape"
        
        # Get parameters
        parameters = data.get("parameters", {})
        num_paths = parameters.get("num_paths", 96)
        token_ind = parameters.get("token_ind", 4)
        num_iter = parameters.get("num_iter", 800)
        
        return {
            "prompt": inputs,
            "num_paths": num_paths,
            "token_ind": token_ind,
            "num_iter": num_iter
        }
    
    def _generate_placeholder_svg(self, prompt):
        """Generate a placeholder SVG when the actual model is not available."""
        import svgwrite
        
        # Create a simple SVG
        dwg = svgwrite.Drawing(size=(512, 512))
        # Add a background rectangle
        dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='#f0f0f0'))
        # Add a circle
        dwg.add(dwg.circle(center=(256, 256), r=100, fill='#3498db'))
        # Add the prompt as text
        dwg.add(dwg.text(prompt, insert=(50, 50), font_size=20, fill='black'))
        # Add a note that this is a placeholder
        dwg.add(dwg.text("Placeholder SVG - Model not available", 
                         insert=(50, 480), font_size=16, fill='red'))
        
        svg_string = dwg.tostring()
        
        # Convert SVG to PNG for preview
        png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'))
        image = Image.open(BytesIO(png_data))
        
        return svg_string, image
    
    def inference(self, inputs):
        """Run inference with the preprocessed inputs."""
        prompt = inputs["prompt"]
        
        # Try to initialize the model if not already done
        if self.model is None and self.diffvg is not None:
            try:
                self._initialize_model()
            except Exception as e:
                print(f"Error initializing model during inference: {e}")
        
        # If we have a working model, use it
        if self.model is not None and self.diffvg is not None:
            try:
                # This would be the actual DiffSketcher implementation
                # For now, we'll just generate a placeholder
                svg_string, image = self._generate_placeholder_svg(prompt)
            except Exception as e:
                print(f"Error during model inference: {e}")
                svg_string, image = self._generate_placeholder_svg(prompt)
        else:
            # Use placeholder if model is not available
            svg_string, image = self._generate_placeholder_svg(prompt)
        
        return {
            "svg": svg_string,
            "image": image
        }
    
    def postprocess(self, inference_output):
        """Post-process the model output."""
        svg_string = inference_output["svg"]
        image = inference_output["image"]
        
        # Convert image to base64 for JSON response
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        img_base64 = f"data:image/png;base64,{img_str}"
        
        return {
            "svg": svg_string,
            "image": img_base64
        }
    
    def handle(self, data, context):
        """Handle the request."""
        if not self.initialized:
            self.initialize(context)
            
        preprocessed_data = self.preprocess(data)
        inference_output = self.inference(preprocessed_data)
        return self.postprocess(inference_output)