Swin-Transformer详解
- 0. 前言
- 1. Swin-Transformer结构简介
- 2. Swin-Transformer结构详解
- 2.1 Patch Partition
- 2.2 Patch Merging
- 2.3 Swin Transformer Block
- 2.3.1 W-MSA
- 2.3.2 SW-MSA
- 3. 模型配置
- 总结
0. 前言
Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper
的荣誉称号。虽然Vision Transformer (ViT)
在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长
,其结构不适合作为密集视觉任务
或高分辨率输入图像
的通过骨干网路。为了最佳的精度和速度的权衡
,提出了Swin-Transformer结构。
论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码地址:https://github.com/microsoft/Swin-Transformer
Pytorch实现代码: pytorch_classification/swin_transformer
Tensorflow2实现代码:tensorflow_classification/swin_transformer
1. Swin-Transformer结构简介
如下图所示为:Swin-Transformer与ViT的对比结构。
从上图中可以看出两种网络结构的部分区别:
- 采样方式
- Swin-Transformer开始采用4倍下采样的方式,后续采用8倍下采样,最终采用16倍下采样
- ViT则一开始就使用16倍下采样
- 目标检测机制
- Swin-Transformer中,通过
4倍、8倍、16倍
下采样的结果分别作为目标检测所用数据,可以使网络以不同感受野
训练目标检测任务,实现对大目标、小目标的检测 - ViT则只使用
16倍下采样
,只有单一分辨率特征
- Swin-Transformer中,通过
接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构
。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数
,因为里面使用的都是两阶段的Swin Transformer Block结构。
2. Swin-Transformer结构详解
首先,介绍Swin-Transformer的基础流程。
- 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [H∗W∗3]
- 图片经过
Patch Partition
层进行图片分割 - 分割后的数据经过
Linear Embedding
层进行特征映射 - 将特征映射后的数据输入具有改进的自关注计算的
Transformer块(Swin Transformer块)
,并与Linear Embedding
一起被称为第1阶段
。 - 与阶段1不同,阶段2-4在输入模型前需要进行
Patch Merging
进行下采样,产生分层表示。 - 最终将经过阶段4的数据经过输出模块(包括一个
LayerNorm层
、一个AdaptiveAvgPool1d层
和一个全连接层
)进行分类。
2.1 Patch Partition
Patch Partition
结构是将图片数据进行分割成不重叠的M*M
补丁。每个补丁被视为一个“标记”,其特征被设置为原始像素RGB值的串联。在论文中,使用4 × 4的patch大小,因此每个patch的特征维数为4 × 4 × 3 = 48。在此原始值特征上应用线性嵌入层(Linear Embedding
),将其投影到任意维度(记为C)。
|
|
注意:在实际操作中,Patch Partition
和Linear Embedding
通过一个二维的卷积层
(输出通道为Embedding维度
,卷积核大小为patch_size
,stride大小为patch_size
)实现。
2.2 Patch Merging
Patch Merging层主要是进行下采样,产生分层表示。随着网络的深入,通过Patch Merging层
来减少令牌的数量。第一个补丁合并层将每组2 × 2
相邻补丁的特征进行拼接,并在拼接后的4c维特征上应用线性层。这将令牌的数量减少2×2 = 4的倍数(分辨率的2倍降采样,长和宽分别变为原来的1/2
),并将输出维度设置为2C
。之后使用Swin Transformer块进行特征变换,分辨率保持在h8 × w8
。这第一个块的补丁合并和特征转换被称为“第二阶段”。该过程重复两次,作为“阶段3”和“阶段4”,输出分辨率分别为h16 × w16
和h32 × w32
。由上述的说明,可以得知:数据在经过Patch Merging
层后,长宽变为原来的1/2,深度变为原来的2倍。
2.3 Swin Transformer Block
Swin Transformer Block
一般以2阶段的串联结构出现,在第一阶段使用Window based Multi-headed Self-Attention(W-MSA)
,第二阶段使用 Shifted Window based Multi-headed Self-Attention(SW-MSA)
,根据当前是奇数还是偶数的Swin Transformer Block
来选择不同的自关注计算方式。
2.3.1 W-MSA
W-MSA全称为:Window based Multi-headed Self-Attention。从名字可以看出,W-MSA是一个窗口化的多头自注意力,与全局自注意力相比,减少了大量的计算量。直观上来说:假如说是4*4的数据,划分后每个窗口包括 M ∗ M M*M M∗M 块,这里假设 M = 2 M=2 M=2。如果进行MSA计算大概需要 ( 4 ∗ 4 ) 2 (4*4)^2 (4∗4)2的计算量,而进行W-MSA则大概需要 ( 2 ∗ 2 ) ∗ ( 2 ∗ 2 ) 2 (2*2)*(2*2)^2 (2∗2)∗(2∗2)2。这样一对比瞬间计算的复杂度就降低了很多(当然上述只是为了方便简单的理解,下面就详细介绍W-MSA降低了多少复杂度)。
|
|
- h代表feature map的高度
- w代表feature map的宽度
- C代表feature map的深度
- M代表每个窗口(Windows)的大小
注意:前者与长宽 h w 成二次关系,后者在 M 固定时为线性关系(默认为7)。
-
首先介绍下Self-Attention的计算
Self-Attention的公式如下所示:
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d} })V Attention(Q,K,V)=SoftMax(dQKT)V -
计算Self-Attention的复杂度
首先,Q、K、V的计算如下所示:
Q h w ∗ C = X h w ∗ C ∗ W Q C ∗ C K h w ∗ C = X h w ∗ C ∗ W K C ∗ C V h w ∗ C = X h w ∗ C ∗ W V C ∗ C Q^{hw*C}=X^{hw*C}*W_Q^{C*C} \\\ K^{hw*C}=X^{hw*C}*W_K^{C*C} \\\ V^{hw*C}=X^{hw*C}*W_V^{C*C} Qhw∗C=Xhw∗C∗WQC∗C Khw∗C=Xhw∗C∗WKC∗C Vhw∗C=Xhw∗C∗WVC∗C- X h w ∗ C X^{hw*C} Xhw∗C 表示将所有像素(token)拼接在一起得到的矩阵(一共有hw个像素,每个像素的深度为C)
- W Q C ∗ C W_Q^{C*C} WQC∗C、 W K C ∗ C W_K^{C*C} WKC∗C、 W V C ∗ C W_V^{C*C} WVC∗C 分别表示生成Q、K、V的变换矩阵
因此,由矩阵复杂度计算公式可知Q、K、V的复杂度均为 h w ∗ C 2 hw*C^2 hw∗C2,此时总复杂度为 3 h w ∗ C 2 3hw*C^2 3hw∗C2。
然后,由Self-Attention的计算公式可知, Q K T QK^T QKT 的计算量如下所示:
Q h w ∗ C K T ( C ∗ h w ) = A h w ∗ h w Q^{hw*C}K^{T(C*hw)} = A^{hw*hw} Qhw∗CKT(C∗hw)=Ahw∗hw
因此, Q K T QK^T QKT 的计算量为 C ∗ h w ∗ h w C*hw*hw C∗hw∗hw, 即 C ∗ ( h w ) 2 C*(hw)^2 C∗(hw)2 。忽略 d \sqrt{d} d 和 S o f t M a x SoftMax SoftMax操作, A ∗ V A*V A∗V的计算量如下所示:
A h w ∗ h w V h w ∗ C = A t t e n t i o n h w ∗ C A^{hw*hw}V^{hw*C} = Attention^{hw*C} Ahw∗hwVhw∗C=Attentionhw∗C
因此, A ∗ V A*V A∗V 的计算量为 h w ∗ C ∗ h w hw*C*hw hw∗C∗hw, 即 C ∗ ( h w ) 2 C*(hw)^2 C∗(hw)2 。所以,Self-Attention公式的复杂度为 2 C ( h w ) 2 2C(hw)^2 2C(hw)2。Self-Attention总的复杂度为 2 C ( h w ) 2 + 3 h w ∗ C 2 2C(hw)^2+3hw*C^2 2C(hw)2+3hw∗C2 -
计算MSA的复杂度
多头注意力计算复杂度与自注意力复杂度
仅缺少一个 ∗ V 0 *V_0 ∗V0 的操作,因此总体复杂度缺少 h w ∗ C 2 hw*C^2 hw∗C2。所以MSA的复杂度为 2 C ( h w ) 2 + 4 h w ∗ C 2 2C(hw)^2+4hw*C^2 2C(hw)2+4hw∗C2。 -
计算W-MSA的复杂度
对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C,这里每个窗口的高为M宽为M,带入公式得:
4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2 + 2(M)^4C 4(MC)2+2(M)4C
又因为有 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,则:
h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 M 2 h w C \frac {h} {M} \times \frac {w} {M} \times (4(MC)^2 + 2(M)^4C)=4hwC^2 + 2M^2 hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
故使用W-MSA模块的计算量为:
4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2 hwC 4hwC2+2M2hwC
假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:
2 ( h w ) 2 C − 2 M 2 h w C = 2 × 11 2 4 × 128 − 2 × 7 2 × 11 2 2 × 128 = 40124743680 2(hw)^2C-2M^2 hwC=2 \times 112^4 \times 128 - 2 \times 7^2 \times 112^2 \times 128=40124743680 2(hw)2C−2M2hwC=2×1124×128−2×72×1122×128=40124743680
2.3.2 SW-MSA
由于W-MSA只能关注窗口本身的内容,而不允许跨窗口连接,窗口与窗口之间是无法进行信息传递的。而SW-MSA通过移位窗口的方式,引入跨窗口连接的同时保持非重叠窗口的高效计算。如下图左所示为第 l
层使用W-MSA的方式,而在下一层 l+1
层必定为 SW-MSA的方式(如右图所示),两者合在一起作为一个2阶段的 Swin Transformer Block模块。两幅图进行对比可以发现:右图相对于左图进行了偏移,长宽分别偏移了
M
2
\frac{M}{2}
2M 个像素单位(每个窗口为
M
∗
M
M*M
M∗M 像素)。
可以看出,偏移后的图像窗口变为了9个。为了提高计算的效率,作者提出了一种更有效的批处理计算方法,即向左上方向循环移位
,如下图所示。在此转换之后,批处理窗口可能由特征映射中不相邻的几个子窗口组成,因此采用屏蔽机制(NLP中的masking 屏蔽不应该需要的信息)
将自关注计算限制在每个子窗口内。
为了更方便地理解左上方向循环移位的操作,这里将具体过程做了一个图,具体内容如下图所示。
从上图可以看出,原始图像在进行移位后,A部分移动到右下角,B部分移位到最右边,C部分移位到最下边。然后将每个部分进行合并合并为等同于移位前窗口大小的窗口。
注意:移位后的信息会产生乱序
,对于该问题,原文作者使用了Mask的方案。
3. 模型配置
最后,对Swin-Transformer各个版本的参数进行介绍。
其中,
win. sz 7x7
表示窗口大小为7x7
dim
表示feature map的channel深度(或者说token的向量长度)head
表示多头注意力模块中head的个数
总结
关于Swin-Transformer模型中大多数内容都已经详细介绍了。当然,还有部分不重要的内容以及如何与代码想匹配没有介绍。后续可能会出一篇文章专门介绍相关代码说明。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。