vae32ch / scale.py
recoilme's picture
Upload folder using huggingface_hub
14c7142 verified
import torch
import numpy as np
from PIL import Image
from diffusers import AutoencoderKL
from tqdm import tqdm
import pathlib
# ── 1. Загружаем VAE ──────────────────────────────────────────────────────────
vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32)
vae.eval().cuda()
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # = 8
# ── 2. Собираем все PNG рекурсивно ───────────────────────────────────────────
dataset_path = pathlib.Path("/workspace/ds")
image_paths = sorted(dataset_path.rglob("*.png"))
print(f"Найдено картинок: {len(image_paths)}")
# Берём первые 3000
image_paths = image_paths[:30000]
# ── 3. Препроцессинг — кроп до кратного 8 без ресайза ────────────────────────
def preprocess(path):
img = Image.open(path).convert("RGB")
w, h = img.size
new_w = (w // vae_scale_factor) * vae_scale_factor
new_h = (h // vae_scale_factor) * vae_scale_factor
if new_w != w or new_h != h:
left = (w - new_w) // 2
top = (h - new_h) // 2
img = img.crop((left, top, left + new_w, top + new_h))
x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
x = x.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
x = x * 2.0 - 1.0 # [-1, 1]
return x
# ── 4. Считаем статистику по каналам ─────────────────────────────────────────
latent_channels = vae.config.latent_channels # 32
all_means = [] # [N, C]
all_stds = [] # [N, C]
errors = []
with torch.no_grad():
for path in tqdm(image_paths, desc="Encoding"):
try:
x = preprocess(path).cuda()
lat = vae.encode(x).latent_dist.sample() # [1, C, H, W]
flat = lat.squeeze(0).float().reshape(latent_channels, -1) # [C, H*W]
all_means.append(flat.mean(dim=1).cpu()) # [C]
all_stds.append(flat.std(dim=1).cpu()) # [C]
except Exception as e:
errors.append((path, str(e)))
if errors:
print(f"\nОшибки ({len(errors)}):")
for p, e in errors:
print(f" {p}: {e}")
mean = torch.stack(all_means).mean(dim=0) # [C]
std = torch.stack(all_stds).mean(dim=0) # [C]
print(f"\nОбработано картинок: {len(all_means)}")
print(f"\nlatents_mean ({latent_channels} каналов):")
print(mean.tolist())
print(f"\nlatents_std ({latent_channels} каналов):")
print(std.tolist())
# ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ──────────────
cfg = vae.config
new_vae = AutoencoderKL(
in_channels = cfg.in_channels,
out_channels = cfg.out_channels,
latent_channels = cfg.latent_channels,
block_out_channels = cfg.block_out_channels,
layers_per_block = cfg.layers_per_block,
norm_num_groups = cfg.norm_num_groups,
act_fn = cfg.act_fn,
down_block_types = cfg.down_block_types,
up_block_types = cfg.up_block_types,
)
new_vae.eval()
# Переносим веса
result = new_vae.load_state_dict(vae.state_dict(), strict=False)
print(f"\nВеса перенесены: {result}")
# Прописываем scaling векторы в конфиг
new_vae.register_to_config(
latents_mean = mean.tolist(),
latents_std = std.tolist(),
scaling_factor = 1.0,
shift_factor = 0.0,
)
print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...")
print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...")
# ── 6. Сохраняем ──────────────────────────────────────────────────────────────
new_vae.save_pretrained("vae32ch2")
print("\nСохранено в vae32ch2/")