| | import gradio as gr |
| | import spaces |
| | import torch |
| | import diffusers |
| | import transformers |
| | import copy |
| | import random |
| | import numpy as np |
| | import torchvision.transforms as T |
| | import math |
| | import os |
| | import peft |
| | from peft import LoraConfig |
| | from safetensors import safe_open |
| | from omegaconf import OmegaConf |
| | from omnitry.models.transformer_flux import FluxTransformer2DModel |
| | from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline |
| |
|
| |
|
| | from huggingface_hub import snapshot_download |
| | snapshot_download(repo_id="Kunbyte/OmniTry", local_dir="./OmniTry") |
| |
|
| | device = torch.device('cuda:0') |
| | weight_dtype = torch.bfloat16 |
| | args = OmegaConf.load('configs/omnitry_v1_unified.yaml') |
| |
|
| | |
| | transformer = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-Fill-dev', subfolder='transformer').requires_grad_(False).to(device, dtype=weight_dtype) |
| | pipeline = FluxFillPipeline.from_pretrained( |
| | 'black-forest-labs/FLUX.1-Fill-dev', |
| | transformer=transformer, |
| | torch_dtype=weight_dtype |
| | ).to(device) |
| |
|
| |
|
| | |
| | lora_config = LoraConfig( |
| | r=args.lora_rank, |
| | lora_alpha=args.lora_alpha, |
| | init_lora_weights="gaussian", |
| | target_modules=[ |
| | 'x_embedder', |
| | 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0', |
| | 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out', |
| | 'ff.net.0.proj', 'ff.net.2', 'ff_context.net.0.proj', 'ff_context.net.2', |
| | 'norm1_context.linear', 'norm1.linear', 'norm.linear', 'proj_mlp', 'proj_out' |
| | ] |
| | ) |
| | transformer.add_adapter(lora_config, adapter_name='vtryon_lora') |
| | transformer.add_adapter(lora_config, adapter_name='garment_lora') |
| |
|
| | with safe_open('OmniTry/omnitry_v1_unified.safetensors', framework="pt") as f: |
| | lora_weights = {k: f.get_tensor(k) for k in f.keys()} |
| | transformer.load_state_dict(lora_weights, strict=False) |
| |
|
| | |
| | def create_hacked_forward(module): |
| |
|
| | def lora_forward(self, active_adapter, x, *args, **kwargs): |
| | result = self.base_layer(x, *args, **kwargs) |
| | if active_adapter is not None: |
| | torch_result_dtype = result.dtype |
| | lora_A = self.lora_A[active_adapter] |
| | lora_B = self.lora_B[active_adapter] |
| | dropout = self.lora_dropout[active_adapter] |
| | scaling = self.scaling[active_adapter] |
| | x = x.to(lora_A.weight.dtype) |
| | result = result + lora_B(lora_A(dropout(x))) * scaling |
| | return result |
| | |
| | def hacked_lora_forward(self, x, *args, **kwargs): |
| | return torch.cat(( |
| | lora_forward(self, 'vtryon_lora', x[:1], *args, **kwargs), |
| | lora_forward(self, 'garment_lora', x[1:], *args, **kwargs), |
| | ), dim=0) |
| | |
| | return hacked_lora_forward.__get__(module, type(module)) |
| |
|
| | for n, m in transformer.named_modules(): |
| | if isinstance(m, peft.tuners.lora.layer.Linear): |
| | m.forward = create_hacked_forward(m) |
| |
|
| |
|
| | def seed_everything(seed=0): |
| | random.seed(seed) |
| | os.environ['PYTHONHASHSEED'] = str(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| |
|
| | @spaces.GPU |
| | def generate(person_image, object_image, object_class, steps, guidance_scale, seed): |
| | |
| | if seed == -1: |
| | seed = random.randint(0, 2**32 - 1) |
| | seed_everything(seed) |
| |
|
| | |
| | max_area = 1024 * 1024 |
| | oW = person_image.width |
| | oH = person_image.height |
| |
|
| | ratio = math.sqrt(max_area / (oW * oH)) |
| | ratio = min(1, ratio) |
| | tW, tH = int(oW * ratio) // 16 * 16, int(oH * ratio) // 16 * 16 |
| | transform = T.Compose([ |
| | T.Resize((tH, tW)), |
| | T.ToTensor(), |
| | ]) |
| | person_image = transform(person_image) |
| |
|
| | |
| | ratio = min(tW / object_image.width, tH / object_image.height) |
| | transform = T.Compose([ |
| | T.Resize((int(object_image.height * ratio), int(object_image.width * ratio))), |
| | T.ToTensor(), |
| | ]) |
| | object_image_padded = torch.ones_like(person_image) |
| | object_image = transform(object_image) |
| | new_h, new_w = object_image.shape[1], object_image.shape[2] |
| | min_x = (tW - new_w) // 2 |
| | min_y = (tH - new_h) // 2 |
| | object_image_padded[:, min_y: min_y + new_h, min_x: min_x + new_w] = object_image |
| |
|
| | |
| | prompts = [args.object_map[object_class]] * 2 |
| | img_cond = torch.stack([person_image, object_image_padded]).to(dtype=weight_dtype, device=device) |
| | mask = torch.zeros_like(img_cond).to(img_cond) |
| |
|
| | with torch.no_grad(): |
| | img = pipeline( |
| | prompt=prompts, |
| | height=tH, |
| | width=tW, |
| | img_cond=img_cond, |
| | mask=mask, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=steps, |
| | generator=torch.Generator(device).manual_seed(seed), |
| | ).images[0] |
| |
|
| | return img |
| |
|
| |
|
| | if __name__ == '__main__': |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown('# Demo of OmniTry') |
| | with gr.Row(): |
| | with gr.Column(): |
| | person_image = gr.Image(type="pil", label="Person Image", height=800) |
| | run_button = gr.Button(value="Submit", variant='primary') |
| |
|
| | with gr.Column(): |
| | object_image = gr.Image(type="pil", label="Object Image", height=800) |
| | object_class = gr.Dropdown(label='Object Class', choices=args.object_map.keys()) |
| |
|
| | with gr.Column(): |
| | image_out = gr.Image(type="pil", label="Output", height=800) |
| |
|
| | with gr.Accordion("Advanced ⚙️", open=False): |
| | guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=50, value=30, step=0.1) |
| | steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) |
| | seed = gr.Number(label="Seed", value=-1, precision=0) |
| |
|
| | with gr.Row(): |
| | gr.Examples( |
| | examples=[ |
| | [ |
| | './demo_example/person_top_cloth.jpg', |
| | './demo_example/object_top_cloth.jpg', |
| | 'top clothes', |
| | ], |
| | [ |
| | './demo_example/person_bottom_cloth.jpg', |
| | './demo_example/object_bottom_cloth.jpg', |
| | 'bottom clothes', |
| | ], |
| | [ |
| | './demo_example/person_dress.jpg', |
| | './demo_example/object_dress.jpg', |
| | 'dress', |
| | ], |
| | [ |
| | './demo_example/person_shoes.jpg', |
| | './demo_example/object_shoes.jpg', |
| | 'shoe', |
| | ], |
| | [ |
| | './demo_example/person_earrings.jpg', |
| | './demo_example/object_earrings.jpg', |
| | 'earrings', |
| | ], |
| | [ |
| | './demo_example/person_bracelet.jpg', |
| | './demo_example/object_bracelet.jpg', |
| | 'bracelet', |
| | ], |
| | [ |
| | './demo_example/person_necklace.jpg', |
| | './demo_example/object_necklace.jpg', |
| | 'necklace', |
| | ], |
| | [ |
| | './demo_example/person_ring.jpg', |
| | './demo_example/object_ring.jpg', |
| | 'ring', |
| | ], |
| | [ |
| | './demo_example/person_sunglasses.jpg', |
| | './demo_example/object_sunglasses.jpg', |
| | 'sunglasses', |
| | ], |
| | [ |
| | './demo_example/person_glasses.jpg', |
| | './demo_example/object_glasses.jpg', |
| | 'glasses', |
| | ], |
| | [ |
| | './demo_example/person_belt.jpg', |
| | './demo_example/object_belt.jpg', |
| | 'belt', |
| | ], |
| | [ |
| | './demo_example/person_bag.jpg', |
| | './demo_example/object_bag.jpg', |
| | 'bag', |
| | ], |
| | [ |
| | './demo_example/person_hat.jpg', |
| | './demo_example/object_hat.jpg', |
| | 'hat', |
| | ], |
| | [ |
| | './demo_example/person_tie.jpg', |
| | './demo_example/object_tie.jpg', |
| | 'tie', |
| | ], |
| | [ |
| | './demo_example/person_bowtie.jpg', |
| | './demo_example/object_bowtie.jpg', |
| | 'bow tie', |
| | ], |
| | ], |
| |
|
| | inputs=[person_image, object_image, object_class], |
| | examples_per_page=100 |
| | ) |
| | |
| | run_button.click(generate, inputs=[person_image, object_image, object_class, steps, guidance_scale, seed], outputs=[image_out]) |
| |
|
| | demo.launch(server_name="0.0.0.0") |