爆火的Swin Transformer到底是什么

news2024/11/27 2:22:12

文章目录

  • 一、名称解读
  • 二、ViT回顾
  • 三、Swin Transformer vs ViT
  • 四、Swin Transformer结构
    • 4.1 Patch Merging模块
    • 4.2 相对位置编码
    • 4.3 shifted window
  • 五 总结

  爆火的Swin Transformer究竟是个啥?今天本篇文章系统讲解下Swin t 结构、优点、位置编码、移位窗口shifted window,并附上部分代码的解释

论文名称:《Swin Transformer:Hierarchical Vision Transformer using Shifted Windows》,简称Swin Transformer、Swin T
论文下载:https://arxiv.org/pdf/2103.14030.pdf
代码地址:https://github.com/microsoft/Swin-Transformer

一、名称解读

  其实论文名字就很好的点出了Swin Transformer的特点,Swin是指Shifted window,使用移位窗口的多层级视觉Transformer,重点在于Hierarchical多层和Shifted Windows移位窗口

二、ViT回顾

  ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作。

图一 ViT结构

  ViT将二维图片切分为patch,然后序列输入,经过一个线性层(全连接),再加上position,对应图片是左边的输入Patch+Position embedding,在输入的下面还有一行小字Extralearnable [class] embedding,它是特殊字符CLS,借鉴Bert,根据它的输出做分类的判断(transformer的输入和输出维度相同,但是只要一个分类结果,所以增加了一个cls token)。然后经过一个标准的Transformer Encoder和MLP头,输出结果。

三、Swin Transformer vs ViT

Transformer应用到图像领域主要有两大挑战:

  • 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较大

  遇到高分辨率的图像,采用ViT处理会产生极大的计算复杂。而Swin Transformer的复杂度会比ViT低,主要通过下面两点降低计算复杂度:
  (1)分层特征图
  (2)局部窗口计算注意力

图二

   (a) 图表示Swin Transformer 通过合并更深层的图像块(以灰色显示)来构建分层特征图,并且由于仅在每个局部窗口内计算自注意力,因此对输入图像大小具有线性计算复杂度(红色)。
   (b) 图是ViT生成单个低分辨率的特征图,并且由于全局自注意力的计算,输入图像大小具有二次方计算复杂度。

四、Swin Transformer结构

图三 Swin Transformer结构

  Swin Transformer的结构还是比较简洁的,它的输入和ViT类似,也是将图片切patch序列输入。用4x4的大小切分成patch,则每个patch是4x4x3=48,输出是 H 4 × W 4 × 48 \frac{H}{4} \times \frac{W}{4} \times 48 4H×4W×48。在输入阶段,位置编码用了绝对位置编码,可以加也可以不加,作者代码中可通过self.ape参数进行选择。输入的二维图片切分patch,是通过卷积实现,将224x224x3的图片转换为56x56x96,输入部分的代码如下:

def forward(self, x):
    B, C, H, W = x.shape
    # FIXME look at relaxing size constraints
    assert H == self.img_size[0] and W == self.img_size[1], \
        f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

    # 1、self.proj卷积:kernel=4,stride=4,224x224x3->56x56x96,公式(w+2p-k)/s + 1
    # 2、flatten:将二维转为一维,[N, 96, 3136]
    # 3、transpose:维度转换,[N, 3136, 96]
    x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
    if self.norm is not None:
        x = self.norm(x)
    return x

  stage1:线性Embedding -> 2x Swin Transformer Block,输出 H 4 × W 4 × C \frac{H}{4} \times \frac{W}{4} \times C 4H×4W×C
  stage2:Patch Merging -> 2x Swin Transformer Block,输出 H 8 × W 8 × 2 C \frac{H}{8} \times \frac{W}{8} \times 2C 8H×8W×2C
  stage3:Patch Merging -> 6x Swin Transformer Block,输出 H 16 × W 16 × 4 C \frac{H}{16} \times \frac{W}{16} \times 4C 16H×16W×4C
  stage4:Patch Merging -> 2x Swin Transformer Block,输出 H 32 × W 32 × 8 C \frac{H}{32} \times \frac{W}{32} \times 8C 32H×32W×8C

4.1 Patch Merging模块

  Patch Merging的作用是降维、升通道,例如stage2,先经过Patch Merging,再经过两个Transformer Block结构,已知Transformer输入和输入维度大小不变,即输入Transformer结构的维度是 H 8 × W 8 × 2 C \frac{H}{8} \times \frac{W}{8} \times 2C 8H×8W×2C,但是输入Patch Merging的维度是 H 4 × W 4 × C \frac{H}{4} \times \frac{W}{4} \times C 4H×4W×C,说明Patch Merging把输入缩小了一半,维度增加了一倍。

图四 Patch Merging流程图

实现流程:
  (1)在行方向和列方向上,间隔2选取元素
  (2)拼接成张量,通道变成4 * dim
  (3)self.reduction:全连接,通道变成2*dim

 def forward(self, x):
    """
    x: B, H*W, C
    """
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"
    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

    x = x.view(B, H, W, C)

    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

    x = self.norm(x)
    x = self.reduction(x)

    return x

  看代码似乎还是不太明白,这里(1)在行方向和列方向上,间隔2选取元素,用了::,没用过这个方法的,可以看下这个例子,间隔n位取数

图五 ::用法

4.2 相对位置编码

  Swin Transformer的相对位置编码并不是用在输入部分,而是在Attention计算过程中,QK计算得到attn张量,再加上位置编码。
  若特征图按7x7的窗口划分,每个窗口有49个token,他们之间是有一定的位置关系。下面为了方便展示,用2x2大小的图表示。例如2x2的特征图,经过attention的QK计算,变成4x4,那么相对位置编码的大小也是4x4,图六就是相对位置编码。

图六 相对位置编码

  图六这个编码是怎么得到的?我们一步步详细解释。
(1)绝对位置编码:对于2x2的特征图,它的绝对位置编码是二维的,行和列用0和1表示,如图七。

图七 绝对位置

(2)以不同颜色为起点,其他像素的相对位置如图八

图八 相对位置

(3)将(2)中的图拉直拼接,得到一个4x4大小的图,如图9

图九 拉直拼接

(4)可以发现图9中的数值,既有0、1,也有负数-1,为了使值都大于等于0,行列都加上(M-1),如图十

图十

(5)图十中的位置都是二维,为了得到一维的结果,可以想到的一个方法是将行和列加起来(不同位置数值相同,方法不可取,如图11),另外一个方法是行坐标都乘上2M-1,再和列相加,如图12,得到的结果就跟图六相同。

图11 行列相加
图12

代码实现,可以把代码拷贝到脚本里跑下,跟上面的图做对比:

window_size = [2, 2]
# coords_h、coords_w分别用来代表行列的值,绝对位置
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])   # tensor([0, 1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) #torch.Size([2, 2, 2])
coords_flatten = torch.flatten(coords, 1)    # torch.Size([2, 4])
"""
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
"""
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  
# torch.Size([2, 4, 4]),得到行列的相对位置
"""
tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])
"""
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  
# torch.Size([4, 4, 2])
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1   # 横坐标乘上(2M-1)
relative_position_index = relative_coords.sum(-1)  # 行、列相加
"""
tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])
"""

4.3 shifted window

  Transformer block用了两种attention,W-MSA(window-multi-head self attention modules,常规attention)和SW-MSA(shifted window-multi-head self attention modules), 图13

图13 transformer

  假设在原图中,被分为4个窗口,向左向下移位两格,变成9个窗口,图14。

图14 移位窗口

  前面提到,attetion只在各个小窗口中计算,那么原本需要4个q、k、v计算的attention,变成了9个q、k、v。在实际代码中,作者通过对特征图移位,并给 Attention 设置 mask 来间接实现的。能在保持原有的 window 个数下,最后的计算结果等价。这是什么意思?我们给九个移位后的窗口用0-8进行编码,如下面图15的左图,再将窗口进行移位,重新拼接成只有四个窗口的图,如右图。
.

图15

  你可能会说,attetion只在小窗口中计算,那例如把窗口5和3拼接成一个窗口,就不能实现窗口attetion了。这里,作者通过设置合理的 mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果,图16,窗口4,拉伸成一维,进行QK计算得到attention向量,而对比5/3窗口,通过QK计算后,其实得到的应该是[5, 53, 5, 53]、[35, 3, 35, 3]、[5, 53, 5, 53]、[35, 3, 35, 3],但是5窗口attention只计算自己,就mask掉,对应图片灰色格子无数字部分。

图16 mask attention

五 总结

  Swin Transformer的两个重点就是位置编码和mask attention,在四中做了详细的介绍。作者提供的代码中,Swin t 可以实现很多任务,分类、目标检测、分割、半监督、特征蒸馏等。
  如果文章对您有所帮助,记得点赞、收藏、评论探讨✌️

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1167022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

材料高温环境电磁参数测试系统 (1GHz-500GHz)

材料高温环境 电磁参数测试系统 (1GHz-500GHz) 材料高温环境电磁参数测试系统测试频率范围可达1GHz~500GHz,最高测试温度达1200℃,可实现高温环境下材料复介电常数、复磁导率、反射率等参数测试。系统由矢量网络分析…

分享四款简单实用的视频下载工具

今天来和大家分享几个视频下载工具。毕竟,在这个自媒体时代,非常多的小伙伴有下载视频的需求,话不多说,直接上干货! 1.Downni 这是一个在线视频下载网站。它支持几乎所有的主流视频平台;然后它还支持多种分…

UNI-APP中如何通过配置访问代理,解决跨域问题

主要思路 通过配置manifest.json文件中的h5选项来完成设置h5是一级属性。 修改manifest.json文件 以下两种都能满足要求 "h5" : {"devServer" : {"https" : false,"port" : 8081,//见备注1"disableHostCheck" : true,&q…

四旋翼无人机PID控制Simulink仿真

底部有完整文件地址 整体采用内外环方式对四旋翼的位置和姿态进行控制 Simulink整体模型图 Matlab版本:R2022a 姿态控制效果 滚转角 ϕ \phi ϕ: 俯仰角 θ \theta

**LEEDCODE 498对角线遍历

class Solution { public:vector<int> findDiagonalOrder(vector<vector<int>>& mat) {int n mat.size();int m mat[0].size();std::vector<int> a;for(int i 0; i < mn-1; i){// 偶数 下往上if(i % 2 0){// 起点 x min(i, n - 1) …

材质之选:找到适合你的地毯

当谈到家居装饰时&#xff0c;地毯是一个经常被忽视的重要元素。但事实上&#xff0c;地毯在家居中扮演了至关重要的角色&#xff0c;不仅可以增加舒适感&#xff0c;还可以改善室内的整体氛围。在这篇文章中&#xff0c;我们将探讨地毯的选择、尺寸、形状和材质&#xff0c;以…

Domino中和邮件安全有关的SPF、DKIM介绍

大家好&#xff0c;才是真的好。 首先&#xff0c;偷偷告诉大家一个非常好的消息&#xff0c;2023年12月7号上午10点&#xff08;美国东部时间&#xff09;&#xff0c;HCL将正式发布Domino 14.0和Sametime 12.0.2版本&#xff01; 发布时间居然能精确到几点钟&#xff0c;很…

同城信息发布小程序制作

信息发布系统的应用较为广泛&#xff0c;如同城信息系统、企业获取各业线索、内部交流、论坛社区小程序、校园交流平台等&#xff0c;如果运营得到往往能起到不错的效果。 它可以连接供需双方&#xff0c;用户获取和发布需求信息、参与互动交流&#xff0c;高效打造有竞争力的…

去库存化进入尾声,内存模组厂加速复苏 | 百能云芯

近期&#xff0c;半导体市场开始传递复苏信号&#xff0c;国内大型半导体制造商纷纷提到明年将迎来多款消费性电子产品的需求回升&#xff0c;这表明半导体设计库存的去化即将进入尾声。另一个重要的迹象是内存市场&#xff0c;自去年以来&#xff0c;国外许多记忆体制造商面临…

数据库-引擎

存储引擎&#xff1a; mysql当中数据用各中不同的技术存储在文件中&#xff0c;每一种技术都使用的是不同的存储机制&#xff0c;索引技巧 索引水平&#xff0c;以及最终提供的不同功能和能力&#xff0c;这些就是我们说的引擎。 功能&#xff1a; mysql将数据存储在文件系统…

vscode设置保存后,自动格式化代码

第一步&#xff1a;打开setting.json文件 第二步&#xff1a;在setting.json中加入以下代码 "editor.formatOnType": true, "editor.formatOnSave": true, "editor.formatOnPaste": true

文心一言 VS 讯飞星火 VS chatgpt (127)-- 算法导论11.2 2题

二、用go语言&#xff0c;位向量(bit vector)是一个仅包含0和1的数组。长度为m的位向量所占空间要比包含m个指针的数组少得多。请说明如何用一个位向量来表示一个包含不同元素(无卫星数据)的动态集合。字典操作的运行时间应为 O(1)。 文心一言&#xff0c;代码正常运行&#x…

企业级人像背景分割解决方案

人工智能和机器学习已经在各个行业领域展现出了巨大的潜力。其中&#xff0c;人像背景分割技术作为一项具有广泛应用价值的解决方案&#xff0c;正在受到越来越多企业的关注。美摄科技&#xff0c;作为一家专注于图像和视频处理的人工智能公司&#xff0c;以其卓越的技术实力和…

携手海外网红营销:出海品牌2023黑色星期五的制胜法宝

作为一年一度的购物狂欢节&#xff0c;2023年黑色星期五即将到来。随着全球化的加速&#xff0c;越来越多的品牌开始将目光投向海外市场&#xff0c;寻求在这个特殊的时刻实现销售的飙升。其中&#xff0c;海外网红已经成为数字营销的中坚力量&#xff0c;能够为品牌带来广泛的…

Linux离线安装MySQL8报缺少perl包问题

前言&#xff1a;Linux在线安装MySQL是比较简单的&#xff0c;这里主要介绍离线安装 linux版本为CentOS7&#xff0c;具体为&#xff1a;CentOS-7-x86_64-DVD-2009.iso mysql版本为8&#xff0c;具体为&#xff1a;mysql-8.2.0-1.el7.x86_64.rpm-bundle.tar 准备工作 安装之前…

Excel根据给定值,锁定所在行

使用方法 使用函数&#xff1a;MATCH(给定的值,查找的列) 适用示例 在表格中查询“2023/1/15”所在的行&#xff1a; 最终结果&#xff1a;

windows mysql安装

1、首先去官网下载mysql安装包&#xff0c;官网地址&#xff1a;MySQL :: Download MySQL Community Server 2&#xff1a;把安装包放到你安装mysql的地方&#xff0c;然后进行解压缩&#xff0c;注意&#xff0c;解压后的mysql没有配置文件&#xff0c;我们需要创建配置文件 配…

上传LaTeX版本的NeurIPS文章到arXiv总是Failed的解决方案

往arXiv上传NeurIPS模版文章时&#xff0c;一直出现两处报错&#xff0c;一处是下图中的图片错误&#xff1a; 但是&#xff0c;我怀疑是不是图片并排放置的minipage不可用&#xff0c;于是改成了正常的图片形式来测试&#xff1a; 仍然是相同的错误&#xff0c;于是我又尝试去…

一个使用uniapp+vue3+ts+pinia+uview-plus开发小程序的基础模板

uniappuviewPlusvue3tspiniavite 开发基础模板 使用 uniapp vue3 ts pinia vite 开发基础模板&#xff0c;拿来即可使用&#xff0c;不要删除 yarn.lock 文件&#xff0c;否则会启动报错&#xff0c;这个可能和 pinia 的版本有关&#xff0c;所以不要随意修改。 拉取代码…

Redis中String类型的命令

目录 Redis中的内部编码 redis的数据结构和内部编码 Redis中的String类型 String类型的常见命令 set get mget mset String类型的计数命令 incr incrby decr incrbyfloat 其他命令 append getrange setrange strlen String类型的内部编码 Redis中的内部编码…