Rotary Position Embedding (RoPE, 旋转式位置编码) | 原理讲解+torch代码实现

news2024/12/25 13:40:18
  • 🔥 RoPE为苏剑林大佬之作,最早应用于他自研的RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》

  • 🔥 据我了解,最近发布的大语言模型:Meta的LLaMA、清华的ChatGLM都采用了RoPE。这也足以证明了RoPE的优势。

  • 🔥 本文讲解下个人对RoPE原理的理解以及自己用torch复现了一下,更详细地请参阅苏神的原文(文末已附上链接)。

  • 😄 如对RoPE公式推导有任何疑问,可评论区或私信反馈,我将做出详细解答。

文章目录

  • 1、RoPE 动机
    • 1.1、绝对位置编码
    • 1.2、相对位置编码
    • 1.3、RoPE
  • 2、RoPE 原理
    • 2.1、将待解问题公式化(提出假设)
    • 2.2、推导求解
    • 2.3、RoPE的编码形式
  • 3、RoPE 代码实现(torch版)
  • Reference

1、RoPE 动机

1.1、绝对位置编码

  • 最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。

  • 绝对位置编码的讲解可看我的博客:随记·手撕coding | absolute positional embedding

  • 优势: 实现简单,可预先计算好,不用参与训练,速度快。

  • 劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。

1.2、相对位置编码

  • 经典相对位置编码RPR式的讲解可看我的博客:相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记 【在k, v中注入相对位置信息】
  • 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。

1.3、RoPE

  • RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。
  • 主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了。

2、RoPE 原理

⭐ 那rope是怎么在q,k中注入这种相对位置信息的呢?我看了苏神的推导。大概是这样的:先假设q,k是二维的情形,因为复数可用二维向量表示,所以借助复数域来求解。在推导的过程中,用的最多的一句话就是:“为简单起见,假设xxx” 这对推导十分关键。

  • 有关复数相关基础知识可看这:数学 | 复数的代数、向量、矩阵、极坐标、指数形式 | 复数相乘的物理意义【旋转+缩放】

2.1、将待解问题公式化(提出假设)

首先,假设新的qk向量(即假设已注入绝对位置信息)的内积会引入相对位置信息。并在最后假设合理的初始化条件:
在这里插入图片描述

2.2、推导求解

不是一般性,考虑其q,k向量为二维的情形,借助复数域推导出为q,k向量编码绝对位置信息的函数 f 。
在这里插入图片描述

别看公式多,理解起来并不难。下面我细说一下其中几个关键的推导步骤:

  • 式(8) 的推导:
    在这里插入图片描述

2.3、RoPE的编码形式

上面我们设了q,k的绝对位置编码函数为:
在这里插入图片描述
然后又求出了:在这里插入图片描述
而:
在这里插入图片描述
那带入(4)式就可以得出q,k的绝对位置编码函数了(下面以q为例,k同理)
在这里插入图片描述
为避免这个正交矩阵过于稀疏,浪费算力,代码实现时都是依据下面公式来计算RoPE:
在这里插入图片描述
注:苏神在θ的选择上沿用了tansformer的θi = 10000-2i/d 。因为苏神实验发现,在RoPE中采用这个θ也可以带来一定的远程衰减性(意思就是token之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的θ也可,只要满足远程衰减。

3、RoPE 代码实现(torch版)

  • 代码实现基于torch,代码中也写好详细注释。如有错误,评论区或私信我反馈,谢谢~
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# %%

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复

    # (bs, head, max_len, output_dim)
    # reshape后就是:偶数sin, 奇数cos了
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings


# %%

def RoPE(q, k):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)


    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了



    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k


# %%

def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)

    if use_RoPE:
        q, k = RoPE(q, k)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_scores = att_logits.masked_fill(mask == 0, -1e-9)  # mask掉为0的部分,设为负无穷大

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)


    # (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
    print(res.shape, att_scores.shape)







Reference

  • Transformer升级之路:2、博采众长的旋转式位置编码
  • 《RoFormer: Enhanced Transformer with Rotary Position Embedding》
  • RoPE详细推导版
  • Transformer升级之路:6、旋转位置编码的完备性分析
  • 让研究人员绞尽脑汁的Transformer位置编码
  • Transformer升级之路:4、二维位置的旋转式位置编码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

轻松高效!三种方法教你音频转文字!

我们在日常生活中,总会遇到许多需要音频转文字的情况。这个时候大部分小伙伴会选择一边播放音频一边记录的方式来整理音频的内容,这样既麻烦又费时,整理的效率也不高。其实我们只需要使用软件来协助我们将音频转换成文字,就可以很…

2023年03月六级真题全3套【可复制可划线查词】共11页PDF

2023年03月六级真题全3套【可复制可划线查词】共11页PDF 2023年03月六级真题全3套【可复制可划线查词】共11页PDF 2023年03月六级真题全3套【可复制可划线查词】共11页PDF

html基础知识总结

(一)html 1、html html:超文本标签语言,专门用来制作网页的一门语言。超文本:就是它不仅可以放文本内容,还可以是图片,声音,视频,多媒体等等内容 2、 html标签的分类 …

ASEMI双向可控硅BT137性能特点, BT137应用及购买指南

编辑-Z 本文将详细介绍可控硅BT137的性能特点、应用领域以及购买时需要注意的事项,帮助您更好地了解和选择BT137可控硅。 一、BT137可控硅简介 可控硅(Silicon Controlled Rectifier,简称SCR)是一种四层三端半导体器件&#xff…

sql 优化----》1)分析与定位策略

https://www.cnblogs.com/cshaptx4869/p/10482500.html 1:通过 show status 了解各种的SQL的执行频率 2:定位执行频率低的SQL语句: 1):通过慢日志定位 慢日志:可以通过两个方式配置 方式一:配置文件,my.cnf show_query…

25 # eventloop 执行流程

浏览器事件环 1、浏览器的进程 进程是计算机调度的基本单位,进程中包含着线程,浏览器是多进程进程,大致有下面几种 每一个页卡都是进程(互不影响)浏览器也有一个主进程(用户界面)每一个页卡里…

聊一聊行业的前景、就业方向和薪资待遇

软件测试行业是和软件开发相辅相成得一个行业,但目前大家对于软件测试行业的了解并不多,甚至很多学了软件测试的朋友也不是很了解。今天,就来给大家说一说,软件测试行业的前景、就业方向和薪资待遇。 岗位前景 很多小伙伴都曾听…

【PHPWord】PHPWord 根据word模板生成的内容动态生成目录以及页码

文章目录 一、需求分析二、PHPWord 中模板页码的设置三、模板内生成目录四、总结一、需求分析 在实际业务中,我们可能需要根据一些比较复杂的业务模板,生成对应的Word 文件。 本文将掌握: 使用模板配置页码使用模板插入目录二、PHPWord 中模板页码的设置 1.配置页码 注意…

dex2jar 报错 com.googlecode.d2j.DexException: not support version

​ 目录 ​ 一.问题发现 二.调查原因: 三. 根本原因调查: 四.解决问题 一.问题发现 使用dex2jar工具反编的时候,一输入指令,结果报com.googlecode.d2j.DexException: not support version错误(如下图) 异常情况.png 二.调查…

Autosar之自签名证书与CA证书

文章目录 一、安全传输1.框架2.如何实现传输安全?3. 对称加密和非对称加密的区别?4.伪随机数和真随机数5.数字签名 —— 验证完整性 & 认证数据来源6.为什么使用摘要算法的数字签名可以验证完整性?7.为什么数字签名可以认证数据来源&…

南开大学计算机考研分析

关注我们的微信公众号 姚哥计算机考研 更多详情欢迎咨询 南开大学(B)考研难度(☆☆☆☆☆☆) 南开大学计算机学科的研究工作始于1958年,是在实力雄厚的数学学科和物理学科的基础上发展起来的,是我国最早…

关于Gitee上传代码以后主页没有显示贡献度(没有显示小绿块)

事情起因:在一个闲暇的下午,吃着火锅唱着歌,突然!我发现我的Gitee有一片白 起初,没有人在意这场灾难 当我首次发现这个问题的时候,我毫无波澜的认为是Gitee出现了BUG。因为我的这些空白天数里都是有提交的…

Linux fork—进程控制

程序和进程 程序:是指编译好的二进制文件,在磁盘上,不占用系统资源(cpu、内存、打开的文件、设备、锁…)。进程:是一个抽象的概念,与操作系统原理联系紧密,进程是活跃的程序,占用系统资源&…

【备战秋招】每日一题:4月1日美团春招(二批)第二题:题面+题目思路 + C++/python/js/Go/java带注释

2023大厂笔试模拟练习网站(含题解) www.codefun2000.com 最近我们一直在将收集到的各种大厂笔试的解题思路还原成题目并制作数据,挂载到我们的OJ上,供大家学习交流,体会笔试难度。现已录入200道互联网大厂模拟练习题&…

大数据分析案例-基于Adaboost算法构建糖尿病预测模型

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

华为OD机试之打印机队列(Java源码)

打印机队列 题目描述 有5台打印机打印文件,每台打印机有自己的待打印队列。 因为打印的文件内容有轻重缓急之分,所以队列中的文件有1~10不同的代先级,其中 数字越大优先级越高 打印机会从自己的待打印队列中选择优先级最高的文件来打印。 如…

5月29号软件资讯更新合集......

Paozhu C Admin 管理后台 1.4.0 版本发布 Paozhu C web 框架 1.4.0 版本发布。 提供一个完整的 admin 管理后台,支持图片管理,文件上传,修改百度开源编辑器 ueditor 上传管理程序为 c 框架自带 C ORM 框架,支持 HTTP/1 HTTP/2 …

InsCode AI 创作助手使用方法

CSDN最近推出了InsCode,可实现对话式AI辅助编程,能够帮助我们高效地创作文章,成倍提高生产力!让我们一起来看看如何使用吧! 首先,点击进入【发布】页面 右上角显示【创作助手】,可直接点击进入…

Tcl-10. 字符串比较,匹配,替换,类别,映射,string 相关

一、字符串比较:string compare, string equal 我们在 expr 和控制语句如 if、while 中可用比较运算符””、”!” 、“”、 “”等来进行字符串比较,但是如不注意的话就会产生问题。首先必须用双引号来将字符串值括起来,这样表达式语法分析…

​​​​Linux Shell 实现一键部署Oracle21 rpm包方式

oracle前言 Oracle开发的关系数据库产品因性能卓越而闻名,Oracle数据库产品为财富排行榜上的前1000家公司所采用,许多大型网站也选用了Oracle系统,是世界最好的数据库产品。此外,Oracle公司还开发其他应用程序和软件。同时&#…