Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| ################################################################################# | |
| # AE # | |
| ################################################################################# | |
| class AE(nn.Module): | |
| def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)): | |
| super().__init__() | |
| self.output_emb_width = output_emb_width | |
| self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:]) | |
| self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1]) | |
| def preprocess(self, x): | |
| x = x.permute(0, 3, 1, 2).float() | |
| return x | |
| def encode(self, x): | |
| x_in = self.preprocess(x) | |
| x_encoder = self.encoder(x_in) | |
| return x_encoder | |
| def forward(self, x): | |
| x_in = self.preprocess(x) | |
| x_encoder = self.encoder(x_in) | |
| x_out = self.decoder(x_encoder) | |
| return x_out | |
| def decode(self, x): | |
| x_out = self.decoder(x) | |
| return x_out | |
| ################################################################################# | |
| # VAE # | |
| ################################################################################# | |
| class VAE(nn.Module): | |
| def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)): | |
| super().__init__() | |
| self.output_emb_width = output_emb_width | |
| self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:]) | |
| self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1]) | |
| def preprocess(self, x): | |
| x = x.permute(0, 3, 1, 2).float() | |
| return x | |
| def encode(self, x): | |
| x_in = self.preprocess(x) | |
| x_encoder = self.encoder(x_in) | |
| x_encoder = DiagonalGaussianDistribution(x_encoder) | |
| x_encoder = x_encoder.sample() | |
| return x_encoder | |
| def forward(self, x, need_loss=False): | |
| x_in = self.preprocess(x) | |
| x_encoder = self.encoder(x_in) | |
| x_encoder = DiagonalGaussianDistribution(x_encoder) | |
| kl_loss = x_encoder.kl() | |
| x_encoder = x_encoder.sample() | |
| x_out = self.decoder(x_encoder) | |
| if need_loss: | |
| # sigma vae for better quality | |
| log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log() | |
| log_sigma = -6 + F.softplus(log_sigma - (-6)) | |
| rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma | |
| rec = rec.sum(dim=(1,2,3)) | |
| loss = { | |
| "rec": rec.mean(), | |
| "kl": kl_loss.mean()} | |
| return x_out, loss | |
| else: | |
| return x_out | |
| def decode(self, x): | |
| x_out = self.decoder(x) | |
| return x_out | |
| ################################################################################# | |
| # AE Zoos # | |
| ################################################################################# | |
| def ae(**kwargs): | |
| return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs) | |
| def vae(**kwargs): | |
| return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs) | |
| AE_models = { | |
| 'AE_Model': ae, 'VAE_Model': vae | |
| } | |
| ################################################################################# | |
| # Inner Architectures # | |
| ################################################################################# | |
| class Encoder(nn.Module): | |
| def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)): | |
| super().__init__() | |
| self.model = nn.ModuleList() | |
| self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (0, 0)) | |
| block_in = width * in_ch_mult[0] | |
| for i in range(len(in_ch_mult)): | |
| block_in = width * in_ch_mult[i] | |
| block_out = width * ch_mult[i] | |
| self.model.append(CausalPad2d((0, 0, 2, 0))) | |
| self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (0, 0))) | |
| for j in range(depth): | |
| self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j)) | |
| block_in = block_out | |
| self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (0, 0)) | |
| def forward(self, x): | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv_in(x) | |
| for layer in self.model: | |
| x = layer(x) | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv_out(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)): | |
| super().__init__() | |
| self.model = nn.ModuleList() | |
| block_in = width * ch_mult[0] | |
| self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (0,0)) | |
| for i in range(len(in_ch_mult)): | |
| block_in = width * ch_mult[i] | |
| block_out = width * in_ch_mult[i] | |
| for j in range(depth): | |
| self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j)) | |
| block_in = block_out | |
| self.model.append(Upsample(block_in)) | |
| self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (0,0)) | |
| self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (0, 0)) | |
| def forward(self, x): | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv_in(x) | |
| for layer in self.model: | |
| x = layer(x) | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv_out1(x) | |
| x = x * torch.sigmoid(x) | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv_out2(x) | |
| return x.permute(0,2,3,1) | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (0, 0)) | |
| def forward(self, x): | |
| x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest") | |
| x = F.pad(x, (0, 0, 2, 0)) | |
| x = self.conv(x) | |
| return x | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.padd = CausalPad2d((0, 0, 2*(3 ** dil), 0)) | |
| self.conv1 = torch.nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size=(3, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| dilation=(3 ** dil, 1), | |
| ) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = torch.nn.Conv2d(out_channels, | |
| out_channels, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| ) | |
| def forward(self, x): | |
| h = x | |
| h = h*torch.sigmoid(h) | |
| h = self.padd(h) | |
| h = self.conv1(h) | |
| h = h*torch.sigmoid(h) | |
| h = self.conv2(h) | |
| h = self.dropout(h) | |
| return x+h | |
| class DiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.Tensor([0.]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.sum(torch.pow(self.mean, 2) | |
| + self.var - 1.0 - self.logvar, | |
| dim=[1, 2, 3]) | |
| else: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var - 1.0 - self.logvar + other.logvar, | |
| dim=[1, 2, 3]) | |
| def nll(self, sample, dims=[1,2,3]): | |
| if self.deterministic: | |
| return torch.Tensor([0.]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum( | |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
| dim=dims) | |
| def mode(self): | |
| return self.mean | |
| class CausalPad2d(nn.Module): | |
| def __init__(self, pad): | |
| super().__init__() | |
| self.pad = pad | |
| def forward(self, x): | |
| return F.pad(x, self.pad) | |