Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool | |
| def calc_node_depth(topology): | |
| def dfs(node, topology): | |
| if topology[node] < 0: | |
| return 0 | |
| return 1 + dfs(topology[node], topology) | |
| depth = [] | |
| for i in range(len(topology)): | |
| depth.append(dfs(i, topology)) | |
| return depth | |
| def residual_ratio(k): | |
| return 1 / (k + 1) | |
| class Affine(nn.Module): | |
| def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): | |
| super(Affine, self).__init__() | |
| if scale: | |
| self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) | |
| else: | |
| self.register_parameter("scale", None) | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(num_parameters)) | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward(self, input): | |
| output = input | |
| if self.scale is not None: | |
| scale = self.scale.unsqueeze(0) | |
| while scale.dim() < input.dim(): | |
| scale = scale.unsqueeze(2) | |
| output = output.mul(scale) | |
| if self.bias is not None: | |
| bias = self.bias.unsqueeze(0) | |
| while bias.dim() < input.dim(): | |
| bias = bias.unsqueeze(2) | |
| output += bias | |
| return output | |
| class BatchStatistics(nn.Module): | |
| def __init__(self, affine=-1): | |
| super(BatchStatistics, self).__init__() | |
| self.affine = nn.Sequential() if affine == -1 else Affine(affine) | |
| self.loss = 0 | |
| def clear_loss(self): | |
| self.loss = 0 | |
| def compute_loss(self, input): | |
| input_flat = input.view(input.size(1), input.numel() // input.size(1)) | |
| mu = input_flat.mean(1) | |
| logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() | |
| self.loss = mu.pow(2).mean() + logvar.pow(2).mean() | |
| def forward(self, input): | |
| self.compute_loss(input) | |
| return self.affine(input) | |
| class ResidualBlock(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False | |
| ): | |
| super(ResidualBlock, self).__init__() | |
| self.residual_ratio = residual_ratio | |
| self.shortcut_ratio = 1 - residual_ratio | |
| residual = [] | |
| residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) | |
| if batch_statistics: | |
| residual.append(BatchStatistics(out_channels)) | |
| if not last_layer: | |
| residual.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
| self.residual = nn.Sequential(*residual) | |
| self.shortcut = nn.Sequential( | |
| nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), | |
| nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), | |
| BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(), | |
| ) | |
| def forward(self, input): | |
| return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) | |
| class ResidualBlockTranspose(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): | |
| super(ResidualBlockTranspose, self).__init__() | |
| self.residual_ratio = residual_ratio | |
| self.shortcut_ratio = 1 - residual_ratio | |
| self.residual = nn.Sequential( | |
| nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh() | |
| ) | |
| self.shortcut = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(), | |
| nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), | |
| ) | |
| def forward(self, input): | |
| return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) | |
| class SkeletonResidual(nn.Module): | |
| def __init__( | |
| self, | |
| topology, | |
| neighbour_list, | |
| joint_num, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| padding_mode, | |
| bias, | |
| extra_conv, | |
| pooling_mode, | |
| activation, | |
| last_pool, | |
| ): | |
| super(SkeletonResidual, self).__init__() | |
| kernel_even = False if kernel_size % 2 else True | |
| seq = [] | |
| for _ in range(extra_conv): | |
| # (T, J, D) => (T, J, D) | |
| seq.append( | |
| SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| joint_num=joint_num, | |
| kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
| stride=1, | |
| padding=padding, | |
| padding_mode=padding_mode, | |
| bias=bias, | |
| ) | |
| ) | |
| seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
| # (T, J, D) => (T/2, J, 2D) | |
| seq.append( | |
| SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| joint_num=joint_num, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| padding_mode=padding_mode, | |
| bias=bias, | |
| add_offset=False, | |
| ) | |
| ) | |
| seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! | |
| self.residual = nn.Sequential(*seq) | |
| # (T, J, D) => (T/2, J, 2D) | |
| self.shortcut = SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| joint_num=joint_num, | |
| kernel_size=1, | |
| stride=stride, | |
| padding=0, | |
| bias=True, | |
| add_offset=False, | |
| ) | |
| seq = [] | |
| # (T/2, J, 2D) => (T/2, J', 2D) | |
| pool = SkeletonPool( | |
| edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool | |
| ) | |
| if len(pool.pooling_list) != pool.edge_num: | |
| seq.append(pool) | |
| seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
| self.common = nn.Sequential(*seq) | |
| def forward(self, input): | |
| output = self.residual(input) + self.shortcut(input) | |
| return self.common(output) | |
| class SkeletonResidualTranspose(nn.Module): | |
| def __init__( | |
| self, | |
| neighbour_list, | |
| joint_num, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding, | |
| padding_mode, | |
| bias, | |
| extra_conv, | |
| pooling_list, | |
| upsampling, | |
| activation, | |
| last_layer, | |
| ): | |
| super(SkeletonResidualTranspose, self).__init__() | |
| kernel_even = False if kernel_size % 2 else True | |
| seq = [] | |
| # (T, J, D) => (2T, J, D) | |
| if upsampling is not None: | |
| seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) | |
| # (2T, J, D) => (2T, J', D) | |
| unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) | |
| if unpool.input_edge_num != unpool.output_edge_num: | |
| seq.append(unpool) | |
| self.common = nn.Sequential(*seq) | |
| seq = [] | |
| for _ in range(extra_conv): | |
| # (2T, J', D) => (2T, J', D) | |
| seq.append( | |
| SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| joint_num=joint_num, | |
| kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
| stride=1, | |
| padding=padding, | |
| padding_mode=padding_mode, | |
| bias=bias, | |
| ) | |
| ) | |
| seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
| # (2T, J', D) => (2T, J', D/2) | |
| seq.append( | |
| SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| joint_num=joint_num, | |
| kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
| stride=1, | |
| padding=padding, | |
| padding_mode=padding_mode, | |
| bias=bias, | |
| add_offset=False, | |
| ) | |
| ) | |
| self.residual = nn.Sequential(*seq) | |
| # (2T, J', D) => (2T, J', D/2) | |
| self.shortcut = SkeletonConv( | |
| neighbour_list, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| joint_num=joint_num, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| add_offset=False, | |
| ) | |
| if activation == "relu": | |
| self.activation = nn.PReLU() if not last_layer else None | |
| else: | |
| self.activation = nn.Tanh() if not last_layer else None | |
| def forward(self, input): | |
| output = self.common(input) | |
| output = self.residual(output) + self.shortcut(output) | |
| if self.activation is not None: | |
| return self.activation(output) | |
| else: | |
| return output | |