相对位置偏置代码解析

news2025/4/21 17:07:46

1. 初始化相对位置偏置嵌入

self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

     假设window_size=7、slef.heads=4,则 2 * window_size - 1 = 13;嵌入层的大小为13*13=169,创建一个大小为169*4的嵌入矩阵。

2. 创建位置索引

pos = torch.arange(window_size)   # tensor([0, 1, 2, 3, 4, 5, 6])

   pos 是一个从 0window_size-1 的一维张量

3. 生成二维网格

grid = torch.meshgrid(pos, pos, indexing='ij')

        torch.meshgrid(pos, pos, indexing='ij') 创建一个二维网格,表示每个位置的(x,y)坐标。pos 是形状为 (N,) 的张量,那么输出将是两个形状为 (N, N) 的张量。第一个张量沿着行变化,第二个张量沿着列变化。

# torch.meshgrid(pos, pos, indexing='ij')返回两个张量,并作为一个元组返回
(tensor([[0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4],
        [5, 5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6, 6]]), tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]]))

# 如果想要输出的是两个单独的张量,则可以使用下面的代码
# grid_x, grid_y = torch.meshgrid(pos, pos, indexing='ij')

# 打印单独的张量
# print(grid_x)
# print(grid_y)

3.1 torch.stack 将上述两个二维张量沿新的维度堆叠 

grid = torch.stack(torch.meshgrid(pos, pos, indexing='ij'))
tensor([[[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]],

        [[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]])

3.2 Rearrange函数 

# 其中 c=2 代表两个网格(x 和 y 坐标),i=7 和 j=7 代表网格的维度
grid = rearrange(grid, 'c i j -> (i j) c')

# (i j) c 表示将原始张量的维度重新排列成 (i j) 和 c。
# 即将网格的每一个点((i, j))展平为一维,并将每个点的 c 个值(这里是两个值)放在新的维度 c 中

        张量 grid 的形状从 (2, window_size, window_size) 变为 (window_size * window_size, 2)指的意思就是将第二个维度i和第三个维度j合并为一个维度。形状为 (49, 2)。

此时输出的维度为2 

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [1, 5],
        [1, 6],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [2, 5],
        [2, 6],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [3, 5],
        [3, 6],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4],
        [4, 5],
        [4, 6],
        [5, 0],
        [5, 1],
        [5, 2],
        [5, 3],
        [5, 4],
        [5, 5],
        [5, 6],
        [6, 0],
        [6, 1],
        [6, 2],
        [6, 3],
        [6, 4],
        [6, 5],
        [6, 6]])

        在重排后的张量中,每个元素代表窗口内一个位置的 (x, y) 坐标对。通过这种方式,可以方便地处理窗口内位置的相对关系。 

3.2.1 将其展平为一维

# 使用reshape方法转换为一维张量
one_dim = grid.reshape(-1)

print(one_dim)
tensor([0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 1, 0, 1, 1, 1, 2, 1, 3, 1, 4,
        1, 5, 1, 6, 2, 0, 2, 1, 2, 2, 2, 3, 2, 4, 2, 5, 2, 6, 3, 0, 3, 1, 3, 2,
        3, 3, 3, 4, 3, 5, 3, 6, 4, 0, 4, 1, 4, 2, 4, 3, 4, 4, 4, 5, 4, 6, 5, 0,
        5, 1, 5, 2, 5, 3, 5, 4, 5, 5, 5, 6, 6, 0, 6, 1, 6, 2, 6, 3, 6, 4, 6, 5,
        6, 6])

4. 计算相对位置

rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rearrange(grid, 'i ... -> i 1 ...') 

        将这个张量的第一个维度保持不变,同时在第二个维度的位置插入一个新的维度,新的维度大小为1。

rearrange(grid, 'i ... -> i 1 ...'):
tensor([
  [[0, 0]], [[0, 1]], [[0, 2]], [[0, 3]], [[0, 4]], [[0, 5]], [[0, 6]],
  [[1, 0]], [[1, 1]], [[1, 2]], [[1, 3]], [[1, 4]], [[1, 5]], [[1, 6]],
  [[2, 0]], [[2, 1]], [[2, 2]], [[2, 3]], [[2, 4]], [[2, 5]], [[2, 6]],
  [[3, 0]], [[3, 1]], [[3, 2]], [[3, 3]], [[3, 4]], [[3, 5]], [[3, 6]],
  [[4, 0]], [[4, 1]], [[4, 2]], [[4, 3]], [[4, 4]], [[4, 5]], [[4, 6]],
  [[5, 0]], [[5, 1]], [[5, 2]], [[5, 3]], [[5, 4]], [[5, 5]], [[5, 6]],
  [[6, 0]], [[6, 1]], [[6, 2]], [[6, 3]], [[6, 4]], [[6, 5]], [[6, 6]]
])

rearrange(grid, 'j ... -> 1 j ...'):
tensor([
  [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6],
   [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6],
   [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6],
   [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], [3, 6],
   [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5], [4, 6],
   [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5], [5, 6],
   [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6]]
])

        grid变形为(49,1,2)和(1,49,2),则相减时, i 维度的 49 会与第二个张量的 49 对齐,第二个维度 11 对齐。... 维度大小为 2,这两个维度自然对齐。

减法操作通过广播机制进行:

        对每个 (i, j) 对应的元素,计算 A[i, 0, :] - B[0, j, :]最终形状: (49, 49, 2)。因为 49(第一个维度)和 49(第二个维度)都被保留下来,并且 2 是原始的最后一个维度。

rel_pos += window_size - 1
# 相减得到相对位置,并加上 window_size - 1 调整索引为正:
tensor([
  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]],

  [[6, 6], [6, 5], [6, 4], [6, 3], [6, 2], [6, 1], [6, 0],
   [5, 6], [5, 5], [5, 4], [5, 3], [5, 2], [5, 1], [5, 0],
   [4, 6], [4, 5], [4, 4], [4, 3], [4, 2], [4, 1], [4, 0],
   [3, 6], [3, 5], [3, 4], [3, 3], [3, 2], [3, 1], [3, 0],
   [2, 6], [2, 5], [2, 4], [2, 3], [2, 2], [2, 1], [2, 0],
   [1, 6], [1, 5], [1, 4], [1, 3], [1, 2], [1, 1], [1, 0],
   [0, 6], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], [0, 0]]
])

5. 计算相对位置索引

rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim=-1)

# rel_pose的形状为torch.Size([49, 49, 2])
# torch.tensor([13, 1]):这是一个形状为 (2,) 的一维张量。它包含两个元素:13 和 1

# 当其相乘时,torch.tensor([13, 1])权重张量会被扩展为形状为(1,1,2)的张量,结果仍然是(49,49,2)的三维张量
# 三维张量中每个 (x, y) 对都被替换为其与权重张量的乘积,这一步将对每个 (13*x, 1*y) 对进行求和

 相对位置乘以 [13, 1] 并求和得到唯一索引:

# 给定的张量表示了一个二维网格中每个位置的线性索引。每个数值在张量中指示了在一维数组中的线性位置
tensor([
  [84, 83, 82, 81, 80, 79, 78, 71, 70, 69, 68, 67, 66, 65],
  [72, 71, 70, 69, 68, 67, 66, 59, 58, 57, 56, 55, 54, 53],
  [60, 59, 58, 57, 56, 55, 54, 47, 46, 45, 44, 43, 42, 41],
  [48, 47, 46, 45, 44, 43, 42, 35, 34, 33, 32, 31, 30, 29],
  [36, 35, 34, 33, 32, 31, 30, 23, 22, 21, 20, 19, 18, 17],
  [24, 23, 22, 21, 20, 19, 18, 11, 10, 9,  8,  7,  6,  5],
  [12, 11, 10,  9,  8,  7,  6, -1, -2, -3, -4, -5, -6, -7]
])

   rel_pos 是一个张量,通常表示元素之间的相对位置关系。

6. 注册缓冲区

self.register_buffer('rel_pos_indices', rel_pos_indices, persistent=False)

        注册缓冲区(或称缓冲区,Buffer)的主要作用是在内存中预留指定大小的存储空间,用于对输入/输出(I/O)的数据进行临时存储。

        rel_pos_indices: 这是一个张量(tensor),它将被注册为缓冲区。这个张量可以是任何需要在前向传播中使用的、但不希望被优化器更新的数据。

    persistent=False: 这是一个可选参数。默认情况下,persistent 是 True。当 persistent=True 时,这个缓冲区会被保存在模型的 state_dict 中,这样当模型被保存和加载时,这个缓冲区也会被保存和恢复。当 persistent=False 时,这个缓冲区不会被保存在 state_dict 中。这意味着如果你保存模型然后加载它,这个缓冲区的数据将不会被恢复。

        为什么有时我们需要一个 persistent=False 的缓冲区呢?因为这个缓冲区的数据可以在模型初始化时重新计算或获取,或者这个缓冲区中的数据不是模型状态的重要部分,因此不需要在模型保存和加载时一起保存。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1954208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenGL入门第六步:材质

目录 结果显示 材质介绍 函数解析 具体代码 结果显示 材质介绍 当描述一个表面时,我们可以分别为三个光照分量定义一个材质颜色(Material Color):环境光照(Ambient Lighting)、漫反射光照(Diffuse Lighting)和镜面光照(Specular Lighting)。通过为每…

springboot整合 knife4j 接口文档

第一步&#xff1a;引入依赖 <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi2-spring-boot-starter</artifactId><version>4.4.0</version></dependency> 第二步&#xff1a;写入配置 方…

国内NAT服务器docker方式搭建rustdesk服务

前言 如果遇到10054,就不要设置id服务器!!! 由于遇到大带宽,但是又贵,所以就NAT的啦,但是只有ipv4共享和一个ipv6,带宽50MB(活动免费会升130MB~) https://bigchick.xyz/aff.php?aff322 月付-5 循环 &#xff1a;CM-CQ-Monthly-5 年付-60循环&#xff1a;CM-CQ-Annually-60官方…

2024后端开发面试题总结

一、前言 上一篇离职贴发布之后仿佛登上了热门&#xff0c;就连曾经阿里的师兄都看到了我的分享&#xff0c;这波流量真是受宠若惊&#xff01; 回到正题&#xff0c;文章火之后&#xff0c;一些同学急切想要让我分享一下面试内容&#xff0c;回忆了几个晚上顺便总结一下&#…

SQL数据库:通过在视频监控平台服务器上直接使用SQL存储过程,在海量记录中查询特定时间段内-某个摄像头的所有视频片段

目录 一、背景 1、存储过程 2、视频监控系统 二、需求和数据表 1、具体要求 2、数据表 3、部分数据 三、实现 1、目标 2、创建存储过程 &#xff08;1&#xff09;存储过程代码 &#xff08;2&#xff09;创建成功 3、存储过程的解释 4、SQL命令调用方式 5、调用…

【FunClip】阿里开源AI视频剪辑神器:全面体验与教程

目录 引言1. FunClip概览1.1 什么是FunClip1.2 FunClip的市场定位1.3 FunClip的创新意义 2. FunClip的功能特性3. FunClip的实际应用案例4. FunClip的使用教程4.1 在线体验FunClip4.2 本地部署Gradio版本4.3 命令行运行 结语参考引用 引言 随着数字媒体的蓬勃发展&#xff0c;…

OpenGL笔记十五之GLM叠加实验

OpenGL笔记十五之GLM叠加实验 —— 2024-07-27 晚上 bilibili赵新政老师的教程看后笔记 code review! 文章目录 OpenGL笔记十五之GLM叠加实验1.每一帧都旋转的三角形2.每一帧在旋转前&#xff0c;都重置为一次单位矩阵&#xff0c;这要只会旋转1度3.每一帧旋转前&#xff0c…

centos stream 9安装 Kubernetes v1.30 集群

1、版本说明&#xff1a; 系统版本&#xff1a;centos stream 9 Kubernetes版本&#xff1a;最新版(v1.30) docker版本&#xff1a;27.1.1 节点主机名ip主节点k8s-master172.31.0.10节点1k8s-node1172.31.0.11节点2k8s-node2172.31.0.12 2、首先&#xff0c;使用Vagrant和Virt…

XSS漏洞:xss.haozi.me靶场1-12 | A-F

目录 0x00&#xff08;无限制&#xff09; 0x01&#xff08;闭合标签绕过&#xff09; 0x02&#xff08;双引号闭合绕过&#xff09; 0x03&#xff08;过滤括号&#xff09; 0x04&#xff08;编码绕过&#xff09; 0x05&#xff08;注释闭合绕过&#xff09; 0x06&#…

【网络爬虫技术】(1·绪论)

&#x1f308; 个人主页&#xff1a;十二月的猫-CSDN博客 &#x1f525; 系列专栏&#xff1a; &#x1f3c0;网络爬虫开发技术入门_十二月的猫的博客-CSDN博客 &#x1f4aa;&#x1f3fb; 十二月的寒冬阻挡不了春天的脚步&#xff0c;十二点的黑夜遮蔽不住黎明的曙光 目录 …

【C语言】两个数组比较详解

目录 C语言中两个数组比较详解1. 逐元素比较1.1 示例代码1.2 输出结果1.3 分析 2. 内置函数的使用2.1 示例代码2.2 输出结果2.3 分析 3. 在嵌入式系统中的应用3.1 示例代码3.2 输出结果3.3 分析 4. 拓展技巧4.1 使用指针优化比较4.2 输出结果4.3 分析 5. 表格总结6. 结论7. 结束…

yolov8 训练模型

一、准备数据 1.1 收集数据 以拳皇为例&#xff0c;将录制的游戏视频进行抽帧。 import cv2 import os# 视频文件路径 video_path 1.mp4# 输出帧的保存目录 output_dir ./output_frames os.makedirs(output_dir, exist_okTrue)# 读取视频 cap cv2.VideoCapture(video_pa…

7-25学习笔记

一、锁对象 Lock接口 1、创建锁对象 ReentrantLock类 Lock locknew ReentrantLock(true); 默认创建的是非公平锁 在创建锁对象时传入一个true参数 便会创建公平锁 先来后到 是重入锁 排他锁 加锁后不允许其它线程进入 2、加锁、解锁 &#xff08;1&#xff09;loc…

服务器搭建总结

服务器搭建好初期要记得开放端口&#xff0c;配置安全组 &#xff0c;主要的有22&#xff0c;80&#xff0c;3389&#xff0c;8888等&#xff0c;宝塔连接的端口在8888上&#xff0c;不开放无法连接 由于时使用的腾讯云服务器&#xff0c;所以在宝塔选择上使用了Windows的面板…

Autodesk Revit v2025 激解锁版下载及安装教程 (三维建模软件)

前言 Revit是欧特克公司知名的三维建模软件&#xff0c;是建筑业BIM体系中使用最广泛的软件之一&#xff0c;其核心功能是三维建筑模型参数化设计、渲染效果图、算量&#xff0c;土建建模、机电建模、用来帮助工程师在施工前精确模拟阶段。 一、下载地址 下载链接&#xff1…

消息摘要算法:MD5加密

&#x1f31f; 主题简介 今天&#xff0c;我们将深入探讨一种经典且广泛应用的加密算法——MD5。通过案例形式了解其原理、实现方法及注意细节。无论你是Python爱好者还是JavaScript高手&#xff0c;这篇内容都将为你揭开MD5的神秘面纱。 &#x1f4da; 内容介绍 MD5&#xf…

大话成像公众号文章阅读学习(一)

系列文章目录 文章目录 系列文章目录前言一、扫射拍摄二、索尼Alpha 9 III2.1. 视频果冻效应2.2 闪光灯同步速度2.3 其他功能 三 A9III 局限性总结 前言 大话成像是一个专注成像的公众号&#xff0c;文章都很好。 今天看的这篇是 特朗普遭枪击后“大片”出自它 文章地址 htt…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《考虑电动汽车动态拥堵的配电网灵活性资源双层优化调度 》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

揭秘企业为何钟情定制红酒:品牌形象与不同的礼品的双重魅力

在商务世界的广阔天地里&#xff0c;红酒不仅仅是一种饮品&#xff0c;更是一种传递情感、展示品味的不同媒介。近年来&#xff0c;越来越多的企业开始钟情于定制红酒&#xff0c;其中洒派红酒&#xff08;Bold & Generous&#xff09;通过其品质和个性化的定制服务&#x…

深入源码:解析SpotBugs (4)如何自定义一个 SpotBugs plugin

自定义一个 spotbugs 的插件&#xff0c;官方有比较详细的说明&#xff1a; https://spotbugs.readthedocs.io/en/stable/implement-plugin.html 本篇是跟随官网demo的足迹&#xff0c;略显无聊&#xff0c;可跳过。 创建工程 执行maven 命令 mvn archetype:generate -Darche…