| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | !pip install -q diffusers transformers accelerate safetensors |
| |
|
| | import torch |
| | import gc |
| | from huggingface_hub import hf_hub_download |
| | from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | |
| | |
| | DEVICE = "cuda" |
| | DTYPE = torch.float16 |
| |
|
| | SOL_REPO = "AbstractPhil/sd15-flow-matching" |
| | SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt" |
| |
|
| | |
| | |
| | |
| | print("Loading CLIP...") |
| | clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
| |
|
| | print("Loading VAE...") |
| | vae = AutoencoderKL.from_pretrained( |
| | "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| | subfolder="vae", |
| | torch_dtype=DTYPE |
| | ).to(DEVICE).eval() |
| |
|
| | print("Loading UNet...") |
| | unet = UNet2DConditionModel.from_pretrained( |
| | "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| | subfolder="unet", |
| | torch_dtype=DTYPE, |
| | ).to(DEVICE).eval() |
| |
|
| | print("Loading DDPM Scheduler...") |
| | sched = DDPMScheduler(num_train_timesteps=1000) |
| |
|
| | |
| | |
| | |
| | print(f"\nLoading Sol from {SOL_REPO}...") |
| | weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME) |
| | checkpoint = torch.load(weights_path, map_location="cpu") |
| |
|
| | state_dict = checkpoint["student"] |
| | print(f" gstep: {checkpoint.get('gstep', 'unknown')}") |
| |
|
| | if any(k.startswith("unet.") for k in state_dict.keys()): |
| | state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} |
| |
|
| | state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))} |
| |
|
| | missing, unexpected = unet.load_state_dict(state_dict, strict=False) |
| | print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}") |
| |
|
| | del checkpoint, state_dict |
| | gc.collect() |
| |
|
| | for p in unet.parameters(): |
| | p.requires_grad = False |
| |
|
| | print("✓ Sol ready!") |
| |
|
| | |
| | |
| | |
| | def alpha_sigma(t: torch.LongTensor): |
| | """Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly.""" |
| | ac = sched.alphas_cumprod.to(DEVICE)[t] |
| | alpha = ac.sqrt().view(-1, 1, 1, 1).float() |
| | sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float() |
| | return alpha, sigma |
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5): |
| | """ |
| | Matches trainer's sample() method exactly: |
| | 1. Use DDPM scheduler timesteps |
| | 2. Model predicts velocity v |
| | 3. Convert v → x0_hat → eps_hat |
| | 4. Use sched.step(eps_hat, t, x_t) |
| | """ |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | |
| | # Encode prompts |
| | inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
| | cond = clip_enc(**inputs).last_hidden_state.to(DTYPE) |
| | |
| | inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
| | uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE) |
| | |
| | # Set scheduler timesteps |
| | sched.set_timesteps(steps, device=DEVICE) |
| | |
| | # Start from noise |
| | x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
| | |
| | print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}") |
| | |
| | for i, t_scalar in enumerate(sched.timesteps): |
| | t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long) |
| | |
| | # Model predicts VELOCITY (not epsilon!) |
| | v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample |
| | v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample |
| | |
| | # CFG on velocity |
| | v_hat = v_uncond + cfg * (v_cond - v_uncond) |
| | |
| | # Convert velocity to epsilon (EXACTLY as trainer does) |
| | alpha, sigma = alpha_sigma(t) |
| | |
| | # v = alpha * eps - sigma * x0 |
| | # x_t = alpha * x0 + sigma * eps |
| | # Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2) |
| | # Then: eps = (x_t - alpha * x0) / sigma |
| | denom = alpha**2 + sigma**2 |
| | x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8) |
| | eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8) |
| | |
| | # Step with epsilon |
| | step_out = sched.step(eps_hat, t_scalar, x_t.float()) |
| | x_t = step_out.prev_sample.to(DTYPE) |
| | |
| | if (i + 1) % max(1, steps // 5) == 0: |
| | print(f" Step {i+1}/{steps}, t={t_scalar}") |
| | |
| | # Decode |
| | x_t = x_t / 0.18215 |
| | img = vae.decode(x_t).sample |
| | img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
| | |
| | return Image.fromarray((img * 255).astype(np.uint8)) |
| |
|
| | |
| | |
| | |
| | print("\n" + "="*60) |
| | print("Generating test images with Sol (correct sampler)") |
| | print("="*60) |
| |
|
| | from IPython.display import display |
| |
|
| | prompts = [ |
| | "a castle at sunset", |
| | "a portrait of a woman", |
| | "a city street at night", |
| | ] |
| |
|
| | for prompt in prompts: |
| | print() |
| | img = generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5) |
| | display(img) |
| | |
| | print("\n✓ Done!") |