【深度学习】深刻理解Swin Transformer

news2024/12/19 2:00:34

Swin Transformer 是一种基于 Transformer 的视觉模型,由 Microsoft 研究团队提出,旨在解决传统 Transformer 模型在计算机视觉任务中的高计算复杂度问题。其全称是 Shifted Window Transformer,通过引入分层架构和滑动窗口机制,Swin Transformer 在性能和效率之间取得了平衡,广泛应用于图像分类、目标检测、分割等视觉任务,称为新一代的backbone,可直接套用在各项下游任务中。在Swin Transformer中,提供大、中、小不同版本模型,可以进行自由选择合适的使用。

论文原文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

1.  介绍

        Transformer 最初在自然语言处理(NLP)领域大获成功,但直接将 Transformer 应用于计算机视觉任务存在很多挑战。传统Transformer中,拿到了图像数据,将图片进行划分成一个个patch,尽可能patch细一些。但是图像中像素点太多了,如果需要更多的特征,就必须构建很长的序列。而越长的序列算起注意力肯定越慢,自注意力机制的计算复杂度是O(n^2),当处理高分辨率图像时,这种复杂度会快速增长,这就导致了效率问题。

        而且图像中许多视觉信息依赖于局部关系,而标准 Transformer 处理的是全局关系,可能无法有效捕获局部特征。Swin Transformer便采用窗口和分层的形式来替代长序列的方法,CNN中经常提到感受野,在Transformer中对应的就是分层。也就是说,我们可以将当前这件事做L次(Lx),每次都会两两进行合并,向量数越来越小(400个token-200个token-100个token),窗口的大小也会增大。分层操作也就是,第一层的时候token很多,第二层合并token,第三层合并token,就像我们的卷积和池化的操作。而在传统的Transformer中,第一层怎么做,第二层第三层也会采用同样的尺寸进行,都是一样的操作。

2. Swin Transformer 整体架构

2.1. Patch Embedding

        在 Swin Transformer 中,Patch Embedding 负责将输入图像分割成多个小块(patches),并将这些小块的像素值嵌入到一个高维空间中,形成适合 Transformer 处理的特征表示。在传统的卷积神经网络(CNN)中,卷积操作可以用来提取局部特征。在 Swin Transformer 中,为了将输入图像转化为适合 Transformer 模型处理的 patch 序列,首先对输入图像进行分块。假设输入图像的大小为 224x224x3,其通过一个卷积操作实现。卷积操作可以将每个局部区域的像素值映射为一个更高维的特征向量。假设输入图像大小为 224x224x3,应用一个卷积层,参数为 Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4)),这表示卷积核的大小是 4x4,步长是 4,输入的通道数是 3(RGB图像),输出的通道数是 96。卷积后,图像的空间维度会变小,输出的特征图的尺寸会变为 56x56(通过计算:(224 - 4) / 4 + 1 = 56)。所以,卷积后的输出大小是 56x56x96,这表示每个空间位置(56x56)都有一个96维的特征向量。

        在 Swin Transformer 中,通常将图像通过卷积操作分割成不重叠的小块(patches)。每个小块对应一个特征向量。例如,56x56x96 的输出可以视为有 3136 个 patch,每个 patch 是一个 96 维的向量。这些特征向量将作为 Transformer 模型的输入序列。根据不同的卷积参数(如 kernel_size 和 stride),你可以控制生成的 patch 的数量和每个 patch 的维度。例如,如果使用更小的卷积核和步长,可以得到更细粒度的 patch,反之则可以得到较大的 patch。

  • kernel_size 决定了每个 patch 的空间大小。
  • stride 决定了每个 patch 之间的间隔,即步长。

2.2. window_partition

        在 Swin Transformer 中,图像的特征表示不仅仅是通过 Patch Embedding 来获得,还通过 窗口划分(Window Partition) 来进一步细化和处理,通过窗口内的局部注意力机制来增强计算效率并捕捉局部特征。

        假设输入的图像经过卷积处理后得到了大小为 56x56x96 的特征图,将这个特征图划分为多个小窗口(window),每个窗口包含一部分局部信息,其中窗口大小为7x7,特征图大小为56x56。为了将特征图划分成大小为 7x7 的窗口,我们首先计算在空间维度(高和宽)上可以分成多少个窗口,水平和垂直方向上,每个 7x7 窗口可以覆盖 56 / 7 = 8 个窗口(总共 8x8 = 64 个窗口),窗口内部的特征图由 96 个通道组成。因此,在划分后,特征图的维度将变为 (64, 7, 7, 96),其中:

  • 64 表示窗口的数量(即 8x8 = 64 个窗口)。
  • 7x7 是每个窗口的空间维度。
  • 96 是每个窗口内的特征通道数。

        在 Swin Transformer 中,Token 通常指的是图像中的局部特征,每个 Token 是图像的一个小区域。在 Window Partition 过程中,我们将整个图像的 Token 重新组织成窗口(Window)。之前每个 Token 对应一个图像位置,现在每个 Token 对应一个窗口的内部特征。所以,原来每个 Token(如卷积后的每个空间位置)代表了图像的一部分信息,现在我们通过窗口划分来捕捉更大范围的局部信息。这种划分有助于模型专注于图像的局部结构,同时减少计算量,因为每个窗口只在局部范围内进行注意力计算。

2.3. W-MSA(Windwow multi-head self attention)

        在 Swin Transformer 中,W-MSA (Window Multi-Head Self Attention) 是关键的注意力机制,它通过在每个窗口内部独立地计算自注意力(Self-Attention)来减少计算复杂度,并捕捉局部特征。

        通过 Window Partition 将特征图划分为 64 个窗口,每个窗口的尺寸为 7x7,并且每个位置的特征通道数为 96,因此每个窗口的形状为 (7, 7, 96),这些窗口将作为 W-MSA 的输入。在 Multi-Head Self-Attention 中,首先需要将输入特征矩阵(窗口内的特征)通过三个不同的矩阵进行线性变换,得到 查询(Q)键(K)值(V),这三个矩阵用于计算注意力得分。对于每个头(Head),计算过程是独立的。假设有 3 个头,那么每个头的输入特征维度为 96 / 3 = 32,因为 96 维的输入被平均分成了 3 个头,每个头负责 32 维的特征。在 W-MSA 中,针对每个窗口独立计算自注意力得分,计算方法如下:

  • 对每个窗口中的 49 个像素点(即每个位置的特征向量)进行查询Q、键K、值V的计算。

  • 自注意力得分(Attention Score) 是通过计算查询与键的点积(或者其他相似度度量)得到的,这可以表示为:

    \text{Attention Score} = \frac{Q \cdot K^T}{\sqrt{d_k}}

    其中,d_k 是每个头的维度(在这里是 32),Q 和 K 的乘积衡量了每个位置之间的相似性。

  • Softmax:通过 Softmax 操作将得分归一化,使其成为概率分布,得到每个位置与其他位置的相关性。

  • 加权值(Weighted Sum):使用得分对值V进行加权求和,得到每个位置的最终输出表示。

        每个头的自注意力计算都会产生一个形状为 (64, 3, 49, 49) 的结果,其中,64 表示窗口的数量,3 表示头的数量,49 是每个窗口中位置的数量(7x7),49 代表每个位置对其他位置的注意力得分(自注意力矩阵)。因此,每个头会计算出每个窗口内所有位置之间的自注意力得分,输出的形状为 (64, 3, 49, 49)

2.4. window_reverse

   Window Reverse 操作的目的是将计算得到的 (64, 49, 96) 特征图恢复回原始的空间维度 (56, 56, 96)。为此,我们需要将每个窗口的 49 个位置(7x7)重新排列到原始的图像空间中。步骤:

  • Reshape 操作: 每个窗口的特征图形状是 (49, 96),我们将其转换成 (7, 7, 96) 的形状,表示每个窗口中的每个像素点都有一个 96 维的特征向量。

  • 按窗口拼接: 将所有 64 个窗口按照它们在特征图中的位置重新排列成 56x56 的大特征图。原始的输入特征图大小是 56x56,这意味着 64 个窗口将按照 8x8 的网格排列,并恢复到一个 (56, 56, 96) 的特征图。

        在 Window Reverse 操作后,恢复得到的特征图形状是 (56, 56, 96),这与卷积后的特征图的形状一致。56x56 是恢复后的空间维度,代表每个像素点在特征图中的位置;96 是每个像素点的特征维度,表示每个位置的特征信息。

2.5. SW-MASA

为什么要滑动窗口(Shifted Window)?

        原始的 Window MSA 将图像划分为固定的窗口(例如 7x7),并在每个窗口内计算自注意力。这样做的一个问题是每个窗口内部的信息相对封闭,没有与相邻窗口之间的信息交流。因此,模型容易局限于各自的小区域,无法充分捕捉不同窗口之间的关联。

        通过引入 滑动窗口Shifted Window)机制,窗口在原来位置的基础上向四个方向移动一部分,重叠区域与原窗口有交集。这样,原本相互独立的窗口就可以共享信息,增强了模型的表达能力和全局感知。

位移操作(Shift Operation)

        位移操作的细节如下:

  • 初始的窗口被划分为 4x4 的块(例如 7x7 窗口),每个块进行独立的自注意力计算。
  • 在进行位移时,原来 4x4 的窗口将被平移,变成新的大小为 9x9 的窗口,窗口重叠区域包含了不同窗口之间的信息。
  • 通过平移,模型能获取到更广泛的信息,使得窗口之间能够通过共享信息来融合彼此的特征,避免局部化。

        Shifted Window MSA 会导致计算量的增加,特别是在窗口滑动后,窗口数量从 4x4 变为 9x9,计算量几乎翻倍。为了控制计算量的增长,可以通过 mask 操作 来减少不必要的计算。在位移后,窗口之间会重叠。为了避免重复计算,我们可以使用 mask 来屏蔽掉不需要计算的部分。在计算注意力时,对于每个位置的 QK 的匹配,使用 softmax 时,设置不需要计算的位置的值为负无穷,这样对应位置的注意力值将接近零,不会对结果产生影响。

        在进行 SW-MSA 后,输出的特征图的形状仍然是 56x56x96,与输入特征图的大小一致。通过 shifted windowmask 操作,模型不仅保留了原始的窗口内的自注意力计算,还增强了窗口之间的信息交换和融合。即使窗口被移动了,经过计算后的特征也需要回到其原本的位置,也就是还原平移,保持图像的完整性。

2.6. PatchMerging

        PatchMergingSwin Transformer 中的一种下采样操作,但是不同于池化,这个相当于间接的(对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4),目的是将输入特征图的空间维度(即高和宽)逐渐减小,同时增加通道数,从而在保持计算效率的同时获得更高层次的特征表示。它是下采样的过程,但与常规的池化操作不同,PatchMerging 通过将相邻的 patch 拼接在一起,并对拼接后的特征进行线性变换,从而实现下采样。具体来说,在 Swin Transformer 中,随着网络层数的加深,输入的特征图会逐渐减小其空间尺寸(即 H 和 W 维度),而同时增加其通道数(即 C 维度),以便模型可以捕捉到更为复杂的高层次信息。

        假设输入的特征图形状为 H x W x CPatchMerging 通过以下步骤来实现下采样和通道数扩展:

  • 分割和拼接(Splitting and Concatenation)

    • 输入的特征图会按照一定的步长(通常是 2)进行分割,即对每个 2x2 的 patch 进行合并。
    • 这样原本的 H x W 的空间尺寸会缩小一半,变成 H/2 x W/2
    • 然后,将每个 2x2 的 patch 内部的特征进行拼接,得到新的特征维度。假设原始通道数为 C,拼接后的通道数为 4C
  • 卷积操作

    • 对拼接后的特征进行 卷积,以进一步增强特征表达。卷积操作用于转换特征空间,虽然通道数增加了,但通过卷积,特征能够更加丰富。

2.7. 分层计算

        在 Swin Transformer 中,模型的每一层都会进行下采样操作,同时逐步增加通道数。每次 PatchMerging 后的特征图会作为输入进入下一层的 Attention 计算。通过这种方式,Swin Transformer 能够逐渐提取到越来越复杂的特征,同时保持计算效率。每一层的 PatchMerging 操作实际上是将输入的特征图通过 线性变换(通常是卷积)合并成更高维度的特征图,从而为后续的注意力计算提供更丰富的表示。

        从图中可以得到,通道数在每层中并不是从C变成4C而是2C,这是因为中间又加了一层卷积操作。

3. 实验结果

        在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。

参考资料:

【深度学习】详解 Swin Transformer (SwinT)-CSDN博客

深度学习之Swin Transformer学习篇(详细 - 附代码)_swintransformer训练-CSDN博客

图解Swin Transformer - 知乎

【论文精读】Swin Transformer - 知乎

ICCV2021最佳论文:Swin Transformer论文解读+源码复现,迪哥带你从零解读霸榜各大CV任务的Swin Transformer模型!_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2261878.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniCloud云开发视频教程-从基础入门到项目开发实战-uniapp进阶课文章管理系统(云函数/云数据库/云存储)

大家好,我是爱搞知识的咸虾米。 今天给大家带来一门uniCloud基础入门到项目开发实战的课程。 视频学习地址:https://www.bilibili.com/video/BV1PP411E7qG/ 开始学习这门课之前,最好先学习一下uniapp零基础入门这套课,相信很多同…

GLB格式转换为STL格式

GLB与STL格式简介 GLB格式 GLB代表“GL传输格式二进制文件”(GL Transmission Format Binary)。GLB主要用于共享3D数据,包含三维模型、场景、光源、材质、节点层次和动画等详细信息,是一种标准化的文件格式,适用于多…

Qt编译MySQL数据库驱动

目录 Qt编译MySQL数据库驱动 测试程序 Qt编译MySQL数据库驱动 (1)先找到MySQL安装路径以及Qt安装路径 C:\Program Files\MySQL\MySQL Server 8.0 D:\qt\5.12.12 (2)在D:\qt\5.12.12\Src\qtbase\src\plugins\sqldrivers\mysql下…

MySQL通过binlog日志进行数据恢复

记录一次阿里云MySQL通过binlog日志进行数据回滚 问题描述由于阿里云远程mysql没有做安全策略 所以服务器被别人远程攻击把数据库给删除,通过查看binlog日志可以看到进行了drop操作,下面将演示通过binlog日志进行数据回滚操作。 1、查询是否开始binlog …

如何在 Ubuntu 22.04 上安装和使用 Rust 编程语言环境

简介 Rust 是一门由 Mozilla 开发的系统编程语言,专注于性能、可靠性和内存安全。它在没有垃圾收集的情况下实现了内存安全,这使其成为构建对性能要求苛刻的应用程序(如操作系统、游戏引擎和嵌入式系统)的理想选择。 接下来&…

前端项目初始化搭建(二)

一、使用 Vite 创建 Vue 3 TypeScript 项目 PS E:\web\cursor-project\web> npm create vitelatest yf-blog -- --template vue-ts> npx > create-vite yf-blog --template vue-tsScaffolding project in E:\web\cursor-project\web\yf-blog...Done. Now run:cd yf-…

生活小妙招之UE CaptureRT改

需求,四个不同的相机拍摄结果同屏分屏显示 一般的想法是四个Capture拍四张RT,然后最后在面片/UI上组合。这样的开销是创建4张RT,材质中采样4次RT。 以更省的角度,想要对以上流程做优化,4个相机拍摄是必须的&#xff…

【AIGC进阶-ChatGPT提示词副业解析】探索生活的小确幸:在平凡中寻找幸福

引言 在这个快节奏的现代社会中,我们常常被各种压力和焦虑所困扰,忘记了生活中那些细小而珍贵的幸福时刻。本文将探讨如何在日常生活中发现和珍惜那些"小确幸",以及如何通过尝试新事物来丰富我们的生活体验。我们还将讨论保持神秘感和期待感对于维持生活乐趣的重要性…

C#编程报错- “ComboBox”是“...ComboBox”和“...ComboBox”之间的不明确的引用

1、问题描述 在学习使用C#中的Winform平台编写一个串口助手程序时, 在编写一个更新ComboBox列表是遇到了问题,出错的代码是 2、报错信息 CS1503 参数 2: 无法从“System.Windows.Forms.ComboBox”转换为“System.Windows.Forms.ComboBox” CS1503 …

ollama+open-webui,本地部署自己的大模型

目录 一、效果预览 二、部署ollama 1.ollama说明 2.安装流程 2.1 windows系统 2.1.1下载安装包 2.1.2验证安装结果 2.1.3设置模型文件保存地址 2.1.4拉取大模型镜像 2.2linux系统 2.2.1下载并安装ollama 2.2.2设置环境变量 2.2.3拉取模型文件 三、部署open-webui…

leetcode_203. 移除链表元素

203. 移除链表元素 - 力扣(LeetCode) 开始写的时候没有想明白的问题 1. 开始我是想头节点 尾节点 中间节点 分开处理 如果删除的是头节点 然后又要删除头节点的后继节点 那么 这样子的话头节点分开处理就毫无意义了 接着是尾节点 开始我定义的是curr h…

【大模型微调学习5】-大模型微调技术LoRA

【大模型微调学习5】-大模型微调技术LoRA LoRa微调1.现有 PEFT 方法的局限与挑战2.LoRA: 小模型有大智慧 (2021)3.AdaLoRA: 自适应权重矩阵的高效微调 (2023)4.QLoRA: 高效微调量化大模型 (2023) LoRa微调 1.现有 PEFT 方法的局限与挑战 Adapter方法,通过增加模型深…

.NET 技术系列 | 通过CreatePipe函数创建管道

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

DS18B20温度传感器(STM32)

一、介绍 DS18B20是一种常见的数字型温度传感器,具备独特的单总线接口方式。其控制命令和数据都是以数字信号的方式输入输出,相比较于模拟温度传感器,具有功能强大、硬件简单、易扩展、抗干扰性强等特点。 传感器参数 测温范围为-55℃到1…

shell编程2 永久环境变量和字符串显位

声明 学习视频来自B站UP主 泷羽sec 常见变量 echo $HOME (家目录 root用户) /root cd /root windows的环境变量可以去设置里去新建 为什么输入ls dir的命令的时候就会输出相应的内容呢 因为这些命令都有相应的变量 which ls 通过这个命令查看ls命令脚本…

MaskGCT——开源文本转语音模型,可模仿任何人说话声音

前期介绍过很多语音合成的模型,比如ChatTTS,微软语音合成大模型,字节跳动自家发布的语音合成模型Seed-TTS。其模型随着技术的不断发展,模型说话的声音也越来越像人类,虽然 seed-tts 可以进行语音合成等功能&#xff0c…

java全栈day16--Web后端实战(数据库)

一、数据库介绍 二、Mysql安装(自行在网上找,教程简单) 安装好了进行Mysql连接 连接语法:winr输入cmd,在命令行中再输入mysql -uroot -p密码 方法二:winr输入cmd,在命令行中再输入mysql -uroo…

geoserver 瓦片地图,tomcat和nginx实现负载均衡

在地理信息系统(GIS)领域,GeoServer作为一个强大的开源服务器,能够发布各种地图服务,包括瓦片地图服务。为了提高服务的可用性和扩展性,结合Tomcat和Nginx实现负载均衡成为了一个有效的解决方案。本文将详细…

达梦8-达梦数据的示例用户和表

1、示例库说明: 创建达梦数据的示例用户和表,导入测试数据。 在完成达梦数据库的安装之后,在/opt/dmdbms/samples/instance_script目录下有用于创建示例用户的SQL文件。samples目录前的路径根据实际安装情况进行修改,本文将达梦…

利用notepad++删除特定关键字所在的行

1、按组合键Ctrl H,查找模式选择 ‘正则表达式’,不选 ‘.匹配新行’ 2、查找目标输入 : ^.*关键字.*\r\n (不保留空行) ^.*关键字.*$ (保留空行)3、替换为:(空) 配置界面参考下图: ​​…