Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lib.pvtv2 import pvt_v2_b2 | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BasicConv2d(nn.Module): | |
| def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): | |
| super(BasicConv2d, self).__init__() | |
| self.conv = nn.Conv2d(in_planes, out_planes, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, bias=False) | |
| self.bn = nn.BatchNorm2d(out_planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| return x | |
| class CFM(nn.Module): | |
| def __init__(self, channel): | |
| super(CFM, self).__init__() | |
| self.relu = nn.ReLU(True) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) | |
| self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) | |
| self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) | |
| self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1) | |
| self.conv4 = BasicConv2d(3 * channel, channel, 3, padding=1) | |
| def forward(self, x1, x2, x3): | |
| x1_1 = x1 | |
| x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 | |
| x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ | |
| * self.conv_upsample3(self.upsample(x2)) * x3 | |
| x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) | |
| x2_2 = self.conv_concat2(x2_2) | |
| x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) | |
| x3_2 = self.conv_concat3(x3_2) | |
| x1 = self.conv4(x3_2) | |
| return x1 | |
| class GCN(nn.Module): | |
| def __init__(self, num_state, num_node, bias=False): | |
| super(GCN, self).__init__() | |
| self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, bias=bias) | |
| def forward(self, x): | |
| h = self.conv1(x.permute(0, 2, 1)).permute(0, 2, 1) | |
| h = h - x | |
| h = self.relu(self.conv2(h)) | |
| return h | |
| class SAM(nn.Module): | |
| def __init__(self, num_in=32, plane_mid=16, mids=4, normalize=False): | |
| super(SAM, self).__init__() | |
| self.normalize = normalize | |
| self.num_s = int(plane_mid) | |
| self.num_n = (mids) * (mids) | |
| self.priors = nn.AdaptiveAvgPool2d(output_size=(mids + 2, mids + 2)) | |
| self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=1) | |
| self.conv_proj = nn.Conv2d(num_in, self.num_s, kernel_size=1) | |
| self.gcn = GCN(num_state=self.num_s, num_node=self.num_n) | |
| self.conv_extend = nn.Conv2d(self.num_s, num_in, kernel_size=1, bias=False) | |
| def forward(self, x, edge): | |
| edge = F.upsample(edge, (x.size()[-2], x.size()[-1])) | |
| n, c, h, w = x.size() | |
| edge = torch.nn.functional.softmax(edge, dim=1)[:, 1, :, :].unsqueeze(1) | |
| x_state_reshaped = self.conv_state(x).view(n, self.num_s, -1) | |
| x_proj = self.conv_proj(x) | |
| x_mask = x_proj * edge | |
| x_anchor1 = self.priors(x_mask) | |
| x_anchor2 = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) | |
| x_anchor = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1) | |
| x_proj_reshaped = torch.matmul(x_anchor.permute(0, 2, 1), x_proj.reshape(n, self.num_s, -1)) | |
| x_proj_reshaped = torch.nn.functional.softmax(x_proj_reshaped, dim=1) | |
| x_rproj_reshaped = x_proj_reshaped | |
| x_n_state = torch.matmul(x_state_reshaped, x_proj_reshaped.permute(0, 2, 1)) | |
| if self.normalize: | |
| x_n_state = x_n_state * (1. / x_state_reshaped.size(2)) | |
| x_n_rel = self.gcn(x_n_state) | |
| x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped) | |
| x_state = x_state_reshaped.view(n, self.num_s, *x.size()[2:]) | |
| out = x + (self.conv_extend(x_state)) | |
| return out | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, in_planes, ratio=16): | |
| super(ChannelAttention, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.max_pool = nn.AdaptiveMaxPool2d(1) | |
| self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) | |
| self.relu1 = nn.ReLU() | |
| self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) | |
| max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) | |
| out = avg_out + max_out | |
| return self.sigmoid(out) | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super(SpatialAttention, self).__init__() | |
| assert kernel_size in (3, 7), 'kernel size must be 3 or 7' | |
| padding = 3 if kernel_size == 7 else 1 | |
| self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = torch.mean(x, dim=1, keepdim=True) | |
| max_out, _ = torch.max(x, dim=1, keepdim=True) | |
| x = torch.cat([avg_out, max_out], dim=1) | |
| x = self.conv1(x) | |
| return self.sigmoid(x) | |
| class PolypPVT(nn.Module): | |
| def __init__(self, channel=32): | |
| super(PolypPVT, self).__init__() | |
| self.backbone = pvt_v2_b2() # [64, 128, 320, 512] | |
| path = './pretrained_pth/pvt_v2_b2.pth' | |
| save_model = torch.load(path) | |
| model_dict = self.backbone.state_dict() | |
| state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} | |
| model_dict.update(state_dict) | |
| self.backbone.load_state_dict(model_dict) | |
| self.Translayer2_0 = BasicConv2d(64, channel, 1) | |
| self.Translayer2_1 = BasicConv2d(128, channel, 1) | |
| self.Translayer3_1 = BasicConv2d(320, channel, 1) | |
| self.Translayer4_1 = BasicConv2d(512, channel, 1) | |
| self.CFM = CFM(channel) | |
| self.ca = ChannelAttention(64) | |
| self.sa = SpatialAttention() | |
| self.SAM = SAM() | |
| self.down05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) | |
| self.out_SAM = nn.Conv2d(channel, 1, 1) | |
| self.out_CFM = nn.Conv2d(channel, 1, 1) | |
| def forward(self, x): | |
| # backbone | |
| pvt = self.backbone(x) | |
| x1 = pvt[0] | |
| x2 = pvt[1] | |
| x3 = pvt[2] | |
| x4 = pvt[3] | |
| # CIM | |
| x1 = self.ca(x1) * x1 # channel attention | |
| cim_feature = self.sa(x1) * x1 # spatial attention | |
| # CFM | |
| x2_t = self.Translayer2_1(x2) | |
| x3_t = self.Translayer3_1(x3) | |
| x4_t = self.Translayer4_1(x4) | |
| cfm_feature = self.CFM(x4_t, x3_t, x2_t) | |
| # SAM | |
| T2 = self.Translayer2_0(cim_feature) | |
| T2 = self.down05(T2) | |
| sam_feature = self.SAM(cfm_feature, T2) | |
| prediction1 = self.out_CFM(cfm_feature) | |
| prediction2 = self.out_SAM(sam_feature) | |
| prediction1_8 = F.interpolate(prediction1, scale_factor=8, mode='bilinear') | |
| prediction2_8 = F.interpolate(prediction2, scale_factor=8, mode='bilinear') | |
| return prediction1_8, prediction2_8 | |
| if __name__ == '__main__': | |
| model = PolypPVT().cuda() | |
| input_tensor = torch.randn(1, 3, 352, 352).cuda() | |
| prediction1, prediction2 = model(input_tensor) | |
| print(prediction1.size(), prediction2.size()) |