文章目录
- 1.介绍
- 2.原理
- 3.代码
- 4.SE模块的应用
论文:Squeeze-and-Excitation Networks
论文链接:Squeeze-and-Excitation Networks
代码链接:Github
1.介绍
卷积算子使网络能够在每一层的局部感受野中融合空间(spatial)和通道(channel)信息来构造信息特征。本文将重点放在通道(channel)关系上,提出SE(Squeeze-and-Excitation Block)模块,其显式建模通道之间的相互依赖性,自适应的重新校准通道方向上的特征响应,来提高所提取特征的质量。将SE模块堆叠在一起,就形成了SENet(Squeeze-and-Excitation Networks)。
通俗来说,SENet的核心在于通过网络根据损失函数学习特征权重,使得特征图中有效通道的权重变大,无效或效果小的通道权重变小的方式训练模型达到更好的结果。而SE(Squeeze-and-Excitation Block)模块是一个子结构,可嵌入其他模型当中。
2.原理
给定输入
x
x
x,其经一系列卷积操作(定义为
F
t
r
(
⋅
;
θ
)
F_{tr}(·;θ)
Ftr(⋅;θ))后得到通道数为
c
w
c_w
cw的特征,其形状为
(
C
,
H
,
W
)
(C,H,W)
(C,H,W)。 之后通过三种运算来实现SE模块的功能:
【1.
S
q
u
e
e
z
e
Squeeze
Squeeze操作】
卷积核只能关注到局部感受野的空间信息,感受野区域之外的信息无法利用,这使得输出特征图就很难获得足够的信息来提取通道之间的关系。
S
q
u
e
e
z
e
Squeeze
Squeeze操作,定义为
F
s
q
(
⋅
)
F_{sq}(·)
Fsq(⋅),顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。这一操作通过全局平均池化实现:
F
s
q
(
c
2
)
=
1
H
×
W
∑
i
=
1
H
∑
j
=
1
W
c
2
(
i
,
j
)
F_{sq}(c_2)=\frac{1}{H×W}\sum^{H}_{i=1}\sum^{W}_{j=1}c_2(i,j)
Fsq(c2)=H×W1i=1∑Hj=1∑Wc2(i,j)
特征图经过
F
s
q
(
)
F_{sq}()
Fsq()运算后得到全局统计向量,形状为
(
1
,
1
,
c
2
)
(1,1,c_2)
(1,1,c2)。此时一个像素值代表一个通道,从而屏蔽掉空间上的分布信息,更好的利用通道间的相关性。
【2.
E
x
c
i
t
a
t
i
o
n
Excitation
Excitation操作】
E
x
c
i
t
a
t
i
o
n
Excitation
Excitation操作,定义为
F
e
x
(
⋅
;
w
)
F_{ex}(·;w)
Fex(⋅;w),用于捕获通道之间的依赖关系。这里使用了神经网络的门机制,即使用两个全连接层+两个激活函数组成的结构输出和输入与特征同样数目的权重值,也就是每个特征通道的权重系数。并且,为了限制模型复杂度和辅助泛化,在构造全连接层时对通道
c
2
c_2
c2进行了降维处理,降维比例为
r
r
r。计算公式:
其中,
W
1
∈
R
C
r
×
C
,
W
2
∈
R
C
r
×
C
W_1∈R^{\frac{C}{r}}×C,W_2∈R^{\frac{C}{r}}×C
W1∈RrC×C,W2∈RrC×C,两个激活函数依次为
R
e
L
U
、
s
i
g
m
o
i
d
ReLU、sigmoid
ReLU、sigmoid。原理图:
【3.
S
c
a
l
e
Scale
Scale操作】
S
c
a
l
e
Scale
Scale操作定义为
F
s
c
a
l
e
(
⋅
,
⋅
)
F_{scale}(·,·)
Fscale(⋅,⋅),用于将前面得到的注意力权重加权到每个通道的特征上。论文中通过逐通道乘以权重系数,即在在通道维度上引入attention机制来实现。如下图所示:
不同颜色代表不同通道的重要程度。
3.代码
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
#channel:输入通道数;reduction:缩减比率
super(SELayer, self).__init__()
#1.Squeeze
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#2.Excitation
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
#3.Scale
return x * y.expand_as(x)
4.SE模块的应用
例如,可将SE模块集成在残差块中:
以此形成集成后的ResNet网络: