基础论文学习(3)——SwinTransformer

news2025/1/15 23:30:49

目前Transformer应用到图像领域的挑战:

  • 图像分辨率高,像素点多,如果需要更多特征就必须构建很长的序列,但Transformer基于全局自注意力的计算导致计算量较大,能否用窗口+分层的形式代替长序列,实现类似CNN感受野的效果?

针对上述问题,我们提出了一种包含滑窗操作,具有层级设计的Swin Transformer,逐层合并tokens。
在这里插入图片描述

其中滑窗操作包括不重叠的local window + 重叠的cross-window将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量

1. SwinTransformer总体架构

整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块(对image进行卷积,然后对特征图切分为patch),并嵌入到Embedding,构建token序列。
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始进行下采样(W和H不断减小,C不断增大),降低图片分辨率。
  • 而Block具体结构如右图所示,主要是LayerNormMLPWindow AttentionShifted Window Attention组成 (提供了2种attention计算方法)
    在这里插入图片描述
class SwinTransformer(nn.Module):
    def __init__(...):
        super().__init__()
        ...
        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            
        self.pos_drop = nn.Dropout(p=drop_rate)

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(...)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

其中有几个地方处理方法与ViT不同:

  • ViT在输入会给embedding进行位置编码。而Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
  • ViT会单独加上一个可学习参数,作为分类的token。而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层

1.1 Patch Embedding

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。

具体做法是对原始图片(224,224,3)裁成一个个 patch_size * patch_size的窗口大小,然后进行嵌入。

这里可以将stride=4,kernel_size=4设置为patch_size=4大小,按照VIT中patch embedding的方式(不重叠卷积)得到每一个图像块patch对应长度为embed_dim的向量。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。输出(3136, 96)相当于3136个长度为96的token,j将tokens序列排列为正方形即(56*56, 96)

import torch
import torch.nn as nn


class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) # -> (img_size, img_size)
        patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 这里!!
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # 假设采取默认参数
        x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) 
        x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
        x = torch.transpose(x, 1, 2)  # 把通道维放到最后 (N, 56*56, 96)
        if self.norm is not None:
            x = self.norm(x)
        return x

1.2 Window Partition/Reverse

window partition函数是用于对张量按非重叠窗口大学window_size划分为一条条tokens,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。

如输入特征图(56,56,96),默认window_size=7x7,所以分为8x8个窗口,num_windows=64,输出特征图(64, 7, 7, 96),之前的单位是token(共56x56=3136个token),现在的单位是窗口(共8x8=64个window,每个window聚集了7x7=49个token),最后把每个window内的token聚合展平为一个大token,每个大token的shape=(49,96)

window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。
在这里插入图片描述
实现起来,window partition和window reverse没有可学习参数,因而不需要继承其他的类,写成函数就行。上面windows_partition是将送进来的特征进行window_size的划分,最终变为一条条tokens(对应示意图!!!)

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

1.3 W-MSA 和 SW-MSA

两者串联起来就是一个Swin Transformer Block:

  • W-MSA 窗口多头自注意力机制(windows multi-head self attention):窗口内部multi-head self-attention
  • SW-MSA 滑动窗口多头自注意力机制(shift windows multi-head self attention):窗口之间multi-head self-attention
    在这里插入图片描述

W-MSA

传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

输入特征图(64, 7, 7, 96),window size=7(包含7x7个长度96的token),共64个窗口。
在这里插入图片描述

swin transformer是按照window size内的小方格计算self-attention的,比如上图中的windows size=7,也就是每7*7个tokens(红色框)之间计算多头self-attention(head=3)。

3个qkv矩阵放在一起的shape=(3, 64, 3, 49, 32),3个矩阵,64个window,head=3, 窗口大小=7x7=49,每个head特征长度96/3=32,64个窗口自己的attention结果是(64, 3, 49, 49)。

这里注意,计算self-attention的输入tokens的数量和维度都不变换,因此最终的输出特征图依旧是(64, 49, 96),64个窗口,每个窗口7x7个token,每个96维的token都会学习到了窗口内的自注意力。

SW-MSA

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本4个窗口变成了9个窗口。

在这里插入图片描述
在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价
在这里插入图片描述

特征图移位+Mask操作

对特征图位移(torch.roll)之后,还是按照4个窗口计算attention,但是会有冗余计算结果,直接设置对应位置mask为负无穷(softmax后为0),忽略不需要的attetion部分(图中灰色部分),输出的结果同W-MSA 也是(56, 56, 96,不要忘记计算完对特征图还原平移)。
在这里插入图片描述

我们看下Block的前向代码:

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B, H, W, C)

    # cyclic shift
    if self.shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    else:
        shifted_x = x

    # partition windows
    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

    # merge windows
    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

    # reverse cyclic shift
    if self.shift_size > 0:
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = shifted_x
    x = x.view(B, H * W, C)

    # FFN
    x = shortcut + self.drop_path(x)
    x = x + self.drop_path(self.mlp(self.norm2(x)))

    return x

整体流程如下

  • 先对特征图进行LayerNorm
  • 通过self.shift_size决定是否需要对特征图进行shift
  • 然后将特征图切成一个个窗口
  • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  • 将各个窗口合并回来
  • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
  • 做dropout和残差连接
  • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

1.4 Patch Merging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在这里插入图片描述

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。如输入(56, 56, c),变为(28, 28, 4c),全连接输出(28, 28, 2c),这样就使得下一个stage的窗口数量减少了。

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

下面是一个示意图(输入张量B=1, H=W=8, C=1,不包含最后的全连接层调整)
在这里插入图片描述

2. 实验分析

在这里插入图片描述
在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。这篇文章创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量。在Shift Window Attention部分,用一个mask和移位操作,很巧妙的实现计算等价。作者的代码也写得十分赏心悦目,推荐阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/908052.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Php“牵手”淘宝商品详情页数据采集方法,淘宝API接口申请指南

淘宝天猫详情接口 API 是开放平台提供的一种 API 接口,它可以帮助开发者获取商品的详细信息,包括商品的标题、描述、图片等信息。在电商平台的开发中,详情接口API是非常常用的 API,因此本文将详细介绍详情接口 API 的使用。 一、…

C语言,Linux,静态库编写方法,makefile与shell脚本的关系。

静态库编写: 编写.o文件gcc -c(小写) seqlist.c(需要和头文件、main.c文件在同一文件目录下) libs.a->去掉lib与.a剩下的为库的名称‘s’。 -ls是指库名为s。 -L库的路径。 makefile文件编写: CFLAGS-Wall -O2 -g -I ./inc/ LDFLAGS-L./lib/ -l…

华为开源自研AI框架昇思MindSpore应用案例:PFLD实时人脸关键点检测算法

目录 一、环境准备1.进入ModelArts官网2.使用CodeLab体验Notebook实例 二、案例实现 人脸关键点检测是一个非常核心的算法业务,其在许多场景中都有应用。比如我们常用的换脸、换妆、人脸识别等2C APP中的功能,都需要先进行人脸关键点的检测,然…

Python随机密码生成。编写程序,在26个字母大小写和10个数字随机生成10个8位密码。

题目:随机密码生成。编写程序,在26个字母大小写和10个数字随机生成10个8位密码。 样例:类似AB12cdHi的十组8位密码。 代码: import random def passwords():a, b, c ord(a), ord(A), ord(1)r list(range(a , a 26)) list(ra…

Comparable和Comparator区别

Comparable和Comparator接口都是实现集合中元素的比较、排序的,众所周知,诸如Integer,double等基本数据类型,java可以对他们进行比较,而对于类的比较,需要人工定义比较用到的字段比较逻辑。总体来讲&#x…

电脑找不到vcruntime140.dll文件怎么解决?教你解决这个问题

​vcruntime140.dll是Microsoft Visual C 2015 Redistributable Package中的一个文件,它包含了运行C应用程序所需的运行时库。如果在计算机上找不到这个文件,可能会导致一些应用程序无法正常运行。本文将介绍如何修复vcruntime140.dll丢失的问题以及一些…

【计算机网络八股】计算机网络(一)

目录 计算机网络的各层协议及作用?TCP和UDP的区别?UDP 和 TCP 对应的应用场景是什么?详细介绍一下 TCP 的三次握手机制?为什么需要三次握手,而不是两次?为什么要三次握手,而不是四次&#xff1f…

鼠标拖拽盒子移动

目录 需求思路代码页面展示【补充】纯js实现 需求 浮动的盒子添加鼠标拖拽功能 思路 给需要拖动的盒子添加鼠标按下事件鼠标按下后获取鼠标点击位置与盒子边缘的距离给 document 添加鼠标移动事件鼠标移动过程中,将盒子的位置进行重新定位侦听 document 鼠标弹起&a…

线性代数的学习和整理7:各种特殊效果矩阵特例(草稿-----未完成)

目录 1 矩阵 1.1 1维的矩阵 1.2 2维的矩阵 1.3 没有3维的矩阵---3维的是3阶张量 2 方阵 3 单位矩阵 3.1 单位矩阵的定义 3.2 单位矩阵的特性 3.3 为什么单位矩阵I是 [1,0;0,1] 而不是[0,1;1,0] 或[1,1;1,1] 3.4 零矩阵 3.4 看下这个矩阵 [0,1;1,0] 3.5 看下这个矩阵…

很好的启用window10专业版系统自带的远程桌面

启用window10专业版系统自带的远程桌面 文章目录 启用window10专业版系统自带的远程桌面前言1.找到远程桌面的开关2. 找到“应用”项目3. 打开需要远程操作的电脑远程桌面功能 总结 前言 Windows操作系统作为应用最广泛的个人电脑操作系统,在我们身边几乎随处可见。…

微信占内存?教你一招,瞬间释放手机内存

想必大家都有遇到手机内存不足的问题,而微信作为一款功能强大的应用,需要在手机上存储大量的数据以保证其正常运行。 具体来说,微信存储了大量的聊天记录、图片、视频、音频等多媒体文件,还需要存储用户的账号信息、联系人列表、表…

vue3 清空/重置reactive

序: 1、适用场景:表单切换验证如下图。 我举个例子,如果下拉选银行卡,提交表单的时候所属银行是要必填验证,但是如果选支付宝,那所属银行就非必填了,然而很多时候from的rules是以props来传的&a…

Tomcat运行后localhost:8080访问自己编写的网页

主要是注意项目结构,home.html放在src/resources/templates下的home.html下,application.properties可以不做任何配置。还有就是关于web包的位置,作者一开始将web包与tabtab包平行,访问8080出现了此类报错: Whitelabel…

C++11 新特性 ---- 静态断言 static_assert

1.断言 assert 在C11中&#xff0c;您可以使用assert关键字来检查运行时条件是否满足。assert声明了一个断言&#xff0c;它将在运行时检查给定的条件是否成立。如果条件不成立&#xff0c;将输出一个错误消息并可能终止程序。 在程序中包含头文件<cassert> 或 <ass…

axios 各种方式的请求 示例

GET请求 示例一&#xff1a; 服务端代码 GetMapping("/f11") public String f11(Integer pageNum, Integer pageSize) {return pageNum " : " pageSize; }前端代码 <template><div class"home"><button click"getFun1…

ARM--day7(cortex_M4核LED实验流程、异常源、异常处理模式、异常向量表、异常处理流程、软中断编程、cortex_A7核中断实验)

软中断代码&#xff1a;&#xff08;keil软件&#xff09; .text .global _start _start:1.构建异常向量表b resetb undef_interruptb software_interruptb prefetch_dataabortb data_abortb .b irqb fiq reset:2.系统一上电&#xff0c;程序运行在SVC模式1>>初始化SVC模…

pytorch lightning和pytorch版本对应

参见官方文档&#xff1a; https://lightning.ai/docs/pytorch/latest/versioning.html#compatibility-matrix 下图左一列&#xff08;lightning.pytorch&#xff09;安装命令&#xff1a;pip install lightning --use-feature2020-resolver 下图左一列&#xff08;pytorch_lig…

MySQL——基础——外连接

一、外连接查询语法&#xff1a;(实际开发中,左外连接的使用频率要高于右外连接) 左外连接 SELECT 字段列表 FROM 表1 LEFT [OUTER] JOIN 表2 ON 条件...; 相当于查询表1(左表)的所有数据 包含 表1和表2交集部分的数据 右外连接 SELECT 字段列表 FROM 表1 RIGHT [OUTER] JOIN …

在Qt窗口中添加右键菜单

在Qt窗口中添加右键菜单 基于鼠标的事件实现流程demo 基于窗口的菜单策略实现Qt::DefaultContextMenuQt::ActionsContextMenuQt::CustomContextMenu信号API 基于鼠标的事件实现 流程 需要使用:事件处理器函数(回调函数) 在当前窗口类中重写鼠标操作相关的的事件处理器函数&a…

Python支持下最新Noah-MP陆面模式站点、区域模拟及可视化分析技术教程

详情点击公众号链接&#xff1a;Python支持下最新Noah-MP陆面模式站点、区域模拟及可视化分析技术教程 Noah-MP 5.0模型&模型所需环境的搭建 陆面过程的主要研究内容&#xff08;陆表能量平衡、水循环、碳循环等&#xff09;&#xff0c;陆面过程研究的重要性。 图 1 陆面…