Muhammed Ömer ERKOÇ
Add app.py, requirements.txt, examples and model files 2
fb36382
raw
history blame
7.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.pvtv2 import pvt_v2_b2
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class CFM(nn.Module):
def __init__(self, channel):
super(CFM, self).__init__()
self.relu = nn.ReLU(True)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
self.conv4 = BasicConv2d(3 * channel, channel, 3, padding=1)
def forward(self, x1, x2, x3):
x1_1 = x1
x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
* self.conv_upsample3(self.upsample(x2)) * x3
x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
x2_2 = self.conv_concat2(x2_2)
x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
x3_2 = self.conv_concat3(x3_2)
x1 = self.conv4(x3_2)
return x1
class GCN(nn.Module):
def __init__(self, num_state, num_node, bias=False):
super(GCN, self).__init__()
self.conv1 = nn.Conv1d(num_node, num_node, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(num_state, num_state, kernel_size=1, bias=bias)
def forward(self, x):
h = self.conv1(x.permute(0, 2, 1)).permute(0, 2, 1)
h = h - x
h = self.relu(self.conv2(h))
return h
class SAM(nn.Module):
def __init__(self, num_in=32, plane_mid=16, mids=4, normalize=False):
super(SAM, self).__init__()
self.normalize = normalize
self.num_s = int(plane_mid)
self.num_n = (mids) * (mids)
self.priors = nn.AdaptiveAvgPool2d(output_size=(mids + 2, mids + 2))
self.conv_state = nn.Conv2d(num_in, self.num_s, kernel_size=1)
self.conv_proj = nn.Conv2d(num_in, self.num_s, kernel_size=1)
self.gcn = GCN(num_state=self.num_s, num_node=self.num_n)
self.conv_extend = nn.Conv2d(self.num_s, num_in, kernel_size=1, bias=False)
def forward(self, x, edge):
edge = F.upsample(edge, (x.size()[-2], x.size()[-1]))
n, c, h, w = x.size()
edge = torch.nn.functional.softmax(edge, dim=1)[:, 1, :, :].unsqueeze(1)
x_state_reshaped = self.conv_state(x).view(n, self.num_s, -1)
x_proj = self.conv_proj(x)
x_mask = x_proj * edge
x_anchor1 = self.priors(x_mask)
x_anchor2 = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1)
x_anchor = self.priors(x_mask)[:, :, 1:-1, 1:-1].reshape(n, self.num_s, -1)
x_proj_reshaped = torch.matmul(x_anchor.permute(0, 2, 1), x_proj.reshape(n, self.num_s, -1))
x_proj_reshaped = torch.nn.functional.softmax(x_proj_reshaped, dim=1)
x_rproj_reshaped = x_proj_reshaped
x_n_state = torch.matmul(x_state_reshaped, x_proj_reshaped.permute(0, 2, 1))
if self.normalize:
x_n_state = x_n_state * (1. / x_state_reshaped.size(2))
x_n_rel = self.gcn(x_n_state)
x_state_reshaped = torch.matmul(x_n_rel, x_rproj_reshaped)
x_state = x_state_reshaped.view(n, self.num_s, *x.size()[2:])
out = x + (self.conv_extend(x_state))
return out
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class PolypPVT(nn.Module):
def __init__(self, channel=32):
super(PolypPVT, self).__init__()
self.backbone = pvt_v2_b2() # [64, 128, 320, 512]
path = './pretrained_pth/pvt_v2_b2.pth'
save_model = torch.load(path)
model_dict = self.backbone.state_dict()
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
self.backbone.load_state_dict(model_dict)
self.Translayer2_0 = BasicConv2d(64, channel, 1)
self.Translayer2_1 = BasicConv2d(128, channel, 1)
self.Translayer3_1 = BasicConv2d(320, channel, 1)
self.Translayer4_1 = BasicConv2d(512, channel, 1)
self.CFM = CFM(channel)
self.ca = ChannelAttention(64)
self.sa = SpatialAttention()
self.SAM = SAM()
self.down05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
self.out_SAM = nn.Conv2d(channel, 1, 1)
self.out_CFM = nn.Conv2d(channel, 1, 1)
def forward(self, x):
# backbone
pvt = self.backbone(x)
x1 = pvt[0]
x2 = pvt[1]
x3 = pvt[2]
x4 = pvt[3]
# CIM
x1 = self.ca(x1) * x1 # channel attention
cim_feature = self.sa(x1) * x1 # spatial attention
# CFM
x2_t = self.Translayer2_1(x2)
x3_t = self.Translayer3_1(x3)
x4_t = self.Translayer4_1(x4)
cfm_feature = self.CFM(x4_t, x3_t, x2_t)
# SAM
T2 = self.Translayer2_0(cim_feature)
T2 = self.down05(T2)
sam_feature = self.SAM(cfm_feature, T2)
prediction1 = self.out_CFM(cfm_feature)
prediction2 = self.out_SAM(sam_feature)
prediction1_8 = F.interpolate(prediction1, scale_factor=8, mode='bilinear')
prediction2_8 = F.interpolate(prediction2, scale_factor=8, mode='bilinear')
return prediction1_8, prediction2_8
if __name__ == '__main__':
model = PolypPVT().cuda()
input_tensor = torch.randn(1, 3, 352, 352).cuda()
prediction1, prediction2 = model(input_tensor)
print(prediction1.size(), prediction2.size())