模块出处
[link] [code] [PR 23] Cross-level Feature Aggregation Network for Polyp Segmentation
模块名称
Cross-level Feature Fusion (CFF)
模块作用
双级特征融合
模块结构
模块代码
import torch
import torch.nn as nn
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class CFF(nn.Module):
def __init__(self, in_channel1, in_channel2, out_channel):
self.init__ = super(CFF, self).__init__()
act_fn = nn.ReLU(inplace=True)
self.layer0 = BasicConv2d(in_channel1, out_channel // 2, 1)
self.layer1 = BasicConv2d(in_channel2, out_channel // 2, 1)
self.layer3_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn)
self.layer3_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel // 2),act_fn)
self.layer5_1 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn)
self.layer5_2 = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(out_channel // 2),act_fn)
self.layer_out = nn.Sequential(nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel),act_fn)
def forward(self, x0, x1):
x0_1 = self.layer0(x0)
x1_1 = self.layer1(x1)
x_3_1 = self.layer3_1(torch.cat((x0_1, x1_1), dim=1))
x_5_1 = self.layer5_1(torch.cat((x1_1, x0_1), dim=1))
x_3_2 = self.layer3_2(torch.cat((x_3_1, x_5_1), dim=1))
x_5_2 = self.layer5_2(torch.cat((x_5_1, x_3_1), dim=1))
out = self.layer_out(x0_1 + x1_1 + torch.mul(x_3_2, x_5_2))
return out
if __name__ == '__main__':
x1 = torch.randn([1, 256, 16, 16])
x2 = torch.randn([1, 512, 16, 16])
cff = CFF(in_channel1=256, in_channel2=512, out_channel=64)
out = cff(x1, x2)
print(out.shape) # 1, 64, 16, 16
原文表述
利用特征提取网络可以获得不同分辨率的多级特征。因此,有效整合多级特征非常重要,这可以提高不同尺度特征的表示能力。因此,我们提出了一个 CFF模块来融合相邻的两个特征,然后将其输入分割网络。