|
|
|
|
|
|
|
|
from diffusers import UNet2DModel |
|
|
|
|
|
model = UNet2DModel.from_pretrained('ddpm-anime-faces-64').cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import math |
|
|
from tqdm import tqdm |
|
|
|
|
|
class DDIM: |
|
|
def __init__( |
|
|
self, |
|
|
num_train_timesteps:int = 1000, |
|
|
beta_start: float = 0.0001, |
|
|
beta_end: float = 0.02, |
|
|
sample_steps: int = 20, |
|
|
): |
|
|
self.num_train_timesteps = num_train_timesteps |
|
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) |
|
|
self.alphas = 1.0 - self.betas |
|
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
|
|
self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long() |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample( |
|
|
self, |
|
|
unet: UNet2DModel, |
|
|
batch_size: int, |
|
|
in_channels: int, |
|
|
sample_size: int, |
|
|
eta: float = 0.0, |
|
|
): |
|
|
alphas = self.alphas.to(unet.device) |
|
|
alphas_cumprod = self.alphas_cumprod.to(unet.device) |
|
|
timesteps = self.timesteps.to(unet.device) |
|
|
images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) |
|
|
for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'): |
|
|
pred_noise: torch.Tensor = unet(images, t).sample |
|
|
|
|
|
|
|
|
if not math.isclose(eta, 0.0): |
|
|
one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau] |
|
|
one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t] |
|
|
one_minus_alpha_t = 1.0 - alphas[t] |
|
|
sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5 |
|
|
else: |
|
|
sigma_t = torch.zeros_like(alphas[0]) |
|
|
|
|
|
|
|
|
alphas_cumprod_tau = alphas_cumprod[tau] |
|
|
sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5 |
|
|
alphas_cumprod_t = alphas_cumprod[t] |
|
|
sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5 |
|
|
sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5 |
|
|
first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t |
|
|
|
|
|
|
|
|
coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5 |
|
|
second_term = coeff * pred_noise |
|
|
|
|
|
epsilon = torch.randn_like(images) |
|
|
images = first_term + second_term + sigma_t * epsilon |
|
|
images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() |
|
|
return images |
|
|
|
|
|
ddim = DDIM() |
|
|
images = ddim.sample(model, 32, 3, 64) |
|
|
|
|
|
from diffusers.utils import make_image_grid, numpy_to_pil |
|
|
image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8) |
|
|
image_grid.save('ddim-sample-results.png') |
|
|
|