Transformer之Swin-Transformer结构解读

news2024/11/15 16:00:32

写在最前面之如何只用nn.Linear实现nn.Conv2d的功能

很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么?
首先,Convolution是什么?

Convolution是一种矩形区域内参数共享的Linear
在这里插入图片描述

这么说可能不好理解,那么我们上代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
    """
    为了简单且便于理解,我们设定图片的Size是Kernel_size的整数倍,且Kernel_size等于Stride
    """
        super(LinearConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride

        # 计算权重矩阵的维度
        weight_size = in_channels * kernel_size * kernel_size
        self.linear = nn.Linear(weight_size, out_channels, bias=False)

    def forward(self, x):
        # 计算输出特征图的尺寸
        B, C, H, W = x.size()
        output_height = H // self.stride
        output_width = W // self.stride

        # 展开输入特征,沿着kernel_size的窗口展开
        x_flatten = x.view(B, H // self.kernel_size, self.kernel_size, W // self.kernel_size, self.kernel_size, C)
        x_flatten = x_flatten.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.kernel_size, self.kernel_size, C)

        # 应用线性变换
        output_flatten = self.linear(x_flatten)

        # 重塑输出形状
        output = output_flatten.view(B, self.out_channels, output_height, output_width)

        return output

# 使用nn.Linear实现nn.Conv2d(256, 256, k=7, s=7)
conv2d_manual = Conv2D(256, 256, 7, 7)

# 创建一个随机初始化的输入张量,确保尺寸是7的整数倍
input_tensor = torch.randn(1, 256, 56, 56)  # 假设输入图像大小为56x56,56是7的倍数

# 应用卷积操作
output = conv2d_manual(input_tensor)
# 输出形状应为[1, 256, 8, 8]
print(output.shape)  

上述代码通过了使用输入数据的维度变换,实现了利用nn.Linear来进行nn.Conv2d的过程,当然,nn.Conv1d甚至nn.Conv3d等也是同样操作。这里我们先记住,后面我们详细解释

Swin-Transformer为什么这么叫

首先,需要理解为什么叫Swin!
作者依然使用了Vision Transformer的主题架构,核心区别是对数据处理的区别!
在Vision Transformer中,数据根据spatial维度进行拉伸,并成为[Batch, HW, C]的样子,如图所示,具体参考Transformer之Vision Transformer结构解读
在这里插入图片描述而在Swin-Transformer中,额外增加了一步,就是把维度为 [ B a t c h , H × W , C ] [Batch, H\times W, C] [Batch,H×W,C]的patch_embedding,进行二次分割,变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],如图所示,

  • 第一张图片就是经过patch_embed的patch_embedding
  • 第二张图片就是经过window_partrition分割后的图片
  • 第三张图片就是处理成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C]的图片
    在这里插入图片描述这里还有一个操作,就是在第偶数个Attention-Block中,把输入的patch_embedding进行torch.roll操作,这个操作就是循环位移
    在这里插入图片描述
    这时候就可以解释为什么说Swin-Transformer就是另一种形式的CNN
    从上面的图片中可以看到如下过程:
  • 一张图片,经过nn.Conv2d(k=patch_size, stride=patch_size),将其分割成 N 2 N^2 N2个patch_embedding
  • patch_embedding经过维度重整,从 [ B , H × W , C ] [B, H\times W, C] [B,H×W,C]变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],然后送入nn.Linear()。这里的维度重整加上nn.Linear(),等于nn.Conv2d,可以通过写在最前面的"如何只用nn.Linear()实现nn.Conv2d的功能"看出
  • 上一步可以总结为:经过nn.Conv2d的patch_embedding继续经过若干nn.Conv2d

Swin-Transformer的位置编码

绝对位置编码

详情参考Transformer之位置编码的通俗理解
在patch_embedding过程中,依然将Token和PE相加,如上图二所示。
但是既然有了相对位置编码,为什么还要加上绝对位置编码呢?

  • 数学解释如下:

Q E + P E × K E + P E T = X E + P E × W q × [ X E + P E × W k ] T = X E + P E × W q × W k T × X E + P E T = ( X q + P E q ) × W q × W k T × ( X k + P E k ) T = X q × W q ⏞ Q u e r y × W k T × X k T ⏞ K e y ⏟ 第一项 + P E q × W q ⏞ a × W k T × X k T ⏞ K e y ⏟ 第二项 + X q × W q ⏞ Q u e r y × W k T × P E k T ⏞ b ⏟ 第三项 + P E q × W q ⏞ a × W k T × P E k T ⏞ b ⏟ 第四项 \begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array} QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项 Xq×Wq Query×WkT×XkT Key+第二项 PEq×Wq a×WkT×XkT Key+第三项 Xq×Wq Query×WkT×PEkT b+第四项 PEq×Wq a×WkT×PEkT b
绝对位置编码只能消去第三项和第四项中的d项,依然需要第二项中的a项,才能具有完整的偏置

  • 直觉解释如下
    如果只有相对位置编码,也就是相当于只有相对位置偏置,这个过程和只有绝对位置偏置的意义是相同的,所以只有同时具有相对位置编码和绝对位置编码,才能避免两者是等效的

相对位置编码

详情参考Transformer之位置编码的通俗理解
相对位置编码,实际上是Attention机制的偏置的位置编码:
A t t = s o f t m a x ( Q × K T D i m + r e l a t i v e _ p o s i t i o n _ b i a s ) × V Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V Att=softmax(Dim Q×KT+relative_position_bias)×V
在这里插入图片描述
这里受到CSDN图片尺寸的限制,只能发这种清晰度的,点击这里下载无损svg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1943138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GaussianPro使用笔记

1. 介绍 GaussianPro: 3D Gaussian Splatting with Progressive Propagation 3D高斯分布(3DGS)最近以其高保真度和效率彻底改变了神经渲染领域。然而,3DGS在很大程度上依赖于运动结构(SfM)技术生成的初始化点云。当处理不可避免地包含无纹理…

语音识别 语音识别项目相关笔记内容

语音识别 语音识别项目相关笔记内容 语音识别应用范畴语音识别框架语音基本操作使用scipy.io.wavfile读取wav音频文件获取采样率、长度、通道数使用numpy读取pcm格式音频文件读取wav音频文件,并绘制图像读取双声道的wav音频文件,分别绘制不同声道的波形图读取一个采样率为16k…

拍立淘API返回值:图像搜索技术的商品信息获取指南

拍立淘API是基于图像搜索技术的商品信息获取工具,广泛应用于阿里巴巴旗下的电商平台如淘宝、天猫等。这一API通过用户上传的商品图片,利用深度学习、计算机视觉等先进技术自动识别图片中的商品信息,并返回与之相关的搜索结果。以下是对拍立淘…

go语言Gin框架的学习路线(八)

目录 GORM Model定义 使用 Model 结构体的自定义数据模型 理解并记忆 GORM 的 Model 结构体可以通过以下几个步骤和技巧: 1. 理解基本概念 2. 熟悉基本字段 3. 记忆技巧 4. 使用场景 结构体标记 支持的结构体标记(Struct tags) 关联…

浏览器渲染揭秘:从加载到显示的全过程;浏览器工作原理与详细流程

目录 浏览器工作原理与流程 一、渲染开始时间点 二、渲染主线程的渲染流程 2.1、渲染流程总览 2.2、渲染具体步骤 ①解析html-Parse HTML ②样式计算-Recalculate Style ③布局-Layout ④分层-Layer 相关拓展 ⑤绘制-Paint ⑥分块-Tiling ⑦光栅化-Raster ⑧画-D…

四川赤橙宏海商务信息咨询有限公司引领抖音电商新趋势

随着互联网的飞速发展,电子商务已成为企业竞争的新高地。在众多电商平台中,抖音以其独特的短视频直播形式,吸引了亿万用户的目光,成为电商领域的新宠。在这样的背景下,四川赤橙宏海商务信息咨询有限公司凭借其专业的电…

新零售电商:订单管理系统设计

传统电商依托于线上流量产生消费者,新零售则是通过网上商城、小程序以及其他应用程序相结合形成网店,同时与线下实体门店和现代物流进行深度整合,最终形成了新的销售模式。 订单管理系统上下游对接系统繁杂,包括:商品中…

【Java语法基础】1、变量、运算符、输入输出

1.变量、运算符、输入输出 跟C一样,先把必须写的框架写出来: package org.example; public class Main{public static void main(String[] args){//在里面写实际的代码} }变量 必须先定义,才能使用。与C、C差不多。 没有赋初值的变量无法…

如何实现Web服务只允许特定客户端访问

如何实现Web服务只允许特定客户端访问 需求来源 为了满足B/S系统给客户演示的需要,需要部署一套系统允许公网能够访问,便于业务人员到客户哪里进行系统演示,但是目前网络安全非常重要,希望能防止暴力破解或者端口扫描等黑客攻击…

全网最实用--神经网络各个组件以及效率指标 (含代码助理解,粘贴即用)

文章目录 一、神经网络相关组件0.前奏1.全连接层(Fully Connected Layer, FC)/密集层(Dense Layer):2.卷积层(Convolutional Layer, Conv):a.一维卷积b.二维卷积c.分组卷积 3.池化层(Pooling La…

汇编语言例题分析

以下数据段定义了如下数据,对应内存图请填空,写出每个内存字节中的2位16进制数(注意写准确,2位16进制数,末尾不带h)。 Data1 segment x db 1,2,3 y db “ABa” z dw 1,2 Data1 ends 物理地址从0000开始&…

C语言指针超详解——最终篇二

C语言指针系列文章目录 入门篇 强化篇 进阶篇 最终篇一 最终篇二 文章目录 C语言指针系列文章目录1. sizeof 与 strlen1.1 字符数组1.2 二维数组 2. 指针运算笔试题解析 以上接指针最终篇一 1. sizeof 与 strlen 1.1 字符数组 代码三&#xff1a; #include<stdio.h>…

『 Linux 』System V共享内存

文章目录 System V IPCSystem V 共享内存的直接原理System V共享内存的创建挂接共享内存取消挂接共享内存释放共享内存利用System V共享内存进行进程间通信共享内存的特性共享内存的属性通过命名管道为共享内存添加同步互斥机制 System V IPC System v 是一种操作系统和相关技术…

不懂这些,面试都不敢说自己熟悉Redis

点赞再看&#xff0c;Java进阶一大半 下面这位就是Redis的创始人&#xff0c;他叫antirez&#xff0c;让我们Java开发者又要多学一门Redis的始作俑者。 我们肯定很难想象Redis创始人竟然学的是是建筑专业&#xff0c;而当年antirez是为了帮网站管理员监控访问者的实时行为才开发…

22集 如何minimax密钥和groupid-《MCU嵌入式AI开发笔记》

22集 如何获取minimax密钥和groupid-《MCU嵌入式AI开发笔记》 minimax密钥获取 https://www.minimaxi.com/platform 进入minimax网站&#xff0c;注册登录后&#xff0c;进入“账户管理”&#xff0c; 然后再点击“接口密钥”&#xff0c;然后再点击“创建新的密钥”。 之…

linux系统安装python3和pip

一、安装python 1、安装依赖环境 yum install gcc -y yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel yum install zlib zlib-devel openssl -y yum install openssl…

什么是信创沙箱?信创沙箱的原理是什么?

在这个数字化高速发展的时代&#xff0c;信息安全问题愈发显得重要。我们每天都在为数据的安全性、隐私性和完整性操心。有时候&#xff0c;感觉就像是一场没有终点的马拉松。而在这场马拉松中&#xff0c;深信达信创沙箱&#xff08;Trusted Computing Sandbox&#xff09;无疑…

谷粒商城实战笔记-52~53-商品服务-API-三级分类-新增-修改

文章目录 一&#xff0c;52-商品服务-API-三级分类-新增-新增效果完成1&#xff0c;点击Append按钮&#xff0c;显示弹窗2&#xff0c;测试完整代码 二&#xff0c;53-商品服务-API-三级分类-修改-修改效果完成1&#xff0c;添加Edit按钮并绑定事件2&#xff0c;修改弹窗确定按…

Windows 11 家庭中文版 安装 VMWare 报 安装程序检测到主机启用了Hyper-V或Device

1、问题 我的操作系统信息如下&#xff1a; 我在安装 VMWare 的时候&#xff0c;报&#xff1a; 因为我之前安装了 docker 桌面版&#xff0c;所以才报这个提示。 安装程序检测到主机启用了 Hyper-v或 Device/credential Guard。要在启用了Hyper-或 Device/Credential Guard …

如何防止热插拔烧坏单片机

大家都知道一般USB接口属于热插拔&#xff0c;实际任意带电进行连接的操作都可以属于热插拔。我们前面讲过芯片烧坏的原理&#xff0c;那么热插拔就是导致芯片烧坏的一个主要原因之一。 在电子产品的整个装配过程、以及产品使用过程经常会面临接口热插拔或者类似热插拔的过程。…