Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class ImageDecoder(nn.Module): | |
| def __init__(self, img_size, input_nc, output_nc, ngf=16, norm_layer=nn.LayerNorm): | |
| super(ImageDecoder, self).__init__() | |
| n_upsampling = int(math.log(img_size, 2)) | |
| ks_list = [3] * (n_upsampling // 3) + [5] * (n_upsampling - n_upsampling // 3) | |
| stride_list = [2] * n_upsampling | |
| decoder = [] | |
| chn_mult = [] | |
| for i in range(n_upsampling): | |
| chn_mult.append(2 ** (n_upsampling - i - 1)) | |
| decoder += [nn.ConvTranspose2d(input_nc, chn_mult[0] * ngf, | |
| kernel_size=ks_list[0], stride=stride_list[0], | |
| padding=ks_list[0] // 2, output_padding=stride_list[0]-1), | |
| norm_layer([chn_mult[0] * ngf, 2, 2]), | |
| nn.ReLU(True)] | |
| for i in range(1, n_upsampling): # add upsampling layers | |
| chn_prev = chn_mult[i - 1] * ngf | |
| chn_next = chn_mult[i] * ngf | |
| decoder += [nn.ConvTranspose2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, output_padding=stride_list[i]-1), | |
| norm_layer([chn_next, 2 ** (i+1) , 2 ** (i+1)]), | |
| nn.ReLU(True)] | |
| decoder += [nn.Conv2d(chn_mult[-1] * ngf, output_nc, kernel_size=7, padding=7 // 2)] | |
| decoder += [nn.Sigmoid()] | |
| self.decode = nn.Sequential(*decoder) | |
| def forward(self, latent_feat, trg_char, trg_img=None): | |
| """Standard forward""" | |
| dec_input = torch.cat((latent_feat, trg_char),-1) | |
| dec_input = dec_input.view(dec_input.size(0), dec_input.size(1), 1, 1) | |
| dec_out = self.decode(dec_input) | |
| output = {} | |
| output['gen_imgs'] = dec_out | |
| if trg_img is not None: | |
| output['img_l1loss'] = F.l1_loss(dec_out, trg_img) | |
| return output | |