| | |
| | import os |
| | os.environ["NCCL_P2P_DISABLE"] = "1" |
| | os.environ["NCCL_IB_DISABLE"] = "1" |
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| | import math |
| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from torch.utils.data import DataLoader, Sampler |
| | from torch.utils.data.distributed import DistributedSampler |
| | from torch.optim.lr_scheduler import LambdaLR |
| | from collections import defaultdict |
| | from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2,AsymmetricAutoencoderKL,FlowMatchEulerDiscreteScheduler |
| | from accelerate import Accelerator, DeepSpeedPlugin |
| | from datasets import load_from_disk |
| | from tqdm import tqdm |
| | from PIL import Image, ImageOps |
| | import wandb |
| | import random |
| | import gc |
| | from accelerate.state import DistributedType |
| | from torch.distributed import broadcast_object_list |
| | from torch.utils.checkpoint import checkpoint |
| | from diffusers.models.attention_processor import AttnProcessor2_0 |
| | from datetime import datetime |
| | import bitsandbytes as bnb |
| | import torch.nn.functional as F |
| | from collections import deque |
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| | |
| | ds_path = "/workspace/sdxs-1b/datasets/ds1234_flux32" |
| | project = "unet" |
| | |
| | batch_size = 32 |
| | base_learning_rate = 3e-5 |
| | min_learning_rate = 1e-5 |
| | num_epochs = 8 |
| | sample_interval_share = 20 |
| | cfg_dropout = 0.10 |
| | max_length = 248 |
| | use_wandb = True |
| | use_comet_ml = False |
| | save_model = True |
| | use_decay = True |
| | fbp = False |
| | optimizer_type = "adam8bit" |
| | torch_compile = False |
| | unet_gradient = True |
| | loss_normalize = False |
| | fixed_seed = False |
| | shuffle = True |
| | comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" |
| | comet_ml_workspace = "recoilme" |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | |
| | torch.backends.cuda.enable_flash_sdp(True) |
| | torch.backends.cuda.enable_mem_efficient_sdp(True) |
| | torch.backends.cuda.enable_math_sdp(False) |
| | save_barrier = 1.5 |
| | warmup_percent = 0.03 |
| | |
| | betta2 = 0.995 |
| | eps = 1e-7 |
| | clip_grad_norm = 1.0 |
| | limit = 0 |
| | checkpoints_folder = "" |
| | gradient_accumulation_steps = 1 |
| | dtype = torch.float32 |
| | mixed_precision = "no" |
| |
|
| | |
| | n_diffusion_steps = 40 |
| | samples_to_generate = 12 |
| | guidance_scale = 4 |
| |
|
| | |
| | generated_folder = "samples" |
| | os.makedirs(generated_folder, exist_ok=True) |
| |
|
| | |
| | current_date = datetime.now() |
| | seed = int(current_date.strftime("%Y%m%d")) + 1 |
| | if fixed_seed: |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | accelerator = Accelerator( |
| | mixed_precision=mixed_precision, |
| | gradient_accumulation_steps=gradient_accumulation_steps |
| | ) |
| | device = accelerator.device |
| |
|
| | print("init") |
| |
|
| | |
| | if accelerator.is_main_process: |
| | if use_wandb: |
| | wandb.init(project=project, config={ |
| | "batch_size": batch_size, |
| | "base_learning_rate": base_learning_rate, |
| | "num_epochs": num_epochs, |
| | "optimizer_type": optimizer_type, |
| | }) |
| | if use_comet_ml: |
| | from comet_ml import Experiment |
| | comet_experiment = Experiment( |
| | api_key=comet_ml_api_key, |
| | project_name=project, |
| | workspace=comet_ml_workspace |
| | ) |
| | hyper_params = { |
| | "batch_size": batch_size, |
| | "base_learning_rate": base_learning_rate, |
| | "num_epochs": num_epochs, |
| | } |
| | comet_experiment.log_parameters(hyper_params) |
| |
|
| | |
| | |
| | |
| | |
| | vae = AutoencoderKLFlux2.from_pretrained("vae", torch_dtype=dtype).to(device).eval() |
| | tokenizer = AutoTokenizer.from_pretrained("tokenizer") |
| | text_model = AutoModel.from_pretrained("text_encoder").to(device).eval() |
| | scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler") |
| |
|
| | def encode_texts(texts, max_length=max_length): |
| | if texts is None: |
| | texts = [""] |
| | |
| | if isinstance(texts, str): |
| | texts = [texts] |
| |
|
| | with torch.no_grad(): |
| | |
| | toks = tokenizer( |
| | texts, |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True, |
| | return_tensors="pt" |
| | ).to(device) |
| |
|
| | text_input_ids = toks.input_ids |
| | attention_mask = toks.attention_mask |
| |
|
| | |
| | |
| | outputs = text_model( |
| | input_ids=text_input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True |
| | ) |
| | |
| | layer_index = -2 |
| | prompt_embeds = outputs.hidden_states[layer_index] |
| |
|
| | |
| | |
| | final_layer_norm = text_model.text_model.final_layer_norm |
| | prompt_embeds = final_layer_norm(prompt_embeds) |
| |
|
| | return prompt_embeds, attention_mask |
| |
|
| | shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| | if shift_factor is None: shift_factor = 0.0 |
| | scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
| | if scaling_factor is None: scaling_factor = 1.0 |
| |
|
| | def _patchify_latents(latents): |
| | batch_size, num_channels_latents, height, width = latents.shape |
| | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) |
| | latents = latents.permute(0, 1, 3, 5, 2, 4) |
| | latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) |
| | return latents |
| |
|
| | @staticmethod |
| | def _unpatchify_latents(latents): |
| | batch_size, num_channels_latents, height, width = latents.shape |
| | latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) |
| | latents = latents.permute(0, 1, 4, 2, 5, 3) |
| | latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) |
| | return latents |
| |
|
| | def flux_encode(vae,latents): |
| | |
| | image_latents = _patchify_latents(latents) |
| | |
| | latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) |
| | latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) |
| | latents = (image_latents - latents_bn_mean) / latents_bn_std |
| | |
| | latents = _unpatchify_latents(latents) |
| | return latents |
| | |
| | def flux_decode(vae,latents): |
| | |
| | image_latents = _patchify_latents(latents) |
| | |
| | latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) |
| | latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) |
| | latents = image_latents * latents_bn_std + latents_bn_mean |
| | |
| | latents = _unpatchify_latents(latents) |
| | return latents |
| |
|
| | class DistributedResolutionBatchSampler(Sampler): |
| | def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True): |
| | self.dataset = dataset |
| | self.batch_size = max(1, batch_size // num_replicas) |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.shuffle = shuffle |
| | self.drop_last = drop_last |
| | self.epoch = 0 |
| | |
| | try: |
| | widths = np.array(dataset["width"]) |
| | heights = np.array(dataset["height"]) |
| | except KeyError: |
| | widths = np.zeros(len(dataset)) |
| | heights = np.zeros(len(dataset)) |
| | |
| | self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0) |
| | self.size_groups = {} |
| | for w, h in self.size_keys: |
| | mask = (widths == w) & (heights == h) |
| | self.size_groups[(w, h)] = np.where(mask)[0] |
| | |
| | self.group_num_batches = {} |
| | total_batches = 0 |
| | for size, indices in self.size_groups.items(): |
| | num_full_batches = len(indices) // (self.batch_size * self.num_replicas) |
| | self.group_num_batches[size] = num_full_batches |
| | total_batches += num_full_batches |
| | |
| | self.num_batches = (total_batches // self.num_replicas) * self.num_replicas |
| | |
| | def __iter__(self): |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | all_batches = [] |
| | rng = np.random.RandomState(self.epoch) |
| | |
| | for size, indices in self.size_groups.items(): |
| | indices = indices.copy() |
| | if self.shuffle: |
| | rng.shuffle(indices) |
| | num_full_batches = self.group_num_batches[size] |
| | if num_full_batches == 0: |
| | continue |
| | valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas] |
| | batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas) |
| | start_idx = self.rank * self.batch_size |
| | end_idx = start_idx + self.batch_size |
| | gpu_batches = batches[:, start_idx:end_idx] |
| | all_batches.extend(gpu_batches) |
| | |
| | if self.shuffle: |
| | rng.shuffle(all_batches) |
| | accelerator.wait_for_everyone() |
| | return iter(all_batches) |
| |
|
| | def __len__(self): |
| | return self.num_batches |
| |
|
| | def set_epoch(self, epoch): |
| | self.epoch = epoch |
| |
|
| | |
| | def get_fixed_samples_by_resolution(dataset, samples_per_group=1): |
| | size_groups = defaultdict(list) |
| | try: |
| | widths = dataset["width"] |
| | heights = dataset["height"] |
| | except KeyError: |
| | widths = [0] * len(dataset) |
| | heights = [0] * len(dataset) |
| | for i, (w, h) in enumerate(zip(widths, heights)): |
| | size = (w, h) |
| | size_groups[size].append(i) |
| | |
| | fixed_samples = {} |
| | for size, indices in size_groups.items(): |
| | n_samples = min(samples_per_group, len(indices)) |
| | if len(size_groups)==1: |
| | n_samples = samples_to_generate |
| | if n_samples == 0: |
| | continue |
| | sample_indices = random.sample(indices, n_samples) |
| | samples_data = [dataset[idx] for idx in sample_indices] |
| | |
| | latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) |
| | texts = [item["text"] for item in samples_data] |
| | |
| | |
| | embeddings, masks = encode_texts(texts) |
| | |
| | fixed_samples[size] = (latents, embeddings, masks, texts) |
| | |
| | print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") |
| | return fixed_samples |
| |
|
| | if limit > 0: |
| | dataset = load_from_disk(ds_path).select(range(limit)) |
| | else: |
| | dataset = load_from_disk(ds_path) |
| |
|
| | dataset = dataset.filter( |
| | lambda x: [not (path.startswith("/workspace/dataset/animesfw") or path.startswith("/workspace/dataset/d4/animesfw")) for path in x["image_path"]], |
| | batched=True, |
| | batch_size=10000, |
| | num_proc=8 |
| | ) |
| | print(f"Осталось примеров после фильтрации: {len(dataset)}") |
| |
|
| | |
| | def collate_fn_simple(batch): |
| | |
| | latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype) |
| | |
| | |
| | raw_texts = [item["text"] for item in batch] |
| | texts = [ |
| | "" if t.lower().startswith("zero") |
| | else "" if random.random() < cfg_dropout |
| | else t[1:].lstrip() if t.startswith(".") |
| | else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() |
| | for t in raw_texts |
| | ] |
| | |
| | |
| | embeddings, attention_mask = encode_texts(texts) |
| | |
| | |
| | attention_mask = attention_mask.to(dtype=torch.int64) |
| |
|
| | return latents, embeddings, attention_mask |
| |
|
| | batch_sampler = DistributedResolutionBatchSampler( |
| | dataset=dataset, |
| | batch_size=batch_size, |
| | num_replicas=accelerator.num_processes, |
| | rank=accelerator.process_index, |
| | shuffle=shuffle |
| | ) |
| |
|
| | dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) |
| | if accelerator.is_main_process: |
| | print("Total samples", len(dataloader)) |
| | dataloader = accelerator.prepare(dataloader) |
| |
|
| | start_epoch = 0 |
| | global_step = 0 |
| | total_training_steps = (len(dataloader) * num_epochs) |
| | world_size = accelerator.state.num_processes |
| |
|
| | |
| | latest_checkpoint = os.path.join(checkpoints_folder, project) |
| | if os.path.isdir(latest_checkpoint): |
| | print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
| | unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) |
| | if unet_gradient: |
| | unet.enable_gradient_checkpointing() |
| | unet.set_use_memory_efficient_attention_xformers(False) |
| | try: |
| | unet.set_attn_processor(AttnProcessor2_0()) |
| | except Exception as e: |
| | print(f"Ошибка при включении SDPA: {e}") |
| | unet.set_use_memory_efficient_attention_xformers(True) |
| | else: |
| | raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}") |
| | |
| |
|
| | def create_optimizer(name, params): |
| | if name == "adam8bit": |
| | return bnb.optim.AdamW8bit( |
| | params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01, |
| | |
| | ) |
| | elif name == "adam": |
| | return torch.optim.AdamW( |
| | params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01 |
| | ) |
| | else: |
| | raise ValueError(f"Unknown optimizer: {name}") |
| |
|
| | if fbp: |
| | trainable_params = list(unet.parameters()) |
| | optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} |
| | def optimizer_hook(param): |
| | optimizer_dict[param].step() |
| | optimizer_dict[param].zero_grad(set_to_none=True) |
| | for param in trainable_params: |
| | param.register_post_accumulate_grad_hook(optimizer_hook) |
| | unet, optimizer = accelerator.prepare(unet, optimizer_dict) |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | unet.requires_grad_(True) |
| | optimizer = create_optimizer(optimizer_type, unet.parameters()) |
| | |
| | def lr_schedule(step): |
| | x = step / (total_training_steps * world_size) |
| | warmup = warmup_percent |
| | if not use_decay: |
| | return base_learning_rate |
| | if x < warmup: |
| | return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) |
| | decay_ratio = (x - warmup) / (1 - warmup) |
| | return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ |
| | (1 + math.cos(math.pi * decay_ratio)) |
| | lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) |
| | unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) |
| |
|
| | if torch_compile: |
| | print("compiling") |
| | unet = torch.compile(unet) |
| | print("compiling - ok") |
| |
|
| | |
| | fixed_samples = get_fixed_samples_by_resolution(dataset) |
| |
|
| | |
| | def get_negative_embedding(neg_prompt="", batch_size=1): |
| | if not neg_prompt: |
| | hidden_dim = 2048 |
| | seq_len = max_length |
| | empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) |
| | empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) |
| | return empty_emb, empty_mask |
| |
|
| | uncond_emb, uncond_mask = encode_texts([neg_prompt]) |
| | uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) |
| | uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) |
| |
|
| | return uncond_emb, uncond_mask |
| | |
| | |
| | uncond_emb, uncond_mask = get_negative_embedding("low quality") |
| |
|
| |
|
| | |
| | |
| | @torch.compiler.disable() |
| | @torch.no_grad() |
| | def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): |
| | uncond_emb, uncond_mask = uncond_data |
| | |
| | original_model = None |
| | try: |
| | if not torch_compile: |
| | original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval() |
| | else: |
| | original_model = unet.eval() |
| |
|
| | vae.to(device=device).eval() |
| | |
| | all_generated_images = [] |
| | all_captions = [] |
| | |
| | |
| | for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): |
| | width, height = size |
| | sample_latents = sample_latents.to(dtype=dtype, device=device) |
| | sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) |
| | sample_mask = sample_mask.to(device=device) |
| | |
| | latents = torch.randn( |
| | sample_latents.shape, |
| | device=device, |
| | dtype=sample_latents.dtype, |
| | generator=torch.Generator(device=device).manual_seed(seed) |
| | ) |
| | |
| | scheduler.set_timesteps(n_diffusion_steps, device=device) |
| | |
| | for t in scheduler.timesteps: |
| | if guidance_scale != 1: |
| | latent_model_input = torch.cat([latents, latents], dim=0) |
| | |
| | |
| | |
| | curr_batch_size = sample_text_embeddings.shape[0] |
| | seq_len = sample_text_embeddings.shape[1] |
| | hidden_dim = sample_text_embeddings.shape[2] |
| | |
| | neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) |
| | text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) |
| | |
| | |
| | neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) |
| | attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) |
| |
|
| | else: |
| | latent_model_input = latents |
| | text_embeddings_batch = sample_text_embeddings |
| | attention_mask_batch = sample_mask |
| |
|
| | |
| | model_out = original_model( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=text_embeddings_batch, |
| | encoder_attention_mask=attention_mask_batch, |
| | ) |
| | flow = getattr(model_out, "sample", model_out) |
| | |
| | if guidance_scale != 1: |
| | flow_uncond, flow_cond = flow.chunk(2) |
| | flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) |
| | |
| | latents = scheduler.step(flow, t, latents).prev_sample |
| | |
| | current_latents = latents |
| | if step==0: |
| | current_latents = sample_latents |
| |
|
| | latents = current_latents.detach() * scaling_factor + shift_factor |
| | latents = flux_decode(vae,latents) |
| | decoded = vae.decode(latents.to(torch.float32)).sample |
| | decoded_fp32 = decoded.to(torch.float32) |
| | |
| | for img_idx, img_tensor in enumerate(decoded_fp32): |
| | img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() |
| | img = img.transpose(1, 2, 0) |
| | |
| | if np.isnan(img).any(): |
| | print("NaNs found, saving stopped! Step:", step) |
| | pil_img = Image.fromarray((img * 255).astype("uint8")) |
| | |
| | max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) |
| | max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) |
| | max_w_overall = max(255, max_w_overall) |
| | max_h_overall = max(255, max_h_overall) |
| | |
| | padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') |
| | all_generated_images.append(padded_img) |
| |
|
| | caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" |
| | all_captions.append(caption_text) |
| | |
| | sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" |
| | pil_img.save(sample_path, "JPEG", quality=96) |
| | |
| | if use_wandb and accelerator.is_main_process: |
| | wandb_images = [ |
| | wandb.Image(img, caption=f"{all_captions[i]}") |
| | for i, img in enumerate(all_generated_images) |
| | ] |
| | wandb.log({"generated_images": wandb_images}) |
| | if use_comet_ml and accelerator.is_main_process: |
| | for i, img in enumerate(all_generated_images): |
| | comet_experiment.log_image( |
| | image_data=img, |
| | name=f"step_{step}_img_{i}", |
| | step=step, |
| | metadata={"caption": all_captions[i]} |
| | ) |
| | finally: |
| | vae.to("cpu") |
| | try: |
| | all_generated_images.clear() |
| | all_captions.clear() |
| | del all_generated_images, all_captions |
| | del latents, current_latents, latent_model_input, flow |
| | del decoded, decoded_fp32 |
| | del sample_latents, sample_text_embeddings, sample_mask |
| | del model_out |
| | except UnboundLocalError: |
| | pass |
| | |
| | |
| | torch.cuda.synchronize() |
| | |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | if accelerator.is_main_process: |
| | if save_model: |
| | print("Генерация сэмплов до старта обучения...") |
| | generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) |
| | accelerator.wait_for_everyone() |
| |
|
| | def save_checkpoint(unet, variant=""): |
| | if accelerator.is_main_process: |
| | model_to_save = None |
| | if not torch_compile: |
| | model_to_save = accelerator.unwrap_model(unet) |
| | else: |
| | model_to_save = unet |
| |
|
| | if variant != "": |
| | model_to_save.to(dtype=torch.float16).save_pretrained( |
| | os.path.join(checkpoints_folder, f"{project}"), variant=variant |
| | ) |
| | else: |
| | model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
| |
|
| | torch.cuda.synchronize() |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | |
| |
|
| | |
| | if accelerator.is_main_process: |
| | print(f"Total steps per GPU: {total_training_steps}") |
| |
|
| | epoch_loss_points = [] |
| | progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
| |
|
| | steps_per_epoch = len(dataloader) |
| | sample_interval = max(1, steps_per_epoch // sample_interval_share) |
| | min_loss = 4. |
| |
|
| | for epoch in range(start_epoch, start_epoch + num_epochs): |
| | batch_losses = [] |
| | batch_grads = [] |
| | batch_sampler.set_epoch(epoch) |
| | accelerator.wait_for_everyone() |
| | unet.train() |
| | |
| | for step, (latents, embeddings, attention_mask) in enumerate(dataloader): |
| | with accelerator.accumulate(unet): |
| | if save_model == False and epoch == 0 and step == 5 : |
| | used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| | print(f"Шаг {step}: {used_gb:.2f} GB") |
| | |
| | |
| | noise = torch.randn_like(latents, dtype=latents.dtype) |
| |
|
| | |
| | u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype) |
| | t = u * (1 - 2 * 1e-5) + 1e-5 |
| | |
| | noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise |
| | |
| | timesteps = t.to(torch.float32).mul(999.0) |
| | timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1) |
| | |
| | |
| | model_pred = unet( |
| | noisy_latents, |
| | timesteps, |
| | encoder_hidden_states=embeddings, |
| | encoder_attention_mask=attention_mask |
| | ).sample |
| | |
| | target = noise - latents |
| |
|
| | mse_loss = F.mse_loss(model_pred.float(), target.float()) |
| | batch_losses.append(mse_loss.detach().item()) |
| |
|
| | if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| | accelerator.wait_for_everyone() |
| |
|
| | losses_dict = {} |
| | losses_dict["mse"] = mse_loss |
| |
|
| | if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| | accelerator.wait_for_everyone() |
| |
|
| | accelerator.backward(mse_loss) |
| |
|
| | if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| | accelerator.wait_for_everyone() |
| | |
| | grad = 0.0 |
| | if not fbp: |
| | if accelerator.sync_gradients: |
| | grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) |
| | grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | if accelerator.sync_gradients: |
| | global_step += 1 |
| | progress_bar.update(1) |
| | if accelerator.is_main_process: |
| | if fbp: |
| | current_lr = base_learning_rate |
| | else: |
| | current_lr = lr_scheduler.get_last_lr()[0] |
| | batch_grads.append(grad) |
| | |
| | log_data = {} |
| | log_data["loss_mse"] = mse_loss.detach().item() |
| | log_data["lr"] = current_lr |
| | log_data["grad"] = grad |
| | if accelerator.sync_gradients: |
| | if use_wandb: |
| | wandb.log(log_data, step=global_step) |
| | if use_comet_ml: |
| | comet_experiment.log_metrics(log_data, step=global_step) |
| |
|
| | if global_step % sample_interval == 0 or global_step==50: |
| | |
| | if save_model: |
| | generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) |
| | elif epoch % 10 == 0: |
| | generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) |
| | last_n = sample_interval |
| | |
| | if save_model: |
| | has_losses = len(batch_losses) > 0 |
| | avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0 |
| | last_loss = batch_losses[-1] if has_losses else 0.0 |
| | max_loss = max(avg_sample_loss, last_loss) |
| | should_save = max_loss < min_loss * save_barrier |
| | print( |
| | f"Saving: {should_save} | Max: {max_loss:.4f} | " |
| | f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" |
| | ) |
| | |
| | if should_save: |
| | min_loss = max_loss |
| | save_checkpoint(unet) |
| | unet.train() |
| |
|
| | if accelerator.is_main_process: |
| | avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 |
| | avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 |
| |
|
| | print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
| | log_data_ep = { |
| | "epoch_loss": avg_epoch_loss, |
| | "epoch_grad": avg_epoch_grad, |
| | "epoch": epoch + 1, |
| | } |
| | if use_wandb: |
| | wandb.log(log_data_ep) |
| | if use_comet_ml: |
| | comet_experiment.log_metrics(log_data_ep) |
| |
|
| | if accelerator.is_main_process: |
| | print("Обучение завершено! Сохраняем финальную модель...") |
| | |
| | save_checkpoint(unet,"fp16") |
| | if use_comet_ml: |
| | comet_experiment.end() |
| | accelerator.free_memory() |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.destroy_process_group() |
| | |
| | print("Готово!") |
| |
|