ACMDM_Motion_Generation / models /AE_2D_Causal.py
sourxbhh's picture
Add model directory
0f34fb9
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
#################################################################################
# AE #
#################################################################################
class AE(nn.Module):
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
super().__init__()
self.output_emb_width = output_emb_width
self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
def preprocess(self, x):
x = x.permute(0, 3, 1, 2).float()
return x
def encode(self, x):
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
return x_encoder
def forward(self, x):
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
x_out = self.decoder(x_encoder)
return x_out
def decode(self, x):
x_out = self.decoder(x)
return x_out
#################################################################################
# VAE #
#################################################################################
class VAE(nn.Module):
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
super().__init__()
self.output_emb_width = output_emb_width
self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
def preprocess(self, x):
x = x.permute(0, 3, 1, 2).float()
return x
def encode(self, x):
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
x_encoder = DiagonalGaussianDistribution(x_encoder)
x_encoder = x_encoder.sample()
return x_encoder
def forward(self, x, need_loss=False):
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
x_encoder = DiagonalGaussianDistribution(x_encoder)
kl_loss = x_encoder.kl()
x_encoder = x_encoder.sample()
x_out = self.decoder(x_encoder)
if need_loss:
# sigma vae for better quality
log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log()
log_sigma = -6 + F.softplus(log_sigma - (-6))
rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma
rec = rec.sum(dim=(1,2,3))
loss = {
"rec": rec.mean(),
"kl": kl_loss.mean()}
return x_out, loss
else:
return x_out
def decode(self, x):
x_out = self.decoder(x)
return x_out
#################################################################################
# AE Zoos #
#################################################################################
def ae(**kwargs):
return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
def vae(**kwargs):
return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
AE_models = {
'AE_Model': ae, 'VAE_Model': vae
}
#################################################################################
# Inner Architectures #
#################################################################################
class Encoder(nn.Module):
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
super().__init__()
self.model = nn.ModuleList()
self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (0, 0))
block_in = width * in_ch_mult[0]
for i in range(len(in_ch_mult)):
block_in = width * in_ch_mult[i]
block_out = width * ch_mult[i]
self.model.append(CausalPad2d((0, 0, 2, 0)))
self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (0, 0)))
for j in range(depth):
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
block_in = block_out
self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (0, 0))
def forward(self, x):
x = F.pad(x, (0, 0, 2, 0))
x = self.conv_in(x)
for layer in self.model:
x = layer(x)
x = F.pad(x, (0, 0, 2, 0))
x = self.conv_out(x)
return x
class Decoder(nn.Module):
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
super().__init__()
self.model = nn.ModuleList()
block_in = width * ch_mult[0]
self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (0,0))
for i in range(len(in_ch_mult)):
block_in = width * ch_mult[i]
block_out = width * in_ch_mult[i]
for j in range(depth):
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
block_in = block_out
self.model.append(Upsample(block_in))
self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (0,0))
self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (0, 0))
def forward(self, x):
x = F.pad(x, (0, 0, 2, 0))
x = self.conv_in(x)
for layer in self.model:
x = layer(x)
x = F.pad(x, (0, 0, 2, 0))
x = self.conv_out1(x)
x = x * torch.sigmoid(x)
x = F.pad(x, (0, 0, 2, 0))
x = self.conv_out2(x)
return x.permute(0,2,3,1)
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (0, 0))
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest")
x = F.pad(x, (0, 0, 2, 0))
x = self.conv(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.padd = CausalPad2d((0, 0, 2*(3 ** dil), 0))
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=(3, 1),
stride=(1, 1),
padding=(0, 0),
dilation=(3 ** dil, 1),
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
)
def forward(self, x):
h = x
h = h*torch.sigmoid(h)
h = self.padd(h)
h = self.conv1(h)
h = h*torch.sigmoid(h)
h = self.conv2(h)
h = self.dropout(h)
return x+h
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class CausalPad2d(nn.Module):
def __init__(self, pad):
super().__init__()
self.pad = pad
def forward(self, x):
return F.pad(x, self.pad)