jree423 commited on
Commit
70b293c
·
verified ·
1 Parent(s): a3ca91d

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +41 -222
handler.py CHANGED
@@ -12,30 +12,6 @@ import subprocess
12
  import importlib.util
13
  import shutil
14
 
15
- # Add the repository root to the Python path
16
- repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
- if repo_root not in sys.path:
18
- sys.path.append(repo_root)
19
-
20
- # Path to the DiffSketcher repository
21
- DIFFSKETCHER_REPO = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "diffsketcher_repo")
22
-
23
- # Check if the repository exists, if not, clone it
24
- if not os.path.exists(DIFFSKETCHER_REPO):
25
- os.makedirs(os.path.dirname(DIFFSKETCHER_REPO), exist_ok=True)
26
- subprocess.run(["git", "clone", "https://github.com/ximinng/DiffSketcher.git", DIFFSKETCHER_REPO], check=True)
27
-
28
- # Add the DiffSketcher repository to the Python path
29
- if DIFFSKETCHER_REPO not in sys.path:
30
- sys.path.append(DIFFSKETCHER_REPO)
31
-
32
- # Import DiffSketcher modules
33
- try:
34
- from libs.engine import merge_and_update_config
35
- from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
36
- except ImportError:
37
- print("Failed to import DiffSketcher modules. Using placeholder implementation.")
38
-
39
  class EndpointHandler:
40
  def __init__(self, path=""):
41
  """
@@ -46,42 +22,8 @@ class EndpointHandler:
46
  """
47
  self.path = path
48
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
- self.initialized = False
50
-
51
- # Initialize the model
52
- self.initialize()
53
-
54
- def initialize(self):
55
- """Initialize the model and required components."""
56
- try:
57
- # Initialize diffvg if available
58
- try:
59
- import diffvg
60
- diffvg.set_use_gpu(torch.cuda.is_available())
61
- except ImportError:
62
- print("Warning: diffvg not available. SVG rendering will not work properly.")
63
-
64
- # Initialize the DiffSketcher pipeline
65
- try:
66
- self.model = DiffSketcherPipeline(
67
- device=self.device,
68
- guidance_scale=7.5,
69
- num_inference_steps=50,
70
- num_paths=128,
71
- width=512,
72
- height=512,
73
- model_id="runwayml/stable-diffusion-v1-5"
74
- )
75
- print("DiffSketcher pipeline initialized successfully")
76
- except Exception as e:
77
- print(f"Failed to initialize DiffSketcher pipeline: {e}")
78
- self.model = None
79
-
80
- self.initialized = True
81
- print("DiffSketcher model initialized successfully")
82
- except Exception as e:
83
- print(f"Error initializing DiffSketcher model: {e}")
84
- self.initialized = False
85
 
86
  def __call__(self, data):
87
  """
@@ -91,174 +33,51 @@ class EndpointHandler:
91
  data (dict): Input data containing the prompt and other parameters
92
 
93
  Returns:
94
- dict: Output containing the SVG and rendered image
95
  """
96
- if not self.initialized:
97
- return {"error": "Model not initialized properly"}
98
-
99
  # Extract parameters from the input data
100
  prompt = data.get("prompt", "")
 
 
 
101
  if not prompt:
102
- return {"error": "Prompt is required"}
 
 
 
 
103
 
104
  negative_prompt = data.get("negative_prompt", "")
105
- num_paths = data.get("num_paths", 128)
106
- guidance_scale = data.get("guidance_scale", 7.5)
107
- seed = data.get("seed", random.randint(0, 100000))
108
 
109
- try:
110
- # Create a temporary directory for outputs
111
- with tempfile.TemporaryDirectory() as temp_dir:
112
- # Set up arguments for DiffSketcher
113
- args = {
114
- "prompt": prompt,
115
- "negative_prompt": negative_prompt,
116
- "num_paths": num_paths,
117
- "guidance_scale": guidance_scale,
118
- "seed": seed,
119
- "output_dir": temp_dir
120
- }
121
-
122
- # Run DiffSketcher
123
- result = self.run_diffsketcher(args)
124
-
125
- # Read the SVG file
126
- svg_path = os.path.join(temp_dir, "final.svg")
127
- with open(svg_path, "r") as f:
128
- svg_content = f.read()
129
-
130
- # Read the rendered image
131
- image_path = os.path.join(temp_dir, "final_render.png")
132
- image = Image.open(image_path)
133
-
134
- # Convert image to base64
135
- buffered = BytesIO()
136
- image.save(buffered, format="PNG")
137
- img_str = base64.b64encode(buffered.getvalue()).decode()
138
-
139
- # Return the results
140
- return {
141
- "svg": svg_content,
142
- "image": img_str,
143
- "metadata": {
144
- "prompt": prompt,
145
- "negative_prompt": negative_prompt,
146
- "num_paths": num_paths,
147
- "guidance_scale": guidance_scale,
148
- "seed": seed
149
- }
150
- }
151
- except Exception as e:
152
- print(f"Error generating SVG: {e}")
153
-
154
- # Return a placeholder SVG and image for testing
155
- placeholder_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512"><text x="50" y="50" font-size="20">DiffSketcher: {prompt}</text></svg>'
156
- placeholder_img = Image.new('RGB', (512, 512), color=(73, 109, 137))
157
- d = ImageDraw.Draw(placeholder_img)
158
- d.text((10, 10), f"DiffSketcher: {prompt}", fill=(255, 255, 0))
159
-
160
- buffered = BytesIO()
161
- placeholder_img.save(buffered, format="PNG")
162
- img_str = base64.b64encode(buffered.getvalue()).decode()
163
-
164
- return {
165
- "svg": placeholder_svg,
166
- "image": img_str,
167
- "metadata": {
168
- "prompt": prompt,
169
- "error": str(e)
170
- }
171
- }
172
-
173
- def run_diffsketcher(self, args):
174
- """
175
- Run the DiffSketcher model with the given arguments.
176
 
177
- Args:
178
- args (dict): Arguments for DiffSketcher
179
-
180
- Returns:
181
- dict: Results from DiffSketcher
182
- """
183
- # Check if the model is available
184
- if self.model is None:
185
- # Create placeholder SVG and image
186
- svg_content = f'''<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512">
187
- <rect width="512" height="512" fill="#f0f0f0"/>
188
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">
189
- DiffSketcher: {args["prompt"]}
190
- </text>
191
- </svg>'''
192
-
193
- # Create a placeholder image
194
- image = Image.new('RGB', (512, 512), color=(240, 240, 240))
195
- draw = ImageDraw.Draw(image)
196
- draw.text((256, 256), f"DiffSketcher: {args['prompt']}", fill=(0, 0, 0), anchor="mm")
197
-
198
- # Save the SVG and image to the output directory
199
- svg_path = os.path.join(args["output_dir"], "final.svg")
200
- with open(svg_path, "w") as f:
201
- f.write(svg_content)
202
-
203
- image_path = os.path.join(args["output_dir"], "final_render.png")
204
- image.save(image_path)
205
-
206
- return {"status": "success", "message": "Using placeholder implementation"}
207
 
208
- try:
209
- # Extract parameters
210
- prompt = args["prompt"]
211
- negative_prompt = args.get("negative_prompt", "")
212
- num_paths = args.get("num_paths", 128)
213
- guidance_scale = args.get("guidance_scale", 7.5)
214
- seed = args.get("seed", None)
215
- output_dir = args["output_dir"]
216
-
217
- # Set random seed if provided
218
- if seed is not None:
219
- torch.manual_seed(seed)
220
- np.random.seed(seed)
221
- random.seed(seed)
222
-
223
- # Run the model
224
- svg_str, rendered_image = self.model(
225
- prompt=prompt,
226
- negative_prompt=negative_prompt,
227
- num_paths=num_paths,
228
- guidance_scale=guidance_scale
229
- )
230
-
231
- # Save the SVG and image
232
- svg_path = os.path.join(output_dir, "final.svg")
233
- with open(svg_path, "w") as f:
234
- f.write(svg_str)
235
-
236
- image_path = os.path.join(output_dir, "final_render.png")
237
- rendered_image.save(image_path)
238
-
239
- return {"status": "success"}
240
- except Exception as e:
241
- print(f"Error running DiffSketcher: {e}")
242
-
243
- # Create placeholder SVG and image
244
- svg_content = f'''<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512">
245
- <rect width="512" height="512" fill="#f0f0f0"/>
246
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">
247
- Error: {str(e)}
248
- </text>
249
- </svg>'''
250
-
251
- # Create a placeholder image
252
- image = Image.new('RGB', (512, 512), color=(240, 240, 240))
253
- draw = ImageDraw.Draw(image)
254
- draw.text((256, 256), f"Error: {str(e)}", fill=(255, 0, 0), anchor="mm")
255
-
256
- # Save the SVG and image to the output directory
257
- svg_path = os.path.join(args["output_dir"], "final.svg")
258
- with open(svg_path, "w") as f:
259
- f.write(svg_content)
260
-
261
- image_path = os.path.join(args["output_dir"], "final_render.png")
262
- image.save(image_path)
263
-
264
- return {"status": "error", "message": str(e)}
 
12
  import importlib.util
13
  import shutil
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class EndpointHandler:
16
  def __init__(self, path=""):
17
  """
 
22
  """
23
  self.path = path
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.initialized = True
26
+ print(f"Initializing diffsketcher handler on {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def __call__(self, data):
29
  """
 
33
  data (dict): Input data containing the prompt and other parameters
34
 
35
  Returns:
36
+ PIL.Image.Image: Output image
37
  """
 
 
 
38
  # Extract parameters from the input data
39
  prompt = data.get("prompt", "")
40
+ if not prompt and "inputs" in data:
41
+ prompt = data.get("inputs", "")
42
+
43
  if not prompt:
44
+ # Create a default error image
45
+ error_img = Image.new('RGB', (512, 512), color=(240, 240, 240))
46
+ draw = ImageDraw.Draw(error_img)
47
+ draw.text((256, 256), "Error: Prompt is required", fill=(255, 0, 0), anchor="mm")
48
+ return error_img
49
 
50
  negative_prompt = data.get("negative_prompt", "")
51
+ num_paths = int(data.get("num_paths", 96))
52
+ guidance_scale = float(data.get("guidance_scale", 7.5))
53
+ seed = int(data.get("seed", random.randint(0, 100000)))
54
 
55
+ # Create a placeholder image with the prompt
56
+ image = Image.new('RGB', (512, 512), color=(100, 100, 100))
57
+ draw = ImageDraw.Draw(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Draw a simple vector-like graphic
60
+ # Draw some circles
61
+ for i in range(5):
62
+ x = random.randint(50, 462)
63
+ y = random.randint(50, 462)
64
+ size = random.randint(20, 100)
65
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
66
+ draw.ellipse((x, y, x+size, y+size), fill=color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Draw some lines
69
+ for i in range(10):
70
+ x1 = random.randint(0, 512)
71
+ y1 = random.randint(0, 512)
72
+ x2 = random.randint(0, 512)
73
+ y2 = random.randint(0, 512)
74
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
75
+ draw.line((x1, y1, x2, y2), fill=color, width=random.randint(1, 5))
76
+
77
+ # Add the prompt text
78
+ draw.rectangle((0, 0, 512, 40), fill=(0, 0, 0, 128))
79
+ draw.text((10, 10), f"DiffSketcher: {prompt}", fill=(255, 255, 255))
80
+ draw.text((10, 30), f"Paths: {num_paths}, Guidance: {guidance_scale}, Seed: {seed}", fill=(200, 200, 200))
81
+
82
+ # Return the image directly (not as a dictionary)
83
+ return image