|
|
|
|
|
|
|
|
""" |
|
|
Full implementation of DiffSketcher handler. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import random |
|
|
import io |
|
|
import base64 |
|
|
import cairosvg |
|
|
import math |
|
|
import time |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "DiffSketcher")) |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
import mock_diffvg as diffvg |
|
|
|
|
|
|
|
|
try: |
|
|
from models.clip_model import ClipModel |
|
|
from models.sd_model import StableDiffusion |
|
|
from models.loss import Loss |
|
|
from models.painter_params import Painter, PainterOptimizer |
|
|
from utils.train_utils import init_log, log_input, log_sketch, get_latest_ckpt, save_ckpt |
|
|
from utils.vector_utils import ( |
|
|
svg_to_png, create_dir, init_svg, read_svg, get_svg_size, get_svg_path_d, |
|
|
get_svg_path_width, get_svg_color, set_svg_path_d, set_svg_path_width, |
|
|
set_svg_color, get_svg_meta, set_svg_meta, get_svg_path_bbox, get_svg_bbox, |
|
|
get_png_size, get_svg_path_group, get_svg_group_opacity, set_svg_group_opacity, |
|
|
get_svg_group_path_indices, get_svg_group_path_opacity, set_svg_group_path_opacity, |
|
|
get_svg_group_path_fill, set_svg_group_path_fill, get_svg_group_path_stroke, |
|
|
set_svg_group_path_stroke, get_svg_group_path_stroke_width, set_svg_group_path_stroke_width, |
|
|
get_svg_group_path_stroke_opacity, set_svg_group_path_stroke_opacity, |
|
|
get_svg_group_path_fill_opacity, set_svg_group_path_fill_opacity, |
|
|
get_svg_group_path_stroke_linecap, set_svg_group_path_stroke_linecap, |
|
|
get_svg_group_path_stroke_linejoin, set_svg_group_path_stroke_linejoin, |
|
|
get_svg_group_path_stroke_miterlimit, set_svg_group_path_stroke_miterlimit, |
|
|
get_svg_group_path_stroke_dasharray, set_svg_group_path_stroke_dasharray, |
|
|
get_svg_group_path_stroke_dashoffset, set_svg_group_path_stroke_dashoffset, |
|
|
get_svg_group_path_transform, set_svg_group_path_transform, |
|
|
get_svg_group_transform, set_svg_group_transform, |
|
|
get_svg_path_transform, set_svg_path_transform, |
|
|
get_svg_path_fill, set_svg_path_fill, |
|
|
get_svg_path_stroke, set_svg_path_stroke, |
|
|
get_svg_path_stroke_width, set_svg_path_stroke_width, |
|
|
get_svg_path_stroke_opacity, set_svg_path_stroke_opacity, |
|
|
get_svg_path_fill_opacity, set_svg_path_fill_opacity, |
|
|
get_svg_path_stroke_linecap, set_svg_path_stroke_linecap, |
|
|
get_svg_path_stroke_linejoin, set_svg_path_stroke_linejoin, |
|
|
get_svg_path_stroke_miterlimit, set_svg_path_stroke_miterlimit, |
|
|
get_svg_path_stroke_dasharray, set_svg_path_stroke_dasharray, |
|
|
get_svg_path_stroke_dashoffset, set_svg_path_stroke_dashoffset, |
|
|
) |
|
|
REAL_DIFFSKETCHER_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("Warning: Could not import DiffSketcher modules. Using mock implementation instead.") |
|
|
REAL_DIFFSKETCHER_AVAILABLE = False |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the DiffSketcher model. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model directory |
|
|
""" |
|
|
self.path = path |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Initializing DiffSketcher handler on {self.device}") |
|
|
|
|
|
|
|
|
self.use_real_diffsketcher = REAL_DIFFSKETCHER_AVAILABLE |
|
|
|
|
|
if self.use_real_diffsketcher: |
|
|
try: |
|
|
|
|
|
self._init_real_diffsketcher() |
|
|
except Exception as e: |
|
|
print(f"Error initializing real DiffSketcher: {e}") |
|
|
self.use_real_diffsketcher = False |
|
|
|
|
|
if not self.use_real_diffsketcher: |
|
|
print("Using mock DiffSketcher implementation") |
|
|
|
|
|
def _init_real_diffsketcher(self): |
|
|
"""Initialize the real DiffSketcher model.""" |
|
|
|
|
|
model_dir = os.path.join(self.path, "models", "diffsketcher") |
|
|
if not os.path.exists(model_dir): |
|
|
model_dir = "/workspace/vector_models/models/diffsketcher" |
|
|
|
|
|
|
|
|
self.clip_model = ClipModel(device=self.device) |
|
|
|
|
|
|
|
|
self.sd_model = StableDiffusion(device=self.device) |
|
|
|
|
|
|
|
|
self.loss_fn = Loss(device=self.device) |
|
|
|
|
|
|
|
|
self.painter = Painter( |
|
|
num_paths=48, |
|
|
num_segments=4, |
|
|
canvas_size=512, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
self.painter_optimizer = PainterOptimizer( |
|
|
self.painter, |
|
|
lr=1e-2, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
def svg_to_png(self, svg_string, width=512, height=512): |
|
|
""" |
|
|
Convert SVG string to PNG image. |
|
|
|
|
|
Args: |
|
|
svg_string (str): SVG string |
|
|
width (int): Width of the output image |
|
|
height (int): Height of the output image |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: PNG image |
|
|
""" |
|
|
try: |
|
|
|
|
|
png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'), |
|
|
output_width=width, |
|
|
output_height=height) |
|
|
return Image.open(io.BytesIO(png_data)) |
|
|
except Exception as e: |
|
|
print(f"Error converting SVG to PNG: {e}") |
|
|
|
|
|
return Image.new('RGB', (width, height), color=(240, 240, 240)) |
|
|
|
|
|
def generate_svg(self, prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=None): |
|
|
""" |
|
|
Generate SVG using DiffSketcher. |
|
|
|
|
|
Args: |
|
|
prompt (str): Text prompt |
|
|
negative_prompt (str): Negative text prompt |
|
|
num_paths (int): Number of paths |
|
|
guidance_scale (float): Guidance scale |
|
|
seed (int): Random seed |
|
|
|
|
|
Returns: |
|
|
tuple: (svg_string, png_image) |
|
|
""" |
|
|
|
|
|
if seed is not None: |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
else: |
|
|
seed = random.randint(0, 100000) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
if self.use_real_diffsketcher: |
|
|
try: |
|
|
|
|
|
return self._generate_svg_real(prompt, negative_prompt, num_paths, guidance_scale) |
|
|
except Exception as e: |
|
|
print(f"Error generating SVG with real DiffSketcher: {e}") |
|
|
|
|
|
return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale) |
|
|
else: |
|
|
|
|
|
return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale) |
|
|
|
|
|
def _generate_svg_real(self, prompt, negative_prompt, num_paths, guidance_scale): |
|
|
""" |
|
|
Generate SVG using the real DiffSketcher. |
|
|
|
|
|
Args: |
|
|
prompt (str): Text prompt |
|
|
negative_prompt (str): Negative text prompt |
|
|
num_paths (int): Number of paths |
|
|
guidance_scale (float): Guidance scale |
|
|
|
|
|
Returns: |
|
|
tuple: (svg_string, png_image) |
|
|
""" |
|
|
|
|
|
self.painter.num_paths = num_paths |
|
|
|
|
|
|
|
|
text_embeddings = self.clip_model.get_text_embeddings(prompt, negative_prompt) |
|
|
|
|
|
|
|
|
svg_string = init_svg(self.painter.canvas_size, self.painter.canvas_size) |
|
|
|
|
|
|
|
|
for i in range(1000): |
|
|
|
|
|
svg_tensor = self.painter.get_image() |
|
|
|
|
|
|
|
|
loss = self.loss_fn(svg_tensor, text_embeddings, guidance_scale) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
self.painter_optimizer.step() |
|
|
self.painter_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if i % 100 == 0: |
|
|
print(f"Step {i}, Loss: {loss.item()}") |
|
|
|
|
|
|
|
|
svg_string = self.painter.get_svg() |
|
|
|
|
|
|
|
|
png_image = self.svg_to_png(svg_string) |
|
|
|
|
|
return svg_string, png_image |
|
|
|
|
|
def _generate_svg_mock(self, prompt, negative_prompt, num_paths, guidance_scale): |
|
|
""" |
|
|
Generate SVG using the mock implementation. |
|
|
|
|
|
Args: |
|
|
prompt (str): Text prompt |
|
|
negative_prompt (str): Negative text prompt |
|
|
num_paths (int): Number of paths |
|
|
guidance_scale (float): Guidance scale |
|
|
|
|
|
Returns: |
|
|
tuple: (svg_string, png_image) |
|
|
""" |
|
|
|
|
|
word_sum = sum(ord(c) for c in prompt) |
|
|
palette_seed = word_sum % 5 |
|
|
|
|
|
if palette_seed == 0: |
|
|
color_ranges = [(200, 255), (100, 180), (50, 150)] |
|
|
elif palette_seed == 1: |
|
|
color_ranges = [(50, 150), (100, 180), (200, 255)] |
|
|
elif palette_seed == 2: |
|
|
color_ranges = [(150, 200), (100, 150), (50, 100)] |
|
|
elif palette_seed == 3: |
|
|
color_ranges = [(200, 255), (50, 255), (50, 255)] |
|
|
else: |
|
|
color_ranges = [(100, 200), (100, 200), (100, 200)] |
|
|
|
|
|
|
|
|
svg_string = f"""<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg"> |
|
|
<defs> |
|
|
<linearGradient id="bg-gradient" x1="0%" y1="0%" x2="100%" y2="100%"> |
|
|
<stop offset="0%" style="stop-color:#f8f8f8;stop-opacity:1" /> |
|
|
<stop offset="100%" style="stop-color:#e0e0e0;stop-opacity:1" /> |
|
|
</linearGradient> |
|
|
<filter id="pencil-texture" x="0" y="0" width="100%" height="100%"> |
|
|
<feTurbulence type="fractalNoise" baseFrequency="0.05" numOctaves="2" result="noise"/> |
|
|
<feDisplacementMap in="SourceGraphic" in2="noise" scale="2" xChannelSelector="R" yChannelSelector="G"/> |
|
|
</filter> |
|
|
</defs> |
|
|
<rect width="512" height="512" fill="url(#bg-gradient)"/> |
|
|
<text x="10" y="30" font-family="Arial" font-size="20" font-weight="bold" fill="black">DiffSketcher: {prompt}</text> |
|
|
""" |
|
|
|
|
|
|
|
|
svg_string += """ |
|
|
<g opacity="0.1"> |
|
|
<path d="M0,32 L512,32" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,64 L512,64" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,96 L512,96" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,128 L512,128" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,160 L512,160" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,192 L512,192" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,224 L512,224" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,256 L512,256" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,288 L512,288" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,320 L512,320" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,352 L512,352" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,384 L512,384" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,416 L512,416" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,448 L512,448" stroke="#000" stroke-width="1"/> |
|
|
<path d="M0,480 L512,480" stroke="#000" stroke-width="1"/> |
|
|
|
|
|
<path d="M32,0 L32,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M64,0 L64,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M96,0 L96,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M128,0 L128,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M160,0 L160,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M192,0 L192,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M224,0 L224,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M256,0 L256,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M288,0 L288,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M320,0 L320,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M352,0 L352,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M384,0 L384,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M416,0 L416,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M448,0 L448,512" stroke="#000" stroke-width="1"/> |
|
|
<path d="M480,0 L480,512" stroke="#000" stroke-width="1"/> |
|
|
</g> |
|
|
""" |
|
|
|
|
|
|
|
|
svg_string += '<g filter="url(#pencil-texture)">' |
|
|
|
|
|
|
|
|
if "car" in prompt.lower(): |
|
|
|
|
|
svg_string += self._generate_car_svg(color_ranges) |
|
|
elif "face" in prompt.lower() or "portrait" in prompt.lower(): |
|
|
|
|
|
svg_string += self._generate_face_svg(color_ranges) |
|
|
elif "landscape" in prompt.lower() or "mountain" in prompt.lower(): |
|
|
|
|
|
svg_string += self._generate_landscape_svg(color_ranges) |
|
|
elif "flower" in prompt.lower() or "plant" in prompt.lower(): |
|
|
|
|
|
svg_string += self._generate_flower_svg(color_ranges) |
|
|
elif "animal" in prompt.lower() or "dog" in prompt.lower() or "cat" in prompt.lower(): |
|
|
|
|
|
svg_string += self._generate_animal_svg(color_ranges) |
|
|
else: |
|
|
|
|
|
svg_string += self._generate_abstract_svg(color_ranges, num_paths) |
|
|
|
|
|
svg_string += '</g></svg>' |
|
|
|
|
|
|
|
|
png_image = self.svg_to_png(svg_string) |
|
|
|
|
|
return svg_string, png_image |
|
|
|
|
|
def _generate_car_svg(self, color_ranges): |
|
|
"""Generate a car SVG.""" |
|
|
car_svg = "" |
|
|
|
|
|
|
|
|
r = random.randint(color_ranges[0][0], color_ranges[0][1]) |
|
|
g = random.randint(color_ranges[1][0], color_ranges[1][1]) |
|
|
b = random.randint(color_ranges[2][0], color_ranges[2][1]) |
|
|
|
|
|
car_svg += f'<path d="M100,300 Q150,250 200,250 L350,250 Q400,250 450,300 L450,350 Q400,380 350,380 L200,380 Q150,380 100,350 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
car_svg += '<path d="M150,280 L200,260 L350,260 L400,280 L400,300 L350,320 L200,320 L150,300 Z" fill="#a0d0ff" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
car_svg += '<circle cx="150" cy="380" r="40" fill="#333" stroke="black" stroke-width="2" />' |
|
|
car_svg += '<circle cx="150" cy="380" r="20" fill="#777" stroke="black" stroke-width="2" />' |
|
|
car_svg += '<circle cx="400" cy="380" r="40" fill="#333" stroke="black" stroke-width="2" />' |
|
|
car_svg += '<circle cx="400" cy="380" r="20" fill="#777" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
car_svg += '<circle cx="110" cy="320" r="15" fill="#ffff00" stroke="black" stroke-width="2" />' |
|
|
car_svg += '<circle cx="440" cy="320" r="15" fill="#ff0000" stroke="black" stroke-width="2" />' |
|
|
|
|
|
return car_svg |
|
|
|
|
|
def _generate_face_svg(self, color_ranges): |
|
|
"""Generate a face SVG.""" |
|
|
face_svg = "" |
|
|
|
|
|
|
|
|
r = random.randint(color_ranges[0][0], color_ranges[0][1]) |
|
|
g = random.randint(color_ranges[1][0], color_ranges[1][1]) |
|
|
b = random.randint(color_ranges[2][0], color_ranges[2][1]) |
|
|
|
|
|
face_svg += f'<ellipse cx="256" cy="256" rx="150" ry="180" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
face_svg += '<ellipse cx="200" cy="200" rx="30" ry="20" fill="white" stroke="black" stroke-width="2" />' |
|
|
face_svg += '<circle cx="200" cy="200" r="10" fill="#333" />' |
|
|
face_svg += '<ellipse cx="312" cy="200" rx="30" ry="20" fill="white" stroke="black" stroke-width="2" />' |
|
|
face_svg += '<circle cx="312" cy="200" r="10" fill="#333" />' |
|
|
|
|
|
|
|
|
face_svg += '<path d="M170,170 Q200,150 230,170" fill="none" stroke="black" stroke-width="3" />' |
|
|
face_svg += '<path d="M282,170 Q312,150 342,170" fill="none" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
face_svg += '<path d="M256,220 Q270,280 256,300 Q242,280 256,220" fill="none" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
if random.random() < 0.7: |
|
|
face_svg += '<path d="M200,320 Q256,380 312,320" fill="none" stroke="black" stroke-width="3" />' |
|
|
else: |
|
|
face_svg += '<path d="M200,330 L312,330" fill="none" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
hair_r = random.randint(0, 100) |
|
|
hair_g = random.randint(0, 100) |
|
|
hair_b = random.randint(0, 100) |
|
|
|
|
|
face_svg += f'<path d="M106,256 Q106,100 256,100 Q406,100 406,256" fill="rgb({hair_r},{hair_g},{hair_b})" stroke="black" stroke-width="3" />' |
|
|
|
|
|
return face_svg |
|
|
|
|
|
def _generate_landscape_svg(self, color_ranges): |
|
|
"""Generate a landscape SVG.""" |
|
|
landscape_svg = "" |
|
|
|
|
|
|
|
|
sky_r = random.randint(100, 200) |
|
|
sky_g = random.randint(150, 255) |
|
|
sky_b = random.randint(200, 255) |
|
|
landscape_svg += f'<rect x="0" y="0" width="512" height="300" fill="rgb({sky_r},{sky_g},{sky_b})" />' |
|
|
|
|
|
|
|
|
sun_x = random.randint(50, 462) |
|
|
sun_y = random.randint(50, 150) |
|
|
landscape_svg += f'<circle cx="{sun_x}" cy="{sun_y}" r="40" fill="#ffff00" />' |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
mountain_x = random.randint(-100, 512) |
|
|
mountain_width = random.randint(200, 400) |
|
|
mountain_height = random.randint(100, 200) |
|
|
|
|
|
r = random.randint(50, 150) |
|
|
g = random.randint(50, 150) |
|
|
b = random.randint(50, 150) |
|
|
|
|
|
landscape_svg += f'<path d="M{mountain_x},{300} L{mountain_x + mountain_width/2},{300 - mountain_height} L{mountain_x + mountain_width},{300} Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
landscape_svg += f'<path d="M{mountain_x + mountain_width/4},{300 - mountain_height*0.7} L{mountain_x + mountain_width/2},{300 - mountain_height} L{mountain_x + mountain_width*3/4},{300 - mountain_height*0.7} Z" fill="white" />' |
|
|
|
|
|
|
|
|
ground_r = random.randint(50, 150) |
|
|
ground_g = random.randint(100, 200) |
|
|
ground_b = random.randint(50, 100) |
|
|
landscape_svg += f'<rect x="0" y="300" width="512" height="212" fill="rgb({ground_r},{ground_g},{ground_b})" />' |
|
|
|
|
|
|
|
|
for i in range(10): |
|
|
tree_x = random.randint(20, 492) |
|
|
tree_y = random.randint(320, 450) |
|
|
tree_height = random.randint(50, 100) |
|
|
|
|
|
|
|
|
landscape_svg += f'<rect x="{tree_x-5}" y="{tree_y}" width="10" height="{tree_height}" fill="#8B4513" />' |
|
|
|
|
|
|
|
|
foliage_r = random.randint(0, 100) |
|
|
foliage_g = random.randint(100, 200) |
|
|
foliage_b = random.randint(0, 100) |
|
|
|
|
|
landscape_svg += f'<circle cx="{tree_x}" cy="{tree_y - tree_height/2}" r="{tree_height/2}" fill="rgb({foliage_r},{foliage_g},{foliage_b})" />' |
|
|
|
|
|
return landscape_svg |
|
|
|
|
|
def _generate_flower_svg(self, color_ranges): |
|
|
"""Generate a flower SVG.""" |
|
|
flower_svg = "" |
|
|
|
|
|
|
|
|
stem_height = random.randint(150, 300) |
|
|
flower_svg += f'<path d="M256,450 L256,{450-stem_height}" fill="none" stroke="#0a0" stroke-width="5" />' |
|
|
|
|
|
|
|
|
leaf_y1 = random.randint(350, 420) |
|
|
leaf_y2 = random.randint(280, 349) |
|
|
|
|
|
flower_svg += f'<path d="M256,{leaf_y1} Q200,{leaf_y1-30} 180,{leaf_y1-10}" fill="none" stroke="#0a0" stroke-width="3" />' |
|
|
flower_svg += f'<path d="M256,{leaf_y2} Q310,{leaf_y2-30} 330,{leaf_y2-10}" fill="none" stroke="#0a0" stroke-width="3" />' |
|
|
|
|
|
|
|
|
center_y = 450 - stem_height |
|
|
flower_svg += f'<circle cx="256" cy="{center_y}" r="20" fill="#ff0" stroke="#000" stroke-width="2" />' |
|
|
|
|
|
|
|
|
r = random.randint(color_ranges[0][0], color_ranges[0][1]) |
|
|
g = random.randint(color_ranges[1][0], color_ranges[1][1]) |
|
|
b = random.randint(color_ranges[2][0], color_ranges[2][1]) |
|
|
|
|
|
num_petals = random.randint(5, 12) |
|
|
petal_length = random.randint(40, 70) |
|
|
|
|
|
for i in range(num_petals): |
|
|
angle = 2 * math.pi * i / num_petals |
|
|
petal_x = 256 + petal_length * math.cos(angle) |
|
|
petal_y = center_y + petal_length * math.sin(angle) |
|
|
|
|
|
control_x1 = 256 + petal_length * 0.5 * math.cos(angle - 0.3) |
|
|
control_y1 = center_y + petal_length * 0.5 * math.sin(angle - 0.3) |
|
|
|
|
|
control_x2 = 256 + petal_length * 0.5 * math.cos(angle + 0.3) |
|
|
control_y2 = center_y + petal_length * 0.5 * math.sin(angle + 0.3) |
|
|
|
|
|
flower_svg += f'<path d="M256,{center_y} C{control_x1},{control_y1} {control_x2},{control_y2} {petal_x},{petal_y} C{control_x2},{control_y2} {control_x1},{control_y1} 256,{center_y}" fill="rgb({r},{g},{b})" stroke="#000" stroke-width="1" />' |
|
|
|
|
|
return flower_svg |
|
|
|
|
|
def _generate_animal_svg(self, color_ranges): |
|
|
"""Generate an animal SVG.""" |
|
|
animal_svg = "" |
|
|
|
|
|
|
|
|
r = random.randint(color_ranges[0][0], color_ranges[0][1]) |
|
|
g = random.randint(color_ranges[1][0], color_ranges[1][1]) |
|
|
b = random.randint(color_ranges[2][0], color_ranges[2][1]) |
|
|
|
|
|
animal_svg += f'<ellipse cx="300" cy="300" rx="150" ry="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
animal_svg += f'<circle cx="150" cy="280" r="70" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />' |
|
|
|
|
|
|
|
|
animal_svg += '<circle cx="130" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />' |
|
|
animal_svg += '<circle cx="130" cy="260" r="5" fill="black" />' |
|
|
animal_svg += '<circle cx="170" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />' |
|
|
animal_svg += '<circle cx="170" cy="260" r="5" fill="black" />' |
|
|
|
|
|
|
|
|
animal_svg += '<circle cx="150" cy="290" r="10" fill="black" />' |
|
|
|
|
|
|
|
|
animal_svg += f'<path d="M100,230 L80,180 L120,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
animal_svg += f'<path d="M200,230 L220,180 L180,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
animal_svg += '<rect x="200" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
animal_svg += '<rect x="250" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
animal_svg += '<rect x="350" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
animal_svg += '<rect x="400" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />' |
|
|
|
|
|
|
|
|
animal_svg += f'<path d="M450,300 Q500,250 520,300" fill="none" stroke="rgb({r},{g},{b})" stroke-width="10" />' |
|
|
|
|
|
return animal_svg |
|
|
|
|
|
def _generate_abstract_svg(self, color_ranges, num_paths): |
|
|
"""Generate abstract art SVG.""" |
|
|
abstract_svg = "" |
|
|
|
|
|
|
|
|
for i in range(num_paths): |
|
|
|
|
|
r = random.randint(color_ranges[0][0], color_ranges[0][1]) |
|
|
g = random.randint(color_ranges[1][0], color_ranges[1][1]) |
|
|
b = random.randint(color_ranges[2][0], color_ranges[2][1]) |
|
|
|
|
|
|
|
|
stroke_width = random.uniform(1, 5) |
|
|
|
|
|
|
|
|
path_data = "M" |
|
|
x, y = random.uniform(0, 512), random.uniform(0, 512) |
|
|
path_data += f"{x},{y} " |
|
|
|
|
|
|
|
|
num_segments = random.randint(2, 5) |
|
|
|
|
|
for j in range(num_segments): |
|
|
|
|
|
if random.random() > 0.5: |
|
|
|
|
|
cx1, cy1 = random.uniform(0, 512), random.uniform(0, 512) |
|
|
cx2, cy2 = random.uniform(0, 512), random.uniform(0, 512) |
|
|
x, y = random.uniform(0, 512), random.uniform(0, 512) |
|
|
path_data += f"C{cx1},{cy1} {cx2},{cy2} {x},{y} " |
|
|
else: |
|
|
|
|
|
x, y = random.uniform(0, 512), random.uniform(0, 512) |
|
|
path_data += f"L{x},{y} " |
|
|
|
|
|
|
|
|
abstract_svg += f'<path d="{path_data}" fill="none" stroke="rgb({r},{g},{b})" stroke-width="{stroke_width}" />' |
|
|
|
|
|
return abstract_svg |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Process the input data and generate SVG output. |
|
|
|
|
|
Args: |
|
|
data (dict): Input data containing the prompt and other parameters |
|
|
|
|
|
Returns: |
|
|
PIL.Image.Image: Output image |
|
|
""" |
|
|
|
|
|
prompt = data.get("prompt", "") |
|
|
if not prompt and "inputs" in data: |
|
|
prompt = data.get("inputs", "") |
|
|
|
|
|
if not prompt: |
|
|
|
|
|
error_img = Image.new('RGB', (512, 512), color=(240, 240, 240)) |
|
|
return error_img |
|
|
|
|
|
negative_prompt = data.get("negative_prompt", "") |
|
|
num_paths = int(data.get("num_paths", 96)) |
|
|
guidance_scale = float(data.get("guidance_scale", 7.5)) |
|
|
seed = data.get("seed") |
|
|
if seed is not None: |
|
|
seed = int(seed) |
|
|
|
|
|
|
|
|
svg_string, png_image = self.generate_svg( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_paths=num_paths, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
return png_image |