目录
ASPP结构介绍
ASPP在代码中的构建
参考资料
ASPP结构介绍
ASPP:Atrous Spatial Pyramid Pooling,空洞空间卷积池化金字塔。
简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。
利用主干特征提取网络,会得到一个浅层特征和一个深层特征,这一篇主要以如何对较深层特征进行加强特征提取,也就是在Encoder中所看到的部分。
它就叫做ASPP,主要有5个部分:
- 1x1卷积
- 膨胀率为6的3x3卷积
- 膨胀率为12的3x3卷积
- 膨胀率为18的3x3卷积
- 对输入进去的特征层进行池化
接着会对这五个部分进行一个堆叠,再利用一个1x1卷积对通道数进行调整,获得上图中绿色的特征。
ASPP在代码中的构建
import torch
import torch.nn as nn
import torch.nn.functional as F
class ASPP(nn.Module):
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=6 * rate, dilation=6 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=12 * rate, dilation=12 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=18 * rate, dilation=18 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch5_conv = nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, bias=True)
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch5_relu = nn.ReLU(inplace=True)
self.conv_cat = nn.Sequential(
nn.Conv2d(dim_out * 5, dim_out ,kernel_size=(1,1), stride=(1,1), padding=0, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
def forward(self, x):
[b, c, row, col] = x.size()
# 五个分支
conv1x1 = self.branch1(x)
conv3x3_1 = self.branch2(x)
conv3x3_2 = self.branch3(x)
conv3x3_3 = self.branch4(x)
# 第五个分支,进行全局平均池化+卷积
global_feature = torch.mean(x, 2, True)
global_feature = torch.mean(global_feature, 3, True)
global_feature = self.branch5_conv(global_feature)
global_feature = self.branch5_bn(global_feature)
global_feature = self.branch5_relu(global_feature)
global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
# 五个分支的内容堆叠起来,然后1x1卷积整合特征。
feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
result = self.conv_cat(feature_cat)
return result
if __name__ == "__main__":
model = ASPP(dim_in=320, dim_out=256, rate=16//16)
print(model)
那么从这里来看的话,也是相当清晰的,branch*(1、2、3、4、5)分别代表了ASPP五个部分在def __init__()可以体现,对于每一个都是卷积、标准化、激活函数。
第五个部分可以看到def forward中,首先呢,是要进行一个全局平均池化,再用1x1卷积通道数的整合,标准化、激活函数,接着采用上采样的方法,把它的大小调整成和我们上面获得的分支一样大小的特征层,这样我们才可以将五个部分进行一个堆叠,使用的是torch.cat()函数实现,最后,利用1x1卷积,对输入进来的特征层进行一个通道数的调整,获得想上图中绿色的部分,接着就会将这个具有较高语义信息的有效特征层就会传入到Decoder当中。
参考资料
(6条消息) Pytorch-torchvision源码解读:ASPP_xiongxyowo的博客-CSDN博客_aspp代码
DeepLabV3-/deeplabv3+.pdf at main · Auorui/DeepLabV3- (github.com)