|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.parallel |
|
|
from miscc.config import cfg |
|
|
from torch.autograd import Variable |
|
|
|
|
|
|
|
|
def conv3x1(in_planes, out_planes, stride=1): |
|
|
"3x1 convolution with padding" |
|
|
kernel_length = 41 |
|
|
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_length, stride=stride, |
|
|
padding=20, bias=False) |
|
|
|
|
|
def old_conv3x1(in_planes, out_planes, stride=1): |
|
|
"3x1 convolution with padding" |
|
|
kernel_length = 3 |
|
|
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_length, stride=stride, |
|
|
padding=1, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upBlock4(in_planes, out_planes): |
|
|
kernel_length = 41 |
|
|
stride = 4 |
|
|
block = nn.Sequential( |
|
|
|
|
|
|
|
|
nn.ConvTranspose1d(in_planes,out_planes,kernel_size=kernel_length,stride=stride, padding=19,output_padding=1), |
|
|
nn.BatchNorm1d(out_planes), |
|
|
|
|
|
nn.PReLU()) |
|
|
return block |
|
|
def upBlock2(in_planes, out_planes): |
|
|
kernel_length = 41 |
|
|
stride = 2 |
|
|
block = nn.Sequential( |
|
|
|
|
|
|
|
|
nn.ConvTranspose1d(in_planes,out_planes,kernel_size=kernel_length,stride=stride, padding=20,output_padding=1), |
|
|
nn.BatchNorm1d(out_planes), |
|
|
|
|
|
nn.PReLU()) |
|
|
return block |
|
|
|
|
|
def sameBlock(in_planes, out_planes): |
|
|
block = nn.Sequential( |
|
|
|
|
|
conv3x1(in_planes, out_planes), |
|
|
nn.BatchNorm1d(out_planes), |
|
|
|
|
|
nn.PReLU()) |
|
|
return block |
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
def __init__(self, channel_num): |
|
|
super(ResBlock, self).__init__() |
|
|
self.block = nn.Sequential( |
|
|
conv3x1(channel_num, channel_num), |
|
|
nn.BatchNorm1d(channel_num), |
|
|
|
|
|
nn.PReLU(), |
|
|
conv3x1(channel_num, channel_num), |
|
|
nn.BatchNorm1d(channel_num)) |
|
|
self.relu = nn.PReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
out = self.block(x) |
|
|
out += residual |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class COND_NET(nn.Module): |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
super(COND_NET, self).__init__() |
|
|
self.t_dim = cfg.TEXT.DIMENSION |
|
|
self.c_dim = cfg.GAN.CONDITION_DIM |
|
|
self.fc = nn.Linear(self.t_dim, self.c_dim, bias=True) |
|
|
self.relu = nn.PReLU() |
|
|
|
|
|
def encode(self, text_embedding): |
|
|
x = self.relu(self.fc(text_embedding)) |
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, text_embedding): |
|
|
c_code = self.encode(text_embedding) |
|
|
|
|
|
return c_code |
|
|
|
|
|
|
|
|
class D_GET_LOGITS(nn.Module): |
|
|
def __init__(self, ndf, nef, bcondition=True): |
|
|
super(D_GET_LOGITS, self).__init__() |
|
|
self.df_dim = ndf |
|
|
self.ef_dim = nef |
|
|
self.bcondition = bcondition |
|
|
kernel_length =41 |
|
|
if bcondition: |
|
|
self.convd1d = nn.ConvTranspose1d(ndf*8,ndf //2,kernel_size=kernel_length,stride=1, padding=20) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.outlogits = nn.Sequential( |
|
|
old_conv3x1(ndf //2 + nef, ndf //2 ), |
|
|
nn.BatchNorm1d(ndf //2 ), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf //2 , 1, kernel_size=16, stride=4), |
|
|
|
|
|
nn.Sigmoid() |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.convd1d = nn.ConvTranspose1d(ndf*8,ndf //2,kernel_size=kernel_length,stride=1, padding=20) |
|
|
self.outlogits = nn.Sequential( |
|
|
nn.Conv1d(ndf // 2 , 1, kernel_size=16, stride=4), |
|
|
|
|
|
nn.Sigmoid()) |
|
|
|
|
|
def forward(self, h_code, c_code=None): |
|
|
|
|
|
h_code = self.convd1d(h_code) |
|
|
if self.bcondition and c_code is not None: |
|
|
|
|
|
c_code = c_code.view(-1, self.ef_dim, 1) |
|
|
|
|
|
|
|
|
c_code = c_code.repeat(1, 1, 16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_c_code = torch.cat((h_code, c_code), 1) |
|
|
else: |
|
|
h_c_code = h_code |
|
|
|
|
|
output = self.outlogits(h_c_code) |
|
|
|
|
|
return output.view(-1) |
|
|
|
|
|
|
|
|
|
|
|
class STAGE1_G(nn.Module): |
|
|
def __init__(self): |
|
|
super(STAGE1_G, self).__init__() |
|
|
self.gf_dim = cfg.GAN.GF_DIM * 8 |
|
|
self.ef_dim = cfg.GAN.CONDITION_DIM |
|
|
|
|
|
self.define_module() |
|
|
|
|
|
def define_module(self): |
|
|
kernel_length = 41 |
|
|
ninput = self.ef_dim |
|
|
ngf = self.gf_dim |
|
|
|
|
|
|
|
|
self.cond_net = COND_NET() |
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(ninput, ngf * 16, bias=False), |
|
|
nn.BatchNorm1d(ngf * 16), |
|
|
|
|
|
nn.PReLU()) |
|
|
|
|
|
|
|
|
self.upsample1 = upBlock4(ngf, ngf // 2) |
|
|
|
|
|
self.upsample2 = upBlock4(ngf // 2, ngf // 4) |
|
|
|
|
|
self.upsample3 = upBlock4(ngf // 4, ngf // 8) |
|
|
|
|
|
self.upsample4 = upBlock2(ngf // 8, ngf // 16) |
|
|
self.upsample5 = upBlock2(ngf // 16, ngf // 16) |
|
|
|
|
|
self.RIR = nn.Sequential( |
|
|
nn.ConvTranspose1d(ngf // 16,1,kernel_size=kernel_length,stride=1, padding=20), |
|
|
|
|
|
nn.Tanh()) |
|
|
|
|
|
def forward(self, text_embedding): |
|
|
|
|
|
c_code = self.cond_net(text_embedding) |
|
|
|
|
|
h_code = self.fc(c_code) |
|
|
|
|
|
h_code = h_code.view(-1, self.gf_dim, 16) |
|
|
|
|
|
h_code = self.upsample1(h_code) |
|
|
|
|
|
h_code = self.upsample2(h_code) |
|
|
|
|
|
h_code = self.upsample3(h_code) |
|
|
|
|
|
h_code = self.upsample4(h_code) |
|
|
h_code = self.upsample5(h_code) |
|
|
|
|
|
|
|
|
fake_RIR = self.RIR(h_code) |
|
|
|
|
|
|
|
|
return None, fake_RIR, text_embedding |
|
|
|
|
|
|
|
|
class STAGE1_D(nn.Module): |
|
|
def __init__(self): |
|
|
super(STAGE1_D, self).__init__() |
|
|
self.df_dim = cfg.GAN.DF_DIM |
|
|
self.ef_dim = cfg.GAN.CONDITION_DIM |
|
|
self.define_module() |
|
|
|
|
|
def define_module(self): |
|
|
ndf, nef = self.df_dim, self.ef_dim |
|
|
kernel_length =41 |
|
|
self.encode_RIR = nn.Sequential( |
|
|
nn.Conv1d(1, ndf, kernel_length, 4, 20, bias=False), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
|
|
nn.Conv1d(ndf, ndf * 2, kernel_length, 4, 20, bias=False), |
|
|
nn.BatchNorm1d(ndf * 2), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
|
|
nn.Conv1d(ndf*2, ndf * 4, kernel_length, 4, 20, bias=False), |
|
|
nn.BatchNorm1d(ndf * 4), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
|
|
|
nn.Conv1d(ndf*4, ndf * 8, kernel_length, 4, 20, bias=False), |
|
|
nn.BatchNorm1d(ndf * 8), |
|
|
|
|
|
nn.LeakyReLU(0.2, inplace=True) |
|
|
) |
|
|
|
|
|
self.get_cond_logits = D_GET_LOGITS(ndf, nef) |
|
|
self.get_uncond_logits = None |
|
|
|
|
|
def forward(self, RIRs): |
|
|
|
|
|
RIR_embedding = self.encode_RIR(RIRs) |
|
|
|
|
|
|
|
|
return RIR_embedding |
|
|
|
|
|
|
|
|
|
|
|
class STAGE2_G(nn.Module): |
|
|
def __init__(self, STAGE1_G): |
|
|
super(STAGE2_G, self).__init__() |
|
|
self.gf_dim = cfg.GAN.GF_DIM |
|
|
self.ef_dim = cfg.GAN.CONDITION_DIM |
|
|
|
|
|
self.STAGE1_G = STAGE1_G |
|
|
|
|
|
for param in self.STAGE1_G.parameters(): |
|
|
param.requires_grad = False |
|
|
self.define_module() |
|
|
|
|
|
def _make_layer(self, block, channel_num): |
|
|
layers = [] |
|
|
for i in range(cfg.GAN.R_NUM): |
|
|
layers.append(block(channel_num)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def define_module(self): |
|
|
ngf = self.gf_dim |
|
|
|
|
|
|
|
|
self.cond_net = COND_NET() |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
conv3x1(1, ngf), |
|
|
nn.ReLU(True), |
|
|
nn.Conv1d(ngf, ngf * 2, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ngf * 2), |
|
|
nn.ReLU(True), |
|
|
nn.Conv1d(ngf * 2, ngf * 4, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ngf * 4), |
|
|
nn.ReLU(True)) |
|
|
self.hr_joint = nn.Sequential( |
|
|
conv3x1(self.ef_dim + ngf * 4, ngf * 4), |
|
|
nn.BatchNorm1d(ngf * 4), |
|
|
nn.ReLU(True)) |
|
|
self.residual = self._make_layer(ResBlock, ngf * 4) |
|
|
|
|
|
self.upsample1 = upBlock4(ngf * 4, ngf * 2) |
|
|
|
|
|
self.upsample2 = upBlock4(ngf * 2, ngf) |
|
|
|
|
|
self.upsample3 = upBlock4(ngf, ngf // 2) |
|
|
|
|
|
self.upsample4 = sameBlock(ngf // 2, ngf // 4) |
|
|
|
|
|
self.RIR = nn.Sequential( |
|
|
conv3x1(ngf // 4, 1), |
|
|
nn.Tanh()) |
|
|
|
|
|
def forward(self, text_embedding): |
|
|
_, stage1_RIR, _= self.STAGE1_G(text_embedding) |
|
|
stage1_RIR = stage1_RIR.detach() |
|
|
encoded_RIR = self.encoder(stage1_RIR) |
|
|
|
|
|
|
|
|
c_code1 = self.cond_net(text_embedding) |
|
|
c_code = c_code1.view(-1, self.ef_dim, 1) |
|
|
c_code = c_code.repeat(1, 1, 256) |
|
|
i_c_code = torch.cat([encoded_RIR, c_code], 1) |
|
|
h_code = self.hr_joint(i_c_code) |
|
|
h_code = self.residual(h_code) |
|
|
|
|
|
h_code = self.upsample1(h_code) |
|
|
h_code = self.upsample2(h_code) |
|
|
h_code = self.upsample3(h_code) |
|
|
h_code = self.upsample4(h_code) |
|
|
|
|
|
fake_RIR = self.RIR(h_code) |
|
|
return stage1_RIR, fake_RIR, c_code1 |
|
|
|
|
|
|
|
|
class STAGE2_D(nn.Module): |
|
|
def __init__(self): |
|
|
super(STAGE2_D, self).__init__() |
|
|
self.df_dim = cfg.GAN.DF_DIM |
|
|
self.ef_dim = cfg.GAN.CONDITION_DIM |
|
|
self.define_module() |
|
|
|
|
|
def define_module(self): |
|
|
ndf, nef = self.df_dim, self.ef_dim |
|
|
self.encode_RIR = nn.Sequential( |
|
|
nn.Conv1d(1, ndf, 3, 1, 1, bias=False), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf, ndf * 2, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ndf * 2), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf * 2, ndf * 4, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ndf * 4), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf * 4, ndf * 8, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ndf * 8), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf * 8, ndf * 16, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ndf * 16), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
nn.Conv1d(ndf * 16, ndf * 32, 16, 4, 6, bias=False), |
|
|
nn.BatchNorm1d(ndf * 32), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
conv3x1(ndf * 32, ndf * 16), |
|
|
nn.BatchNorm1d(ndf * 16), |
|
|
nn.LeakyReLU(0.2, inplace=True), |
|
|
conv3x1(ndf * 16, ndf * 8), |
|
|
nn.BatchNorm1d(ndf * 8), |
|
|
nn.LeakyReLU(0.2, inplace=True) |
|
|
) |
|
|
|
|
|
self.get_cond_logits = D_GET_LOGITS(ndf, nef, bcondition=True) |
|
|
self.get_uncond_logits = D_GET_LOGITS(ndf, nef, bcondition=False) |
|
|
|
|
|
def forward(self, RIRs): |
|
|
RIR_embedding = self.encode_RIR(RIRs) |
|
|
return RIR_embedding |
|
|
|