文章目录
- 介绍
- 代码实现
介绍
ASPP(Atrous Spatial Pyramid Pooling),空洞空间卷积池化金字塔。简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。ASPP 的结构如下:
如图所示,ASPP 本质上由一个1×1的卷积(最上) + 池化金字塔(中间三个) + ASPP Pooling(最下面三层)组成。而池化金字塔各层的膨胀因子可自定义,从而实现自由的多尺度特征提取。
代码实现
'''
pytorch实现ASPP
'''
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform
from torch.nn import functional as F
from typing import Dict, List
class ASPPConv(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
super(ASPPConv, self).__init__(
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
super(ASPP, self).__init__()
modules = [
nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
]
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_res = []
for conv in self.convs:
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
class EncoderAspp(nn.Module):
def __init__(self, params):
super(EncoderAspp, self).__init__()
self.aspp = ASPP(in_channels = 64 , [4, 6, 8], out_channels = 64)
def forward(self, x):
x = self.aspp(x)
return x