[原理理解] Swin Transformer相对位置编码理解

news2024/12/26 22:42:12

文章目录

  • 简述
  • 相对位置编码的意义
  • 直观理解
    • 注意力
    • 相对位置获取必要性
    • 当前位置初步获取
    • 利用广播机制获取相对位置索引XY
    • 获取最后相对位置1
    • 获取最后相对位置2
    • 最终的相对位置值嵌入

简述

在看Swin Transformer的时候,一开始在相对位置编码这一块的理解上卡壳了挺久,也没有充分理解为什么这么做,在这记录一下自己的一些理解,以防之后忘记。

相对位置编码的意义

GPT : 用来表示一个像素或特征点相对于另一个像素或特征点的位置关系。在处理窗口(window)或局部区域时,计算相对位置索引可以帮助模型更好地捕捉局部结构和上下文信息。

直观理解

注意力

例如现在是2x2的像素窗口,想要计算他们的相对位置关系,那怎么计算?首先需要先理解一下多头自注意力机制是在搞啥子,用语言的理解就是,当前这一句话和自己的关系(简单理解)。例如下面这句话形成的注意力第一行就是“我"和 “我”、“在”、“吃”、“饭” 四个字的注意力关系。
在这里插入图片描述
以图像像素理解:由于像素是二维形式,存在行列关系,因此通常需要裁减成小窗口(windows),计算注意力关系,小窗口又想像一维计算注意力这么方便,那只能把像素进行平铺,以2x2窗口为例,需要平铺成4x1计算他们的注意力,跟上图类似,“我在吃饭”,可以理解为一个2x2窗口。

相对位置获取必要性

由于我们在像素上进行平铺,我们想要在注意力上加上位置信息,为啥要这样做?我的理解是像素平铺的方式,把原本像素的间隔拉大了,可能会加大网络学习的难度,以下图为例,左手左脚本来上下仅间隔一个像素,有强位置约束关系,但经过像素平铺为一维后,间隔变大,可能会比较难找到两者的关系。所以需要对位置进行编码,让网络知道左脚和左手位置相近。
在这里插入图片描述

当前位置初步获取

由于像素本身存在行列关系,因此使用行列进行位置编码是最合适的方式。首先使用torch.arangetorch.meshgridtorch.stack 函数形成 行列坐标,然后将行列坐标使用torch.flatten平铺成一维。这个时候整个坐标尺寸是[2,4] ,其中2代表x、y两个坐标,4代表4个像素,如图所示,按照顺序分别对应(0,0),(0,1),(1,0),(1,1)四个像素坐标。
在这里插入图片描述
代码如下:

import torch
## 一步步理解,以2x2的window size 为例

# 步骤1 得到当前window 下的xy坐标
window_size=(2,2)
coords_h = torch.arange(window_size[0]) # 0-1行
coords_w = torch.arange(window_size[1]) #0-1列
coords = torch.meshgrid([coords_h, coords_w]) #形成两个坐标,分别对应 行、列
coords = torch.stack(coords) ## -> 2*(wh, ww) #将两个坐标堆叠起来,得到某个位置的xy坐标
print("coords shape:",coords.shape) #torch.Size([2, 2, 2]),第一维的2代表x,y坐标
print(coords)
## 步骤2,将纵坐标h,横坐标w,平摊,做成2维张量
coords_flatten = torch.flatten(coords, 1) 
print(coords_flatten.shape)# torch.Size([2, 4]),第一维的2代表x,y坐标
print(coords_flatten) # 2, Wh*Ww ,整个窗口所有的h,w索引

利用广播机制获取相对位置索引XY

现在,想获取当前像素和其他像素的相对位置,应该怎么操作?可以直接利用广播机制,列扩展维度作为当前像素位置,行扩展维度作为其他像素的位置,两者相减得到相对xy坐标。 图1,第一行的每一列都是第一个像素;图2,第一行的4列对应4个像素位置;两者相减得到图3,第一行代表是第一个像素和所以4个像素的相对位置关系。
在这里插入图片描述
代码:

## 步骤3,利用广播机制,得到相对位置,举例图
relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1   # 当前窗口扩展列
relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww   #当前窗口扩展行
relative_coords = relative_coords_first - relative_coords_second # 最终得到 2, wh*ww, wh*ww 

relative_coords = relative_coords.permute(1, 2, 0).contiguous() #为u都变换,变成Wh*Ww, Wh*Ww, 2,相对坐标
print(relative_coords[:,:,0])
print(relative_coords[:,:,1])

获取最后相对位置1

现在,我们想要获取非负数的位置索引,怎么做呢?首先我们需要先知道相对位置最小,最大值是多少?
最大值就是当前像素是第一个像素的时候最后一个像素的位置(windowsize -1 , windowsize -1)
最小值就是当前像素是最后一个像素时候第一个像素的位置(-(windowsize-1) ,-(windowsize-1))
因此,对负数进行偏移需要X、Y 各自加上 windowsize-1
在这里插入图片描述
现在,我们已经获取到非负的xy相对位置索引,需要做最后一个步骤,把两个索引映射成单一的维度的索引。能想到的最简单方式就是x+y,但是这个方式是不行的。如下图所示,如果直接两者相加,那么针对同一个像素,其他像素跟他的相对位置索引就会重复。例如第一个像素 和 (第二个像素、第三个像素)索引位置都是1.
在这里插入图片描述

获取最后相对位置2

那么,需要使用什么计算方法,才能让二维索引映射成单维索引呢?回想起二维数组 reshape成一维数组,其他行的单索引是 y*len + x,其中len是列数目,也就是一行有几个数。在相对位置索引中,一行最多的数目又是多少呢?
原本第一个元素和最后一个元素的索引值相差最多,分别是 windowsize - 1-(windowsize -1)如下图所示。
在这里插入图片描述
也就是每一行的值范围是[(-(windowsize -1)) , (windowsize-1) ] ,在上面步骤中,我们让偏移的最小值变成0,也就是索引值范围是 [0 , 2*windowsize -2],总的有 2*windowsize -1 个数目,所以下一行第一个索引的值是,y*(2*windowsize -1)+ x
想象一下,如果现在有一个3x3 数组,那么第二行第一个元素的索引是不是 1*3 + 0 = 3
在这里插入图片描述
这就是为啥相对位置索引要乘以 2 * window_size[1] - 1。
具体代码如下

relative_coords[:, :, 0] += window_size[0] - 1 # 
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords[:,:,0])
print(relative_coords[:,:,1])
print(relative_coords[:,:,0]+relative_coords[:,:,1])
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

最终的相对位置值嵌入

具体的相对位置值加入注意力并不是直接依靠这个索引,而是创建一个可学习参数的table,利用上面的位置索引到这个table里面去找相应值。第一个元素和最后一个元素位置相差最多,正向距离是 windowsize -1 ,反向距离是 -(windowsize -1) ,在加上本身相对位置是0,所以总的相对位置有(2 * window_size[0] - 1) * (2 * window_size[1] - 1) 个值,而不是window_size[0]*window_size[1] * window_size[0]*window_size[1]个,这也是这个可学习table的维度。
最后根据相对位置索引去找table的值就可以啦

relative_position_bias_table = torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
print(relative_position_bias_table.shape)
print(relative_position_bias_table[relative_position_index.view(-1)].shape) #16x6

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2083469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

27 Combobox组件

Tkinter ttk.Combobox 组件使用指南 ttk.Combobox 是 Tkinter 的一个高级控件,它结合了文本框和下拉列表的功能,允许用户从预定义的选项列表中选择一个值。ttk 模块是 Tkinter 的一个扩展,提供了更现代的控件外观和行为。以下是对 ttk.Combo…

hyperf json-rpc

安装 安装docker hyperf 安装 hyperf-rpc-server-v8 (服务端) docker run --name hyperf-rpc-server-v8 \ -v /www/docker/hyperf-rpc-server:/data/project \ -w /data/project \ -p 9508:9501 -it \ --privileged -u root \ --entrypoint /bin/sh \…

港口行业大数据BI建设方案(24页PPT)

方案简介: 港口行业BI建设方案旨在通过数据整合、分析、可视化及智能化决策支持等手段,提升港口运营效率与管理水平。它的建设实施有利推动港口数字化转型、是提升竞争力的关键举措。通过构建高效、智能的BI系统,港口企业能够实现对运营数据…

软设例题—哈夫曼树

哈夫曼树基本概念: 叶子结点的路径长度:结点到根的分支数量 树的路径长度:所有叶子结点路径长度之和 权:叶子结点的数值 叶子结点的带权路径长度:权重*路径 树的带权路径长度:所有叶子结点带权路径之和…

# Windows 系统安装 virtualbox/vmware 虚拟机教程

Windows 系统安装 virtualbox/vmware虚拟机教程 段子手-168 2024-8-28 一、virtualbox/vmware 简介 1、VirtualBox VirtualBox 是开源的、免费虚拟机软件。VirtualBox 是由德国 Innotek 公司开发,由 Sun Microsystems 公司出品的软件,号称是最强的免…

前端学习笔记-Web APIs篇-01

变量声明 变量声明有三个 var let 和 const 建议: const 优先,尽量使用const, 原因是: const 语义化更好很多变量我们声明的时候就知道他不会被更改了,那为什么不用 const呢?实际开发中也是&#xff0c…

如何使用ssm实现基于ssm的软考系统+vue

TOC ssm321基于ssm的软考系统vue 系统概述 1.1 研究背景 如今互联网高速发展,网络遍布全球,通过互联网发布的消息能快而方便的传播到世界每个角落,并且互联网上能传播的信息也很广,比如文字、图片、声音、视频等。从而&#x…

11 索引

目录 没有索引,可能会有什么问题认识磁盘 1. 没有索引,可能会有什么问题 所以:提高数据库的性能,索引是物美价廉的东西。不用加内存,不用改程序,不用调sql,只要执行正确的create index&#x…

Python 数据分析笔记— Numpy 基本操作

文章目录 学习内容:一、什么是数组、矩阵二、创建与访问数组三、矩阵基本操作 学习内容: 一、什么是数组、矩阵 数组(Array):是有序的元素序列,可以是一维、二维、多维。 array1 [1,2,3] 或[a, b, c, d…

Littorine生物合成糖基转移酶和酰基转移酶-文献精读39

Functional genomics analysis reveals two novel genes required for littorine biosynthesis 功能基因组学分析揭示了两个Littorine生物合成所需的新基因,基因组挖掘很有效果~ 摘要 一些茄科药用植物能够生产药用莨菪烷类生物碱(TAs)&am…

MYSQL:简述对B树和B+树的认识

MySQL的索引使用B树结构。 1、B树 在说B树之前,先说说B树,B树是一个多路平衡查找树,相较于普通的二叉树,不会发生极度不平衡的状况,同时也是多路的。 B树的特点是:他会将数据也保存在非叶子节点。而这个…

样本存储需要注意的事项

在实验室和研究机构中,有一些样本是非常重要且需要特殊保护的,这些样本可能包括珍贵的细胞培养物、生物医学样本、药物试剂等等,为了保证这些样本的质量和完整性,采取一些特殊的措施来进行存储管理非常重要。 一旦这些珍贵样本出…

Undertow 性能、配置

一、性能对比 Tomcat vs Jetty vs Undertow性能对比,详细文章: Tomcat vs Jetty vs Undertow性能对比-腾讯云开发者社区-腾讯云 (tencent.com)https://cloud.tencent.com/developer/article/1699803压测指标的结果: 吞吐量:Undertow > Jetty > Tomcat响应时间&…

World of Warcraft [CLASSIC][80][Shushia] Call to Arms: Alterac Valley

Alterac Valley 奥特兰克山谷 明明能拿7000-9000荣誉,白送的大战场,废材太多,看不下去了,动不动就杀女人,丢墓地,最终拿什么3000荣誉,也不知道脑子装啥。 我们55级的时候就能把联盟打的不要不要…

物料类型 UNBW 和 NLAG

业务示例 公司的广告部门负责采购广告业务并承担相应的费用。这些宣传册不应该存储在广告部门;而应该存储在物料仓库中。并且需要基于数量而不是金额进行库存管理。因此这些物料的物料类型为未评估物料(UNBW)。 物料类型 UNBW 物料类型UNBW表示未评估物料。可以通…

第20讲 动画讲解轻松学会STM32的PWM

来源:【STM32】动画讲解轻松学会STM32的PWM_哔哩哔哩_bilibili 基本概念 周期/频率 计算公式:PWM周期1个高电平用时1个低电平用时 PWM的频率1/周期 如图所示此时周期为1ms,即1s内存在1000组这样的高低电平,PWM的频率为1000hz。…

selenium启动总报错 WebDriverManager总是异常

我的环境用这个自动管理驱动的工具 WebDriverManager 总是报错 尝试过很多方法都没有,只好手动指定浏览器的位置 System.setProperty("webdriver.chrome.driver", "C:\\Users\\27224\\.cache\\selenium\\chromedriver\\win64\\128.0.6613.84\\chrome…

030集—CAD 实现钟表时针动态转动效果——vba代码实现

cad图中显示动图案例如下: 部分代码如下: (按下Esc键可退出) #If VBA7 Then 64位系统声明Declare PtrSafe Sub Sleep Lib "kernel32" (ByVal dwMilliseconds As Long) #Else 32位系统声明Declare Sub Sleep Lib "k…

95.SAP MII功能详解(08)Workbench-Transaction介绍

目录 1.Transaction 2.Properties of transaction 1.Transaction You use transactions to access data from multiple sources and execute processes, which are triggered synchronously or asynchronously.您可以使用事务从多个源访问数据并执行同步或异步触发的流程。…

期权新手交易必看!50ETF期权和沪深300ETF期权分享

今天带你了解期权新手交易必看!50ETF期权和沪深300ETF期权分享。上证 50ETF期权和沪深 300ETF期权是国内ETF期权最早上市的两个品种,也是交易量及活跃度最高的两个品种。 50ETF期权 上证50ETF期权就是在你支付一定额度的权利金后,获得了在未…