| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from diffusers import AutoencoderKL |
| | from tqdm import tqdm |
| | import pathlib |
| |
|
| | |
| | vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32) |
| | vae.eval().cuda() |
| |
|
| | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) |
| |
|
| | |
| | dataset_path = pathlib.Path("/workspace/ds") |
| | image_paths = sorted(dataset_path.rglob("*.png")) |
| | print(f"Найдено картинок: {len(image_paths)}") |
| |
|
| | |
| | image_paths = image_paths[:30000] |
| |
|
| | |
| | def preprocess(path): |
| | img = Image.open(path).convert("RGB") |
| | w, h = img.size |
| |
|
| | new_w = (w // vae_scale_factor) * vae_scale_factor |
| | new_h = (h // vae_scale_factor) * vae_scale_factor |
| |
|
| | if new_w != w or new_h != h: |
| | left = (w - new_w) // 2 |
| | top = (h - new_h) // 2 |
| | img = img.crop((left, top, left + new_w, top + new_h)) |
| |
|
| | x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0) |
| | x = x.permute(2, 0, 1).unsqueeze(0) |
| | x = x * 2.0 - 1.0 |
| | return x |
| |
|
| | |
| | latent_channels = vae.config.latent_channels |
| |
|
| | all_means = [] |
| | all_stds = [] |
| | errors = [] |
| |
|
| | with torch.no_grad(): |
| | for path in tqdm(image_paths, desc="Encoding"): |
| | try: |
| | x = preprocess(path).cuda() |
| | lat = vae.encode(x).latent_dist.sample() |
| | flat = lat.squeeze(0).float().reshape(latent_channels, -1) |
| |
|
| | all_means.append(flat.mean(dim=1).cpu()) |
| | all_stds.append(flat.std(dim=1).cpu()) |
| |
|
| | except Exception as e: |
| | errors.append((path, str(e))) |
| |
|
| | if errors: |
| | print(f"\nОшибки ({len(errors)}):") |
| | for p, e in errors: |
| | print(f" {p}: {e}") |
| |
|
| | mean = torch.stack(all_means).mean(dim=0) |
| | std = torch.stack(all_stds).mean(dim=0) |
| |
|
| | print(f"\nОбработано картинок: {len(all_means)}") |
| | print(f"\nlatents_mean ({latent_channels} каналов):") |
| | print(mean.tolist()) |
| | print(f"\nlatents_std ({latent_channels} каналов):") |
| | print(std.tolist()) |
| |
|
| | |
| | cfg = vae.config |
| |
|
| | new_vae = AutoencoderKL( |
| | in_channels = cfg.in_channels, |
| | out_channels = cfg.out_channels, |
| | latent_channels = cfg.latent_channels, |
| | block_out_channels = cfg.block_out_channels, |
| | layers_per_block = cfg.layers_per_block, |
| | norm_num_groups = cfg.norm_num_groups, |
| | act_fn = cfg.act_fn, |
| | down_block_types = cfg.down_block_types, |
| | up_block_types = cfg.up_block_types, |
| | ) |
| | new_vae.eval() |
| |
|
| | |
| | result = new_vae.load_state_dict(vae.state_dict(), strict=False) |
| | print(f"\nВеса перенесены: {result}") |
| |
|
| | |
| | new_vae.register_to_config( |
| | latents_mean = mean.tolist(), |
| | latents_std = std.tolist(), |
| | scaling_factor = 1.0, |
| | shift_factor = 0.0, |
| | ) |
| |
|
| | print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...") |
| | print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...") |
| |
|
| | |
| | new_vae.save_pretrained("vae32ch2") |
| | print("\nСохранено в vae32ch2/") |
| |
|