一文解答Swin Transformer + 代码【详解】

news2024/9/20 6:34:41

文章目录

    • 1、Swin Transformer的介绍
      • 1.1 Swin Transformer解决图像问题的挑战
      • 1.2 Swin Transformer解决图像问题的方法
    • 2、Swin Transformer的具体过程
      • 2.1 Patch Partition 和 Linear Embedding
      • 2.2 W-MSA、SW-MSA
      • 2.3 Swin Transformer代码解析
        • 2.3.1 代码解释
      • 2.4 W-MSA和SW-MSA的softmax计算
        • 2.4.1 代码实现
      • 2.5 patch merging 实现过程
    • 3、不同Swin Transformer版本的区别

1、Swin Transformer的介绍

下面是Swin Transformer论文的Abstract,

在这里插入图片描述

1.1 Swin Transformer解决图像问题的挑战

原论文中的Abstract讲到:Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text。 也就是说在一张图片中,不同的object有不同的尺度,以及与NLP相比,图像处理需要更多的token,下面来看在CNN中是如何解决不同的object有不同的尺度这个问题的,CNN是通过Conv和Pool对特征图中像素的不同感受野来检测不同尺度的物体,如下图所示,

在这里插入图片描述

为什么多尺度检测对于Swin Transformer是有挑战性的,请看下图最左边的Transformer结构,当图片的宽高像素是 h × w \sf{h} \times \sf{w} h×w,所以一共有 h × w \sf{h} \times \sf{w} h×w 个token,输入给Transformer时,输出也是 h × w \sf{h} \times \sf{w} h×w ,所以Transformer并没有做下采样操作,下图中间的是ViT的结构,下图最右边的是Swin Transformer结构,MSA是multi-head self-attention的缩写,

在这里插入图片描述

1.2 Swin Transformer解决图像问题的方法

Swin Transformer如何解决相比NLP中更多的token呢,Swin Transformer提出了hierarchical Transformer,hierarchical Transformer是由shifted windows计算出来的,

下面是MSA,W-MSA,SW-MSA的示意图, 如果对于MSA处理56x56个像素的图片,需要计算 313 6 2 3136^2 31362次的相似度,计算量较大;W-MSA将56x56个像素分割成8x8个windows,每个windows由7x7个像素组成,每个windows单独计算MSA,作用是减少计算量,缺点是windows之间没有信息交互,计算相似度为;SW-MSA可以解决windows之间没有信息交互的问题,例如下图最右边的示例,把一个8x8的图片的分成4个windows,然后分割grid向右下角移动 window // 2,这样分割的windows就可以包含多个windows之间的信息,

在这里插入图片描述

如下图中最左边部分的1,2,3包含多个windows的信息,

在这里插入图片描述

下图黄色虚线框标记的是Swin Transformer处理模块,每次经过Swin Transformer模块,patch Merging使特征层减少一半,通道数增加一倍,

在这里插入图片描述
在这里插入图片描述

与ViT特征层不变不同,每次经过Swin Transformer模块,特征层减少一半,通道数增加一倍,

在这里插入图片描述

2、Swin Transformer的具体过程

2.1 Patch Partition 和 Linear Embedding

如下图,黄色虚线框的作用是下采样,

在这里插入图片描述

将224x224x3的特征图中的4x4x3的像素patch成1个token,然后通过concat操作,就得到了4x4x3=48个特征维度,把这个48个维度的向量看成一个token,所以224x224x3的特征图有56x56个token,每个token的特征维度是48,再经过一个映射把48个维度映射成96个维度,如下图所示,

在这里插入图片描述

Patch Partition 和 Linear Embedding在上面的示意图上是2个步骤,但是在代码中却是一步完成的,如下代码中的self.proj,

在这里插入图片描述

2.2 W-MSA、SW-MSA

如下图,2个连续的swin transformer结构与ViT的结构只有W-MSA、SW-MSA和MSA的区别,其他的MLP、Norm结构都相同,

在这里插入图片描述

下面是MSA和W-MSA的示意图,相较于MSA,W-MSA计算量较小,但是windows之间没有信息交互,

在这里插入图片描述

下面是SW-MSA的示意图,SW-MSA采用新的 window 划分方式,将 windows 向右下角移动 window size // 2,移动之后会发现windows的大小不一样,先把A+B向右移动,然后A+C再向下移动就解决了windows的大小不一样的问题,这样就可以并行地计算SW-MSA了,
需要注意到,移动之后每个window的像素并不是连续的,如何解决这个问题,请看下面的解释,

在这里插入图片描述

如下图,0作为query,其本身和其他window内的数字作为k,示意图中矩形框表示0和移动的像素的乘积,这个乘积再作 − 100 -100 100计算,取softmax之后,值就接近于0了,

在这里插入图片描述

计算完成之后,再把移动之后的像素还原,

在这里插入图片描述

W-MSA、SW-MSA与MSA的不同还在于softmax的计算方式,W-MSA、SW-MSA与MSA的不同有W-MSA、SW-MSA是在softmax里面加入相对位置偏置,MSA是在输入加入位置编码,

在这里插入图片描述

2.3 Swin Transformer代码解析

代码:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # [B, H', W', C]

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
2.3.1 代码解释

代码解释:

在这里插入图片描述
接上面的代码:
在这里插入图片描述

window_partition函数:

在这里插入图片描述

window_reverse函数:

在这里插入图片描述

MLP函数:

在这里插入图片描述

2.4 W-MSA和SW-MSA的softmax计算

在这里插入图片描述

W-MSA和SW-MSA的softmax计算分为2个步骤,
对于window size为2的矩阵,计算其relative position index矩阵,矩阵大小为 w i n d o w s i z e 2 \sf{windowsize}^2 windowsize2 x w i n d o w s i z e 2 \sf{windowsize}^2 windowsize2,relative position index矩阵的size和相似度矩阵的size一致,计算结果不用赘述,很容易看出来,相对位置的矩阵size是 2 M − 1 \sf{2M-1} 2M1 x 2 M − 1 \sf{2M-1} 2M1

在这里插入图片描述

relative position bias matrix是一个可学习的参数,是依照相对位置的所有组合方式而来, relative position bias B 的值来自 relative position bias matrix B ^ \sf{\hat{B}} B^ 的值,

在这里插入图片描述

2.4.1 代码实现

因为2维矩阵需要考虑2个维度的值,不方便计算,所以把2维坐标转为1维坐标,计算过程如下,

在这里插入图片描述

在这里插入图片描述

代码以及代码解释:

下图右边的代码假设window size为7,

在这里插入图片描述

在这里插入图片描述

2.5 patch merging 实现过程

patch merging在swin transformer中的位置如下,

在这里插入图片描述

patch merging的实现过程如下图,原来的像素在每隔1个像素组合,经过concat和LayerNorm以及Linear层之后size减倍,channel数加倍,

在这里插入图片描述

代码:

在这里插入图片描述

3、不同Swin Transformer版本的区别

区别在于stage1的channel数,stage3的Swin Transformer计算的次数,

在这里插入图片描述

参考:
1、原论文:https://arxiv.org/pdf/2103.14030
2、哔哩哔哩:https://www.bilibili.com/video/BV1xF41197E4/?p=6&spm_id_from=pageDriver

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2118918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch入门安装

1、下载安装 &#xff08;1&#xff09;安装Elasticsearch 下载地址&#xff1a;https://www.elastic.co/cn/downloads/elasticsearch 解压后运行 /bin/elasticsearch.bat 运行后访问 http://127.0.0.1:9200/即可 ps1&#xff1a;若无法访问且控制台打印 received plaintex…

计算机基础之-TCP 别再问我啦

TCP 协议 格式及部分含义根据端口号找到上一层的进程如何解包-整个包长度 ACK应答机制序号-确认序号-实现TCP的可靠传输和流量控制:为什么要有两个字段(序号和确认序号)&#xff1f; 16位窗口大小-缓冲区流量控制16位紧急指针三次握手四次挥手为啥 TIME_WAIT 滑动窗口 协议 格式…

代码随想录冲冲冲 Day40 动态规划Part8

121. 买卖股票的最佳时机 dp[i][0] 代表第i天持有股票手上的金额 dp[i][1] 代表第i天不持有股票手上的金额 初始化&#xff1a; dp[0][0] 持有所以是-prices[0] dp[0][1] 不持有所以是0&#xff1b; 递推公式: dp[i][0] 既然是i天时持有&#xff0c;那么就是之前就持有&…

开放式耳机具备什么特点?2024排行前十的四款百元蓝牙耳机推荐

开放式耳机具有以下特点&#xff1a; 佩戴舒适&#xff1a; 开放式耳机通常不需要插入耳道&#xff0c;能减少对耳道的压迫和摩擦&#xff0c;长时间佩戴也不易产生闷热、疼痛或瘙痒等不适&#xff0c;对于耳道敏感或不喜欢入耳式耳机压迫感的人来说是很好的选择。 这类耳机…

无线通信里的一些参数(dB dBm RSRP RSRQ RSSI SIN) / 天线增益

目录 历史由来dB和dBmRSRP RSRQ RSSI SNRRSSI在实际测试环境中的应用天线增益 详细阅读&#xff1a; 一文搞懂dB、dBm、dBw、dBi的来龙去脉 无线通信中 RSRP RSRQ RSSI SINR的定义和区别 RSRP RSRQ RSSI SNR的含义和区别 历史由来 dB展开应写为decibel&#xff0c;其中“deci…

【震撼】8岁女孩用Cursor编程,你还在等什么

1. Cursor: 革新性的AI代码编辑器 1.1 Cursor的崛起 近期&#xff0c;AI驱动的代码编辑器Cursor在开发者社区中引起了广泛关注。其火爆程度不仅源于AI大咖Andrej Karpathy在社交平台X上的推荐&#xff0c;更因一则令人惊叹的新闻&#xff1a;一位年仅8岁的小女孩利用Cursor和其…

金士顿NV2 2TB假固态硬盘抢救记,RL6577/RTS5765DL量产工具,RTS5765DL+B47R扩容开卡修复

之前因为很长时间不买固态硬盘&#xff0c;没注意到NVME的固态盘也有了假货和扩容盘&#xff0c;花200多块买了个2TB的金士顿NV2固态硬盘&#xff0c;我原本以为NV1的假货最多是用黑片冒充正片&#xff0c;结果没想到NV2居然有扩容的。后来发现是扩容盘的时候&#xff0c;已经过…

C++ 音频

一、采样频率 当前主流的采样频率为22.05KHz、44.1KHz、48KHz 22.05KHz&#xff1a;为FM广播声音品质 44.1KHz&#xff1a;为理论上最高的CD声音品质&#xff08;直播&#xff0c;录像&#xff0c;acc&#xff09; 48KHz&#xff1a;人耳可分辨的最高采样频率 &#xff08;…

使用PXE实现自动化安装rockylinux8.10

PXE 一、简介 实现多台服务器自动化安装系统。 二、部署 这里宿主机是 centos7&#xff0c;PXE 部署的是 rockylinux8.10。宿主机需提前关闭 selinux 和防火墙。 2.1 部署 dhcp 安装 dhcp [roottest-server ~]# yum install -y dhcp修改配置文件 # 复制默认的配置文件 …

极光出席深圳国际人工智能展并荣获“最具投资价值人工智能奖”

9月8-10日&#xff0c;由深圳市工业和信息化局、深圳市发展和改革委员会、深圳市科技创新局、深圳市政务服务和数据管理局、深圳市中小企业服务局共同指导&#xff0c;深圳市人工智能行业协会主办的第五届深圳国际人工智能展正式开幕。作为中国领先的客户互动和营销科技服务商&…

基于人工智能的智能农业监控系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 智能农业是利用现代信息技术和人工智能进行农业生产的优化管理&#xff0c;通过实时监控和预测系统&#xff0c;可以改善作物的生产效…

工控安全需求分析与安全保护工程

工控系统安全威胁与需求分析 工控系统&#xff08;ICS&#xff09;&#xff0c;是由各种控制组件、监测组件、数据处理与展示组件共同构成的对工业生产过程进行控制和监控的业务流程管控系统 分类&#xff1a;离散制造类和过程控制类 工控系统安全保护机制与技术 工控系统安全…

无人机加速度计的详解!!!

一、加速度计的基本定义 加速度计是一种用于测量物体加速度的传感器。它能够感知物体在各个方向上的加速度变化&#xff0c;并将这些变化转换为电信号进行输出。 二、加速度计的工作原理 加速度计的工作原理基于牛顿第二定律&#xff0c;即力等于质量乘以加速度&#xff08;…

Pygame中Sprite类实现多帧动画3-2

3.2.3 设置帧的宽度、高度、范围及列数 通过如图6所示的代码设置帧的宽度、高度、范围及列数。 图6 设置帧的宽度、高度、范围及列数的代码 其中&#xff0c;frame_width、frame_height、rect和columns都是MySprite类的属性&#xff0c;在其__init__()方法中定义&#xff0c;…

计算机毕业设计选题推荐-产品委托配送系统-Java/Python项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

【项目综合】基于 Boost 库的站内搜索引擎(保姆式讲解,小白包看包会!)

目录 一、项目背景 1&#xff09;搜索引擎是什么 2&#xff09;Boost 库是什么 3&#xff09;搜索的结果是什么 二、项目原理 1&#xff09;宏观原理和整体流程 2&#xff09;正序索引与倒序索引 3&#xff09;所用技术栈和项目环境 4&#xff09;项目源码地址&#x…

[【人工智能学习笔记】4_3 深度学习基础之卷积神经网络

卷积神经网络概述 卷积神经网络(Convolutional Neural Network, CNN)一种带有卷积结构的深度神经网络,通过特征提取和分类识别完成对输入数据的判别;在1989年提出,早期被成功用于手写字符图像识别;2012年更深层次的AlexNet网络取得成功,伺候卷积神经网络被广泛应用于各…

5G毫米波阵列天线仿真——CDF计算(方法一)

累计分布函数&#xff08;CDF&#xff09;在统计学上是一个由0增长到1的曲线。5G中CDF被3GPP标准推荐使用&#xff0c;5G 天线阵的有效全向辐射功率EIRP的CDF函数被用来评价设备的质量和性能。由于EIRP是在某一个方向角theta, phi上的辐射功率&#xff0c;幅值由天线增益与激励…

微波无源器件1 一种用于紧凑双极化波束形成网络的新型双模定向耦合器

摘要&#xff1a; 在本文中提出了一种用于实现紧凑双极化波束形成网络的新型定向耦合器。此器件的功能为两个用于矩形波导TE01和TE10模式的独立定向耦合器。这两个模式之间并不耦合。可以获得两个模式的不同耦合值。这个耦合器可以两次用于两个正交计划。因此可以获得此完整网络…

Centos7 安装RocketMQ(二进制版)

一、介绍 RocketMQ&#xff1a;云原生“消息、事件、流”实时数据处理平台&#xff0c;覆盖云边端一体化数据处理场景 在阿里孕育 RocketMQ的雏形时期&#xff0c;我们将其用于异步通信、搜索、社交网络活动流、数据管道&#xff0c;贸易流程中。随着我们的贸易业务吞吐量的上…