1. 相对位置嵌入:给注意力机制加“人际关系记忆”
像班级座位表
想象全班同学(序列的各个元素)坐成一个圈,老师(模型)要记住每个人之间的相对位置:
- 传统方法:老师给每个座位贴绝对编号(第1排第2座…),但换教室就全乱了。
- 相对位置法:老师只记“小红在小明左边3个位置”,无论坐哪都适用。
实际怎么用?
- 计算注意力时,给“相邻近的词对”额外加分(比如相邻+1分,隔一个+0.5分)。
- 优势:换个长文章也能用,因为只关心词之间的距离,不关心具体在第几行。
💡 一句话总结:不记绝对位置,只记谁和谁挨得近。
2. 旋转位置编码(RoPE):用“转盘子”游戏理解
像旋转餐桌
假设每个词是一个盘子,盘子上有菜(词的特征):
- 初始状态:所有盘子朝同一方向摆好。
- 旋转操作:第1个盘子转10°,第2个转20°…第N个转N×10°。
- 匹配度检查:两个盘子夹角越小(位置越近),菜的味道越搭配(注意力分数越高)。
为什么好用?
- 自动适应长度:新加盘子继续转就行(第100个转1000°),不用重新设计。
- 长文章神器:ChatGPT这类模型就用这个方法处理超长文本。
💡 一句话总结:每个词转不同角度,越近的词转的角度越相似。
一、相对位置嵌入(Relative Positional Embedding)
1. 核心思想
不直接编码绝对位置,而是通过键-值对之间的相对距离动态计算位置偏置,使模型能灵活处理任意长度序列。即通过可学习的相对位置偏置矩阵**,在注意力计算中显式注入位置信息。公式表示为:
Attention
=
Softmax
(
Q
K
T
d
k
+
B
)
V
\text{Attention} = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}} + B\right)V
Attention=Softmax(dkQKT+B)V
其中
B
B
B 是相对位置偏置矩阵,
B
i
,
j
B_{i,j}
Bi,j 表示位置
i
i
i 和
j
j
j 之间的偏置。
2. 实现方式
以Swin Transformer为例:
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
# 定义可学习的相对位置偏置矩阵
self.bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))
# 生成相对位置索引
coords = torch.arange(window_size)
relative_coords = coords[:, None] - coords[None, :] # [M, M]
relative_coords += window_size - 1 # 转换为非负数
self.register_buffer("relative_index", relative_coords)
def forward(self):
# 根据索引从表中取出偏置值 [M*M, num_heads]
bias = self.bias_table[self.relative_index.view(-1)]
return bias.view(1, self.relative_index.shape[0],
self.relative_index.shape[1], -1) # [1, M, M, num_heads]
3. 关键点
- 窗口限制:通常限制最大相对距离(如Swin中 M = 7 M=7 M=7),减少参数量
- 计算效率:偏置矩阵可预先计算并缓存
- 变长支持:通过插值或截断适应不同长度
二、旋转位置编码(RoPE)
1. 核心思想
通过旋转矩阵将绝对位置信息融入query和key的向量表示中。给定位置 ( m ),旋转操作定义为:
f
RoPE
(
x
,
m
)
=
(
cos
m
θ
−
sin
m
θ
sin
m
θ
cos
m
θ
)
(
x
0
x
1
)
f_{\text{RoPE}}(x, m) = \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} x_0 \\ x_1 \end{pmatrix}
fRoPE(x,m)=(cosmθsinmθ−sinmθcosmθ)(x0x1)
其中
θ
\theta
θ是预设的频率基。
2. 完整实现
以LLaMA的实现为例:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# 预计算旋转角度 [seq_len, dim//2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs) # [seq_len, dim//2]
return torch.polar(torch.ones_like(freqs), freqs) # 复数形式
def apply_rotary_emb(
x: torch.Tensor, # [batch, seq_len, num_heads, head_dim]
freqs_cis: torch.Tensor
) -> torch.Tensor:
# 将x转换为复数形式
x_complex = torch.view_as_complex(
x.float().reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).contiguous()
)
# 应用旋转 [batch, seq_len, num_heads, head_dim//2]
x_rotated = x_complex * freqs_cis
# 转换回实数
return torch.view_as_real(x_rotated).flatten(-2)
3. 数学特性
- 距离衰减:注意力分数随相对距离 ∣ m − n ∣ |m-n| ∣m−n∣自然衰减
- 线性可加性:满足 ⟨ f RoPE ( x , m ) , f RoPE ( x , n ) ⟩ = g ( m − n ) \langle f_{\text{RoPE}}(x,m), f_{\text{RoPE}}(x,n) \rangle = g(m-n) ⟨fRoPE(x,m),fRoPE(x,n)⟩=g(m−n)
- 长程衰减:高频维度(大 θ \theta θ)对远距离更敏感
三、对比总结
特性 | 相对位置嵌入 | 旋转位置编码 |
---|---|---|
位置信息存储方式 | 可学习的偏置矩阵 | 预设的旋转角度 |
计算复杂度 | O ( L 2 ) O(L^2) O(L2) | O ( L ) O(L) O(L) |
长度扩展性 | 需插值或截断 | 天然支持任意长度 |
显式相对位置 | 是 | 通过旋转隐式包含 |
参数量 | 随窗口大小增长 | 零参数(仅计算) |
典型应用 | Swin Transformer, T5 | LLaMA, GPT-J, PaLM |
四、代码实现选择建议
-
图像处理:优先选择相对位置嵌入(适合局部注意力)
# Swin Transformer风格 attn = q @ k.transpose(-2, -1) + relative_bias
-
长文本处理:优先选择RoPE(适合全局注意力)
# LLaMA风格 q_rot = apply_rotary_emb(q, freqs_cis) k_rot = apply_rotary_emb(k, freqs_cis) attn = (q_rot @ k_rot.transpose(-2, -1)) * scaling_factor