Clemylia commited on
Commit
9c1c93c
·
verified ·
1 Parent(s): 89ffd2a

Upload vae_model_architecture.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vae_model_architecture.py +35 -0
vae_model_architecture.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # NOTE: LATENT_DIM doit être le même que celui utilisé pour l'entraînement (128)
6
+ class VAE(nn.Module):
7
+ def __init__(self, latent_dim=128):
8
+ super(VAE, self).__init__()
9
+ self.latent_dim = latent_dim
10
+
11
+ # ENCODEUR
12
+ self.encoder = nn.Sequential(
13
+ nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
14
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
15
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),
16
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.ReLU(),
17
+ nn.Flatten()
18
+ )
19
+ self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
20
+ self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
21
+
22
+ # DÉCODEUR
23
+ self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
24
+ self.decoder = nn.Sequential(
25
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),
26
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
27
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
28
+ nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
29
+ nn.Tanh()
30
+ )
31
+
32
+ def decode(self, z):
33
+ h = self.decoder_input(z)
34
+ h = h.view(-1, 256, 4, 4)
35
+ return self.decoder(h)