Linea / model.py
potato
Add model.py
9dfc87e
import torch.nn as nn
import torch.nn.functional as F
import torch
import functools
from torchvision import models
from torch.autograd import Variable
import numpy as np
import math
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class GlobalGenerator2(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', use_sig=False, n_UPsampling=0):
assert(n_blocks >= 0)
super(GlobalGenerator2, self).__init__()
activation = nn.ReLU(True)
mult = 8
model = [nn.ReflectionPad2d(4), nn.Conv2d(input_nc, ngf*mult, kernel_size=7, padding=0), norm_layer(ngf*mult), activation]
### downsample
for i in range(n_downsampling):
model += [nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=4, stride=2, padding=1),
norm_layer(ngf * mult // 2), activation]
mult = mult // 2
if n_UPsampling <= 0:
n_UPsampling = n_downsampling
### resnet blocks
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
### upsample
for i in range(n_UPsampling):
next_mult = mult // 2
if next_mult == 0:
next_mult = 1
mult = 1
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * next_mult), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * next_mult)), activation]
mult = next_mult
if use_sig:
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
else:
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, cond=None):
return self.model(input)
class InceptionV3(nn.Module): #avg pool
def __init__(self, num_classes, isTrain, use_aux=True, pretrain=False, freeze=True, every_feat=False):
super(InceptionV3, self).__init__()
""" Inception v3 expects (299,299) sized images for training and has auxiliary output
"""
self.every_feat = every_feat
self.model_ft = models.inception_v3(pretrained=pretrain)
stop = 0
if freeze and pretrain:
for child in self.model_ft.children():
if stop < 17:
for param in child.parameters():
param.requires_grad = False
stop += 1
num_ftrs = self.model_ft.AuxLogits.fc.in_features #768
self.model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
# Handle the primary net
num_ftrs = self.model_ft.fc.in_features #2048
self.model_ft.fc = nn.Linear(num_ftrs,num_classes)
self.model_ft.input_size = 299
self.isTrain = isTrain
self.use_aux = use_aux
if self.isTrain:
self.model_ft.train()
else:
self.model_ft.eval()
def forward(self, x, cond=None, catch_gates=False):
# N x 3 x 299 x 299
x = self.model_ft.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.model_ft.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.model_ft.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.model_ft.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.model_ft.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.model_ft.Mixed_5b(x)
feat1 = x
# N x 256 x 35 x 35
x = self.model_ft.Mixed_5c(x)
feat11 = x
# N x 288 x 35 x 35
x = self.model_ft.Mixed_5d(x)
feat12 = x
# N x 288 x 35 x 35
x = self.model_ft.Mixed_6a(x)
feat2 = x
# N x 768 x 17 x 17
x = self.model_ft.Mixed_6b(x)
feat21 = x
# N x 768 x 17 x 17
x = self.model_ft.Mixed_6c(x)
feat22 = x
# N x 768 x 17 x 17
x = self.model_ft.Mixed_6d(x)
feat23 = x
# N x 768 x 17 x 17
x = self.model_ft.Mixed_6e(x)
feat3 = x
# N x 768 x 17 x 17
aux_defined = self.isTrain and self.use_aux
if aux_defined:
aux = self.model_ft.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.model_ft.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.model_ft.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.model_ft.Mixed_7c(x)
# N x 2048 x 8 x 8
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 2048 x 1 x 1
feats = F.dropout(x, training=self.isTrain)
# N x 2048 x 1 x 1
x = torch.flatten(feats, 1)
# N x 2048
x = self.model_ft.fc(x)
# N x 1000 (num_classes)
if self.every_feat:
# return feat21, feats, x
return x, feat21
return x, aux