LittleNyima commited on
Commit
c950b46
·
1 Parent(s): 3e86d6b

Update README and code to enhance clarity

Browse files
Files changed (3) hide show
  1. README.md +9 -90
  2. sampling_ddim.py +75 -0
  3. sampling_ddpm.py +62 -0
README.md CHANGED
@@ -6,9 +6,9 @@ pipeline_tag: unconditional-image-generation
6
  tags:
7
  - art
8
  ---
9
- # Abstract
10
 
11
- **DDPM** model trained on [huggan/anime-faces](https://huggingface.co/datasets/huggan/anime-faces) dataset.
12
 
13
  ## Training Arguments
14
 
@@ -25,95 +25,14 @@ tags:
25
 
26
  For training code, please refer to [this link](https://github.com/LittleNyima/code-snippets/blob/master/ddpm-tutorial/ddpm_training.py).
27
 
28
- # Inference
29
 
30
- This project aims to implement DDPM from scratch, so `DDPMScheduler` is not used. Instead, I use only `UNet2DModel` and implement a simple scheduler myself. The inference code is:
31
 
32
- ```python
33
- import torch
34
- from tqdm import tqdm
35
- from diffusers import UNet2DModel
36
 
37
- class DDPM:
38
- def __init__(
39
- self,
40
- num_train_timesteps:int = 1000,
41
- beta_start: float = 0.0001,
42
- beta_end: float = 0.02,
43
- ):
44
- self.num_train_timesteps = num_train_timesteps
45
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
46
- self.alphas = 1.0 - self.betas
47
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
48
- self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
49
-
50
- def add_noise(
51
- self,
52
- original_samples: torch.Tensor,
53
- noise: torch.Tensor,
54
- timesteps: torch.Tensor,
55
- ):
56
- alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device ,dtype=original_samples.dtype)
57
- noise = noise.to(original_samples.device)
58
- timesteps = timesteps.to(original_samples.device)
59
 
60
- # \sqrt{\bar\alpha_t}
61
- sqrt_alpha_prod = alphas_cumprod[timesteps].flatten() ** 0.5
62
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
63
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
64
-
65
- # \sqrt{1 - \bar\alpha_t}
66
- sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).flatten() ** 0.5
67
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
68
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
69
-
70
- return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
71
-
72
- @torch.no_grad()
73
- def sample(
74
- self,
75
- unet: UNet2DModel,
76
- batch_size: int,
77
- in_channels: int,
78
- sample_size: int,
79
- ):
80
- betas = self.betas.to(unet.device)
81
- alphas = self.alphas.to(unet.device)
82
- alphas_cumprod = self.alphas_cumprod.to(unet.device)
83
- timesteps = self.timesteps.to(unet.device)
84
- images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
85
- for timestep in tqdm(timesteps, desc='Sampling'):
86
- pred_noise: torch.Tensor = unet(images, timestep).sample
87
-
88
- # mean of q(x_{t-1}|x_t)
89
- alpha_t = alphas[timestep]
90
- alpha_cumprod_t = alphas_cumprod[timestep]
91
- sqrt_alpha_t = alpha_t ** 0.5
92
- one_minus_alpha_t = 1.0 - alpha_t
93
- sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5
94
- mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t
95
-
96
- # variance of q(x_{t-1}|x_t)
97
- if timestep > 1:
98
- beta_t = betas[timestep]
99
- one_minus_alpha_cumprod_t_minus_one = 1.0 - alphas_cumprod[timestep - 1]
100
- one_divided_by_sigma_square = alpha_t / beta_t + 1.0 / one_minus_alpha_cumprod_t_minus_one
101
- variance = (1.0 / one_divided_by_sigma_square) ** 0.5
102
- else:
103
- variance = torch.zeros_like(timestep)
104
-
105
- epsilon = torch.randn_like(images)
106
- images = mean + variance * epsilon
107
- images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
108
- return images
109
-
110
- model = UNet2DModel.from_pretrained('ddpm-animefaces-64').cuda()
111
- ddpm = DDPM()
112
- images = ddpm.sample(model, 32, 3, 64)
113
-
114
- from diffusers.utils import make_image_grid, numpy_to_pil
115
- image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
116
- image_grid.save('ddpm-sample-results.png')
117
- ```
118
-
119
- This can also be found in [this link](https://github.com/LittleNyima/code-snippets/blob/master/ddpm-tutorial/ddpm_sampling.py).
 
6
  tags:
7
  - art
8
  ---
9
+ # ddpm-anime-faces-64
10
 
11
+ **ddpm-anime-faces-64** is an educational project for introducing the training and sampling processes of DDPM and DDIM models. The model is trained on [huggan/anime-faces](https://huggingface.co/datasets/huggan/anime-faces) dataset.
12
 
13
  ## Training Arguments
14
 
 
25
 
26
  For training code, please refer to [this link](https://github.com/LittleNyima/code-snippets/blob/master/ddpm-tutorial/ddpm_training.py).
27
 
28
+ ## Inference
29
 
30
+ This project aims to implement DDPM from scratch, so `DDPMScheduler` is not used. Instead, I use only `UNet2DModel` and implement a simple scheduler myself.
31
 
32
+ Please refer to `sampling_ddpm.py` and `sampling_ddim.py` for detailed usages.
 
 
 
33
 
34
+ ## References
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ 1. [DDPM Tutorial (Written in Chinese)](https://littlenyima.github.io/posts/13-denoising-diffusion-probabilistic-models/)
37
+ 2. [DDIM Tutorial (Written in Chinese)](https://littlenyima.github.io/posts/14-denoising-diffusion-implicit-models/)
38
+ 3. [GitHub Repo](https://github.com/LittleNyima/code-snippets)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sampling_ddim.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model
2
+
3
+ from diffusers import UNet2DModel
4
+
5
+ model = UNet2DModel.from_pretrained('ddpm-anime-faces-64').cuda()
6
+
7
+
8
+ # core
9
+
10
+ import torch
11
+ import math
12
+ from tqdm import tqdm
13
+
14
+ class DDIM:
15
+ def __init__(
16
+ self,
17
+ num_train_timesteps:int = 1000,
18
+ beta_start: float = 0.0001,
19
+ beta_end: float = 0.02,
20
+ sample_steps: int = 20,
21
+ ):
22
+ self.num_train_timesteps = num_train_timesteps
23
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
24
+ self.alphas = 1.0 - self.betas
25
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
26
+ self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long()
27
+
28
+ @torch.no_grad()
29
+ def sample(
30
+ self,
31
+ unet: UNet2DModel,
32
+ batch_size: int,
33
+ in_channels: int,
34
+ sample_size: int,
35
+ eta: float = 0.0,
36
+ ):
37
+ alphas = self.alphas.to(unet.device)
38
+ alphas_cumprod = self.alphas_cumprod.to(unet.device)
39
+ timesteps = self.timesteps.to(unet.device)
40
+ images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
41
+ for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'):
42
+ pred_noise: torch.Tensor = unet(images, t).sample
43
+
44
+ # sigma_t
45
+ if not math.isclose(eta, 0.0):
46
+ one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau]
47
+ one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t]
48
+ one_minus_alpha_t = 1.0 - alphas[t]
49
+ sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5
50
+ else:
51
+ sigma_t = torch.zeros_like(alphas[0])
52
+
53
+ # first term of x_tau
54
+ alphas_cumprod_tau = alphas_cumprod[tau]
55
+ sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5
56
+ alphas_cumprod_t = alphas_cumprod[t]
57
+ sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5
58
+ sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5
59
+ first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
60
+
61
+ # second term of x_tau
62
+ coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5
63
+ second_term = coeff * pred_noise
64
+
65
+ epsilon = torch.randn_like(images)
66
+ images = first_term + second_term + sigma_t * epsilon
67
+ images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
68
+ return images
69
+
70
+ ddim = DDIM()
71
+ images = ddim.sample(model, 32, 3, 64)
72
+
73
+ from diffusers.utils import make_image_grid, numpy_to_pil
74
+ image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
75
+ image_grid.save('ddim-sample-results.png')
sampling_ddpm.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from diffusers import UNet2DModel
4
+
5
+ class DDPM:
6
+ def __init__(
7
+ self,
8
+ num_train_timesteps:int = 1000,
9
+ beta_start: float = 0.0001,
10
+ beta_end: float = 0.02,
11
+ ):
12
+ self.num_train_timesteps = num_train_timesteps
13
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
14
+ self.alphas = 1.0 - self.betas
15
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
16
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
17
+
18
+ @torch.no_grad()
19
+ def sample(
20
+ self,
21
+ unet: UNet2DModel,
22
+ batch_size: int,
23
+ in_channels: int,
24
+ sample_size: int,
25
+ ):
26
+ betas = self.betas.to(unet.device)
27
+ alphas = self.alphas.to(unet.device)
28
+ alphas_cumprod = self.alphas_cumprod.to(unet.device)
29
+ timesteps = self.timesteps.to(unet.device)
30
+ images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
31
+ for timestep in tqdm(timesteps, desc='Sampling'):
32
+ pred_noise: torch.Tensor = unet(images, timestep).sample
33
+
34
+ # mean of q(x_{t-1}|x_t)
35
+ alpha_t = alphas[timestep]
36
+ alpha_cumprod_t = alphas_cumprod[timestep]
37
+ sqrt_alpha_t = alpha_t ** 0.5
38
+ one_minus_alpha_t = 1.0 - alpha_t
39
+ sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5
40
+ mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t
41
+
42
+ # variance of q(x_{t-1}|x_t)
43
+ if timestep > 0:
44
+ beta_t = betas[timestep]
45
+ one_minus_alpha_cumprod_t_minus_one = 1.0 - alphas_cumprod[timestep - 1]
46
+ one_divided_by_sigma_square = alpha_t / beta_t + 1.0 / one_minus_alpha_cumprod_t_minus_one
47
+ variance = (1.0 / one_divided_by_sigma_square) ** 0.5
48
+ else:
49
+ variance = torch.zeros_like(timestep)
50
+
51
+ epsilon = torch.randn_like(images)
52
+ images = mean + variance * epsilon
53
+ images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
54
+ return images
55
+
56
+ model = UNet2DModel.from_pretrained('ddpm-animefaces-64').cuda()
57
+ ddpm = DDPM()
58
+ images = ddpm.sample(model, 32, 3, 64)
59
+
60
+ from diffusers.utils import make_image_grid, numpy_to_pil
61
+ image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
62
+ image_grid.save('ddpm-sample-results.png')