【Paper Note】Swin Transformer: Hierarchical ViT using Shifted Windows

news2025/3/13 3:08:22

Swin Transformer: Hierarchical ViT using Shifted Windows

  • 概述
  • 核心思想
  • 整体结构
    • 名词解释
    • 与vit区别
  • 模型处理过程
    • 概括
    • Patch Embedding
    • BasicLayer
      • Patch Merging
      • Swin Transform Block
        • Window Attention
        • Shifted Window Attention
        • 小结
  • 模型使用及代码
    • 模型使用
      • 环境配置
      • SwinT
    • 代码
      • Patch Embedding
      • Patch Merging
      • Mask

概述

1.SwinTransformer想设计一个可以作为密集预测任务的Transformer Backbone,其采用PatchMerging的策略,构建了层次化的特征,使得其可以作为密集预测任务的Backbone。
2.同时考虑到密集预测任务中,tokens数目太多导致计算量过大的问题,其采用一种在local window内部计算Self-Attention的机制去降低计算复杂度,使得整体计算复杂度由O(N^2)降低至O(N)水平。
3.为了弥补Local Self-Attention带来了远程依赖关系缺失的问题,其创新性地采用了Shift Window操作,引入了不同window之间的关系,并且在精度以及速度上都超越了简单的Sliding Window的方法。

核心思想

Swin Transformer就是想让 Vision Transformer像卷积神经网络一样,也能够分成几个 block(分组计算),也能做层级式的特征提取,从而导致提出来的特征有多尺度的概念

分组计算的复杂度优势

  • 原生 Transformer 对 N 个 token 做 Self-Attention ,复杂度为 NxN,
    0 Swin Transformer 将 N 个 token 拆为 N/n 组, 每组 n (n设为常数)个token 进行计算,复杂度降为 [N*nxn] ,考虑到 n 是常数,那么复杂度其实为N。

分组计算导致的问题和解决方式

  • 其一是分组后 Transformer 的视野局限于 n 个token,看不到全局信息
    • 对于问题一,Swin Transformer 的解决方案即 Hierarchical,每个 stage 后对 2x2 组的特征向量进行融合和压缩(空间尺寸HxW变成0.5Hx0.5W,特征维度C->4C->2C ),这样视野就和 CNN-based 的结构一样,随着 stage 逐渐变大。
  • 其二是组与组之间的信息缺乏交互。
    • 对于问题二,Swin Transformer 的解决方法是 Shifted Windows,

整个Swin Transformer其实最重要的就两个点:

  • 相对位置信息
    • 核心点在于可以把每种相对位置信息和att对应的一行信息对应上
  • 移动窗口注意力机制
    • 移动窗口注意力机制核心点在于mask,mask矩阵的生成是通过窗口索引tensor相减得到的;

综合就是两个优点:

  1. 相比于ViT,Swin Transfomer 计算复杂度大幅度降低,具有输入图像大小线性计算复杂度
  2. Swin Transformer随着深度加深,逐渐合并图像块来构建层次化Transformer,可以作为通用的视觉骨干网络,应用于图像分类、目标检测和语义分割等任务。

Swin transformer和viT的架构不同之处:
在这里插入图片描述

整体结构

在这里插入图片描述
上图有四个stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。
流程解释:

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。
  • 在每个Stage里,由Patch Merging和多个Block组成
  • Patch Merging模块主要在每个Stage一开始降低图片分辨率
  • Block具体结构如右图所示,主要是LayerNorm(LN),MLP(Multilayer Perceptron多层感知器),Window Attention 和 Shifted Window Attention组成

名词解释

假设输入图片的尺寸为224X224,先划分成多个大小为4x4像素的小片,每个小片之间没有交集。

  • patch:224/4=56,那么一共可以划分56x56个小片。每一个小片就叫一个patch,
  • token:每一个patch将会被对待成一个token。所以patch=token。
  • window:而一张图被划分为7x7个window,每个window之间也没有交集。那么每个window就会包含8x8个patch

与vit区别

  1. patch大小:与ViT一样对于输入的图像信息先做一个PatchEmbed操作将图像进行切分后打成多个patches传入进行后续的处理,但与ViT不同的是初始的切分不再以16 * 16的大小,而是以4 * 4的大小(为了看到更多细节)
  2. PatchMerging且后续通过PatchMerging的操作不断增加尺寸,进而可以得到多尺度信息便于在目标检测和语义分割中的使用
  3. 位置编码:ViT在输入会给embedding进行位置编码。
    Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
  4. 分类:ViT会单独加上一个可学习参数,作为分类的token。
    Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层

模型处理过程

概括

PatchEmbed将图像换分为多个patches,
之后接入多个BasicLayer进行处理(默认是和上述结构图一致,4个虚线框中的结构),
再然后将结果做avgpool输出计算结果
最后再进行分类操作(所以这里与ViT中不一样的是并没有采用一个cls token来进行分类而是对多个tokens取均值参与最终的分类运算)

Patch Embedding

不能直接将一整幅图片作为一个patch,所以需要对图像进行切分然后处理为一个patch,但与ViT不同的是,Swin-T不在以16*16作为一个切割大小,而是以4 * 4作为切分大小,并通过后续的Patch Merging操作不断增大每个Patch的大小,进而实现多尺度变化

BasicLayer

生成Patch之后就进入Swin- Transformer的核心模块部分了,每个basiclayer主要是由若干个Swin-Transformer Block和一个Patch Merging

Patch Merging

  • 作用:是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 ,类似于CNN中Pooling层。进而形成层次化的设计,同时也能节省一定运算量。
  • 启发:在做Window Attention这个操作时,数据的维度变换是和CNN是有些相似的地方的,当然SwinTransformer的初衷也是想让Transformer能像CNN一样能够分成多个Block,进而在不同层级的Block之间提取到分辨率不同的特征信息
  • 实现:SwinTransformer引入了Patch Merging操作来实现,类似于CNN的池化的操作
    在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。
    每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。
    然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍

Swin Transform Block

这部分是整个程序的核心,它由窗口多头自注意层(window multi-head self-attention, W-MSA)和移位窗口多头自注意层(shifted-window multi-head self-attention, SW-MSA)组成
包含了论文中的很多知识点,涉及到相对位置编码、mask、window self-attention、shifted window self-attention

整体流程如下:

  • 输入到该stage的特征 z的l-1 先经过LN进行归一化
  • 经过W-MSA进行特征的学习
  • 接着的是一个残差操作得到 z ^ l z\hat{}^l z^l的估计值(头上带个帽子就是估计值的意思)
  • 接着是一个LN,一个MLP以及一个残差,得到这一层的输出特征 z l z^l zl
  • SW-MSA层的结构和W-MSA层类似,不同的是计算特征部分分别使用了SW-MSA和W-MSA,
    可以从上面的源码中看出它们除了shifted的这个bool值不同之外,其它的值是保持完全一致的。这一部分可以表示为式(2)

在这里插入图片描述

在这里插入图片描述
Swin Transformer使用window self-attention降低了计算复杂度,为了保证不重叠窗口之间有联系,采用了shifted window self-attention的方式重新计算一遍窗口偏移之后的自注意力,所以Swin Transformer Block都是成对出现的 (W-MSA + SW-MSA为一对) ,不同大小的Swin Transformer的Block个数也都为偶数,Block的数量不可能为奇数。

Window Attention

传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。
Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
Window Attention与传统的Attention主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码

在这里插入图片描述

绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,
相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数B。

实际上这里在参与Attention计算的B 是relative_position_bias_table这个可学习的参数,而relative_position_index则是作为一个index去取relative_position_bias_table中的值来参与运算
有了相对位置索引(relative_position_index)之后,后续将相对位置bias(relative_position_bias_table)加入 Q K T QK^T QKT
这里比较难理解的就是relative_position_index的生成代码,如下图所示为整个relative_position_index的生成过程:

假设window_size = 2*2即每个窗口有4个token [M=2] ,如图1所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图2所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

在这里插入图片描述
第一行就是以蓝色为中心的坐标,第二行是以紫色框为中心各颜色框的坐标,以此类推

下图没有明确的计算过程但更加清晰
在这里插入图片描述
然后再最后一维上进行求和,展开成一个一维坐标,并注册为一个不参与网络学习的变量

Shifted Window Attention

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

shifted window也就是把左侧的“规则”windows变为右侧“不规则”的windows,因为这样就能实现左侧“规则”windows之间的“信息交流”

在这里插入图片描述

左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本四个窗口变成了9个窗口。

为此论文提出了一种针对于shifted window Attention更加高效的计算方式,如下图所示,为论文提供的高效计算shifted window Attention的示意图

在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价。

在这里插入图片描述

  1. 将特征数据进行cyclic shift操作,这个操作具体的代码中是使用的torch.roll实现的,如下图,通过将A B C三个区域的数据移动到如图的位置,那么整个窗口的划分就变得大小一致了
    在这里插入图片描述2. Attention Mask:通过设置合理的mask,让Shifted Window Attention在与Window Attention相同的窗口个数下,达到等价的计算结果。得到大小一致的窗口之后,再进行带掩码的MSA操作,因为shift之后windows的大小都一致,所以在进行Attention计算时就比较好并行计算,同时通过掩码的作用,原本不属于同一个窗口的数据进行Attention之后也不会得到较高的注意力(比如蓝天和草原之间的Attention值就不会高)。

如下图,window_size=2,shift_size=-1,最左侧方块所示,我们分别对这9个方块编号为0~8,那么经过roll处理以后,每个区域的位置分布就如第二个方块所示;

再以window_size在每个window内做带掩码的MSA,具体而言就是相同编号的区域做MSA时就没有mask,不同区域之间做MSA就需要有掩码,例如

右下侧的那个window内一共有4个区域的数据(8,6,2,0),那么区域8的Q只和区域8的K^ T相乘时才不带掩码,与其他区域的K^T相乘都需要带掩码,计算结果就如右下侧的红色框中所示:
在这里插入图片描述
3. reverse cyclic shift
把之前cyclic shift的shift参数设置成对应的正数就行

小结

首先我们对Shift Window后的每个窗口都给上index,并且做一个roll操作(window_size=2, shift_size=-1)

在这里插入图片描述

希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。
而要想在原始四个窗口下得到正确的结果,我们就必须给Attention的结果加入一个mask(如下图最右边所示)
最后正确的结果如下图所示

在这里插入图片描述
引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量。
在Shift Window Attention部分,用一个mask和移位操作,很巧妙的实现计算等价。

模型使用及代码

模型使用

环境配置

环境配置参考Swin Transformer算法环境配置(语义分割)

SwinT

Swin-Transformer最核心的部分制成了一个类似于nn.Conv2D的接口并命名为SwinT。其输入、输出数据形状完全和Conv2D(CNN)一样,这极大的方便了使用Transformer来编写模型代码。

参考SwinT-让Swin-Transformer的使用变得和CNN一样方便快捷

代码

代码讲解参考Swin-Transformer(原理 + 代码)详解

非常详细的原理和代码展示【深度学习】详解 Swin Transformer (SwinT)

Patch Embedding

  • Patch Partition
    作用:将RGB图转为非重叠的patch块。这里的patch尺寸为 4x4,乘上对应的RGB通道可得大小为4 x 4 x3=48。

  • Linear Embedding
    作用:将处理好的patch投影到指定的维度,这里embed_dim=96。

  • 核心代码实现
    通过设定固定大小(4*4)的patch进行卷积,实现Patch Partition,再设定输出通道实现 Linear Embedding

self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size,stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

Patch Merging

作用:将传入矩阵划分为2 x 2 大小的窗口,每个窗口的对应位置(例如下图中的同色块[^3])相merge,再对merge后的四个特征矩阵相concatenate。最后经过layer normalization和linear layer降维。
在这里插入图片描述
Layer normalization和Linear layer的初始化

self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

其中由图可知,每一层通道在传递给LayerNorm时都是原通道的4倍。传递给Linear时同理,Linear的输入为原通道的4倍,输出为原通道的2倍。

Merging的实现

  def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

其中12-17行的作用是对行数或者列数是奇数的层进行扩充;
19-24完成的是Merging操作,即每隔2行2列取一次元素并将这些元素沿最后一个维度(通道维度)concat

Mask

构建Mask是为了以后SW-MSA移动后窗口只对连续部分做self-attention,整个构建过程分为两步。

  def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍,起到了padding的作用
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/433214.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

跨境卖家都要知道的:对话式销售

买家可以用他们的指纹登录大多数东西,并通过与它交谈来管理他们的日历。这些人不会填写一份表格,如果他们填写的字段越多,表格的长度就会越长。如果他们知道只会受到骚扰,他们当然不会下载某些东西。 相反,他们更喜欢…

[Linux系统]系统安全及应用一

系统安全及应用一、账号安全基本措施1.1系统账号清理1.1.1将非登录用户的shell设为/sbin/nologin1.1.2锁定长期不使用的账号1.1.3删除无用的账号1.1.4锁定账号文件文件chattr1.1.5查看文件校验和md5sum1.2密码安全控制1.2.1设置密码有效期1.3历史命令限制1.3.1 减少记录命令的条…

5GHz无线局域网系统模拟

移动电视双天线分集接收技术 随着DVB-T在手机电视、车载电视、楼宇电视、地铁电视等户外广播领域内的发展,在这些接收范围内,多径衰落、多普勒频移等小范围衰落是不可避免的问题,解决这些衰落和干扰成为倍受关注的问题。为了解决衰落&#x…

rk3568 点亮LCD (BT656 BT1120)

rk3568 适配 BT656/BT1120 BT.656 TX 和 BT.1120 TX,是一种并行输出接口,而 Camera 对应的是 BT.656 RX和 BT.1120 RX,是一种并行输入接口,两则在协议上是一致的。与同为并口的RGB非常像,在rk3568 芯片上RGB和BT656/B…

【jenkins】Jenkins连接 Gitlab实现 push代码自动构建

目录 一、安装插件 二、构建任务 三、为任务配置触发器 四、到gitlab进行设置webhooks 4.1 设置网络 4.2 到jenkins对应项目的源码库 4.3 测试 4.3.1 点击测试--标签推送事件 4.3.2 点击编辑 一、安装插件 持续部署的第一步需要检查是否安装gitlab插件: gitla…

Doris(9):删除数据(Delete)

Delete不同于其他导入方式,它是一个同步过程。和Insert into相似,所有的Delete操作在Doris中是一个独立的导入作业,一般Delete语句需要指定表和分区以及删除的条件来筛选要删除的数据。 Doris 目前可以通过两种方式删除数据: DE…

记录-JavaScript常规加密技术

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 当今Web开发中,数据安全是一个至关重要的问题,为了确保数据的安全性,我们需要使用加密技术。JavaScript作为一种客户端编程语言,可以很好地为数据进行加…

Spring Boot 安全

目录 1.概述 2.token 2.1.理论 2.2.使用 3.JWT 3.1.理论 3.2.使用 4.oauth 5.Spring Security 5.1.概述 5.2.基本认证授权 5.3.加密 1.概述 在后端来说,安全主要就是控制用户访问,让对应权限的用户能访问到对应的资源,主要是两点…

AOP通知中获取数据

AOP通知中获取数据 之前我们写AOP仅仅是在原始方法前后追加一些操作,接下来我们要说说AOP中数据相关的内容,我们将从获取参数、获取返回值和获取异常三个方面来研究切入点的相关信息。 获取切入点方法的参数:所有的通知类型都可以获取参数 …

Vulhub开源漏洞靶场用Java远程访问

事件起因,被迫参加某竞赛,中途发现,全员摸鱼,遂一起摸鱼Vulhub是一个面向大众的开源漏洞靶场,无需docker知识,简单执行一条命令即可编译、运行一个完整的漏洞靶场镜像。 Installation 在Ubuntu 20.04下安…

JVM 垃圾回收详解之内存分配和回收原则+死亡对象判断方法

前言 当需要排查各种内存溢出问题、当垃圾收集成为系统达到更高并发的瓶颈时,我们就需要对这些“自动化”的技术实施必要的监控和调节。 堆空间的基本结构 Java 的自动内存管理主要是针对对象内存的回收和对象内存的分配。同时,Java 自动内存管理最核…

【STM32】基础知识 第七课 存储器映射 寄存器映射

【STM32】基础知识 第七课 存储器映射 & 寄存器映射 STM32 寻址范围存储器映射存储器功能划分 (F1 为例)Block 0Block 1Block 2寄存器映射 寄存器映射 (F1 为例)寄存器映射举例寄存器地址计算GPIO 外设基地址及偏移量寄存器地址及偏移量寄存器地址计算过程 使用结构体映射寄…

《2-数组》

数组 1.简介: 数组(Array)是一种固定长度的存储相同数据类型在连续内存空间中的数据结构 引出:[索引 (Index)]----元素在数组中的位置 2.初始化 写法:一般用到无初始值、给定初始值 在不给定…

中国制造业连续13年全球第一,MES管理系统,打造竞争新优势

根据工业和信息化部最近发布的数据,在2022年,中国的制造业增加值在全球的占比接近30%,制造业规模已连续13年位居世界第一。根据国家统计局的最新数字,一到二月份,我国的生产值与去年同期相比上升了2.1&…

实现声明式锁,支持分布式锁自定义锁、SpEL和结合事务

目录 2.实现 2.1 定义注解2.2 定义锁接口2.3 锁的实现 2.3.1 什么是SPI2.3.2 通过SPI实现锁的多个实现类2.3.3 通过SPI自定义实现锁3.定义切面 3.1 切面实现3.2 SpEL表达式获取动态key3.3 锁与事务的结合4.测试 4.1 ReentrantLock测试4.2 RedissonClient测试4.3 自定义锁测试5…

移动硬盘如何分区?教您快速解决!

案例:怎么对移动硬盘进行分区? 【我平常找一个文件需要耗费很长时间,十分麻烦。我现在想通过对移动硬盘进行分区的方式,整理好我的文件,方便使用时查找。有没有人知道移动硬盘怎么分区?教教我!…

深入浅出JS定时器:从setTimeout到setInterval

前言 当谈到 JavaScript 编程语言最基本的概念时,定时器就是一个必须掌握的知识点。在编写网站时,你经常会遇到需要在一定时间间隔内执行一些代码的情况。这时候,JavaScript 定时器就可以派上用场了。 什么是定时器? JS 定时器是…

[Gitops--2]Argocd和Gitlab-runner安装配置

ArgoCd Argo是一组k8s原生工具集,用于运行和管理k8s上的作业和应用程序.Argo提供了一种在k8s上创建工作和应用的三种计算模式:服务模式,工作流模式和基于事件模式.所有的Argo工具都实现为了创建控制器和自定义资源. 为什么选ArgoCD 应用程序的定义,配置和环境都应该是声明性…

ChatGPT和GPT-4帮你写人物传记

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

研读Rust圣经解析——Rust learn-11(测试,迭代器,闭包)

研读Rust圣经解析——Rust learn-11(测试,迭代器,闭包) 测试编写测试模块声明test模块编写测试方法执行测试测试结果检查 闭包定义一个闭包完整写法闭包可以捕获环境闭包类比函数闭包类型推断闭包获取所有权将被捕获的值移出闭包和…