-
🔥 RoPE为苏剑林大佬之作,最早应用于他自研的RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》
-
🔥 据我了解,最近发布的大语言模型:Meta的LLaMA、清华的ChatGLM都采用了RoPE。这也足以证明了RoPE的优势。
-
🔥 本文讲解下个人对RoPE原理的理解以及自己用torch复现了一下,更详细地请参阅苏神的原文(文末已附上链接)。
-
😄 如对RoPE公式推导有任何疑问,可评论区或私信反馈,我将做出详细解答。
文章目录
- 1、RoPE 动机
- 1.1、绝对位置编码
- 1.2、相对位置编码
- 1.3、RoPE
- 2、RoPE 原理
- 2.1、将待解问题公式化(提出假设)
- 2.2、推导求解
- 2.3、RoPE的编码形式
- 3、RoPE 代码实现(torch版)
- Reference
1、RoPE 动机
1.1、绝对位置编码
-
最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。
-
绝对位置编码的讲解可看我的博客:随记·手撕coding | absolute positional embedding
-
优势: 实现简单,可预先计算好,不用参与训练,速度快。
-
劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。
1.2、相对位置编码
- 经典相对位置编码RPR式的讲解可看我的博客:相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记 【在k, v中注入相对位置信息】
- 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。
1.3、RoPE
- RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。
- 主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了。
2、RoPE 原理
⭐ 那rope是怎么在q,k中注入这种相对位置信息的呢?我看了苏神的推导。大概是这样的:先假设q,k是二维的情形,因为复数可用二维向量表示,所以借助复数域来求解。在推导的过程中,用的最多的一句话就是:“为简单起见,假设xxx” 这对推导十分关键。
- 有关复数相关基础知识可看这:数学 | 复数的代数、向量、矩阵、极坐标、指数形式 | 复数相乘的物理意义【旋转+缩放】
2.1、将待解问题公式化(提出假设)
首先,假设新的qk向量(即假设已注入绝对位置信息)的内积会引入相对位置信息。并在最后假设合理的初始化条件:
2.2、推导求解
不是一般性,考虑其q,k向量为二维的情形,借助复数域推导出为q,k向量编码绝对位置信息的函数 f 。
别看公式多,理解起来并不难。下面我细说一下其中几个关键的推导步骤:
- 式(8) 的推导:
2.3、RoPE的编码形式
上面我们设了q,k的绝对位置编码函数为:
然后又求出了:
而:
那带入(4)式就可以得出q,k的绝对位置编码函数了(下面以q为例,k同理)
为避免这个正交矩阵过于稀疏,浪费算力,代码实现时都是依据下面公式来计算RoPE:
注:苏神在θ的选择上沿用了tansformer的θi = 10000-2i/d 。因为苏神实验发现,在RoPE中采用这个θ也可以带来一定的远程衰减性(意思就是token之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的θ也可,只要满足远程衰减。
3、RoPE 代码实现(torch版)
- 代码实现基于torch,代码中也写好详细注释。如有错误,评论区或私信我反馈,谢谢~
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# %%
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
# (max_len, 1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
# (output_dim//2)
ids = torch.arange(0, output_dim // 2, dtype=torch.float) # 即公式里的i, i的范围是 [0,d/2]
theta = torch.pow(10000, -2 * ids / output_dim)
# (max_len, output_dim//2)
embeddings = position * theta # 即公式里的:pos / (10000^(2i/d))
# (max_len, output_dim//2, 2)
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# (bs, head, max_len, output_dim//2, 2)
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # 在bs维度重复,其他维度都是1不重复
# (bs, head, max_len, output_dim)
# reshape后就是:偶数sin, 奇数cos了
embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
embeddings = embeddings.to(device)
return embeddings
# %%
def RoPE(q, k):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
max_len = q.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
# q,k: (bs, head, max_len, output_dim)
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape) # reshape后就是正负交替了
# 更新qw, *对应位置相乘
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
# 更新kw, *对应位置相乘
k = k * cos_pos + k2 * sin_pos
return q, k
# %%
def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
# q.shape: (bs, head, seq_len, dk)
# k.shape: (bs, head, seq_len, dk)
# v.shape: (bs, head, seq_len, dk)
if use_RoPE:
q, k = RoPE(q, k)
d_k = k.size()[-1]
att_logits = torch.matmul(q, k.transpose(-2, -1)) # (bs, head, seq_len, seq_len)
att_logits /= math.sqrt(d_k)
if mask is not None:
att_scores = att_logits.masked_fill(mask == 0, -1e-9) # mask掉为0的部分,设为负无穷大
att_scores = F.softmax(att_logits, dim=-1) # (bs, head, seq_len, seq_len)
if dropout is not None:
att_scores = dropout(att_scores)
# (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
return torch.matmul(att_scores, v), att_scores
if __name__ == '__main__':
# (bs, head, seq_len, dk)
q = torch.randn((8, 12, 10, 32))
k = torch.randn((8, 12, 10, 32))
v = torch.randn((8, 12, 10, 32))
res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)
# (bs, head, seq_len, dk), (bs, head, seq_len, seq_len)
print(res.shape, att_scores.shape)
Reference
- Transformer升级之路:2、博采众长的旋转式位置编码
- 《RoFormer: Enhanced Transformer with Rotary Position Embedding》
- RoPE详细推导版
- Transformer升级之路:6、旋转位置编码的完备性分析
- 让研究人员绞尽脑汁的Transformer位置编码
- Transformer升级之路:4、二维位置的旋转式位置编码