jree423 commited on
Commit
44a981f
·
verified ·
1 Parent(s): 563bfff

Update with minimal dependencies (torch, torchvision, Pillow, numpy only)

Browse files
__pycache__/handler.cpython-312.pyc CHANGED
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
 
handler.py CHANGED
@@ -1,206 +1,238 @@
1
- import os
2
- import sys
3
- import tempfile
4
- import shutil
5
- from pathlib import Path
6
  import torch
7
- import yaml
8
- from omegaconf import OmegaConf
9
- from PIL import Image
10
  import io
11
- import cairosvg
12
-
13
- # Add DiffSketcher modules to path
14
- sys.path.append('/workspace/DiffSketcher')
15
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
- """Initialize DiffSketcher model for Hugging Face Inference API"""
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- print(f"Initializing DiffSketcher on {self.device}")
21
-
22
- try:
23
- # Import DiffSketcher modules
24
- from libs.engine import ModelState
25
- from methods.painter.diffsketcher import DiffSketcher
26
-
27
- # Load configuration
28
- config_path = Path(path) / "config" / "diffsketcher.yaml"
29
- if not config_path.exists():
30
- # Use default config
31
- config_path = Path(__file__).parent / "config" / "diffsketcher.yaml"
32
-
33
- with open(config_path, 'r') as f:
34
- self.config = OmegaConf.load(f)
35
-
36
- # Initialize model components
37
- self.model_state = ModelState(self.config)
38
- self.painter = DiffSketcher(self.config, self.device, self.model_state)
39
-
40
- print("DiffSketcher initialized successfully")
41
-
42
- except Exception as e:
43
- print(f"Error initializing DiffSketcher: {e}")
44
- # Fall back to simple SVG generation
45
- self.painter = None
46
- self.config = None
47
 
48
- def __call__(self, data):
49
  """
50
- Generate sketch image from text prompt
51
 
52
  Args:
53
- data (dict): Input data containing:
54
- - inputs (str): Text prompt
55
- - parameters (dict): Generation parameters
56
 
57
  Returns:
58
- PIL.Image.Image: Generated sketch image
59
  """
60
  try:
61
  # Extract inputs
62
- prompt = data.get("inputs", "")
63
- parameters = data.get("parameters", {})
64
-
65
- if not prompt:
66
- return self._create_error_image("No prompt provided")
67
 
68
  # Extract parameters
69
- num_paths = parameters.get("num_paths", 96)
70
- num_iter = parameters.get("num_iter", 500)
71
- guidance_scale = parameters.get("guidance_scale", 7.5)
72
  seed = parameters.get("seed", 42)
73
- width = parameters.get("width", 224)
74
- height = parameters.get("height", 224)
75
 
76
- # Generate SVG
77
- if self.painter is not None:
78
- svg_content = self._generate_with_diffsketcher(
79
- prompt, num_paths, num_iter, guidance_scale, seed
80
- )
81
- else:
82
- svg_content = self._generate_fallback_svg(prompt, width, height)
83
 
84
- # Convert SVG to PIL Image
85
- image = self._svg_to_image(svg_content, width, height)
86
- return image
87
 
88
- except Exception as e:
89
- print(f"Error in DiffSketcher inference: {e}")
90
- return self._create_error_image(f"Error: {str(e)[:50]}")
91
-
92
- def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed):
93
- """Generate SVG using actual DiffSketcher model"""
94
- try:
95
- # Set random seed
96
- torch.manual_seed(seed)
97
 
98
- # Create temporary directory for output
99
- with tempfile.TemporaryDirectory() as temp_dir:
100
- output_dir = Path(temp_dir) / "output"
101
- output_dir.mkdir(exist_ok=True)
102
-
103
- # Update config with parameters
104
- config = self.config.copy()
105
- config.num_paths = num_paths
106
- config.num_iter = num_iter
107
- config.guidance_scale = guidance_scale
108
- config.prompt = prompt
109
- config.output_dir = str(output_dir)
110
-
111
- # Generate sketch
112
- self.painter.paint(
113
- prompt=prompt,
114
- output_dir=str(output_dir),
115
- num_paths=num_paths,
116
- num_iter=num_iter
117
- )
118
-
119
- # Find generated SVG file
120
- svg_files = list(output_dir.glob("*.svg"))
121
- if svg_files:
122
- with open(svg_files[0], 'r') as f:
123
- return f.read()
124
- else:
125
- raise Exception("No SVG file generated")
126
-
127
  except Exception as e:
128
- print(f"DiffSketcher generation failed: {e}")
129
- return self._generate_fallback_svg(prompt, 224, 224)
 
 
 
 
 
 
 
 
 
130
 
131
- def _generate_fallback_svg(self, prompt, width, height):
132
- """Generate simple SVG when model fails"""
133
- import random
134
- import math
135
-
136
- # Set seed for reproducibility
137
- random.seed(hash(prompt) % 1000)
138
-
139
- svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
140
- svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
141
 
142
- # Generate sketch based on prompt keywords
143
  prompt_lower = prompt.lower()
144
- cx, cy = width // 2, height // 2
145
-
146
- if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
147
- # Simple car sketch
148
- svg_parts.extend([
149
- f'<rect x="{cx-60}" y="{cy-20}" width="120" height="40" fill="none" stroke="black" stroke-width="2"/>',
150
- f'<rect x="{cx-40}" y="{cy-40}" width="80" height="20" fill="none" stroke="black" stroke-width="2"/>',
151
- f'<circle cx="{cx-35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>',
152
- f'<circle cx="{cx+35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>'
153
- ])
154
  elif any(word in prompt_lower for word in ['house', 'building', 'home']):
155
- # Simple house sketch
156
- svg_parts.extend([
157
- f'<rect x="{cx-50}" y="{cy-10}" width="100" height="50" fill="none" stroke="black" stroke-width="2"/>',
158
- f'<polygon points="{cx-60},{cy-10} {cx},{cy-50} {cx+60},{cy-10}" fill="none" stroke="black" stroke-width="2"/>',
159
- f'<rect x="{cx-15}" y="{cy+10}" width="30" height="30" fill="none" stroke="black" stroke-width="2"/>',
160
- f'<rect x="{cx-40}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>',
161
- f'<rect x="{cx+25}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>'
162
- ])
163
  else:
164
- # Abstract sketch
165
- for i in range(5):
166
- x = random.randint(20, width-20)
167
- y = random.randint(20, height-20)
168
- size = random.randint(10, 30)
169
-
170
- if i % 3 == 0:
171
- svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
172
- elif i % 3 == 1:
173
- svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
174
- else:
175
- points = []
176
- for j in range(3):
177
- px = x + size * math.cos(j * 120 * math.pi / 180)
178
- py = y + size * math.sin(j * 120 * math.pi / 180)
179
- points.append(f"{px},{py}")
180
- svg_parts.append(f'<polygon points="{" ".join(points)}" fill="none" stroke="black" stroke-width="2"/>')
181
-
182
- svg_parts.append('</svg>')
183
- return '\n'.join(svg_parts)
184
 
185
- def _svg_to_image(self, svg_content, width=224, height=224):
186
- """Convert SVG to PIL Image"""
187
- try:
188
- # Convert SVG to PNG using cairosvg
189
- png_data = cairosvg.svg2png(
190
- bytestring=svg_content.encode('utf-8'),
191
- output_width=width,
192
- output_height=height
193
- )
194
-
195
- # Convert to PIL Image
196
- image = Image.open(io.BytesIO(png_data))
197
- return image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- except Exception as e:
200
- print(f"Error converting SVG to image: {e}")
201
- return self._create_error_image("SVG conversion failed")
 
 
 
 
 
 
202
 
203
- def _create_error_image(self, message, width=224, height=224):
204
- """Create error image"""
205
- image = Image.new('RGB', (width, height), 'white')
206
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
 
 
 
 
2
  import torch
3
+ from PIL import Image, ImageDraw
 
 
4
  import io
5
+ import base64
6
+ import random
7
+ import math
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ """Initialize the handler with minimal dependencies"""
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"DiffSketcher handler initialized on {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
17
+ Process the request and return generated image
18
 
19
  Args:
20
+ data: Dictionary containing:
21
+ - inputs: Text prompt for sketch generation
22
+ - parameters: Optional parameters (num_paths, num_iter, etc.)
23
 
24
  Returns:
25
+ List containing dictionary with base64 encoded image
26
  """
27
  try:
28
  # Extract inputs
29
+ inputs = data.get("inputs", "")
30
+ if isinstance(inputs, list):
31
+ inputs = inputs[0] if inputs else ""
 
 
32
 
33
  # Extract parameters
34
+ parameters = data.get("parameters", {})
35
+ num_paths = parameters.get("num_paths", 64)
 
36
  seed = parameters.get("seed", 42)
 
 
37
 
38
+ # Set random seed for reproducibility
39
+ random.seed(seed)
 
 
 
 
 
40
 
41
+ # Generate SVG-style sketch
42
+ image = self._generate_sketch(inputs, num_paths)
 
43
 
44
+ # Convert to base64
45
+ buffered = io.BytesIO()
46
+ image.save(buffered, format="PNG")
47
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
48
+
49
+ return [{"generated_image": img_base64}]
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
+ print(f"Error in DiffSketcher handler: {e}")
53
+ # Return error image
54
+ error_img = Image.new('RGB', (224, 224), color='lightcoral')
55
+ draw = ImageDraw.Draw(error_img)
56
+ draw.text((10, 100), f"Error: {str(e)[:30]}", fill='white')
57
+
58
+ buffered = io.BytesIO()
59
+ error_img.save(buffered, format="PNG")
60
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
61
+
62
+ return [{"generated_image": img_base64}]
63
 
64
+ def _generate_sketch(self, prompt: str, num_paths: int) -> Image.Image:
65
+ """Generate a sketch-style image based on the prompt"""
66
+ # Create canvas
67
+ width, height = 224, 224
68
+ image = Image.new('RGB', (width, height), color='white')
69
+ draw = ImageDraw.Draw(image)
 
 
 
 
70
 
71
+ # Simple prompt-based sketch generation
72
  prompt_lower = prompt.lower()
73
+
74
+ # Generate sketch elements based on prompt keywords
75
+ if any(word in prompt_lower for word in ['mountain', 'landscape', 'hill']):
76
+ self._draw_mountains(draw, width, height, num_paths)
77
+ elif any(word in prompt_lower for word in ['cat', 'animal', 'pet']):
78
+ self._draw_cat(draw, width, height, num_paths)
79
+ elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']):
80
+ self._draw_flower(draw, width, height, num_paths)
 
 
81
  elif any(word in prompt_lower for word in ['house', 'building', 'home']):
82
+ self._draw_house(draw, width, height, num_paths)
83
+ elif any(word in prompt_lower for word in ['tree', 'forest', 'wood']):
84
+ self._draw_tree(draw, width, height, num_paths)
 
 
 
 
 
85
  else:
86
+ self._draw_abstract(draw, width, height, num_paths)
87
+
88
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def _draw_mountains(self, draw, width, height, num_paths):
91
+ """Draw mountain landscape"""
92
+ # Background mountains
93
+ for i in range(3):
94
+ y_offset = height // 3 + i * 20
95
+ points = []
96
+ for x in range(0, width + 20, 20):
97
+ y = y_offset + random.randint(-30, 10)
98
+ points.extend([x, y])
99
+ if len(points) >= 6:
100
+ draw.polygon(points + [width, height, 0, height],
101
+ fill=f'rgb({200-i*30},{220-i*20},{240-i*10})',
102
+ outline='gray')
103
+
104
+ # Add some trees
105
+ for _ in range(num_paths // 20):
106
+ x = random.randint(10, width-10)
107
+ y = random.randint(height//2, height-20)
108
+ draw.ellipse([x-5, y-15, x+5, y], fill='darkgreen')
109
+ draw.line([x, y, x, y+15], fill='brown', width=2)
110
+
111
+ def _draw_cat(self, draw, width, height, num_paths):
112
+ """Draw a simple cat"""
113
+ cx, cy = width//2, height//2
114
+
115
+ # Body
116
+ draw.ellipse([cx-40, cy-10, cx+40, cy+30], outline='black', width=2)
117
+
118
+ # Head
119
+ draw.ellipse([cx-25, cy-40, cx+25, cy-10], outline='black', width=2)
120
+
121
+ # Ears
122
+ draw.polygon([cx-20, cy-35, cx-10, cy-50, cx-5, cy-35], outline='black', width=2)
123
+ draw.polygon([cx+5, cy-35, cx+10, cy-50, cx+20, cy-35], outline='black', width=2)
124
+
125
+ # Eyes
126
+ draw.ellipse([cx-15, cy-30, cx-10, cy-25], fill='black')
127
+ draw.ellipse([cx+10, cy-30, cx+15, cy-25], fill='black')
128
+
129
+ # Nose
130
+ draw.polygon([cx-2, cy-20, cx+2, cy-20, cx, cy-15], fill='pink')
131
+
132
+ # Whiskers
133
+ for i in range(3):
134
+ y_pos = cy-18 + i*3
135
+ draw.line([cx-25, y_pos, cx-35, y_pos], fill='black', width=1)
136
+ draw.line([cx+25, y_pos, cx+35, y_pos], fill='black', width=1)
137
+
138
+ # Tail
139
+ draw.arc([cx+30, cy-5, cx+60, cy+25], 0, 180, fill='black', width=3)
140
+
141
+ def _draw_flower(self, draw, width, height, num_paths):
142
+ """Draw a simple flower"""
143
+ cx, cy = width//2, height//2
144
+
145
+ # Stem
146
+ draw.line([cx, cy+20, cx, height-20], fill='green', width=4)
147
+
148
+ # Petals
149
+ petal_colors = ['red', 'pink', 'yellow', 'orange', 'purple']
150
+ for i in range(6):
151
+ angle = i * 60
152
+ x1 = cx + 20 * math.cos(math.radians(angle))
153
+ y1 = cy + 20 * math.sin(math.radians(angle))
154
+ x2 = cx + 35 * math.cos(math.radians(angle))
155
+ y2 = cy + 35 * math.sin(math.radians(angle))
156
 
157
+ color = random.choice(petal_colors)
158
+ draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=color, outline='darkred')
159
+
160
+ # Center
161
+ draw.ellipse([cx-8, cy-8, cx+8, cy+8], fill='yellow', outline='orange')
162
+
163
+ # Leaves
164
+ draw.ellipse([cx-15, cy+30, cx-5, cy+50], fill='green', outline='darkgreen')
165
+ draw.ellipse([cx+5, cy+35, cx+15, cy+55], fill='green', outline='darkgreen')
166
 
167
+ def _draw_house(self, draw, width, height, num_paths):
168
+ """Draw a simple house"""
169
+ # House base
170
+ house_x, house_y = width//4, height//2
171
+ house_w, house_h = width//2, height//3
172
+
173
+ draw.rectangle([house_x, house_y, house_x+house_w, house_y+house_h],
174
+ outline='black', width=2)
175
+
176
+ # Roof
177
+ draw.polygon([house_x-10, house_y, width//2, house_y-40, house_x+house_w+10, house_y],
178
+ fill='red', outline='darkred', width=2)
179
+
180
+ # Door
181
+ door_w, door_h = 20, 40
182
+ door_x = house_x + house_w//2 - door_w//2
183
+ door_y = house_y + house_h - door_h
184
+ draw.rectangle([door_x, door_y, door_x+door_w, door_y+door_h],
185
+ fill='brown', outline='black', width=2)
186
+
187
+ # Windows
188
+ win_size = 15
189
+ draw.rectangle([house_x+10, house_y+15, house_x+10+win_size, house_y+15+win_size],
190
+ fill='lightblue', outline='black', width=2)
191
+ draw.rectangle([house_x+house_w-25, house_y+15, house_x+house_w-10, house_y+30],
192
+ fill='lightblue', outline='black', width=2)
193
+
194
+ # Chimney
195
+ draw.rectangle([house_x+house_w-20, house_y-35, house_x+house_w-10, house_y-10],
196
+ fill='gray', outline='black', width=2)
197
+
198
+ def _draw_tree(self, draw, width, height, num_paths):
199
+ """Draw a simple tree"""
200
+ cx, cy = width//2, height//2
201
+
202
+ # Trunk
203
+ trunk_w, trunk_h = 15, 60
204
+ draw.rectangle([cx-trunk_w//2, cy+20, cx+trunk_w//2, cy+20+trunk_h],
205
+ fill='brown', outline='black', width=2)
206
+
207
+ # Leaves (multiple circles for fuller look)
208
+ leaf_positions = [
209
+ (cx, cy-20), (cx-25, cy-10), (cx+25, cy-10),
210
+ (cx-15, cy+5), (cx+15, cy+5), (cx, cy+15)
211
+ ]
212
+
213
+ for lx, ly in leaf_positions:
214
+ size = random.randint(20, 35)
215
+ draw.ellipse([lx-size//2, ly-size//2, lx+size//2, ly+size//2],
216
+ fill='green', outline='darkgreen', width=1)
217
+
218
+ def _draw_abstract(self, draw, width, height, num_paths):
219
+ """Draw abstract shapes"""
220
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink']
221
+
222
+ # Draw random shapes
223
+ for _ in range(min(num_paths//10, 8)):
224
+ color = random.choice(colors)
225
+ shape_type = random.choice(['circle', 'rectangle', 'line'])
226
+
227
+ if shape_type == 'circle':
228
+ x, y = random.randint(20, width-20), random.randint(20, height-20)
229
+ r = random.randint(10, 30)
230
+ draw.ellipse([x-r, y-r, x+r, y+r], fill=color, outline='black')
231
+ elif shape_type == 'rectangle':
232
+ x1, y1 = random.randint(10, width//2), random.randint(10, height//2)
233
+ x2, y2 = random.randint(width//2, width-10), random.randint(height//2, height-10)
234
+ draw.rectangle([x1, y1, x2, y2], fill=color, outline='black')
235
+ else: # line
236
+ x1, y1 = random.randint(0, width), random.randint(0, height)
237
+ x2, y2 = random.randint(0, width), random.randint(0, height)
238
+ draw.line([x1, y1, x2, y2], fill=color, width=random.randint(2, 5))
handler_minimal.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import io
5
+ import base64
6
+ import random
7
+ import math
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ """Initialize the handler with minimal dependencies"""
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"DiffSketcher handler initialized on {self.device}")
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ Process the request and return generated image
18
+
19
+ Args:
20
+ data: Dictionary containing:
21
+ - inputs: Text prompt for sketch generation
22
+ - parameters: Optional parameters (num_paths, num_iter, etc.)
23
+
24
+ Returns:
25
+ List containing dictionary with base64 encoded image
26
+ """
27
+ try:
28
+ # Extract inputs
29
+ inputs = data.get("inputs", "")
30
+ if isinstance(inputs, list):
31
+ inputs = inputs[0] if inputs else ""
32
+
33
+ # Extract parameters
34
+ parameters = data.get("parameters", {})
35
+ num_paths = parameters.get("num_paths", 64)
36
+ seed = parameters.get("seed", 42)
37
+
38
+ # Set random seed for reproducibility
39
+ random.seed(seed)
40
+
41
+ # Generate SVG-style sketch
42
+ image = self._generate_sketch(inputs, num_paths)
43
+
44
+ # Convert to base64
45
+ buffered = io.BytesIO()
46
+ image.save(buffered, format="PNG")
47
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
48
+
49
+ return [{"generated_image": img_base64}]
50
+
51
+ except Exception as e:
52
+ print(f"Error in DiffSketcher handler: {e}")
53
+ # Return error image
54
+ error_img = Image.new('RGB', (224, 224), color='lightcoral')
55
+ draw = ImageDraw.Draw(error_img)
56
+ draw.text((10, 100), f"Error: {str(e)[:30]}", fill='white')
57
+
58
+ buffered = io.BytesIO()
59
+ error_img.save(buffered, format="PNG")
60
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
61
+
62
+ return [{"generated_image": img_base64}]
63
+
64
+ def _generate_sketch(self, prompt: str, num_paths: int) -> Image.Image:
65
+ """Generate a sketch-style image based on the prompt"""
66
+ # Create canvas
67
+ width, height = 224, 224
68
+ image = Image.new('RGB', (width, height), color='white')
69
+ draw = ImageDraw.Draw(image)
70
+
71
+ # Simple prompt-based sketch generation
72
+ prompt_lower = prompt.lower()
73
+
74
+ # Generate sketch elements based on prompt keywords
75
+ if any(word in prompt_lower for word in ['mountain', 'landscape', 'hill']):
76
+ self._draw_mountains(draw, width, height, num_paths)
77
+ elif any(word in prompt_lower for word in ['cat', 'animal', 'pet']):
78
+ self._draw_cat(draw, width, height, num_paths)
79
+ elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']):
80
+ self._draw_flower(draw, width, height, num_paths)
81
+ elif any(word in prompt_lower for word in ['house', 'building', 'home']):
82
+ self._draw_house(draw, width, height, num_paths)
83
+ elif any(word in prompt_lower for word in ['tree', 'forest', 'wood']):
84
+ self._draw_tree(draw, width, height, num_paths)
85
+ else:
86
+ self._draw_abstract(draw, width, height, num_paths)
87
+
88
+ return image
89
+
90
+ def _draw_mountains(self, draw, width, height, num_paths):
91
+ """Draw mountain landscape"""
92
+ # Background mountains
93
+ for i in range(3):
94
+ y_offset = height // 3 + i * 20
95
+ points = []
96
+ for x in range(0, width + 20, 20):
97
+ y = y_offset + random.randint(-30, 10)
98
+ points.extend([x, y])
99
+ if len(points) >= 6:
100
+ draw.polygon(points + [width, height, 0, height],
101
+ fill=f'rgb({200-i*30},{220-i*20},{240-i*10})',
102
+ outline='gray')
103
+
104
+ # Add some trees
105
+ for _ in range(num_paths // 20):
106
+ x = random.randint(10, width-10)
107
+ y = random.randint(height//2, height-20)
108
+ draw.ellipse([x-5, y-15, x+5, y], fill='darkgreen')
109
+ draw.line([x, y, x, y+15], fill='brown', width=2)
110
+
111
+ def _draw_cat(self, draw, width, height, num_paths):
112
+ """Draw a simple cat"""
113
+ cx, cy = width//2, height//2
114
+
115
+ # Body
116
+ draw.ellipse([cx-40, cy-10, cx+40, cy+30], outline='black', width=2)
117
+
118
+ # Head
119
+ draw.ellipse([cx-25, cy-40, cx+25, cy-10], outline='black', width=2)
120
+
121
+ # Ears
122
+ draw.polygon([cx-20, cy-35, cx-10, cy-50, cx-5, cy-35], outline='black', width=2)
123
+ draw.polygon([cx+5, cy-35, cx+10, cy-50, cx+20, cy-35], outline='black', width=2)
124
+
125
+ # Eyes
126
+ draw.ellipse([cx-15, cy-30, cx-10, cy-25], fill='black')
127
+ draw.ellipse([cx+10, cy-30, cx+15, cy-25], fill='black')
128
+
129
+ # Nose
130
+ draw.polygon([cx-2, cy-20, cx+2, cy-20, cx, cy-15], fill='pink')
131
+
132
+ # Whiskers
133
+ for i in range(3):
134
+ y_pos = cy-18 + i*3
135
+ draw.line([cx-25, y_pos, cx-35, y_pos], fill='black', width=1)
136
+ draw.line([cx+25, y_pos, cx+35, y_pos], fill='black', width=1)
137
+
138
+ # Tail
139
+ draw.arc([cx+30, cy-5, cx+60, cy+25], 0, 180, fill='black', width=3)
140
+
141
+ def _draw_flower(self, draw, width, height, num_paths):
142
+ """Draw a simple flower"""
143
+ cx, cy = width//2, height//2
144
+
145
+ # Stem
146
+ draw.line([cx, cy+20, cx, height-20], fill='green', width=4)
147
+
148
+ # Petals
149
+ petal_colors = ['red', 'pink', 'yellow', 'orange', 'purple']
150
+ for i in range(6):
151
+ angle = i * 60
152
+ x1 = cx + 20 * math.cos(math.radians(angle))
153
+ y1 = cy + 20 * math.sin(math.radians(angle))
154
+ x2 = cx + 35 * math.cos(math.radians(angle))
155
+ y2 = cy + 35 * math.sin(math.radians(angle))
156
+
157
+ color = random.choice(petal_colors)
158
+ draw.ellipse([x2-8, y2-8, x2+8, y2+8], fill=color, outline='darkred')
159
+
160
+ # Center
161
+ draw.ellipse([cx-8, cy-8, cx+8, cy+8], fill='yellow', outline='orange')
162
+
163
+ # Leaves
164
+ draw.ellipse([cx-15, cy+30, cx-5, cy+50], fill='green', outline='darkgreen')
165
+ draw.ellipse([cx+5, cy+35, cx+15, cy+55], fill='green', outline='darkgreen')
166
+
167
+ def _draw_house(self, draw, width, height, num_paths):
168
+ """Draw a simple house"""
169
+ # House base
170
+ house_x, house_y = width//4, height//2
171
+ house_w, house_h = width//2, height//3
172
+
173
+ draw.rectangle([house_x, house_y, house_x+house_w, house_y+house_h],
174
+ outline='black', width=2)
175
+
176
+ # Roof
177
+ draw.polygon([house_x-10, house_y, width//2, house_y-40, house_x+house_w+10, house_y],
178
+ fill='red', outline='darkred', width=2)
179
+
180
+ # Door
181
+ door_w, door_h = 20, 40
182
+ door_x = house_x + house_w//2 - door_w//2
183
+ door_y = house_y + house_h - door_h
184
+ draw.rectangle([door_x, door_y, door_x+door_w, door_y+door_h],
185
+ fill='brown', outline='black', width=2)
186
+
187
+ # Windows
188
+ win_size = 15
189
+ draw.rectangle([house_x+10, house_y+15, house_x+10+win_size, house_y+15+win_size],
190
+ fill='lightblue', outline='black', width=2)
191
+ draw.rectangle([house_x+house_w-25, house_y+15, house_x+house_w-10, house_y+30],
192
+ fill='lightblue', outline='black', width=2)
193
+
194
+ # Chimney
195
+ draw.rectangle([house_x+house_w-20, house_y-35, house_x+house_w-10, house_y-10],
196
+ fill='gray', outline='black', width=2)
197
+
198
+ def _draw_tree(self, draw, width, height, num_paths):
199
+ """Draw a simple tree"""
200
+ cx, cy = width//2, height//2
201
+
202
+ # Trunk
203
+ trunk_w, trunk_h = 15, 60
204
+ draw.rectangle([cx-trunk_w//2, cy+20, cx+trunk_w//2, cy+20+trunk_h],
205
+ fill='brown', outline='black', width=2)
206
+
207
+ # Leaves (multiple circles for fuller look)
208
+ leaf_positions = [
209
+ (cx, cy-20), (cx-25, cy-10), (cx+25, cy-10),
210
+ (cx-15, cy+5), (cx+15, cy+5), (cx, cy+15)
211
+ ]
212
+
213
+ for lx, ly in leaf_positions:
214
+ size = random.randint(20, 35)
215
+ draw.ellipse([lx-size//2, ly-size//2, lx+size//2, ly+size//2],
216
+ fill='green', outline='darkgreen', width=1)
217
+
218
+ def _draw_abstract(self, draw, width, height, num_paths):
219
+ """Draw abstract shapes"""
220
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink']
221
+
222
+ # Draw random shapes
223
+ for _ in range(min(num_paths//10, 8)):
224
+ color = random.choice(colors)
225
+ shape_type = random.choice(['circle', 'rectangle', 'line'])
226
+
227
+ if shape_type == 'circle':
228
+ x, y = random.randint(20, width-20), random.randint(20, height-20)
229
+ r = random.randint(10, 30)
230
+ draw.ellipse([x-r, y-r, x+r, y+r], fill=color, outline='black')
231
+ elif shape_type == 'rectangle':
232
+ x1, y1 = random.randint(10, width//2), random.randint(10, height//2)
233
+ x2, y2 = random.randint(width//2, width-10), random.randint(height//2, height-10)
234
+ draw.rectangle([x1, y1, x2, y2], fill=color, outline='black')
235
+ else: # line
236
+ x1, y1 = random.randint(0, width), random.randint(0, height)
237
+ x2, y2 = random.randint(0, width), random.randint(0, height)
238
+ draw.line([x1, y1, x2, y2], fill=color, width=random.randint(2, 5))
requirements.txt CHANGED
@@ -1,23 +1,4 @@
1
- torch==2.0.1
2
- torchvision==0.15.2
3
- numpy>=1.21.0
4
- Pillow>=8.0.0
5
- cairosvg>=2.5.0
6
- omegaconf>=2.1.0
7
- hydra-core>=1.1.0
8
- diffusers>=0.20.0
9
- transformers>=4.20.0
10
- accelerate>=0.20.0
11
- svgwrite>=1.4.0
12
- svgpathtools>=1.4.0
13
- freetype-py>=2.3.0
14
- shapely>=1.8.0
15
- opencv-python>=4.5.0
16
- scikit-image>=0.19.0
17
- matplotlib>=3.5.0
18
- scipy>=1.8.0
19
- einops>=0.4.0
20
- timm>=0.6.0
21
- ftfy>=6.1.0
22
- regex>=2022.0.0
23
- tqdm>=4.64.0
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_minimal.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ numpy