import torch import numpy as np from PIL import Image from diffusers import AutoencoderKL from tqdm import tqdm import pathlib # ── 1. Загружаем VAE ────────────────────────────────────────────────────────── vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32) vae.eval().cuda() vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # = 8 # ── 2. Собираем все PNG рекурсивно ─────────────────────────────────────────── dataset_path = pathlib.Path("/workspace/ds") image_paths = sorted(dataset_path.rglob("*.png")) print(f"Найдено картинок: {len(image_paths)}") # Берём первые 3000 image_paths = image_paths[:30000] # ── 3. Препроцессинг — кроп до кратного 8 без ресайза ──────────────────────── def preprocess(path): img = Image.open(path).convert("RGB") w, h = img.size new_w = (w // vae_scale_factor) * vae_scale_factor new_h = (h // vae_scale_factor) * vae_scale_factor if new_w != w or new_h != h: left = (w - new_w) // 2 top = (h - new_h) // 2 img = img.crop((left, top, left + new_w, top + new_h)) x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0) x = x.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] x = x * 2.0 - 1.0 # [-1, 1] return x # ── 4. Считаем статистику по каналам ───────────────────────────────────────── latent_channels = vae.config.latent_channels # 32 all_means = [] # [N, C] all_stds = [] # [N, C] errors = [] with torch.no_grad(): for path in tqdm(image_paths, desc="Encoding"): try: x = preprocess(path).cuda() lat = vae.encode(x).latent_dist.sample() # [1, C, H, W] flat = lat.squeeze(0).float().reshape(latent_channels, -1) # [C, H*W] all_means.append(flat.mean(dim=1).cpu()) # [C] all_stds.append(flat.std(dim=1).cpu()) # [C] except Exception as e: errors.append((path, str(e))) if errors: print(f"\nОшибки ({len(errors)}):") for p, e in errors: print(f" {p}: {e}") mean = torch.stack(all_means).mean(dim=0) # [C] std = torch.stack(all_stds).mean(dim=0) # [C] print(f"\nОбработано картинок: {len(all_means)}") print(f"\nlatents_mean ({latent_channels} каналов):") print(mean.tolist()) print(f"\nlatents_std ({latent_channels} каналов):") print(std.tolist()) # ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ────────────── cfg = vae.config new_vae = AutoencoderKL( in_channels = cfg.in_channels, out_channels = cfg.out_channels, latent_channels = cfg.latent_channels, block_out_channels = cfg.block_out_channels, layers_per_block = cfg.layers_per_block, norm_num_groups = cfg.norm_num_groups, act_fn = cfg.act_fn, down_block_types = cfg.down_block_types, up_block_types = cfg.up_block_types, ) new_vae.eval() # Переносим веса result = new_vae.load_state_dict(vae.state_dict(), strict=False) print(f"\nВеса перенесены: {result}") # Прописываем scaling векторы в конфиг new_vae.register_to_config( latents_mean = mean.tolist(), latents_std = std.tolist(), scaling_factor = 1.0, shift_factor = 0.0, ) print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...") print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...") # ── 6. Сохраняем ────────────────────────────────────────────────────────────── new_vae.save_pretrained("vae32ch2") print("\nСохранено в vae32ch2/")