# A modified version of "Fully Convolutional Mesh Autoencoder using Efficient Spatially Varying Kernels" # https://arxiv.org/abs/2006.04325 # and thanks to this more modern implementation as well # https://github.com/g-fiche/Mesh-VQ-VAE # https://arxiv.org/abs/2312.08291 import torch import torch.nn as nn import numpy as np import os ################################################################################# # AE # ################################################################################# class AE(nn.Module): def __init__(self, model, bs=16, num_vertices=6890): super().__init__() # currently only set up is for SMPL-H self.num_vertices = num_vertices self.bs=bs self.encoder = Encoder(model) self.decoder = Decoder(model) def encode(self, x): B, L = x.shape[0], x.shape[1] x = x.view(B * L, self.num_vertices, 3) x_encoder = self.encoder(x) return x_encoder def forward(self, x): B, L = x.shape[0], x.shape[1] x = x.view(B * L, self.num_vertices, 3) x_encoder = self.encoder(x) x_out = self.decoder(x_encoder) x_out = x_out.view(B, L, self.num_vertices, 3) return x_out def decode(self, x): T = x.shape[1] if x.shape[1] % self.bs != 0: x = torch.cat([x, torch.zeros_like(x[:, :self.bs-x.shape[1] % self.bs])], dim=1) outputs = [] for i in range(x.shape[0]): outputss = [] for j in range(0, x.shape[1], self.bs): chunk = x[i, j:j + self.bs] out = self.decoder(chunk) outputss.append(out) outputs.append(torch.cat(outputss, dim=0)[:T]) x_out = torch.stack(outputs, dim=0) return x_out ################################################################################# # AE Zoos # ################################################################################# def ae(**kwargs): config_model = {"batch": 16, "connection_folder": "body_models/ConnectionMatrices/", "initial_connection_fn": "body_models/ConnectionMatrices/_pool0.npy", "connection_layer_lst": ["pool0", "pool1", "pool2", "pool3", "pool4", "pool5", "pool6", "pool7_28", "unpool7_28", "unpool6", "unpool5", "unpool4", "unpool3", "unpool2", "unpool1", "unpool0"], "channel_lst": [64, 64, 128, 128, 256, 256, 512, 12, 512, 256, 256, 128, 128, 64, 64, 3], "weight_num_lst": [9, 0, 9, 0, 9, 0, 9, 0, 0, 9, 0, 9, 0, 9, 0, 9], "residual_rate_lst": [0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0], } return AE(FullyConvAE(config_model, **kwargs), bs=config_model["batch"]) AE_models = { 'AE_Model': ae } class Encoder(nn.Module): def __init__(self, model): super(Encoder, self).__init__() self.model = model def forward(self, x): out = self.model.forward_till_layer_n(x, len(self.model.channel_lst) // 2) return out class Decoder(nn.Module): def __init__(self, model): super(Decoder, self).__init__() self.model = model def forward(self, x): out = self.model.forward_from_layer_n(x, len(self.model.channel_lst) // 2) return out class FullyConvAE(nn.Module): def __init__( self, config_model=None, test_mode=False ): # layer_info_lst= [(point_num, feature_dim)] super(FullyConvAE, self).__init__() self.test_mode = test_mode self.channel_lst = config_model["channel_lst"] self.residual_rate_lst = config_model["residual_rate_lst"] self.weight_num_lst = config_model["weight_num_lst"] self.initial_connection_fn = config_model["initial_connection_fn"] data = np.load(self.initial_connection_fn) neighbor_id_dist_lstlst = data[:, 1:] # point_num*(1+2*neighbor_num) self.point_num = data.shape[0] self.neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape( (self.point_num, -1, 2) )[ :, :, 0 ] # point_num*neighbor_num self.neighbor_num_lst = np.array(data[:, 0]) # point_num self.relu = nn.ELU() self.batch = config_model["batch"] #####For Laplace computation###### self.initial_neighbor_id_lstlst = torch.LongTensor( self.neighbor_id_lstlst ).cuda() # point_num*max_neighbor_num self.initial_neighbor_num_lst = torch.FloatTensor( self.neighbor_num_lst ).cuda() # point_num self.connection_folder = config_model["connection_folder"] self.connection_layer_fn_lst = [] fn_lst = os.listdir(self.connection_folder) self.connection_layer_lst = config_model["connection_layer_lst"] for layer_name in self.connection_layer_lst: layer_name = "_" + layer_name + "." find_fn = False for fn in fn_lst: if (layer_name in fn) and ((".npy" in fn) or (".npz" in fn)): self.connection_layer_fn_lst += [self.connection_folder + fn] find_fn = True break if find_fn == False: print("!!!ERROR: cannot find the connection layer fn") self.init_layers(self.batch) self.initial_max_neighbor_num = self.initial_neighbor_id_lstlst.shape[1] def init_layers(self, batch): self.layer_lst = ( [] ) ##[in_channel, out_channel, in_pn, out_pn, max_neighbor_num, neighbor_num_lst,neighbor_id_lstlst,conv_layer, residual_layer] self.layer_num = len(self.channel_lst) in_point_num = self.point_num in_channel = 3 for l in range(self.layer_num): out_channel = self.channel_lst[l] weight_num = self.weight_num_lst[l] residual_rate = self.residual_rate_lst[l] connection_info = np.load(self.connection_layer_fn_lst[l]) out_point_num = connection_info.shape[0] neighbor_num_lst = torch.FloatTensor( connection_info[:, 0].astype(float) ).cuda() # out_point_num*1 neighbor_id_dist_lstlst = connection_info[ :, 1: ] # out_point_num*(max_neighbor_num*2) print(self.connection_layer_fn_lst[l]) print() neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape( (out_point_num, -1, 2) )[ :, :, 0 ] # out_point_num*max_neighbor_num neighbor_id_lstlst = torch.LongTensor(neighbor_id_lstlst).cuda() max_neighbor_num = neighbor_id_lstlst.shape[1] avg_neighbor_num = round(neighbor_num_lst.mean().item()) effective_w_weights_rate = neighbor_num_lst.sum() / float( max_neighbor_num * out_point_num ) effective_w_weights_rate = round(effective_w_weights_rate.item(), 3) pc_mask = torch.ones(in_point_num + 1).cuda() pc_mask[in_point_num] = 0 neighbor_mask_lst = pc_mask[ neighbor_id_lstlst ].contiguous() # out_pn*max_neighbor_num neighbor is 1 otherwise 0 zeros_batch_outpn_outchannel = torch.zeros( (batch, out_point_num, out_channel) ).cuda() if (residual_rate < 0) or (residual_rate > 1): print("Invalid residual rate", residual_rate) ####parameters for conv############### conv_layer = "" if residual_rate < 1: weights = torch.randn(weight_num, out_channel * in_channel).cuda() weights = nn.Parameter(weights).cuda() self.register_parameter("weights" + str(l), weights) bias = nn.Parameter(torch.zeros(out_channel).cuda()) self.register_parameter("bias" + str(l), bias) w_weights = torch.randn(out_point_num, max_neighbor_num, weight_num) / ( avg_neighbor_num * weight_num ) w_weights = nn.Parameter(w_weights.cuda()) self.register_parameter("w_weights" + str(l), w_weights) conv_layer = (weights, bias, w_weights) ####parameters for residual############### ## a residual layer with out_point_num==in_point_num and residual_rate==1 is a pooling or unpooling layer residual_layer = "" if residual_rate > 0: p_neighbors = "" weight_res = "" if out_point_num != in_point_num: p_neighbors = nn.Parameter( ( torch.randn(out_point_num, max_neighbor_num) / (avg_neighbor_num) ).cuda() ) self.register_parameter("p_neighbors" + str(l), p_neighbors) if out_channel != in_channel: weight_res = torch.randn(out_channel, in_channel) # self.normalize_weights(weight_res) weight_res = weight_res / out_channel weight_res = nn.Parameter(weight_res.cuda()) self.register_parameter("weight_res" + str(l), weight_res) residual_layer = (weight_res, p_neighbors) #####put everythin together layer = ( in_channel, out_channel, in_point_num, out_point_num, weight_num, max_neighbor_num, neighbor_num_lst, neighbor_id_lstlst, conv_layer, residual_layer, residual_rate, neighbor_mask_lst, zeros_batch_outpn_outchannel, ) self.layer_lst += [layer] in_point_num = out_point_num in_channel = out_channel # precompute the parameters so as to accelerate forwarding in testing mode def init_test_mode(self): for l in range(len(self.layer_lst)): layer_info = self.layer_lst[l] ( in_channel, out_channel, in_pn, out_pn, weight_num, max_neighbor_num, neighbor_num_lst, neighbor_id_lstlst, conv_layer, residual_layer, residual_rate, neighbor_mask_lst, zeros_batch_outpn_outchannel, ) = layer_info if len(conv_layer) != 0: ( weights, bias, raw_w_weights, ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num w_weights = "" w_weights = raw_w_weights * neighbor_mask_lst.view( out_pn, max_neighbor_num, 1 ).repeat( 1, 1, weight_num ) # out_pn*max_neighbor_num*weight_num weights = torch.einsum( "pmw,wc->pmc", [w_weights, weights] ) # out_pn*max_neighbor_num*(out_channel*in_channel) weights = weights.view( out_pn, max_neighbor_num, out_channel, in_channel ) conv_layer = weights, bias ####compute output of residual layer#### if len(residual_layer) != 0: ( weight_res, p_neighbors_raw, ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num if in_pn != out_pn: p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat( 1, max_neighbor_num ) residual_layer = weight_res, p_neighbors self.layer_lst[l] = ( in_channel, out_channel, in_pn, out_pn, weight_num, max_neighbor_num, neighbor_num_lst, neighbor_id_lstlst, conv_layer, residual_layer, residual_rate, neighbor_mask_lst, zeros_batch_outpn_outchannel, ) # a faster mode for testing # input_pc batch*in_pn*in_channel # out_pc batch*out_pn*out_channel def forward_one_conv_layer_batch_during_test( self, in_pc, layer_info, is_final_layer=False ): batch = in_pc.shape[0] ( in_channel, out_channel, in_pn, out_pn, weight_num, max_neighbor_num, neighbor_num_lst, neighbor_id_lstlst, conv_layer, residual_layer, residual_rate, neighbor_mask_lst, zeros_batch_outpn_outchannel, ) = layer_info device = in_pc.get_device() if device < 0: device = "cpu" in_pc_pad = torch.cat( (in_pc, torch.zeros(batch, 1, in_channel).to(device)), 1 ) # batch*(in_pn+1)*in_channel in_neighbors = in_pc_pad[ :, neighbor_id_lstlst.to(device) ] # batch*out_pn*max_neighbor_num*in_channel ####compute output of convolution layer#### out_pc_conv = zeros_batch_outpn_outchannel.clone() if len(conv_layer) != 0: ( weights, bias, ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num out_neighbors = torch.einsum( "pmoi,bpmi->bpmo", [weights.to(device), in_neighbors] ) # batch*out_pn*max_neighbor_num*out_channel out_pc_conv = out_neighbors.sum(2) out_pc_conv = out_pc_conv + bias if is_final_layer == False: out_pc_conv = self.relu( out_pc_conv ) ##self.relu is defined in the init function # if(self.residual_rate==0): # return out_pc ####compute output of residual layer#### out_pc_res = zeros_batch_outpn_outchannel.clone() if len(residual_layer) != 0: ( weight_res, p_neighbors, ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num if in_channel != out_channel: in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad]) out_pc_res = [] if in_pn == out_pn: out_pc_res = in_pc_pad[:, 0:in_pn].clone() else: in_neighbors = in_pc_pad[ :, neighbor_id_lstlst.to(device) ] # batch*out_pn*max_neighbor_num*out_channel out_pc_res = torch.einsum( "pm,bpmo->bpo", [p_neighbors.to(device), in_neighbors] ) out_pc = out_pc_conv.to(device) * np.sqrt(1 - residual_rate) + out_pc_res.to( device ) * np.sqrt(residual_rate) return out_pc # use in train mode. Slower than test mode # input_pc batch*in_pn*in_channel # out_pc batch*out_pn*out_channel def forward_one_conv_layer_batch(self, in_pc, layer_info, is_final_layer=False): batch = in_pc.shape[0] ( in_channel, out_channel, in_pn, out_pn, weight_num, max_neighbor_num, neighbor_num_lst, neighbor_id_lstlst, conv_layer, residual_layer, residual_rate, neighbor_mask_lst, zeros_batch_outpn_outchannel, ) = layer_info in_pc_pad = torch.cat( (in_pc, torch.zeros(batch, 1, in_channel).cuda()), 1 ) # batch*(in_pn+1)*in_channel in_neighbors = in_pc_pad[ :, neighbor_id_lstlst ] # batch*out_pn*max_neighbor_num*in_channel ####compute output of convolution layer#### out_pc_conv = zeros_batch_outpn_outchannel.clone() if len(conv_layer) != 0: ( weights, bias, raw_w_weights, ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num w_weights = raw_w_weights * neighbor_mask_lst.view( out_pn, max_neighbor_num, 1 ).repeat( 1, 1, weight_num ) # out_pn*max_neighbor_num*weight_num weights = torch.einsum( "pmw,wc->pmc", [w_weights, weights] ) # out_pn*max_neighbor_num*(out_channel*in_channel) weights = weights.view(out_pn, max_neighbor_num, out_channel, in_channel) out_neighbors = torch.einsum( "pmoi,bpmi->bpmo", [weights, in_neighbors] ) # batch*out_pn*max_neighbor_num*out_channel out_pc_conv = out_neighbors.sum(2) out_pc_conv = out_pc_conv + bias if is_final_layer == False: out_pc_conv = self.relu( out_pc_conv ) ##self.relu is defined in the init function ####compute output of residual layer#### out_pc_res = zeros_batch_outpn_outchannel.clone() if len(residual_layer) != 0: ( weight_res, p_neighbors_raw, ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num if in_channel != out_channel: in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad]) out_pc_res = [] if in_pn == out_pn: out_pc_res = in_pc_pad[:, 0:in_pn].clone() else: in_neighbors = in_pc_pad[ :, neighbor_id_lstlst ] # batch*out_pn*max_neighbor_num*out_channel p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat( 1, max_neighbor_num ) out_pc_res = torch.einsum("pm,bpmo->bpo", [p_neighbors, in_neighbors]) # print(out_pc_conv.shape, out_pc_res.shape) out_pc = out_pc_conv * np.sqrt(1 - residual_rate) + out_pc_res * np.sqrt( residual_rate ) return out_pc def forward_till_layer_n(self, in_pc, layer_n): out_pc = in_pc.clone() for i in range(layer_n): if self.test_mode == False: out_pc = self.forward_one_conv_layer_batch(out_pc, self.layer_lst[i]) else: out_pc = self.forward_one_conv_layer_batch_during_test( out_pc, self.layer_lst[i] ) # out_pc = self.final_linear(out_pc.transpose(1,2)).transpose(1,2) #batch*3*point_num return out_pc def forward_from_layer_n(self, in_pc, layer_n): out_pc = in_pc.clone() for i in range(layer_n, self.layer_num): if i < (self.layer_num - 1): if self.test_mode == False: out_pc = self.forward_one_conv_layer_batch( out_pc, self.layer_lst[i] ) else: out_pc = self.forward_one_conv_layer_batch_during_test( out_pc, self.layer_lst[i] ) else: if self.test_mode == False: out_pc = self.forward_one_conv_layer_batch( out_pc, self.layer_lst[i], is_final_layer=True ) else: out_pc = self.forward_one_conv_layer_batch_during_test( out_pc, self.layer_lst[i], is_final_layer=True ) return out_pc def forward_layer_n(self, in_pc, layer_n): out_pc = in_pc.clone() if layer_n < (self.layer_num - 1): if self.test_mode == False: out_pc = self.forward_one_conv_layer_batch( out_pc, self.layer_lst[layer_n] ) else: out_pc = self.forward_one_conv_layer_batch_during_test( out_pc, self.layer_lst[layer_n] ) else: if self.test_mode == False: out_pc = self.forward_one_conv_layer_batch( out_pc, self.layer_lst[layer_n], is_final_layer=True ) else: out_pc = self.forward_one_conv_layer_batch_during_test( out_pc, self.layer_lst[layer_n], is_final_layer=True ) return out_pc