【HuggingFace Transformers】LlamaRotaryEmbedding源码解析

news2024/12/23 5:18:37

LlamaRotaryEmbedding源码解析

  • 1. LlamaRotaryEmbedding类 介绍
  • 2. 逆频率向量
  • 3. LlamaRotaryEmbedding类 源码解析
    • 3.1 transformers v4.44.2版
    • 3.2 transformers v4.41.1版

1. LlamaRotaryEmbedding类 介绍

LLaMa模型中,LlamaRotaryEmbedding类实现了Rotary Position Embedding(RoPE)的方法。RoPE的核心思想是对位置编码应用旋转变换,使得不同位置之间的相对位置关系在编码过程中得到保留。这种旋转变换不仅能捕捉到序列的绝对位置,还能捕捉到相对位置,从而更好地处理长距离依赖。以下是旋转变换的过程:
在这里插入图片描述
图片来源:RoFormer: Enhanced Transformer with Rotary Position Embedding

2. 逆频率向量

频率向量在旋转位置编码 (Rotary Position Embedding, RoPE) 中是用于表示不同位置的频率信息的向量。这些向量帮助模型在处理序列数据时能够区分不同位置的相对关系。频率向量的计算和应用可以增强模型对位置的感知,从而改进模型的性能。频率向量通常由基数、向量的维度和位置索引生成。例如,频率向量可以表示为:
f r e q u e n c y = b a s e 2 i d frequency=base^{\frac{2i}{d}} frequency=based2i
则逆频率向量表示为:
i n v _ f r e q = 1 b a s e 2 i d inv\_freq=\frac{1}{base^{\frac{2i}{d}}} inv_freq=based2i1
其中:

  • i 是频率向量中的索引(例如第 i 个频率分量)。
  • d 是嵌入维度的大小。
  • base 是一个预定义的基数,通常为 10000

下面举个例子,说明一下逆频率向量的构建过程:

假设维度大小 dim = 8,基数 base = 10000
步骤如下:

  • 1.生成频率索引: 对于维度为 8,步长为 2,生成索引 [0, 2, 4, 6]。
  • 2.计算比例: 将索引除以维度大小得到比例 [0/8, 2/8, 4/8, 6/8] = [0, 0.25, 0.5, 0.75]。
  • 3.计算频率: 使用基数 10000,对比例应用幂运算,生成频率向量:
    f r e q u e n c y = [ 1000 0 0 , 1000 0 0.25 , 1000 0 0.5 , 1000 0 0.75 ] frequency=[10000^0,10000^{0.25} ,10000^{0.5} ,10000 ^{0.75} ] frequency=[100000,100000.25,100000.5,100000.75]
    计算结果约为:
    f r e q u e n c y ≈ [ 1 , 17.78 , 316.23 , 5623.41 ] frequency≈[1,17.78,316.23,5623.41] frequency[1,17.78,316.23,5623.41]
  • 4.取倒数生成逆频率向量。最后,对这些频率值取倒数,生成逆频率向量:
    i n v _ f r e q ≈ [ 1.0 , 0.056 , 0.00316 , 0.000178 ] {inv\_freq}≈[1.0,0.056,0.00316,0.000178] inv_freq[1.0,0.056,0.00316,0.000178]

参考代码:_compute_default_rope_parameters

def _compute_default_rope_parameters(
    config: Optional[PretrainedConfig] = None,
    device: Optional["torch.device"] = None,
    seq_len: Optional[int] = None,
    **rope_kwargs,
) -> Tuple["torch.Tensor", float]:
    """
    计算原始 RoPE 实现中的逆频率参数

    参数:
        config ([`~transformers.PretrainedConfig`]):
            模型配置,用于从中获取 RoPE 参数(如 base 和 dim)。
        device (`torch.device`):
            初始化逆频率参数时使用的设备(如 GPU 或 CPU)。
        seq_len (`int`, *optional*):
            当前序列长度。对于此类型的 RoPE 实现,该参数未被使用。
        rope_kwargs (`Dict`, *optional*):
            向后兼容以前的 RoPE 类实例化方式,该参数将在 v4.45 中移除。

    返回:
        包含 (`torch.Tensor`, `float`) 元组,其中包括 RoPE 嵌入的逆频率和应用于计算出的 cos/sin 的后处理缩放因子
        (此类型的 RoPE 中未使用)。
    """
    # 如果 config 参数不为 None 且 rope_kwargs 参数有值,抛出异常,二者是互斥的
    if config is not None and len(rope_kwargs) > 0:
        raise ValueError(
            "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
            f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
        )

    # 使用传入的 rope_kwargs 来初始化 base 和 dim 参数
    if len(rope_kwargs) > 0:
        base = rope_kwargs["base"]
        dim = rope_kwargs["dim"]
    # 使用 config 配置中的参数来初始化 base 和 dim 参数
    elif config is not None:
        base = config.rope_theta
        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
        head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        dim = int(head_dim * partial_rotary_factor)

    # RoPE 后处理的缩放因子,默认为 1.0(未在此类型 RoPE 中使用)
    attention_factor = 1.0  # Unused in this type of RoPE

    # 计算逆频率参数
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
    
    # 返回计算出的逆频率和缩放因子
    return inv_freq, attention_factor

3. LlamaRotaryEmbedding类 源码解析

算法核心为:

  • 计算inv_freq
  • 扩展inv_freq和position_ids
  • inv_freq_expanded与position_ids_expanded相乘并转置得到freqs
  • 拼接freqs
  • 计算cos值 和 sin值

3.1 transformers v4.44.2版

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.44.2
import torch

from typing import Optional
from torch import nn
from transformers import LlamaConfig
from transformers.utils import logging
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

logger = logging.get_logger(__name__)


class LlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim=None,  # 嵌入维度
        max_position_embeddings=2048,  # 最大位置嵌入数
        base=10000,  # 基数,用于计算逆频率
        device=None,
        scaling_factor=1.0,  # 缩放因子
        rope_type="default",  # RoPE 类型
        config: Optional[LlamaConfig] = None,  # 可选的 Llama 配置
    ):
        super().__init__()
        # TODO (joao): remove the `if` below, only used for BC
        # 移除下面的 if,此代码仅用于向后兼容(BC)
        self.rope_kwargs = {}  # 初始化存储 RoPE 参数的字典
        if config is None:
            logger.warning_once(
                "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
                "`config` argument. All other arguments will be removed in v4.45"
            )
            # 如果没有传入配置对象,使用传入的参数进行初始化
            self.rope_kwargs = {
                "rope_type": rope_type,
                "factor": scaling_factor,
                "dim": dim,
                "base": base,
                "max_position_embeddings": max_position_embeddings,
            }
            self.rope_type = rope_type
            self.max_seq_len_cached = max_position_embeddings  # 初始化缓存的最大序列长度
            self.original_max_seq_len = max_position_embeddings  # 保存原始最大序列长度
        else:
            # BC: "rope_type" was originally "type"
            # 如果传入了配置对象,向后兼容: "rope_type" 原来被称为 "type"
            if config.rope_scaling is not None:  # 如果配置中有 rope_scaling 参数
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))  # 获取 RoPE 类型
            else:
                self.rope_type = "default"  # 如果没有指定,使用默认类型
            self.max_seq_len_cached = config.max_position_embeddings  # 从配置中读取最大位置嵌入数
            self.original_max_seq_len = config.max_position_embeddings  # 保存原始最大序列长度

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]  # 根据 RoPE 类型选择初始化函数

        # 使用初始化函数计算逆频率 (inv_freq) 和注意力缩放 (attention_scaling)
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)

        self.register_buffer("inv_freq", inv_freq, persistent=False)  # 注册逆频率缓冲区(不会被持久化)
        self.original_inv_freq = self.inv_freq  # 保存原始逆频率

    def _dynamic_frequency_update(self, position_ids, device):
        """
        对于动态 RoPE 层,应在以下情况下重新计算 `inv_freq`:
        1 - 超过缓存的序列长度(允许缩放)
        2 - 当前序列长度在原始尺度内(避免对小序列失去精度)
        """
        seq_len = torch.max(position_ids) + 1  # 计算当前序列长度
        if seq_len > self.max_seq_len_cached:  # 序列长度增长时更新逆频率
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len, **self.rope_kwargs
            )
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # 重新注册逆频率缓冲区  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len  # 更新缓存的最大序列长度

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # 如果序列长度变小,重置逆频率
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)  # 恢复原始逆频率
            self.max_seq_len_cached = self.original_max_seq_len  # 恢复缓存的最大序列长度

    @torch.no_grad()  # 禁用梯度计算以提高性能
    def forward(self, x, position_ids):
        """
        :param x: [bs, num_attention_heads, seq_len, head_size]
        :param position_ids: [bs, seq_len]
        """
        # 如果是动态 RoPE 类型,更新逆频率
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # 核心 RoPE 计算块
        # inv_freq: [dim/2] -> inv_freq_expanded: [batch_size, dim/2, 1]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)  # 扩展逆频率
        # position_ids: [batch_size, seq_len] -> position_ids_expanded: [batch_size, 1, seq_len]
        position_ids_expanded = position_ids[:, None, :].float()  # 扩展位置 ID

        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):  # 关闭自动混合精度
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)  # 计算频率。这里的 @ 符号是矩阵乘法的符号,表示对两个张量进行矩阵乘法运算。
            # freqs: [batch_size, seq_len, dim/2]
            emb = torch.cat((freqs, freqs), dim=-1)  # 拼接频率
            # emb: [batch_size, seq_len, dim]
            cos = emb.cos()  # 计算 cos 嵌入
            sin = emb.sin()  # 计算 sin 嵌入

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        # 高级 RoPE 类型应用后处理缩放因子,等价于缩放注意力
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        # 返回 cos 和 sin 嵌入
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

3.2 transformers v4.41.1版

如果v4.44.2看着有些复杂,可以参考v4.41.1

# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:52
# @transformers.version: v4.41.1
import torch

from torch import nn


class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor  # 缩放因子
        self.dim = dim  # 嵌入维度
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入
        self.base = base  # 基数
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # 计算逆频率向量
        self.register_buffer("inv_freq", inv_freq, persistent=False)  # 将逆频率向量注册到缓存
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings  # 最大序列长度缓存

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # -----------------------核心 RoPE 计算块----------------------- #
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2103575.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch 向量数据库本地部署 及操作方法

elasticsearch是个分布式向量数据库&#xff0c;支持多种查找模式。此外还拥有 Metadata、Filtering、Hybrid Search、Delete、Store Documents、Async等能力。本文仅是记录本地测试途中遇到的问题。 一&#xff0c;环境部署 下载软件 首先去官网&#xff0c;选择适合平台下…

Kafka-设计原理

ControllerLeader - PartitionRebalance消息发布机制HW与LEO日志分段 Controller Kafka核心总控制器Controller&#xff1a;在Kafka集群中会有一个或者多个broker&#xff0c;其中有一个broker会被选举为控制器&#xff08;Kafka Controller&#xff09;&#xff0c;它负责管理…

Hyper-v 安装 centOS

一.Hyper-v安装 1. 右键此电脑&#xff0c;点击属性&#xff0c;查看自己的window版本 如果是专业版或者企业版&#xff0c;则无需额外操作&#xff0c;如果是家庭版&#xff0c;则需要先运行一个脚本来进行安装。 参考这一篇&#xff1a;window10 家庭版如何开启Hyper-v-CSDN…

FPGA开发:初识FPGA

FPGA是什么&#xff1f; FPGA的全称是现场可编程门阵列&#xff08;Field Programmable Gate Array&#xff09;&#xff0c;一种以数字电路为主的集成芯片&#xff0c;属于可编程逻辑器件PLD的一种。简单来说&#xff0c;就是能用代码编程&#xff0c;直接修改FPGA芯片中数字…

OceanBase 关于 place_group_by HINT的使用

PLACE_GROUP_BY Hint 表示在多表关联时&#xff0c;如果满足单表查询后直接进行group by 的情形下&#xff0c;在跟其它表进行关联统计&#xff0c;减少表内部联接。 NO_PLACE_GROUP_BY Hint 表示在多表关联时&#xff0c;在关联后才对结果进行group by。 使用place_group_by …

二百五十九、Java——采集Kafka数据,解析成一条条数据,写入另一Kafka中(一般JSON)

一、目的 由于部分数据类型频率为1s&#xff0c;从而数据规模特别大&#xff0c;因此完整的JSON放在Hive中解析起来&#xff0c;尤其是在单机环境下&#xff0c;效率特别慢&#xff0c;无法满足业务需求。 而Flume的拦截器并不能很好的转换数据&#xff0c;因为只能采用Java方…

启动.cmd文件一闪而过,看不到报错信息

在window的环境中&#xff0c;双击.cmd文件&#xff0c;有报错信息&#xff0c;但是一闪而过 例如启动zookeeper时&#xff0c;没有zoo.cfg文件会报错&#xff0c;但是启动一闪而过&#xff0c;你看不到报错信息 有文本工具编辑cmd文件&#xff0c;在最后添加 pause 再次启…

Linux 之 lsblk 【可用块的设备信息】

功能介绍 在 Linux 系统中&#xff0c;“lsblk”&#xff08;list block devices&#xff09;命令用于列出所有可用的块设备信息 应用场景 查看存储设备信息&#xff1a;“lsblk” 命令可以帮助你快速了解系统中的存储设备&#xff0c;包括硬盘、固态硬盘、U 盘等。你可以查…

9_4_QTextEdit

QTextEdit //核心属性//获取文本 toPlainText(); toHtml(); toMarkdown(); //输入框为空时的提示功能 placeHolderText(); //只读 readOnly();//定义文本光标 QTextcursor cursorcursor.position(); cursor.selectedText();//核心信号//文本改变 textChanged(); //选中范围 se…

【黑马点评】附近商户

需求 选择商铺类型后&#xff0c;按照距离当前用户所在位置从近到远的顺序&#xff0c;分页展示该类型的所有商铺。 接口&#xff1a; 参数&#xff1a; typeId&#xff1a;商铺类型current&#xff1a;页码x&#xff1a;经度y&#xff1a;纬度 返回值&#xff1a;所有typeId…

LVS 负载均衡集群指南

1. 引言 LVS (Linux Virtual Server) 虚拟服务器&#xff0c;是 Linux 内核中实现的负载均衡技术&#xff0c;以其高性能、高可靠性和高可用性而闻名。LVS 工作在 TCP/IP 协议栈的第四层 (传输层)&#xff0c;通过将流量分配到多个后端服务器&#xff0c;提高系统性能、可用性…

硬件工程师笔试面试知识器件篇——电阻

目录 1、电阻 1.1 基础 电阻原理图 阻实物图 1.1.1、定义 1.1.2、工作原理 1.1.3、类型 1.1.4、材料 1.1.5、标记 1.1.6、应用 1.1.7、特性 1.1.8、测量 1.1.9、计算 1.1.10、颜色编码 1.1.11、公差 1.1.12、功率 1.1.13、重要性 1.2、相关问题 1.2.1、电阻…

数组和指针 笔试题(1)

目录 0.复习 1.笔试题1 2.笔试题2 3.笔试题3 4.笔试题4 5.笔试题5 0.复习 在做笔试题之前&#xff0c;我们首先复习一下数组名的理解 数组名的所有情况&#xff1a; 1.&数组名&#xff0c;取出的是整个数组的地址 2.sizeof&#xff08;数组名&#xff09;&#x…

LLM常见问题(Attention 优化部分)

1. 传统 Attention 存在哪些问题&#xff1f; 传统的 Attention 机制忽略了源端或目标端句子中词与词之间的依赖关系。传统的 Attention 机制过度依赖 Encoder-Decoder 架构上。传统的 Attention 机制依赖于Decoder的循环解码器&#xff0c;所以依赖于 RNN,LSTM 等循环结构。传…

【Transformer】Tokenization

文章目录 直观理解分词方式词粒度-Word字粒度-Character子词粒度-Subword&#xff08;目前最常使用&#xff09; 词表大小的影响参考资料 直观理解 在理解Transformer或者大模型对输入进行tokenize之前&#xff0c;需要理解什么是token&#xff1f; 理工科的兄弟姐妹们应该都…

027集——goto语句用法——C#学习笔记

goto语句可指定代码的跳行运行&#xff1a; 实例如下&#xff1a; 代码如下&#xff1a; using System; using System.Collections.Generic; using System.Linq; using System.Security.Policy; using System.Text; using System.Threading.Tasks;namespace ConsoleApp2 { //…

采用基于企业服务总线(ESB)的面向服务架构(SOA)集成方案实现统一管理维护的银行信息系统

目录 案例 【题目】 【问题 1】(7 分) 【问题 2】(12 分) 【问题 3】(6 分) 【答案】 【问题 1】解析 【问题 2】解析 【问题 3】解析 相关推荐 案例 阅读以下关于 Web 系统设计的叙述&#xff0c;在答题纸上回答问题 1 至问题 3。 【题目】 某银行拟将以分行为主体…

是噱头还是低成本新宠?加州大学用视觉追踪实现跨平台的机器手全掌控?

导读&#xff1a; 在当今科技飞速发展的时代&#xff0c;机器人的应用越来越广泛。从工业生产到医疗保健&#xff0c;从物流运输到家庭服务&#xff0c;机器人正在逐渐改变我们的生活方式。而机器人的有效操作和控制&#xff0c;离不开高效的遥操作系统。今天&#xff0c;我们要…

OHIF Viewer (3.9版本最新版) 适配移动端——最后一篇

根据一些调用资料和尝试,OHIF 的底层用的是Cornerstonejs ,这个是基于web端写的,如果说写在微信小程序里,确实有很多报错, 第一个问题就是 npm下载的依赖, 一、运行环境差异 微信小程序的运行环境与传统的 Node.js 环境有很大不同。小程序在微信客户端中运行,有严格的…

传输大咖38 | 如何应对汽车行业内外网文件交换挑战?

在数字化浪潮的推动下&#xff0c;汽车行业正经历着前所未有的变革。随着智能网联汽车的兴起&#xff0c;内外网文件的安全交换成为了一个亟待解决的问题。本文将探讨汽车行业在内外网文件交换中遇到的难题&#xff0c;并介绍镭速如何为这些问题提供有效的解决方案。 跨网文件传…