AI学习记录 - 旋转位置编码

news2024/9/19 19:04:03

创作不易,有用点赞,写作有利于锻炼一门新的技能,有很大一部分是我自己总结的新视角

1、前置条件:要理解旋转位置编码前,要熟悉自注意力机制,否则很难看得懂,在我的系列文章中有对自注意力机制的画图解释。

先说重要的结论(下面 q向量 和 k向量 是自注意力矩阵诞生的,不懂先去看注意力机制):

结论1:旋转位置编码本身是绝对位置编码,但是和自注意力机制中的一个qk向量结合之后,就变成了相对位置编码。因为自注意力机制中qk会计算点积,正是恰好这个内积,顺带把旋转位置编码变成了相对位置编码,所以一般说旋转位置编码既包含了绝对位置编码含义,也包含了相对位置编码含义。
结论2:假设没有位置编码这个东西,自注意力机制中,qk向量进行内积的时候,经过反向传播,会逐渐得出词汇与词汇的关联度矩阵,假设10个词汇计算内积,当两个词汇关联度越高,这两个词汇的内积(q * k)越大,重点来了:当对q 和 k叠加上旋转位置编码之后,那不仅仅是两个词汇关联度越高,内积越大,并且当两个词汇位置距离越近,内积也越大。
结论3:原来词向量跟词向量的内积大小只跟词汇的语义相关,内积越大,两个词汇的语义关联度越高。叠加上旋转位置编码后,距离相近的词向量内积也大。当一个句子中,两个词汇距离很远但是语义强相关,那他们的内积就是大;当两个词汇语义没啥关联但是距离很近,内积也是大;当两个词汇距离又近,语义有强相关,内积就是大大的。

2、经过上面的结论,其实我们知道了旋转位置编码在哪个位置起到的作用,就是得出 q 和 k 向量之后。

在说旋转位置编码怎么旋转之前,数学界已经就有了怎么对一个向量进行旋转,举个例子

在这里插入图片描述

如果你本身对位置编码不熟悉,在了解旋转位置编码之前,建议先去看我的另一篇博客,有个传统绝对位置编码的解释,旋转位置编码在没和qk叠加之前,其实和绝对位置编码差不多,你会发现他们的公式在某些地方非常的接近。如果这个所谓的旋转位置编码和传统绝对位置编码通过一样的方式叠加到词向量上面,旋转位置编码还是一个绝对位置编码,关键在于叠加方式不一样。当然传统位置编码使用旋转位置编码的叠加方式,也没有产生相对位置含义,所以旋转位置编码的计算公式和他的叠加方式是相互相成的。
传统绝对位置编码公式:

在这里插入图片描述

旋转位置编码公式:

在这里插入图片描述

对应的图片变化:

在这里插入图片描述

3、上面知道如果向量需要旋转,其实需要一个二维向量,但是 q 和 k 都是一维向量,怎么办呢,通过如下叠加,把 q 和 k 向量都按照如下图所示变成二维向量:

在这里插入图片描述

然后把q的每一列当成(x,y)取出来,下图所示,一共有8个(x,y),所有的q向量都进行这样子的计算,计算完成之后,我们就说q叠加上了旋转位置编码。

在这里插入图片描述

然后又转换回来,这个q叠加上了旋转位置编码

在这里插入图片描述

4、我简单提供一个证明,证明在向量在旋转位置编码之后,词汇距离越近,内积就越大,假设两个token的q向量都一样。

假设两个token的初始表示为相同的向量:𝑣=[1,0,1,0]

旋转矩阵为:

下面我们来套用上面说到的公式计算:

当这个向量位置为 1

在这里插入图片描述

当这个向量位置为 3

在这里插入图片描述

在这里插入图片描述

5、最后代码实现,在这里我也是拿某些大佬的,我在这里写了很多print形状,从观察矩阵形状变化去理解比较好

我这里提一下,就是你会发现代码其实有点难以看懂,这是因为涉及到批次计算,多头,导致矩阵代码中做了很多的矩阵变换,但是本质的流程还是我上面所说的,只是在实现过程中,考虑到优化导致的代码难以按照我上面所述的流程看懂,但是本质和上面一样。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# %%

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # batch_size = 8
    # nums_head = 12
    # max_len = 10
    # output_dim = 32
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]
    theta = torch.pow(10000, -2 * ids / output_dim)

    print(position) # [[0.],[1.],[2.],[3.],[4.],[5.],[6.],[7.],[8.],[9.]]
    print(output_dim) # 32
    print(theta) # tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02,
                    # 3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03,
                    # 1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04])

    print(theta.size()) # torch.Size([16])
    print(position.size()) # torch.Size([10, 1])
    embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))
    print(embeddings.size()) # torch.Size([10, 16])
    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
    # For example:
    # torch.sin(embeddings) = tensor([[ 0.0000,  0.8415,  0.9093,  0.1411, -0.7568, -0.9589]])
    # torch.cos(embeddings) = tensor([[ 1.0000,  0.5403, -0.4161, -0.9900, -0.6536,  0.2837]])
    # torch.stack = tensor([[[ 0.0000,  1.0000],
    #                        [ 0.8415,  0.5403],
    #                        [ 0.9093, -0.4161],
    #                        [ 0.1411, -0.9900],
    #                        [-0.7568, -0.6536],
    #                        [-0.9589,  0.2837]]])
    print(embeddings.size()) # torch.Size([10, 16, 2])
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复
    print(embeddings.size()) # torch.Size([8, 12, 10, 16, 2])
    # reshape后就是:偶数sin, 奇数cos了
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))



    print(embeddings.size()) # torch.Size([8, 12, 10, 32])
    embeddings = embeddings.to(device)
    return embeddings


# %%

def RoPE(q, k):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0] # batch_size = 8
    nums_head = q.shape[1] # nums_head = 12
    max_len = q.shape[2] # max_len = 10
    output_dim = q.shape[3] # output_dim = 32
    

    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
    print(pos_emb.size()) # torch.Size([8, 12, 10, 32])




    # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制



    print(cos_pos.size()) # torch.Size([8, 12, 10, 32])
    print(sin_pos.size()) # torch.Size([8, 12, 10, 32])



    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    print(q2.size()) # torch.Size([8, 12, 10, 16, 2])
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了
    print(q2.size()) # torch.Size([8, 12, 10, 32])
    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos
    print(q.size()) # torch.Size([8, 12, 10, 32])




    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k


# %%

def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)

    if use_RoPE:
        q, k = RoPE(q, k)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)  # mask掉为0的部分,设为无穷大

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)


    # (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
    print(res.shape, att_scores.shape)





本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2115273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win32函数调用约定(Calling Convention)

平常我们在C#中使用DllImportAttribute引入函数时,不指明函数调用约定(CallingConvention)这个参数,也可以正常调用。如FindWindow函数 [DllImport("user32.dll", EntryPoint"FindWindow", SetLastError true)] public static ext…

来啦| LVMH路威酩轩25届校招智鼎高潜人才思维能力测验高分攻略

路威酩轩香水化妆品(上海)有限公司是LVMH集团于2000年成立,负责集团旗下的部分香水化妆品品牌在中国的销售包括迪奥、娇兰、纪梵希、贝玲妃、玫珂菲、凯卓、帕尔马之水以及馥蕾诗等。作为目前全球最大的奢侈品集团LVMH 集团秉承悠久的历史,不断打破常规&…

群晖最新版(DSM 7.2) 下使用 Web Station 部署 flask 项目

0. 需求由来 为了在 DSM 7.2 版本下的群晖 NAS 里运行我基于 flask 3.0.2 编写的网页应用程序,我上网查了非常多资料,也踩了很多坑。最主要的就是 7.2 版本的界面与旧版略有不同,而网络上的资料大多基于旧版界面,且大部分仅仅说明…

记忆化搜索【下】

375. 猜数字大小II 题目分析 题目链接:375. 猜数字大小 II - 力扣(LeetCode) 题目比较长,大致意思就是给一个数,比如说10,定的数字是7,让我们在[1, 10]这个区间猜。 如果猜大或猜小都会说明…

2024AI绘画工具排行榜:探索最受欢迎的AI绘图软件特点与选择指南

AI绘画工具各有优势,从开放性到对特定语言和文化的支持,以及对图像细节和艺术性的不同关注点,根据具体需求选择合适的工具 MidJourney 图片品质卓越,充满独特创意,初期能够免费获取数十账高质量图片,整个生…

C++20中支持的非类型模板参数

C20中支持将类类型作为非类型模板参数:作为模板参数传入的对象具有const T类型,其中T是对象的类型,并且具有静态存储持续时间(static storage duration)。 在C20之前,非类型模板参数仅限于:左值引用类型、整数类型、指…

VMware Fusion Pro 13 Mac版虚拟机 安装Win11系统教程

Mac分享吧 文章目录 Win11安装完成,软件打开效果一、VMware安装Windows11虚拟机1️⃣:准备镜像2️⃣:创建虚拟机3️⃣:虚拟机设置4️⃣:安装虚拟机5️⃣:解决连不上网问题 安装完成!&#xff0…

用Pytho解决分类问题_DBSCAN聚类算法模板

一:DBSCAN聚类算法的介绍 DBSCAN(Density-Based Spatial Clustering of Applications with Noise)是一种基于密度的聚类算法,DBSCAN算法的核心思想是将具有足够高密度的区域划分为簇,并能够在具有噪声的空间数据库中发…

关于SpringMVC的理解

1、SpringMVC 应用 1.1、简介 1.1.1、MVC 体系结构 三层架构: 我们的开发架构⼀般都是基于两种形式,⼀种是 C/S 架构,也就是客户端/服务器;另⼀种是 B/S 架构,也就是浏览器服务器。在 JavaEE 开发中,⼏乎…

陪护系统|陪护系统源码|护理陪护小程序

随着医疗水平的不断提高,人们对护理服务的需求也越来越高。为了更好地满足患者和家属的需求,陪护系统定制开发应运而生。 陪护系统定制开发是根据医疗机构的实际需求,设计并开发一套专门用于陪护服务的系统。该系统拥有一系列丰富的功能&…

基于人工智能的图片生成系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图片生成是计算机视觉领域的一个重要任务,基于生成对抗网络(GAN)的图片生成系统能够从噪声中生成逼…

大数据-119 - Flink Window总览 窗口机制-滚动时间窗口-基于时间驱动基于事件驱动

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

揭秘 AMD GPU 上 PyTorch Profiler 的性能洞察

Unveiling performance insights with PyTorch Profiler on an AMD GPU — ROCm Blogs 2024年5月29日,作者:Phillip Dang。 在机器学习领域,优化性能通常和改进模型架构一样重要。在本文中,我们将深入探讨 PyTorch Profiler&#…

小白建立个人网站初步尝试

一、VScode 代码是在VScode上运行的&#xff0c;可以看作者另一篇文章&#xff1a;http://t.csdnimg.cn/mOmdF 二、代码基本框架 代码解释<!DOCTYPE html>声明为HTML5文档<html><head>头部元素&#xff0c;不显示在页面<meta charset"utf-8"…

数学建模强化宝典(14)Fisher 最优分割法

前言 Fisher最优分割法是一种对有序样品进行聚类的方法&#xff0c;它在分类过程中不允许打破样品的顺序。这种方法的目标是找到一种分割方式&#xff0c;使得各段内样品之间的差异最小&#xff0c;而各段之间的差异最大。以下是关于Fisher最优分割法的详细介绍&#xff1a; 一…

【LeetCode热题100】前缀和

这篇博客共记录了8道前缀和算法相关的题目&#xff0c;分别是&#xff1a;【模版】前缀和、【模版】二维前缀和、寻找数组的中心下标、除自身以外数组的乘积、和为K的子数组、和可被K整除的子数组、连续数组、矩阵区域和。 #include <iostream> #include <vector> …

【408数据结构】散列 (哈希)知识点集合复习考点题目

苏泽 “弃工从研”的路上很孤独&#xff0c;于是我记下了些许笔记相伴&#xff0c;希望能够帮助到大家 知识点 1. 散列查找 散列查找是一种高效的查找方法&#xff0c;它通过散列函数将关键字映射到数组的一个位置&#xff0c;从而实现快速查找。这种方法的时间复杂度平均为…

自我指导:提升语言模型自我生成指令的能力

人工智能咨询培训老师叶梓 转载标明出处 传统的语言模型&#xff0c;尤其是经过指令微调的大型模型&#xff0c;虽然在零样本&#xff08;zero-shot&#xff09;任务泛化上表现出色&#xff0c;但它们高度依赖于人类编写的指令数据。这些数据往往数量有限、多样性不足&#xf…

配置Java(JDK)环境变量

一、配置JDK环境变量 将JDK-22压缩包加压缩到指定目录下面&#xff0c;本机路径是&#xff1a;C:\Program Files\Java&#xff08;可以加压缩到自己的指定路径&#xff0c;记住这个路径&#xff0c;配置环境变量时候要使用&#xff09;。 鼠标右键“此电脑”&#xff0c;点击“…

独立按键单击检测(延时消抖+定时器扫描)

目录 独立按键简介 按键抖动 模块接线 延时消抖 Key.h Key.c 定时器扫描按键代码 Key.h Key.c main.c 思考 MultiButton按键驱动 独立按键简介 ​ 轻触按键相当于一种电子开关&#xff0c;按下时开关接通&#xff0c;松开时开关断开&#xff0c;实现原理是通过轻触…