LlamaRotaryEmbedding源码解析
- 1. LlamaRotaryEmbedding类 介绍
- 2. 逆频率向量
- 3. LlamaRotaryEmbedding类 源码解析
- 3.1 transformers v4.44.2版
- 3.2 transformers v4.41.1版
1. LlamaRotaryEmbedding类 介绍
在LLaMa模型中,LlamaRotaryEmbedding类实现了Rotary Position Embedding(RoPE)的方法。RoPE的核心思想是对位置编码应用旋转变换,使得不同位置之间的相对位置关系在编码过程中得到保留。这种旋转变换不仅能捕捉到序列的绝对位置,还能捕捉到相对位置,从而更好地处理长距离依赖。以下是旋转变换的过程:
图片来源:RoFormer: Enhanced Transformer with Rotary Position Embedding
2. 逆频率向量
频率向量在旋转位置编码 (Rotary Position Embedding, RoPE) 中是用于表示不同位置的频率信息的向量。这些向量帮助模型在处理序列数据时能够区分不同位置的相对关系。频率向量的计算和应用可以增强模型对位置的感知,从而改进模型的性能。频率向量通常由基数、向量的维度和位置索引生成。例如,频率向量可以表示为:
f
r
e
q
u
e
n
c
y
=
b
a
s
e
2
i
d
frequency=base^{\frac{2i}{d}}
frequency=based2i
则逆频率向量表示为:
i
n
v
_
f
r
e
q
=
1
b
a
s
e
2
i
d
inv\_freq=\frac{1}{base^{\frac{2i}{d}}}
inv_freq=based2i1
其中:
i
是频率向量中的索引(例如第i
个频率分量)。d
是嵌入维度的大小。base
是一个预定义的基数,通常为 10000。
下面举个例子,说明一下逆频率向量的构建过程:
假设维度大小 dim = 8,基数 base = 10000
步骤如下:
- 1.生成频率索引: 对于维度为 8,步长为 2,生成索引 [0, 2, 4, 6]。
- 2.计算比例: 将索引除以维度大小得到比例 [0/8, 2/8, 4/8, 6/8] = [0, 0.25, 0.5, 0.75]。
- 3.计算频率: 使用基数 10000,对比例应用幂运算,生成频率向量:
f r e q u e n c y = [ 1000 0 0 , 1000 0 0.25 , 1000 0 0.5 , 1000 0 0.75 ] frequency=[10000^0,10000^{0.25} ,10000^{0.5} ,10000 ^{0.75} ] frequency=[100000,100000.25,100000.5,100000.75]
计算结果约为:
f r e q u e n c y ≈ [ 1 , 17.78 , 316.23 , 5623.41 ] frequency≈[1,17.78,316.23,5623.41] frequency≈[1,17.78,316.23,5623.41]- 4.取倒数生成逆频率向量。最后,对这些频率值取倒数,生成逆频率向量:
i n v _ f r e q ≈ [ 1.0 , 0.056 , 0.00316 , 0.000178 ] {inv\_freq}≈[1.0,0.056,0.00316,0.000178] inv_freq≈[1.0,0.056,0.00316,0.000178]
参考代码:_compute_default_rope_parameters
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
计算原始 RoPE 实现中的逆频率参数
参数:
config ([`~transformers.PretrainedConfig`]):
模型配置,用于从中获取 RoPE 参数(如 base 和 dim)。
device (`torch.device`):
初始化逆频率参数时使用的设备(如 GPU 或 CPU)。
seq_len (`int`, *optional*):
当前序列长度。对于此类型的 RoPE 实现,该参数未被使用。
rope_kwargs (`Dict`, *optional*):
向后兼容以前的 RoPE 类实例化方式,该参数将在 v4.45 中移除。
返回:
包含 (`torch.Tensor`, `float`) 元组,其中包括 RoPE 嵌入的逆频率和应用于计算出的 cos/sin 的后处理缩放因子
(此类型的 RoPE 中未使用)。
"""
# 如果 config 参数不为 None 且 rope_kwargs 参数有值,抛出异常,二者是互斥的
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
# 使用传入的 rope_kwargs 来初始化 base 和 dim 参数
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
# 使用 config 配置中的参数来初始化 base 和 dim 参数
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
# RoPE 后处理的缩放因子,默认为 1.0(未在此类型 RoPE 中使用)
attention_factor = 1.0 # Unused in this type of RoPE
# 计算逆频率参数
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
# 返回计算出的逆频率和缩放因子
return inv_freq, attention_factor
3. LlamaRotaryEmbedding类 源码解析
算法核心为:
- 计算inv_freq
- 扩展inv_freq和position_ids
- inv_freq_expanded与position_ids_expanded相乘并转置得到freqs
- 拼接freqs
- 计算cos值 和 sin值
3.1 transformers v4.44.2版
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.44.2
import torch
from typing import Optional
from torch import nn
from transformers import LlamaConfig
from transformers.utils import logging
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
logger = logging.get_logger(__name__)
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None, # 嵌入维度
max_position_embeddings=2048, # 最大位置嵌入数
base=10000, # 基数,用于计算逆频率
device=None,
scaling_factor=1.0, # 缩放因子
rope_type="default", # RoPE 类型
config: Optional[LlamaConfig] = None, # 可选的 Llama 配置
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
# 移除下面的 if,此代码仅用于向后兼容(BC)
self.rope_kwargs = {} # 初始化存储 RoPE 参数的字典
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
# 如果没有传入配置对象,使用传入的参数进行初始化
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings # 初始化缓存的最大序列长度
self.original_max_seq_len = max_position_embeddings # 保存原始最大序列长度
else:
# BC: "rope_type" was originally "type"
# 如果传入了配置对象,向后兼容: "rope_type" 原来被称为 "type"
if config.rope_scaling is not None: # 如果配置中有 rope_scaling 参数
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) # 获取 RoPE 类型
else:
self.rope_type = "default" # 如果没有指定,使用默认类型
self.max_seq_len_cached = config.max_position_embeddings # 从配置中读取最大位置嵌入数
self.original_max_seq_len = config.max_position_embeddings # 保存原始最大序列长度
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] # 根据 RoPE 类型选择初始化函数
# 使用初始化函数计算逆频率 (inv_freq) 和注意力缩放 (attention_scaling)
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) # 注册逆频率缓冲区(不会被持久化)
self.original_inv_freq = self.inv_freq # 保存原始逆频率
def _dynamic_frequency_update(self, position_ids, device):
"""
对于动态 RoPE 层,应在以下情况下重新计算 `inv_freq`:
1 - 超过缓存的序列长度(允许缩放)
2 - 当前序列长度在原始尺度内(避免对小序列失去精度)
"""
seq_len = torch.max(position_ids) + 1 # 计算当前序列长度
if seq_len > self.max_seq_len_cached: # 序列长度增长时更新逆频率
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # 重新注册逆频率缓冲区 # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len # 更新缓存的最大序列长度
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # 如果序列长度变小,重置逆频率
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) # 恢复原始逆频率
self.max_seq_len_cached = self.original_max_seq_len # 恢复缓存的最大序列长度
@torch.no_grad() # 禁用梯度计算以提高性能
def forward(self, x, position_ids):
"""
:param x: [bs, num_attention_heads, seq_len, head_size]
:param position_ids: [bs, seq_len]
"""
# 如果是动态 RoPE 类型,更新逆频率
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# 核心 RoPE 计算块
# inv_freq: [dim/2] -> inv_freq_expanded: [batch_size, dim/2, 1]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # 扩展逆频率
# position_ids: [batch_size, seq_len] -> position_ids_expanded: [batch_size, 1, seq_len]
position_ids_expanded = position_ids[:, None, :].float() # 扩展位置 ID
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # 关闭自动混合精度
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # 计算频率。这里的 @ 符号是矩阵乘法的符号,表示对两个张量进行矩阵乘法运算。
# freqs: [batch_size, seq_len, dim/2]
emb = torch.cat((freqs, freqs), dim=-1) # 拼接频率
# emb: [batch_size, seq_len, dim]
cos = emb.cos() # 计算 cos 嵌入
sin = emb.sin() # 计算 sin 嵌入
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
# 高级 RoPE 类型应用后处理缩放因子,等价于缩放注意力
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
# 返回 cos 和 sin 嵌入
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
3.2 transformers v4.41.1版
如果v4.44.2看着有些复杂,可以参考v4.41.1
# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.41.1
import torch
from torch import nn
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor # 缩放因子
self.dim = dim # 嵌入维度
self.max_position_embeddings = max_position_embeddings # 最大位置嵌入
self.base = base # 基数
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # 计算逆频率向量
self.register_buffer("inv_freq", inv_freq, persistent=False) # 将逆频率向量注册到缓存
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings # 最大序列长度缓存
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
# -----------------------核心 RoPE 计算块----------------------- #
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)