SwinTransformer的相对位置索引的原理以及源码分析

news2024/10/5 1:22:11

文章目录

  • 1. 理论分析
  • 2. 完整代码


引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

    如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)(0,1)=(0,1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

在这里插入图片描述

对应源代码为:

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引

coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

    请注意,这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

    其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。
    
在这里插入图片描述

对应源代码为:

relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

    
接着将所有的行标都乘上2M-1。

在这里插入图片描述

    
对应源代码为:

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

在这里插入图片描述

relative_position_index = relative_coords.sum(-1)

    刚刚上面也说了,之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M1)×(2M1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

在这里插入图片描述

对应源代码为:

'''
relative_position_bias_table:
	其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

'''
index:
	虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),
	但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。
	作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_

window_size = [2,2]
num_heads = 3

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_position_index = relative_coords.sum(-1)

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)

print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)

index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)

relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

print('relative_position_bias:',relative_position_bias.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901025.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

认识并理解webSocket

今天逛牛客,看到有大佬分享说前端面试的时候遇到了关于webSocket的问题,一看自己都没见过这个知识点,赶紧学习一下,在此记录! WebSocket 是一种网络通信协议,提供了全双工通信渠道,即客户端和服…

无法下载cuda

cuda下载不了 一、台式机电脑浏览器打不开cuda下载下面二、解决办法 一、台式机电脑浏览器打不开cuda下载下面 用360、chrome、Edge浏览器都打不开下载页面,有的人说后缀com改成cn,都不行。知乎上说是网络问题,电信换成换成移动/联通的网络会…

2229:Sumsets

网址如下&#xff1a; OpenJudge - 2229:Sumsets 这题不是我想出来的 在这里仅做记录 代码如下&#xff1a; #include<iostream> using namespace std;const int N 1000000000; int dp[1000010]; int n;int main() {cin >> n;dp[0] 1;dp[1] 1;for (int i 2…

Win11系统文件夹预览无法预览PDF文件,PDF阅读器是adobe acrobat

三步走 首先&#xff0c;打开文件夹预览功能 然后&#xff0c;设置adobe acrobat为默认PDF打开应用 最后&#xff0c;打开在Windows资源管理器中启用PDF缩略图&#xff0c;正常设定后&#xff0c;会显示配置文件&#xff0c;稍等一会。

5个实用的文章生成器,高效输出优质文章

在自媒体时代&#xff0c;优质内容的持续输出是吸引读者、提升影响力的关键。然而&#xff0c;对于许多自媒体创作者来说&#xff0c;频繁的创作难免会遭遇灵感枯竭、创作不出文章的困扰。此时&#xff0c;文章生成器便成为了得力的助手。文章生成器的优势能够快速自动生成高质…

7 系列 FPGA 引脚及封装(参考ug475)

目录 I/O BankPins引脚定义I/O and Multi-Function PinsPower Supply PinsDedicated XADC PinsTransceiver PinsDedicated Configuration PinsTemperature Sensor Pins Device 视图整个 FPGAIOBILOGIC,OLOGIC,IDELAY,ODELAYBUFIO,BUFR,IDELAYCTRLBUFMRCEBRAM,DSPIBUFDS_GTE2CLB…

Spring源码十四:Spring生命周期

上一篇我们在Spring源码十三&#xff1a;非懒加载单例Bean中看到了Spring会在refresh方法中去调用我们的finishBeanFactoryInitialization方法去实例化&#xff0c;所有非懒加载器单例的bean。并实例化后的实例放到单例缓存中。到此我们refresh方法已经接近尾声。 Spring的生命…

医疗器械企业CRM系统推荐清单(2024版)

近年来&#xff0c;我国医疗器械行业在国家政策支持、医改深入、人口老龄化和消费能力提升等因素推动下&#xff0c;得到了快速发展&#xff0c;正日益成为创新能力增强、市场需求旺盛的朝阳产业。然而&#xff0c;行业也面临价格压力、市场份额重新分配、合规风险以及产品和服…

【C语言】register 关键字

在C语言中&#xff0c;register关键字用于提示编译器将变量尽量存储在CPU的寄存器中&#xff0c;而不是在内存中。这是为了提高访问速度&#xff0c;因为寄存器的访问速度比内存快得多。使用register关键字的变量通常是频繁使用的局部变量。 基本用法 void example() {regist…

使用ChatGPT写学术论文的技巧和最佳实践指南

大家好&#xff0c;感谢关注。我是七哥&#xff0c;一个在高校里不务正业&#xff0c;折腾学术科研AI实操的学术人。关于使用ChatGPT等AI学术科研的相关问题可以和作者七哥&#xff08;yida985&#xff09;交流&#xff0c;多多交流&#xff0c;相互成就&#xff0c;共同进步&a…

ActiveMq工具之管理页面说明

文章目录 安装ActiveMQ一: 访问管理页面二: 进入管理页面&#xff0c;主页三: Queues页说明四: Topics页说明五: Subscribers页说明 安装ActiveMQ wget https://archive.apache.org/dist//activemq/5.13.3/apache-activemq-5.13.3-bin.tar.gz wget https://mirrors.huaweiclou…

ubuntu 系统中 使用docker 制作 Windows 系统,从此告别 vmware虚拟机

我的系统是 ubuntu 24 前期准备工作&#xff1a; 安装dockerdocker pull 或者 手动制作镜像 docker build 的话 必须要 科学上网&#xff0c; 好像阿里镜像都下不下来。需要 知道 docker 和docker compose 命令的使用方式 我是给docker 挂了 http代理 如果你能pull下来镜像 …

Mysql-常见DML-DQL-语句语法用法总结

1、常见DML语句 1.1 INSERT语句 说明&#xff1a;将数据插入到数据库表中。 INSERT INTO table_name (column1, column2, ...) VALUES (value1, value2, ...); 实例&#xff1a;添加C罗信息到数据库表中 insert into employee (ID, name, gender, entrydate, age) values …

MinIO - 从 环境搭建 -> SpringBoot实战 -> 演示,掌握 Bucket 和 Object 操作

目录 开始 Docker 部署 MinIO 中的基本概念 SpringBoot 集成 MinIO 依赖 配置 MinIO 时间差问题报错 The difference between the request time and the servers time is too large MinIO 中对 Bucket&#xff08;文件夹&#xff09; 的操作 是否存在 / 创建 查询所有…

Android 四大组件

1. Activity 应用程序中&#xff0c;一个Activity通常是一个单独的屏幕&#xff0c;它上面可以显示一些控件&#xff0c;也可以监听并对用户的事件做出响应。 Activity之间通过Intent进行通信&#xff0c;在Intent 的描述结构中&#xff0c;有两个最重要的部分&#xff1a;动…

嵌入式Linux系统编程 — 7.2 进程的环境变量

目录 1 什么是进程的环境变量 2 环境变量的作用 3 应用程序中获取环境变量 3.1 environ全局变量 3.2 获取指定环境变量 getenv 4 添加/删除/修改环境变量 4.1 putenv()函数添加环境变量 4.2 setenv()函数 4.3 unsetenv()函数 1 什么是进程的环境变量 每一个进程都有一…

Node.js 生成vue组件

在项目根目录下创建 create.js /*** 脚本生成vue组件* 主要是利用node自带的fs模块操作文件的写入* ===========================================* 准备步骤:* 1.输入作者名* 2.输入文件名* 3.输入菜单名* 4.输入文件地址* ============================================* 操…

【公益案例展】厦门大学附属成功医院——国产数据库在综合三甲医院核心系统的应用...

‍ 达梦数据公益案例 本项目案例由达梦数据投递并参与数据猿与上海大数据联盟联合推出的 #榜样的力量# 《2024中国数据智能产业最具社会责任感企业》榜单/奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 厦门大学附属成功医院是一所集医疗、教学、科研、保健、疗…

C++基础(八):类和对象 (下)

经过前面的学习&#xff0c;我们已经翻过了两座大山&#xff0c;类和对象入门知识就剩下这一讲了&#xff0c;加油吧&#xff0c;少年&#xff01; 目录 一、再谈构造函数 1.1 构造函数体赋值 1.2 初始化列表&#xff08;理解&#xff09; 1.3 explicit关键字&#xff08;C…

代码随想录算法训练营第13天|二叉树的递归遍历、二叉树的迭代遍历、二叉树的统一迭代法、102.二叉树的层序遍历

打卡Day13 1.理论基础2.二叉树的递归遍历3.二叉树的迭代遍历3.二叉树的统一迭代法4.102.二叉树的层序遍历扩展107. 二叉树的层序遍历 II199.二叉树的右视图637.二叉树的层平均值429.N叉树的层序遍历515.在每个树行中找最大值116.填充每个节点的下一个右侧节点指针117. 填充每个…