Swin-Transformer算法解析

news2024/12/22 22:50:31

本文参考:

SwinTransformer:使用shifted window的层级Transformer(ICCV2021)_tzc_fly的博客-CSDN博客

https://zhuanlan.zhihu.com/p/430047908

目录

1 为什么在视觉中使用Transformer

2 Swin-Transformer算法总体架构

3 Swin-Transformer Block详述

3.1 transformer通用结构

3.2 Swin-Transformer Block 结构

3.2.1 为什么区分W-MSA模式和SW-MSA模式

3.2.2 W-MSA模式

3.2.3 SW-MSA模式

3.2.4 SW-MSA模式为什么要进行cyclic shift/reverse cyclic shift操作

4 Window Attention详述

4.1 MSA计算

4.2 相对位置编码

5 Patch Merging详述

6 算法总结


1 为什么在视觉中使用Transformer

计算机视觉一直由CNN主导,CNN作为各种视觉任务的backbone网络,这些体系结构的进步导致了性能的提高,从而广泛提升了整个领域。

NLP中网络体系结构的演变走了一条不同的道路,今天流行的体系结构是Transformer。Transformer是为sequence modeling而设计的,它以关注数据中的长期依赖关系而闻名。它在语言领域的巨大成功促使研究人员研究了它对计算机视觉的适应性,在图像分类和联合视觉-语言建模效果较号。

我们相信,跨越计算机视觉和自然语言处理的统一体系可以使这两个领域受益。Swin transformer在各种视觉问题上的出色表现能鼓励视觉和语言信号的统一建模。

2 Swin-Transformer算法总体架构

过程描述:

(1)输入图像为(batch=2, channel=3, height, width),经过resize等预处理后变为(2, 3, 224, 224)

(2)给图像赋予embedding,以patch_size=4*4*3(3通道的4*4像素)为一个patch(NLP中的token),reshape后变为(2,56*56,96),即56*56个token,每个token的embedding维数为96.

(3)在Swin-Transformer Block模块中,首先以window_size=7切分patches得到。然后在window内通过W-MSA和SW-MSA进行多头注意力机制的计算。接着再还原回(B,H,W,C)。所以在整个阶段只优化了特征权重,并没有改变维度。

(4)Patch Merging阶段主要是进行降采样(缩小分辨率),并且调整通道数(embedding维度)

(5)经过多次第(4)、(3)步之后变为(B, H/32, W/32, 8C),即(2, 7*7, 768),reshape之后变为(2, 768, 49),在最后一维avgpooling后再squeeze变为(2, 768)

(6)然后通过全连接将最后一维变为分类数量,比如(2, 10)

3 Swin-Transformer Block详述

3.1 transformer通用结构

Swin-Transformer的核心采用了transformer的通用结构如下:

Swin-Transformer使用了成对的transformer结构,一个是W-MSA,另一个是SW-MSA。

W-MSA和SW-MSA分别是具有规则配置和移动窗口配置的多头自注意力模块。

W-MSA: Window Multi-head Self Attention

SW-MSA: Shifted Window Multi-head Self Attention

3.2 Swin-Transformer Block 结构

当Block为W-MSA模式时,cyclic shift和reverse cyclic shift是不存在的。当Block为SW-MSA模式时,以上步骤全部存在。

3.2.1 为什么区分W-MSA模式和SW-MSA模式

首先为了降低自注意力计算的复杂性,我们通过将图片按照window切分并只在window内计算注意力。

在上图中,Layer L为W-MSA模式计算自注意力,此时的缺陷就是缺乏跨窗口连接从而限制了建模能力。

而在Layer L+1层中,窗口分区被移动,从而产生新的窗口,新窗口中的自注意力计算跨越了层l中以前窗口的边界,从而提供了它们之间的连接,显著增强了建模能力。

通过交替使用W-MSA和SW-MSA,使得自注意力计算仅局限在窗口内部,同时又允许跨窗口连接。

3.2.2 W-MSA模式

假如Block输入的x为(2, 56, 56, 96),对应(batch, height, width, embedding)信息。

首先,Window Partition根据window_size=7将x分为一个个window,后续在window内进行MSA。经过该步骤后数据维度为(2*8*8, 7* 7, 96)。2*8*8解释:2为样本数量,8为height切分window_size后的数量,另一个8为width切分window_size后的数量。

然后,Window Attention在7*7窗口内进行MSA,该过程只进行权重更新不会更改数据维度,输出还是(128, 49, 96)。

最后,将数据还原到输入图像大小,即(2, 56, 56, 96)。

3.2.3 SW-MSA模式

假如Block输入的x仍为(2, 56, 56, 96),依然对应(batch, height, width, embedding)信息。

首先,通过cyclic shift对窗口元素进行移位,得到的数据维度仍然为(2, 56, 56, 96)

其次,Window Partition根据window_size=7将x分为一个个window,后续在window内进行MSA。经过该步骤后数据维度为(2*8*8, 7* 7, 96)。

然后,Window Attention在7*7窗口内进行MSA,该过程只进行权重更新不会更改数据维度,输出还是(128, 49, 96)。

接着,将数据还原到输入图像大小,即(2, 56, 56, 96)。

最后,通过reverse cyclic shift对之前移位的窗口进行反向移位操作,得到的(2, 56, 56, 96)。

3.2.4 SW-MSA模式为什么要进行cyclic shift/reverse cyclic shift操作

如上图所示,在SW-MSA模式下移动窗口会增加窗口的数量,从(h/M, w/M)个变成(h/M+1, w/M+1),如上图从4个增加到了9个,并且有些窗口大小是小于M*M的。

解决方法一:将大小不足M*M的窗口填充到M*M的大小,并在计算注意力时屏蔽这些填充值。但是这种native的做法会增加很多计算量(比如计算的窗口数量从2*2变成3*3)。

解决方法二:向左上角循环移位,在这个移位之后,一个批处理窗口可能由几个在特征图中不相邻的子窗口组成,使用mask机制将自注意力计算限制在每个子窗口内。通过循环移位,批处理窗口的数量与常规窗口分区的数量相同。

将图中浅色ABC windows转移到深色ABC的填充部分,这个操作可以用两次torch.roll实现,第一次将第一行移动到最后一行,第二次将第一列移动到最后一列。从而使得最后的feature map依然为2*2的windows,保持原有的计算量,然后再使用图中紫色部分的masked MSA进行计算。结束之后,再reverse cyclic shift。

当我们做cyclic shift 后有:

对于第二个特征图,我们可以很几何地按照之前地window划分方式(标准的2*2个窗口)去计算,但是对于3个窗口:即(4+6),(2+8),(1+3+7+9)的attention会混在一起,所以需要在计算每个窗口时进行mask MSA。

以窗口(4+6)为例,假设该窗口一共有4个patch:

当自注意力计算时,重点在于QKT,为了保证信息只在cyclic shift前的window内交互,我们要确保只存在属于window 4和window 4的两个patch计算attention,换言之就是在计算注意力时候只有行和列属于相同编号的元素才保留,其他元素都mask。

至此,我们利用shift后的feature,和上面说的mask结合,就能得到正确的MSA结果。

我们最后把shift还原,即reverse shift。

4 Window Attention详述

传统的Transformer都是基于全局来计算注意力,而swin-transformer则将注意力的计算限制在每个窗口内,从而减少了计算量。

计算公式为:

Swin-transformer使用attention机制与原始计算的区别在于公式中的QK^{T}计算后加入了相对位置编码。

4.1 MSA计算

假设输入数据维度为(128, 49, 96)。

首先,通过全连接将embedding维度乘以3,得到(128, 49, 288)。

然后,将维度变为(128, 49, 3, num_heads, 288/num_heads),继续变为(3, 128, num_heads, 49, 288/num_heads),然后q, k, v分别为(128, num_heads, 49, 288/num_heads)=(128, 3, 49, 96)。

接着,qk^{T}​​​​​​​得到attention值(128, 3, 49, 49),加上相对位置编码信息(1, 3, 49, 49)后,再进行softmax后乘以v得到最终值,输出维度仍然为(128, 49, 96)。

4.2 相对位置编码

此处相对位置编码并不是固定的,是需要训练的。在模型定义中初始化相对位置编码参数,在后续训练中更新这部分参数。

首先我们计算每个token的相对位置索引信息,这个是一次性的,和window_size相关。相对位置索引=基准点的绝对位置索引-该点的绝对位置索引。对于2*2的window来说,基准点的绝对位置索引分别为(0,0),(0,1),(1,0),(1,1)。

得到了相对位置索引之后,就可以根据索引拿到对应的编码权重值。

5 Patch Merging详述

降采样操作,用于缩小分辨率,调整通道数进而形成层次化的设计。

每次降采样时在行和列方向间隔2选取元素,然后在embedding维度拼接。因为H,W各缩小为1/2,所以embedding维度会变成原先的4倍。最后再通过一个全连接调整embedding维度为原来的2倍。

整体流程如下:

6 算法总结

(1)时间复杂度减少,主要是在于分别只在window内计算注意力

(2)为了融合不同窗口之间的信息,采用shifted window划分策略,不用于过去的sliding window这种native的方式,shifted window的特点在于cyclic shift + masked MSA + reverse cyclic shift,这实现了无padding且不增加窗口数量的情况下达到sliding window的效果

(3)层级的架构可以考虑不同尺度的window,从而获得多尺度信息

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/185916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# 源码 等值线(等高线)云图绘制 ,图上含等高线数值

C# 源码 数据格式为XYZ数据,XY为坐标,Z为对应的值 X Y Z -0.671053 -0.850000 83.330742 -0.671053 -0.850000 85.469604 -0.671053 -0.760526 89.225899 -0.671053 -0.760526 86.994576 -0.671053 -0.760526 86.994576 -0.671053 -0.760526 89.225899 -…

【解读】《云事件响应框架》:云服务用户响应和管理事件首选指南

微信搜索”国际云安全联盟“,回复关键词“云事件”下载本报告 当今互联时代,全面的事件响应策略对于需要管理与降低风险的组织必不可少。然而,在基于云的基础设施和系统的事件响应策略方面,部分由于云的责任共担特性,…

sql的四种连接——左外连接、右外连接、内连接、全连接

一、内连接 满足条件的记录才会出现在结果集中。 二、 左外连接(left outer join,outer可省略) 左表全部出现在结果集中,若右表无对应记录,则相应字段为NULL 举例说明: 客户表: 订单表&#x…

2023年2月系统集成项目管理工程师认证【报名入口】

系统集成项目管理工程师是全国计算机技术与软件专业技术资格(水平)考试(简称软考)项目之一,是由国家人力资源和社会保障部、工业和信息化部共同组织的国家级考试,既属于国家职业资格考试,又是职…

Qt 根据参数 自动生成vs 工程

一,需求 给算法部门提供一套代码框架,让其写算法dll。为了使dll能融入主工程,其框架对格式有一定要求,为了增加算法部门的快发效率,因此开发一个小工具,用于自动生成这套框架。 运行后,只需要…

cdh+dolphinscheduler开启kerberos

搭建环境多台linux主机搭建集群CDH 6.3.2 (Parcel)版本dolphinscheduler 1.3.2版本本流程在CDH已搭建完成并可正常使用后,开启kerberos功能dolphinscheduler用于大数据任务管理与执行,是很不错的任务调度平台,是否提前部署均可开启kerberos目…

数据结构与算法:二叉树的学习

1.了解树形结构 1.概念 树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。它具有以下的特点: …

《Unity Shader 入门精要》 第7章 基础纹理

第7章 基础纹理 纹理最初的目的就是使用一张图片来控制模型的外观。使用纹理映射技术(texture mapping),我们可以把一张图黏在模型表面,逐纹素(texel)(纹素的名字是为了和像素进行区分)地控制模型的颜色。…

爱了爱了,这是什么神仙级Apache Dubbo实战资料,清晰!齐全!已跪!

都2026年了 还没有用过Dubbo? Dubbo是国内最出名的分布式服务框架,也是 Java 程序员必备的必会的框架之一。Dubbo 更是中高级面试过程中经常会问的技术,面试的时候是不是经常不能让面试官满意?无论你是否用过,你都必须…

Postman(2): postman发送带参数的GET请求

发送带参数的GET请求示例:微信公众号获取access_token接口,业务操作步骤1、打开微信公众平台,微信扫码登录:https://mp.weixin.qq.com/debug/cgi-bin/sandbox?tsandbox/login2、打开微信开放文档,找到获取access_toek…

运放电路中各种电阻的计算-运算放大器

运放电路中各种电阻的计算 在学习运算放大器电路的时候,经常需要计算电路的: 输入阻抗Ri, 输出阻抗Ro, 同相端对地等效电阻RP, 反相端对地等效电阻RN, 这些参数很重要,在学习运放相关电路的时候经常要用到&#…

mysql8+mybatis-plus 查询json格式数据

sql 测试json表CREATE TABLE testjson (id int NOT NULL AUTO_INCREMENT,json_obj json DEFAULT NULL,json_arr json DEFAULT NULL,json_str varchar(100) DEFAULT NULL,PRIMARY KEY (id) ) ENGINEInnoDB AUTO_INCREMENT2 DEFAULT CHARSETutf8mb4 COLLATEutf8mb4_0900_ai_ci;IN…

API 网关策略二三事

作者暴渊,API7.ai 技术工程师,Apache APISIX Committer。 近些年随着云原生和微服务架构的日趋发展,API 网关以流量入口的角色在技术架构中扮演着越来越重要的作用。API 网关主要负责接收所有请求的流量并进行处理转发至上游服务,…

【数据结构和算法】认识队列,并实现循环队列

上接前文,我们学习了栈的相关知识内容,接下来,来认识一个与栈类似的,另一种特殊的线性表,队列,本文目的是了解并认识队列这一概念,并实现循环队列 目录 一、认识队列 1.队列的概念 2.队列的实…

入门力扣自学笔记232 C++ (题目编号:1669)

1669. 合并两个链表 题目: 给你两个链表 list1 和 list2 ,它们包含的元素分别为 n 个和 m 个。 请你将 list1 中下标从 a 到 b 的全部节点都删除,并将list2 接在被删除节点的位置。 下图中蓝色边和节点展示了操作后的结果: 请…

Docker-harbor私有仓库部署与管理

目录 前言 一、Harbor概述 二、Harbor的特性 三、Harbor的构成 四、Harbor构建Docker私有仓库 环境配置 部署Harbor服务 物理机访问server IP 添加项目并填写项目名称 通过127.0.0.1来登陆和推送镜像 其他客户端上传镜像到Harbor 维护管理Harbor 创建Harbor用户 …

JavaWeb_JavaScript

一、简介 JavaScript 是一门跨平台、面向对象的脚本语言,而Java语言也是跨平台的、面向对象的语言,只不过Java是编译语言,是需要编译成字节码文件才能运行的;JavaScript是脚本语言,不需要编译,由浏览器直接…

GPT-3是精神病患者吗?从心理学角度评估大型语言模型

原文链接:https://www.techbeat.net/article-info?id4494 作者:seven_ 20世纪60年代,麻省理工学院人工智能实验室的Joseph Weizenbaum编写了第一个自然语言处理(NLP)聊天机器人ELIZA[1],ELIZA通过使用模式…

linux Redis 集群搭建

在单例模式下继续执行,新增文件夹将之前解压后的文件复制到新增的文件夹中修改配置文件,并放入bin中bind 10.88.99.251(ip设置)protected-mode yes(默认yes,开启保护模式,限制为本地访问&#x…

ASEMI整流桥GBU808在选型的过程中需要注意几点

编辑-Z 型号:GBU808 最大重复峰值反向电压(VRRM):800V 最大RMS电桥输入电压(VRMS):560V 最大直流阻断电压(VDC):800V 最大平均正向整流输出电流&#xf…