文章目录
- 1、Swin-Transformer介绍
- 2、模型整体框架
- 3、Patch Mergeing详解
- 4、W-MSA模块详解
- MSA模块计算量
- W-MSA模块计算量
- 5、SW-MSA详解
- 6、Relative Position Bias详解
- 7、模型详细配置参数
1、Swin-Transformer介绍
自从ViT(Vision Transformer)出现之后,这种基于自注意力机制的视觉神经网络逐渐替代CNN称为主流backbone。但由于其参考的self-attention机制是来源于NLP(自然语言)领域的,而语言的文字是人类交流的结晶,本身语义信息高度集中。而CV(计算机视觉)不同,可能从图片上随机扣除一块区域,对整体的识别都没有影响。所以ViT中,很多计算是冗余的,并不需要全局的联系。Swin-Transformer基于ViT的结构进行改进,提出SW/W-MSA结构,有效降低计算量。
原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码:https://github.com/microsoft/Swin-Transformer
Pytorch实现代码:https://github.com/Runist/Swin-Transformer
2、模型整体框架
在讲解之前,如果没有了解过ViT和self-attention的读者,建议还是先看一下前面的文章。相比于Vision Transformer来说,有两点不同:
- Swin Transformer采用了类似CNN的层次化构建方法(Hierarchical feature maps),特征图会随着层数加深,逐渐下采样至4倍,8倍,16倍等。 但Vision Transformer的结构特征图经过一次16倍下采样之后就不变了。
- Swin Transformer使用了Windows Multi-Head Self-Attention(W-MSA)的网络结构。这个结构限定了self-attention的范围,从全局进行qkv运算,变成只在给定窗口大小的范围内进行qkv计算,有效减少了计算量。但这样做也会隔绝了不同窗口之间的信息传递,所以作者又提出了Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过这个方法,可以让信息在相邻的窗口中传递。
网络的整体架构如下所示,由多个Swin Transformer Block堆叠而成:
- Patch Partition和Linear Embeding就是对应ViT的Embedding层,即将图像分块,映射成self-attention中一个个token。在代码中是这么做的,4x4的像素小块为一个Patch,加上其有3个通道,每个Patch就有16x3=48个像素,这在代码中是由一个4x4的卷积进行处理的。那么图像的shape就从[H, W, 3]变成了[H/4, W/4, 48]。然后再通过Linear Emdbeding层对channel维度的数据做线性变换,由48变成C,即[H/4, W/4, 48]变成了[H/4, W/4, C]。
- 模型通过Patch Merging进行下采样,每个Stage都会下采样4倍(除了第一个),同时在channel维度上翻倍。每个Stage都是堆叠Swin Transformer Block而成,这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以你会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用)。
- 对于分类网络在代码中,还有LayerNorm、AvgPooling和一个全连接层组成,这个在图中没有体现。
3、Patch Mergeing详解
每个Stage中会经过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,每隔一个位置取一个像素,从而组成四个feature map。将这四个feature map在channel维度上进行concat拼接,再经过一个LayerNorm层。最后通过一个全连接层在feature map的channel维度上做线性变换,feature map的深度由4*C变成2*C。通过这个例子,可以看出,feature通过Patch Merging层之后,feature map的高和宽会减半,通道数翻倍。
实现代码:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
# *::2,每隔1个取一个值
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
4、W-MSA模块详解
引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左边是ViT用的Multi-head Self-Attention(MSA)模块,对于每个Patch都要和除了它自己之外的Patch去计算attention。在Windows Multi-head Self-Attention(W-MSA)模块中,我们会给定一个windows-size(下图windows-size=2),在一个windows内进行Self-Attention的计算。
这样就有效降低了计算量,具体相差多少呢?论文中给出公式:
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
Ω
(
W
−
M
S
A
)
=
4
h
w
C
2
+
2
M
2
(
h
w
)
2
C
\begin{aligned} & \Omega(MSA) = 4hwC^2 + 2(hw)^2C \\ & \Omega(W-MSA) = 4hwC^2 + 2M^2(hw)^2C \end{aligned}
Ω(MSA)=4hwC2+2(hw)2CΩ(W−MSA)=4hwC2+2M2(hw)2C
- h代表feature map的高度
- w代表feature map的宽度
- C代表feature map的通道数
- M代表window-size,一般设置为7,是固定的。
这两个公式的推导,原文没有细说,我们简单计算一下。首先看一下Self-Attention的公式:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
d
)
V
Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}})V
Attention(Q,K,V)=SoftMax(dQKT)V
MSA模块计算量
对于feature map的每个像素(或称为token,patch),都要通过
W
q
W_q
Wq,
W
k
W_k
Wk,
W
v
W_v
Wv生成对应qkv。这里假设q,k,v的向量长度与feature map的channel数量C保持一致。那么对应所有像素生成Q的过程如下:
X
h
w
×
C
⋅
W
q
C
×
C
=
Q
h
w
×
C
X^{hw \times C} \cdot W_q^{C \times C} = Q^{hw \times C}
Xhw×C⋅WqC×C=Qhw×C
- X h w × C X^{hw \times C} Xhw×C为所有token拼接一起得到的矩阵(一共有hw个像素,每个像素的深度为C
- W q C × C W_q^{C \times C} WqC×C为生成query的变换矩阵
- $ Q^{hw \times C} 为所有像素通过 为所有像素通过 为所有像素通过W_q^{C \times C}$得到的query拼接后的矩阵
根据矩阵运算的计算量公式可以得到生成
Q
Q
Q的计算量为
h
w
×
C
×
C
hw \times C \times C
hw×C×C,生成K和V的过程一样,同理都是
h
w
C
2
hwC^2
hwC2,那么总共是
3
h
w
C
2
3hwC^2
3hwC2。接下来
Q
Q
Q和
K
T
K^T
KT相乘,对应计算量为
(
h
w
)
2
C
(hw)^2C
(hw)2C:
Q
h
w
×
C
⋅
K
T
(
C
×
h
w
)
=
X
h
w
×
h
w
Q^{hw \times C} \cdot K^{T(C \times hw)} = X^{hw \times hw}
Qhw×C⋅KT(C×hw)=Xhw×hw
这里忽略除以
d
\sqrt{d}
d以及softmax的计算量,假设得到
A
h
w
×
h
w
A^{hw \times hw}
Ahw×hw,最后还要乘以
V
V
V,这里对应的计算量是
(
h
w
)
2
C
(hw)^2C
(hw)2C:
A
h
w
×
h
w
⋅
V
h
w
×
C
)
=
X
h
w
×
C
A^{hw \times hw} \cdot V^{hw \times C)} = X^{hw \times C}
Ahw×hw⋅Vhw×C)=Xhw×C
那么对应单头的Self-Attention模块,总共需要
3
h
w
C
2
+
(
h
w
)
2
C
+
(
h
w
)
2
C
=
3
h
w
C
2
+
2
(
h
w
)
2
C
3hwC^2 + (hw)^2C + (hw)^2C = 3hwC^2 + 2(hw)^2C
3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,在之前的文章中有进行过实验对比,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵
W
O
W_O
WO的计算量
h
w
C
2
hwC^2
hwC2。
所以总共加起来是: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C
W-MSA模块计算量
对于W-MSA模块首先要将feature map划分到多个窗口中,假设每个窗口的宽高都是M,那么总共会得到
h
M
×
w
M
\frac{h}{M} \times \frac{w}{M}
Mh×Mw个窗口,在每个窗口内使用多头注意力模块。刚刚计算高度h,宽度为w,通道数为C的feature map计算量为
4
h
w
C
2
+
2
(
h
w
)
2
C
4hwC^2 + 2(hw)^2C
4hwC2+2(hw)2C,这里
h
w
hw
hw替换为
M
M
M,代入公式:
4
(
M
C
)
2
+
2
(
M
)
4
C
4(MC)^2 + 2(M)^4C
4(MC)2+2(M)4C
又因为又
h
M
×
w
M
\frac{h}{M} \times \frac{w}{M}
Mh×Mw个窗口,则:
h
M
×
w
M
×
(
4
(
M
C
)
2
+
2
(
M
)
4
C
)
=
4
h
w
C
2
+
2
(
M
)
2
h
w
C
\frac{h}{M} \times \frac{w}{M} \times(4(MC)^2 + 2(M)^4C) = 4hwC^2 + 2(M)^2hwC
Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2(M)2hwC
所以W-MSA模块的计算量为: 4 h w C 2 + 2 ( M ) 2 h w C 4hwC^2 + 2(M)^2hwC 4hwC2+2(M)2hwC
5、SW-MSA详解
前面又说,采用W-MSA模块时,只会在自己窗口下进行Self-Attention的计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左边是上面说的W-MSA,右边用的是SW-MSA,两个模块是成对出现的。经过滚动的像素(token),其画窗口之后,框住的token就不同了,这样就使得不同窗口的信息有交流了。
这个图比较抽象,包括论文中出现的解析图,都画的原理不是很清晰。
我按照其他博主讲解的思路重新画了一个,我们先按照编号给每个窗口画上标记,左边A对应0区域,B对应3、6区域,C对应1、2区域。
首先将0、1、2移动至最后一行。
其次将3、6、0移动至最后一列。
移动完成之后,4是一个单独区域,5、4为一组,7、1为一组,8、6、2、0为一组。这样都是4x4的窗口,虽然我们这边解析看起来比较麻烦,但在代码中,只需要一个torch.roll()函数就可以实现。但在这里肯定有人回想,5、3本身是两个图像的边缘,混在一起计算不是乱了吗?一起计算也没问题,ViT也是全局计算的。但是Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。源码中具体的方法就是将不计算的位置元素减去100,让权重为0,不让其参与计算,这里就不细说了。
这里需要注意的是,在窗口数据进行滑动完之后,需要将数据还原回去,即挪回到原来的位置上。
对应的代码是:
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, attn_mask):
H, W = self.H, self.W
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
# 把feature map给pad到window size的整数倍
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
# paper中,滑动的size是窗口大小的/2(向下取整)
# torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。溢出的值在对角方向出现。即循环移动。
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
# 把前面pad的数据移除掉
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
6、Relative Position Bias详解
关于相对位置偏执,论文里也没有细讲,就说了参考的哪些论文,然后说使用了相对位置偏执后给够带来明显的提升。根据原论文中的表4可以看出,在Imagenet数据集上如果不使用任何位置偏执,top-1为80.1,但使用了相对位置偏执(rel. pos.)后top-1为83.3,提升还是很明显的。
从论文中提供的公式,这个相对位置的偏执是加载softmax之前的:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
(
d
)
+
B
)
V
Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt(d)} + B)V
Attention(Q,K,V)=Softmax((d)QKT+B)V
由于论文没有说明这个相对位置偏执编码如何计算出来的,这里根据源码解释一下。如图,假设我们现在有一个window-size=2的feature map,这里面如果用绝对位置来表示位置索引,左上角的token为(0, 0)右下角的token为(1, 1)其他位置以此类推。然后如果用相对位置表示,就会有4个情况,但分别都是以自己为(0, 0)计算其他token的相对位置。分别把4个相对位置展开,得到4x4的矩阵,如最下的矩阵所示。
请注意这里说的都是位置索引,并不是最后的位置编码。因为后面我们会根据相对位置索引去取对应位置的参数。取出来的值才是相对位置编码。源码中,作者还将二维索引给转成了一维索引。如果直接将行列相加,就变成一维了。但这样(0, 1)和(1, 0)得到的结果都是1,这样肯定不行。来看看源码的做法怎么做的:
首先,所有行列都加上M-1
其次将所有的行索引乘上2M-1
最后行索引和列索引相加,保证了相对位置关系,也不会出现0+1 = 1+0 的现象了。
刚刚也说了,之前计算的是相对位置索引,并不是实际位置偏执参数。真正使用到的数值需要从relative position bias table,这个表的长度是等于(2M-1)X(2M-1)的。在代码中它是一个可学习参数。
实现代码如下:
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # [Mh, Mw]
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
7、模型详细配置参数
Swin Transformer的网络架构:
下图(表7)是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:
win. sz. 7x7
表示使用的窗口(Windows)的大小dim
表示feature map的channel深度(或者说token的向量长度)head
表示多头注意力模块中head的个数