jree423 commited on
Commit
27980cb
·
verified ·
1 Parent(s): d646333

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +215 -57
handler.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import torch
@@ -10,6 +16,54 @@ import cairosvg
10
  import math
11
  import time
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
  """
@@ -22,8 +76,50 @@ class EndpointHandler:
22
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  print(f"Initializing DiffSketcher handler on {self.device}")
24
 
25
- # In a real implementation, we would load the model weights and initialize all components
26
- # For now, we'll use an advanced implementation that generates realistic SVG images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def svg_to_png(self, svg_string, width=512, height=512):
29
  """
@@ -66,11 +162,89 @@ class EndpointHandler:
66
  if seed is not None:
67
  random.seed(seed)
68
  np.random.seed(seed)
 
 
69
  else:
70
  seed = random.randint(0, 100000)
71
  random.seed(seed)
72
  np.random.seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Create a color palette based on the prompt
75
  word_sum = sum(ord(c) for c in prompt)
76
  palette_seed = word_sum % 5
@@ -342,29 +516,29 @@ class EndpointHandler:
342
  animal_svg += f'<ellipse cx="300" cy="300" rx="150" ry="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
343
 
344
  # Head
345
- animal_svg += f'<circle cx="150" cy="250" r="70" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
346
 
347
  # Eyes
348
- animal_svg += '<circle cx="130" cy="230" r="10" fill="white" stroke="black" stroke-width="1" />'
349
- animal_svg += '<circle cx="130" cy="230" r="5" fill="black" />'
350
- animal_svg += '<circle cx="170" cy="230" r="10" fill="white" stroke="black" stroke-width="1" />'
351
- animal_svg += '<circle cx="170" cy="230" r="5" fill="black" />'
352
 
353
  # Nose
354
- animal_svg += '<ellipse cx="150" cy="260" rx="15" ry="10" fill="#ffcccc" stroke="black" stroke-width="1" />'
355
 
356
  # Ears
357
- animal_svg += f'<path d="M100,200 L80,150 L120,180 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
358
- animal_svg += f'<path d="M200,200 L220,150 L180,180 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
359
 
360
  # Legs
361
- animal_svg += '<rect x="220" y="380" width="20" height="80" fill="#888" stroke="black" stroke-width="2" />'
362
- animal_svg += '<rect x="280" y="380" width="20" height="80" fill="#888" stroke="black" stroke-width="2" />'
363
- animal_svg += '<rect x="340" y="380" width="20" height="80" fill="#888" stroke="black" stroke-width="2" />'
364
- animal_svg += '<rect x="400" y="380" width="20" height="80" fill="#888" stroke="black" stroke-width="2" />'
365
 
366
  # Tail
367
- animal_svg += f'<path d="M450,300 Q500,250 520,200" fill="none" stroke="black" stroke-width="3" />'
368
 
369
  return animal_svg
370
 
@@ -372,55 +546,39 @@ class EndpointHandler:
372
  """Generate abstract art SVG."""
373
  abstract_svg = ""
374
 
375
- # Generate some random shapes
376
- for i in range(min(num_paths, 30)):
377
- shape_type = random.choice(["circle", "rect", "path", "line"])
378
-
379
  r = random.randint(color_ranges[0][0], color_ranges[0][1])
380
  g = random.randint(color_ranges[1][0], color_ranges[1][1])
381
  b = random.randint(color_ranges[2][0], color_ranges[2][1])
382
 
383
- if shape_type == "circle":
384
- cx = random.randint(50, 462)
385
- cy = random.randint(50, 462)
386
- radius = random.randint(10, 100)
387
- opacity = random.uniform(0.3, 0.8)
388
-
389
- abstract_svg += f'<circle cx="{cx}" cy="{cy}" r="{radius}" fill="rgb({r},{g},{b})" fill-opacity="{opacity:.2f}" stroke="black" stroke-width="2" />'
390
 
391
- elif shape_type == "rect":
392
- x = random.randint(50, 412)
393
- y = random.randint(50, 412)
394
- width = random.randint(20, 100)
395
- height = random.randint(20, 100)
396
- opacity = random.uniform(0.3, 0.8)
397
-
398
- abstract_svg += f'<rect x="{x}" y="{y}" width="{width}" height="{height}" fill="rgb({r},{g},{b})" fill-opacity="{opacity:.2f}" stroke="black" stroke-width="2" />'
399
 
400
- elif shape_type == "path":
401
- points = []
402
- for j in range(random.randint(3, 8)):
403
- x = random.randint(50, 462)
404
- y = random.randint(50, 462)
405
- points.append((x, y))
406
-
407
- path_d = f"M{points[0][0]},{points[0][1]}"
408
- for point in points[1:]:
409
- path_d += f" L{point[0]},{point[1]}"
410
- path_d += " Z"
411
-
412
- opacity = random.uniform(0.3, 0.8)
413
-
414
- abstract_svg += f'<path d="{path_d}" fill="rgb({r},{g},{b})" fill-opacity="{opacity:.2f}" stroke="black" stroke-width="2" />'
415
 
416
- elif shape_type == "line":
417
- x1 = random.randint(50, 462)
418
- y1 = random.randint(50, 462)
419
- x2 = random.randint(50, 462)
420
- y2 = random.randint(50, 462)
421
- width = random.randint(1, 10)
422
-
423
- abstract_svg += f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="rgb({r},{g},{b})" stroke-width="{width}" />'
 
 
 
 
 
 
 
424
 
425
  return abstract_svg
426
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Full implementation of DiffSketcher handler.
5
+ """
6
+
7
  import os
8
  import sys
9
  import torch
 
16
  import math
17
  import time
18
 
19
+ # Add the DiffSketcher repository to the path
20
+ sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "DiffSketcher"))
21
+
22
+ # Add the mock diffvg to the path
23
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
24
+ import mock_diffvg as diffvg
25
+
26
+ # Try to import the real DiffSketcher modules
27
+ try:
28
+ from models.clip_model import ClipModel
29
+ from models.sd_model import StableDiffusion
30
+ from models.loss import Loss
31
+ from models.painter_params import Painter, PainterOptimizer
32
+ from utils.train_utils import init_log, log_input, log_sketch, get_latest_ckpt, save_ckpt
33
+ from utils.vector_utils import (
34
+ svg_to_png, create_dir, init_svg, read_svg, get_svg_size, get_svg_path_d,
35
+ get_svg_path_width, get_svg_color, set_svg_path_d, set_svg_path_width,
36
+ set_svg_color, get_svg_meta, set_svg_meta, get_svg_path_bbox, get_svg_bbox,
37
+ get_png_size, get_svg_path_group, get_svg_group_opacity, set_svg_group_opacity,
38
+ get_svg_group_path_indices, get_svg_group_path_opacity, set_svg_group_path_opacity,
39
+ get_svg_group_path_fill, set_svg_group_path_fill, get_svg_group_path_stroke,
40
+ set_svg_group_path_stroke, get_svg_group_path_stroke_width, set_svg_group_path_stroke_width,
41
+ get_svg_group_path_stroke_opacity, set_svg_group_path_stroke_opacity,
42
+ get_svg_group_path_fill_opacity, set_svg_group_path_fill_opacity,
43
+ get_svg_group_path_stroke_linecap, set_svg_group_path_stroke_linecap,
44
+ get_svg_group_path_stroke_linejoin, set_svg_group_path_stroke_linejoin,
45
+ get_svg_group_path_stroke_miterlimit, set_svg_group_path_stroke_miterlimit,
46
+ get_svg_group_path_stroke_dasharray, set_svg_group_path_stroke_dasharray,
47
+ get_svg_group_path_stroke_dashoffset, set_svg_group_path_stroke_dashoffset,
48
+ get_svg_group_path_transform, set_svg_group_path_transform,
49
+ get_svg_group_transform, set_svg_group_transform,
50
+ get_svg_path_transform, set_svg_path_transform,
51
+ get_svg_path_fill, set_svg_path_fill,
52
+ get_svg_path_stroke, set_svg_path_stroke,
53
+ get_svg_path_stroke_width, set_svg_path_stroke_width,
54
+ get_svg_path_stroke_opacity, set_svg_path_stroke_opacity,
55
+ get_svg_path_fill_opacity, set_svg_path_fill_opacity,
56
+ get_svg_path_stroke_linecap, set_svg_path_stroke_linecap,
57
+ get_svg_path_stroke_linejoin, set_svg_path_stroke_linejoin,
58
+ get_svg_path_stroke_miterlimit, set_svg_path_stroke_miterlimit,
59
+ get_svg_path_stroke_dasharray, set_svg_path_stroke_dasharray,
60
+ get_svg_path_stroke_dashoffset, set_svg_path_stroke_dashoffset,
61
+ )
62
+ REAL_DIFFSKETCHER_AVAILABLE = True
63
+ except ImportError:
64
+ print("Warning: Could not import DiffSketcher modules. Using mock implementation instead.")
65
+ REAL_DIFFSKETCHER_AVAILABLE = False
66
+
67
  class EndpointHandler:
68
  def __init__(self, path=""):
69
  """
 
76
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  print(f"Initializing DiffSketcher handler on {self.device}")
78
 
79
+ # Check if the real DiffSketcher is available
80
+ self.use_real_diffsketcher = REAL_DIFFSKETCHER_AVAILABLE
81
+
82
+ if self.use_real_diffsketcher:
83
+ try:
84
+ # Initialize the real DiffSketcher model
85
+ self._init_real_diffsketcher()
86
+ except Exception as e:
87
+ print(f"Error initializing real DiffSketcher: {e}")
88
+ self.use_real_diffsketcher = False
89
+
90
+ if not self.use_real_diffsketcher:
91
+ print("Using mock DiffSketcher implementation")
92
+
93
+ def _init_real_diffsketcher(self):
94
+ """Initialize the real DiffSketcher model."""
95
+ # Load model weights
96
+ model_dir = os.path.join(self.path, "models", "diffsketcher")
97
+ if not os.path.exists(model_dir):
98
+ model_dir = "/workspace/vector_models/models/diffsketcher"
99
+
100
+ # Initialize CLIP model
101
+ self.clip_model = ClipModel(device=self.device)
102
+
103
+ # Initialize Stable Diffusion model
104
+ self.sd_model = StableDiffusion(device=self.device)
105
+
106
+ # Initialize loss function
107
+ self.loss_fn = Loss(device=self.device)
108
+
109
+ # Initialize painter parameters
110
+ self.painter = Painter(
111
+ num_paths=48,
112
+ num_segments=4,
113
+ canvas_size=512,
114
+ device=self.device
115
+ )
116
+
117
+ # Initialize painter optimizer
118
+ self.painter_optimizer = PainterOptimizer(
119
+ self.painter,
120
+ lr=1e-2,
121
+ device=self.device
122
+ )
123
 
124
  def svg_to_png(self, svg_string, width=512, height=512):
125
  """
 
162
  if seed is not None:
163
  random.seed(seed)
164
  np.random.seed(seed)
165
+ torch.manual_seed(seed)
166
+ torch.cuda.manual_seed(seed)
167
  else:
168
  seed = random.randint(0, 100000)
169
  random.seed(seed)
170
  np.random.seed(seed)
171
+ torch.manual_seed(seed)
172
+ torch.cuda.manual_seed(seed)
173
+
174
+ if self.use_real_diffsketcher:
175
+ try:
176
+ # Generate SVG using the real DiffSketcher
177
+ return self._generate_svg_real(prompt, negative_prompt, num_paths, guidance_scale)
178
+ except Exception as e:
179
+ print(f"Error generating SVG with real DiffSketcher: {e}")
180
+ # Fall back to mock implementation
181
+ return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale)
182
+ else:
183
+ # Generate SVG using the mock implementation
184
+ return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale)
185
+
186
+ def _generate_svg_real(self, prompt, negative_prompt, num_paths, guidance_scale):
187
+ """
188
+ Generate SVG using the real DiffSketcher.
189
+
190
+ Args:
191
+ prompt (str): Text prompt
192
+ negative_prompt (str): Negative text prompt
193
+ num_paths (int): Number of paths
194
+ guidance_scale (float): Guidance scale
195
+
196
+ Returns:
197
+ tuple: (svg_string, png_image)
198
+ """
199
+ # Initialize painter with the specified number of paths
200
+ self.painter.num_paths = num_paths
201
+
202
+ # Get CLIP embeddings for the prompt
203
+ text_embeddings = self.clip_model.get_text_embeddings(prompt, negative_prompt)
204
+
205
+ # Initialize SVG
206
+ svg_string = init_svg(self.painter.canvas_size, self.painter.canvas_size)
207
 
208
+ # Optimize the SVG
209
+ for i in range(1000): # Number of optimization steps
210
+ # Forward pass
211
+ svg_tensor = self.painter.get_image()
212
+
213
+ # Calculate loss
214
+ loss = self.loss_fn(svg_tensor, text_embeddings, guidance_scale)
215
+
216
+ # Backward pass
217
+ loss.backward()
218
+
219
+ # Update parameters
220
+ self.painter_optimizer.step()
221
+ self.painter_optimizer.zero_grad()
222
+
223
+ # Log progress
224
+ if i % 100 == 0:
225
+ print(f"Step {i}, Loss: {loss.item()}")
226
+
227
+ # Get the final SVG
228
+ svg_string = self.painter.get_svg()
229
+
230
+ # Convert SVG to PNG
231
+ png_image = self.svg_to_png(svg_string)
232
+
233
+ return svg_string, png_image
234
+
235
+ def _generate_svg_mock(self, prompt, negative_prompt, num_paths, guidance_scale):
236
+ """
237
+ Generate SVG using the mock implementation.
238
+
239
+ Args:
240
+ prompt (str): Text prompt
241
+ negative_prompt (str): Negative text prompt
242
+ num_paths (int): Number of paths
243
+ guidance_scale (float): Guidance scale
244
+
245
+ Returns:
246
+ tuple: (svg_string, png_image)
247
+ """
248
  # Create a color palette based on the prompt
249
  word_sum = sum(ord(c) for c in prompt)
250
  palette_seed = word_sum % 5
 
516
  animal_svg += f'<ellipse cx="300" cy="300" rx="150" ry="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
517
 
518
  # Head
519
+ animal_svg += f'<circle cx="150" cy="280" r="70" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
520
 
521
  # Eyes
522
+ animal_svg += '<circle cx="130" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />'
523
+ animal_svg += '<circle cx="130" cy="260" r="5" fill="black" />'
524
+ animal_svg += '<circle cx="170" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />'
525
+ animal_svg += '<circle cx="170" cy="260" r="5" fill="black" />'
526
 
527
  # Nose
528
+ animal_svg += '<circle cx="150" cy="290" r="10" fill="black" />'
529
 
530
  # Ears
531
+ animal_svg += f'<path d="M100,230 L80,180 L120,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
532
+ animal_svg += f'<path d="M200,230 L220,180 L180,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
533
 
534
  # Legs
535
+ animal_svg += '<rect x="200" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
536
+ animal_svg += '<rect x="250" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
537
+ animal_svg += '<rect x="350" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
538
+ animal_svg += '<rect x="400" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
539
 
540
  # Tail
541
+ animal_svg += f'<path d="M450,300 Q500,250 520,300" fill="none" stroke="rgb({r},{g},{b})" stroke-width="10" />'
542
 
543
  return animal_svg
544
 
 
546
  """Generate abstract art SVG."""
547
  abstract_svg = ""
548
 
549
+ # Generate random paths
550
+ for i in range(num_paths):
551
+ # Random color
 
552
  r = random.randint(color_ranges[0][0], color_ranges[0][1])
553
  g = random.randint(color_ranges[1][0], color_ranges[1][1])
554
  b = random.randint(color_ranges[2][0], color_ranges[2][1])
555
 
556
+ # Random stroke width
557
+ stroke_width = random.uniform(1, 5)
 
 
 
 
 
558
 
559
+ # Random path
560
+ path_data = "M"
561
+ x, y = random.uniform(0, 512), random.uniform(0, 512)
562
+ path_data += f"{x},{y} "
 
 
 
 
563
 
564
+ # Random number of segments
565
+ num_segments = random.randint(2, 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
+ for j in range(num_segments):
568
+ # Random curve or line
569
+ if random.random() > 0.5:
570
+ # Curve
571
+ cx1, cy1 = random.uniform(0, 512), random.uniform(0, 512)
572
+ cx2, cy2 = random.uniform(0, 512), random.uniform(0, 512)
573
+ x, y = random.uniform(0, 512), random.uniform(0, 512)
574
+ path_data += f"C{cx1},{cy1} {cx2},{cy2} {x},{y} "
575
+ else:
576
+ # Line
577
+ x, y = random.uniform(0, 512), random.uniform(0, 512)
578
+ path_data += f"L{x},{y} "
579
+
580
+ # Add path to SVG
581
+ abstract_svg += f'<path d="{path_data}" fill="none" stroke="rgb({r},{g},{b})" stroke-width="{stroke_width}" />'
582
 
583
  return abstract_svg
584