Swin-Transformer详解

news2024/11/24 17:52:55

Swin-Transformer详解

  • 0. 前言
  • 1. Swin-Transformer结构简介
  • 2. Swin-Transformer结构详解
    • 2.1 Patch Partition
    • 2.2 Patch Merging
    • 2.3 Swin Transformer Block
      • 2.3.1 W-MSA
      • 2.3.2 SW-MSA
  • 3. 模型配置
  • 总结

0. 前言

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。虽然Vision Transformer (ViT)在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。

论文名称Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码地址:https://github.com/microsoft/Swin-Transformer
Pytorch实现代码: pytorch_classification/swin_transformer
Tensorflow2实现代码:tensorflow_classification/swin_transformer

1. Swin-Transformer结构简介

如下图所示为:Swin-Transformer与ViT的对比结构。
在这里插入图片描述
从上图中可以看出两种网络结构的部分区别:

  1. 采样方式
    • Swin-Transformer开始采用4倍下采样的方式,后续采用8倍下采样,最终采用16倍下采样
    • ViT则一开始就使用16倍下采样
  2. 目标检测机制
    • Swin-Transformer中,通过4倍、8倍、16倍下采样的结果分别作为目标检测所用数据,可以使网络以不同感受野训练目标检测任务,实现对大目标、小目标的检测
    • ViT则只使用16倍下采样,只有单一分辨率特征

接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数,因为里面使用的都是两阶段的Swin Transformer Block结构。
在这里插入图片描述

2. Swin-Transformer结构详解

首先,介绍Swin-Transformer的基础流程。

  1. 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [HW3]
  2. 图片经过Patch Partition层进行图片分割
  3. 分割后的数据经过Linear Embedding层进行特征映射
  4. 将特征映射后的数据输入具有改进的自关注计算的Transformer块(Swin Transformer块),并与Linear Embedding一起被称为第1阶段
  5. 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  6. 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。

2.1 Patch Partition

Patch Partition结构是将图片数据进行分割成不重叠的M*M补丁。每个补丁被视为一个“标记”,其特征被设置为原始像素RGB值的串联。在论文中,使用4 × 4的patch大小,因此每个patch的特征维数为4 × 4 × 3 = 48。在此原始值特征上应用线性嵌入层(Linear Embedding),将其投影到任意维度(记为C)。

图1 Patch Partition 分割
图2 符号标识

注意:在实际操作中,Patch PartitionLinear Embedding通过一个二维的卷积层输出通道为Embedding维度卷积核大小为patch_sizestride大小为patch_size)实现。

2.2 Patch Merging

Patch Merging层主要是进行下采样,产生分层表示。随着网络的深入,通过Patch Merging层来减少令牌的数量。第一个补丁合并层将每组2 × 2相邻补丁的特征进行拼接,并在拼接后的4c维特征上应用线性层。这将令牌的数量减少2×2 = 4的倍数(分辨率的2倍降采样,长和宽分别变为原来的1/2),并将输出维度设置为2C。之后使用Swin Transformer块进行特征变换,分辨率保持在h8 × w8。这第一个块的补丁合并和特征转换被称为“第二阶段”。该过程重复两次,作为“阶段3”和“阶段4”,输出分辨率分别为h16 × w16h32 × w32。由上述的说明,可以得知:数据在经过Patch Merging层后,长宽变为原来的1/2,深度变为原来的2倍。

在这里插入图片描述

2.3 Swin Transformer Block

Swin Transformer Block 一般以2阶段的串联结构出现,在第一阶段使用Window based Multi-headed Self-Attention(W-MSA),第二阶段使用 Shifted Window based Multi-headed Self-Attention(SW-MSA),根据当前是奇数还是偶数的Swin Transformer Block来选择不同的自关注计算方式。

2.3.1 W-MSA

W-MSA全称为:Window based Multi-headed Self-Attention。从名字可以看出,W-MSA是一个窗口化的多头自注意力,与全局自注意力相比,减少了大量的计算量。直观上来说:假如说是4*4的数据,划分后每个窗口包括 M ∗ M M*M MM 块,这里假设 M = 2 M=2 M=2。如果进行MSA计算大概需要 ( 4 ∗ 4 ) 2 (4*4)^2 442的计算量,而进行W-MSA则大概需要 ( 2 ∗ 2 ) ∗ ( 2 ∗ 2 ) 2 (2*2)*(2*2)^2 22222。这样一对比瞬间计算的复杂度就降低了很多(当然上述只是为了方便简单的理解,下面就详细介绍W-MSA降低了多少复杂度)。

MSA (每个红框表示计算一次注意力)
W-MSA (红框大小表示计算注意力像素大小)
对于一个 $h*w*C$ 的图像,被分割后每个窗口包括 $M*M$ 块。则对应的MSA和W-MSA的计算如下式所示: $$ Ω(MSA)=4hwC^2 +2(hw)^2C \quad\quad\quad\quad\quad\quad \quad\quad \ \ \ \ \ \ (1) \\\ Ω(W-MSA)=4hwC^2 +2M^2hwC \quad\quad\quad\quad\quad\quad\quad (2) $$
  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

注意:前者与长宽 h w 成二次关系,后者在 M 固定时为线性关系(默认为7)。

  • 首先介绍下Self-Attention的计算
    Self-Attention的公式如下所示:
    A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d} })V Attention(Q,K,V)=SoftMax(d QKT)V

  • 计算Self-Attention的复杂度
    首先,Q、K、V的计算如下所示:
    Q h w ∗ C = X h w ∗ C ∗ W Q C ∗ C   K h w ∗ C = X h w ∗ C ∗ W K C ∗ C   V h w ∗ C = X h w ∗ C ∗ W V C ∗ C Q^{hw*C}=X^{hw*C}*W_Q^{C*C} \\\ K^{hw*C}=X^{hw*C}*W_K^{C*C} \\\ V^{hw*C}=X^{hw*C}*W_V^{C*C} QhwC=XhwCWQCC KhwC=XhwCWKCC VhwC=XhwCWVCC

    • X h w ∗ C X^{hw*C} XhwC 表示将所有像素(token)拼接在一起得到的矩阵(一共有hw个像素,每个像素的深度为C)
    • W Q C ∗ C W_Q^{C*C} WQCC W K C ∗ C W_K^{C*C} WKCC W V C ∗ C W_V^{C*C} WVCC 分别表示生成Q、K、V的变换矩阵

    因此,由矩阵复杂度计算公式可知Q、K、V的复杂度均为 h w ∗ C 2 hw*C^2 hwC2,此时总复杂度为 3 h w ∗ C 2 3hw*C^2 3hwC2
    然后,由Self-Attention的计算公式可知, Q K T QK^T QKT 的计算量如下所示:
    Q h w ∗ C K T ( C ∗ h w ) = A h w ∗ h w Q^{hw*C}K^{T(C*hw)} = A^{hw*hw} QhwCKT(Chw)=Ahwhw
    因此, Q K T QK^T QKT 的计算量为 C ∗ h w ∗ h w C*hw*hw Chwhw, 即 C ∗ ( h w ) 2 C*(hw)^2 C(hw)2 。忽略 d \sqrt{d} d S o f t M a x SoftMax SoftMax操作, A ∗ V A*V AV的计算量如下所示:
    A h w ∗ h w V h w ∗ C = A t t e n t i o n h w ∗ C A^{hw*hw}V^{hw*C} = Attention^{hw*C} AhwhwVhwC=AttentionhwC
    因此, A ∗ V A*V AV 的计算量为 h w ∗ C ∗ h w hw*C*hw hwChw, 即 C ∗ ( h w ) 2 C*(hw)^2 C(hw)2 。所以,Self-Attention公式的复杂度为 2 C ( h w ) 2 2C(hw)^2 2C(hw)2。Self-Attention总的复杂度为 2 C ( h w ) 2 + 3 h w ∗ C 2 2C(hw)^2+3hw*C^2 2C(hw)2+3hwC2

  • 计算MSA的复杂度
    多头注意力计算复杂度与自注意力复杂度仅缺少一个 ∗ V 0 *V_0 V0 的操作,因此总体复杂度缺少 h w ∗ C 2 hw*C^2 hwC2。所以MSA的复杂度为 2 C ( h w ) 2 + 4 h w ∗ C 2 2C(hw)^2+4hw*C^2 2C(hw)2+4hwC2

  • 计算W-MSA的复杂度
    对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C,这里每个窗口的高为M宽为M,带入公式得:
    4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2 + 2(M)^4C 4(MC)2+2(M)4C
    又因为有 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,则:
    h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 M 2 h w C \frac {h} {M} \times \frac {w} {M} \times (4(MC)^2 + 2(M)^4C)=4hwC^2 + 2M^2 hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
    故使用W-MSA模块的计算量为:
    4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2 hwC 4hwC2+2M2hwC
    假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:
    2 ( h w ) 2 C − 2 M 2 h w C = 2 × 11 2 4 × 128 − 2 × 7 2 × 11 2 2 × 128 = 40124743680 2(hw)^2C-2M^2 hwC=2 \times 112^4 \times 128 - 2 \times 7^2 \times 112^2 \times 128=40124743680 2(hw)2C2M2hwC=2×1124×1282×72×1122×128=40124743680

2.3.2 SW-MSA

由于W-MSA只能关注窗口本身的内容,而不允许跨窗口连接,窗口与窗口之间是无法进行信息传递的。而SW-MSA通过移位窗口的方式,引入跨窗口连接的同时保持非重叠窗口的高效计算。如下图左所示为第 l 层使用W-MSA的方式,而在下一层 l+1 层必定为 SW-MSA的方式(如右图所示),两者合在一起作为一个2阶段的 Swin Transformer Block模块。两幅图进行对比可以发现:右图相对于左图进行了偏移,长宽分别偏移了 M 2 \frac{M}{2} 2M 个像素单位(每个窗口为 M ∗ M M*M MM 像素)。
在这里插入图片描述
可以看出,偏移后的图像窗口变为了9个。为了提高计算的效率,作者提出了一种更有效的批处理计算方法,即向左上方向循环移位,如下图所示。在此转换之后,批处理窗口可能由特征映射中不相邻的几个子窗口组成,因此采用屏蔽机制(NLP中的masking 屏蔽不应该需要的信息)将自关注计算限制在每个子窗口内。
在这里插入图片描述
为了更方便地理解左上方向循环移位的操作,这里将具体过程做了一个图,具体内容如下图所示。
在这里插入图片描述
从上图可以看出,原始图像在进行移位后,A部分移动到右下角,B部分移位到最右边,C部分移位到最下边。然后将每个部分进行合并合并为等同于移位前窗口大小的窗口。
注意:移位后的信息会产生乱序,对于该问题,原文作者使用了Mask的方案。

3. 模型配置

最后,对Swin-Transformer各个版本的参数进行介绍。
在这里插入图片描述
其中,

  • win. sz 7x7 表示窗口大小为7x7
  • dim表示feature map的channel深度(或者说token的向量长度)
  • head表示多头注意力模块中head的个数

总结

关于Swin-Transformer模型中大多数内容都已经详细介绍了。当然,还有部分不重要的内容以及如何与代码想匹配没有介绍。后续可能会出一篇文章专门介绍相关代码说明。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/599646.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据的存储(浮点型)

目录 浮点型存储的规则 1.前面我们已经学过了整形在数据中的存储是以原码,反码,补码的形式在内存中存储的,那么浮点数是以什么样的形式存储的呢? 接下来我们通过一段代码来观察——> int main() {int n 9;float* p (float*…

String AOP的使用

面向切面编程,面向特定方法编程,以方法为对象,在不修改原方法的基础上,对方法进行操作扩展等,底层是通过动态代理实现的 使用开发步骤: 1、创建一个类,加上Aspect声明为一个AOP切面类&#xff…

2023 重新开始

感觉搞 IT 的日子最近都有点不太好过。 早上接到公司电话说今天是一个大日子。 为什么是大日子,相信所有人都是懂的。这次公司将会经历一次非常大的裁员,很不幸也在列表中。不过感觉这个好像也没有什么关系。 因为早就在意料之中的事情,经历…

c语言之结构体(初阶)

目录 1:结构体类型的声明 2:结构体初始化 3:结构体成员访问 4:结构体传参 1:结构体类型的声明 1:为啥要有结构体,因为当我们描述一个复杂对象的时候,可能平时我们的一个类型不能…

常见的五种排序

🐶博主主页:ᰔᩚ. 一怀明月ꦿ ❤️‍🔥专栏系列:线性代数,C初学者入门训练,题解C,C的使用文章,「初学」C 🔥座右铭:“不要等到什么都没有了,才下…

批量提取某音视频文案(二)

牙叔教程 简单易懂 之前写过一篇 批量提取某音视频文案 , 在之前的教程中, 我用的是微软的语音转文字功能, 今天我们换个方法, 使用 逗哥配音 的 文案提取 功能 准备工作 下载视频和音频 我在github找到的是这个仓库 https://github.com/Johnserf-Seed/TikTokDownload 注意一…

VLANIF虚接口案例实践

1)拓扑 2)需求: -所有PC能够ping通自己的网关 -实现vlan间互通,实现所有的PC互通 3)配置步骤: 第一步:给pc配置IP地址 第二步:交换机创建vlan,做access和trunk -所有的交换机都配…

传统图形学对nerf的对比与应用落地

作者今年参加了China3DV的盛会,大会的发表、线下讨论、学者、工业界等等的交流着实对于Nerf有了更深的思考,以下是作者的抛砖引玉,如有不当之处敬请指出~ 传统图形学与nerf的简介: 传统图形学:显示表达几何表达方式&…

【CloudCompare教程】010:点云的裁剪功能(分段、裁剪、筛选)

本文讲解CloudCompare点云的裁剪功能(分段、裁剪、筛选)。 文章目录 一、点云的分段二、点云的裁剪三、点云的筛选一、点云的分段 加载案例点云数据,如下图所示: 选中图层点云,点击工具栏中的【分割】工具。 点击【激活线状选择】工具: 在需要裁剪的点云上绘制现状裁剪范…

使用免费的SSL证书将nginx配置的普通网站修改为HTTPS网站

一、需求说明 已经在Centos8系统中使用nginx搭建了网站;但是该网站没有实现HTTPS协议不安全;现需要将网站升级为HTTPS站点。 Linux环境对Nginx开源版源码下载、编译、安装、开机自启https://blog.csdn.net/xiaochenXIHUA/article/details/130265983?spm=1001.2014.3001.5501

chatgpt赋能python:Python交易接口简介

Python交易接口简介 Python作为一种高级编程语言,被广泛用于各种不同的领域,其中包括金融市场交易。Python交易接口提供了一种优雅而简单的方式,使得交易者能够方便地执行自己的交易策略。 什么是Python交易接口? Python交易接…

Effective第三版 中英 | 第2章 创建和销毁对象 | 考虑静态工厂方法而不是构造函数

文章目录 Effective第三版第2章 创建和销毁对象前言考虑静态工厂方法而不是构造函数 Effective第三版 第2章 创建和销毁对象 前言 大家好,这里是 Rocky 编程日记 ,喜欢后端架构及中间件源码,目前正在阅读 effective-java 书籍。同时也把自己…

基于SSM的人才招聘网站

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

模拟实现库函数:strcpy

目录 通过cplusplus网站了解函数功能: 断言assert的使用: 关于const: 本篇你最应该了解的内容: 通过cplusplus网站了解函数功能: 要模拟实现库函数,首先我们需要了解这个函数的参数,函数的…

主机加固介绍

最近公司做服务器安全,开始在市场了解产品,对这一块算是短暂的研究了一段时间,有一点心得给大家分享一下。 主机加固 最近主机加固的概念被炒得火热,主机加固的功能也正在被致力于服务器安全的相关人士所关注。 那么究竟什么是主…

【CVPR2022】CSWin Transformer详解

【CVPR2022】CSWin Transformer详解 0. 引言1. 网络结构2. 创新点2.1 Cross-Shaped Window Self-Attention2.2 Locally-Enhanced Positional Encoding(LePE) 3. 实验总结 0. 引言 Transformer设计中一个具有挑战性的问题是,全局自注意力的计算成本非常高&#xff0…

chatgpt赋能python:Python代码怎么敲:了解Python编程语言

Python代码怎么敲:了解Python编程语言 Python是一种高级编程语言,具有易读易用和高效性等优点。这使得Python成为了程序员的最佳选择,并成为了广泛应用于机器学习、Web开发、数据分析等领域。 Python代码敲法:小技巧 Python代码…

chatgpt赋能python:Python主要语句介绍

Python主要语句介绍 Python是一种广泛使用的高级编程语言,其语法简介、易于学习,并有丰富的库和工具支持。在Python中,主要的语句可以帮助开发人员快速编写代码,实现各种各样的任务。在本文中,我们将介绍Python中的主…

性能优化之高Log file sync等待实战案例分享

故障情况 AWR报告如下: 之后他们把大部分业务停掉后,Log file sync等待事件还是非常高。 通过对比昨天跟今天相同时间的AWR,在业务量小非常多的情况,等待时间还是高非常大。 诊断过程 log file sync等待事件首先判断当前系统IO…

“微商城”项目(1环境搭建)

开发工具分享: 百度网盘: 链接:https://pan.baidu.com/s/1lSsCjf-_zx1ymu6uZeG26Q?pwdhuan 提取码:huan 一、环境搭建说明 本项目服务端环境要求为 Windows Apache PHP MySQL。 下面介绍如何搭建环境,部署服…