LLama2源码分析——Rotary Position Embedding分析

news2024/11/26 22:26:42

参考:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)

原理推导参考自上文,以下结合huggingface代码分析公式计算过程

1 旋转角度计算

计算公式如下,其中d为词嵌入维度,这部分和论文原文一样
θ j = 1000 0 − 2 ( j − 1 ) / d , j ∈ [ 1 , 2 , … , d / 2 ] \theta_j=10000^{-2(j-1)/d},j\in [1,2,\ldots,d/2] θj=100002(j1)/d,j[1,2,,d/2]

# 计算词向量元素两两分组之后,每组元素对应的旋转角度
# 维度:[dim / 2]
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

2 计算整个seq的cos_sin矩阵

def _set_cos_sin_cache(self, seq_len, device, dtype):
    self.max_seq_len_cached = seq_len
    # 生成token长度序列
    t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
    # 计算两个矩阵的外积,结果维度[seq_len, dim // 2]
    freqs = torch.einsum("i,j->ij", t, self.inv_freq)
    # 类似[[0, 2, 4, ..., 0, 2, 4, ...], ...]形式,旋转角度两两一组相同
    emb = torch.cat((freqs, freqs), dim=-1)
    self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
    self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

3 计算旋转式位置编码

f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ \begin{aligned}f_q(x_m,m)&=(W_qx_m)e^{im\theta} \\f_k(x_n,n)&=(W_kx_n)e^{in\theta}\end{aligned} fq(xm,m)fk(xn,n)=(Wqxm)eimθ=(Wkxn)einθ
公式根据欧拉公式转化后为
( q m ( 1 ) + i q m ( 2 ) ) ∗ ( cos ⁡ ( m θ ) + i sin ⁡ ( m θ ) ) (q_{m}^{(1)}+iq_{m}^{(2)})*(\cos(m\theta)+i\sin(m\theta)) (qm(1)+iqm(2))(cos(mθ)+isin(mθ))

展开后将结果重新表示为实数向量即为
q m e i m θ = [ q m ( 1 ) cos ⁡ ( m θ ) − q m ( 2 ) sin ⁡ ( m θ ) , q m ( 2 ) cos ⁡ ( m θ ) + q m ( 1 ) sin ⁡ ( m θ ) ] q_me^{im\theta}=[q_m^{(1)}\cos(m\theta)-q_m^{(2)}\sin(m\theta),q_m^{(2)}\cos(m\theta)+q_m^{(1)}\sin(m\theta)] qmeimθ=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]
key的计算同理,以上公式是2维embedding的旋转编码计算,实际代码中是将高纬度的embedding两两分组按照上述公式计算,同一组内的旋转角度相同,此处Llama代码中的分组计算方式与论文原文有所区别,论文原文中是将embedding_dim维度(最后一维)的向量按照相邻两个位置数字为一组,可以按照如下代码理解

>>> a
tensor([[1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 7, 8]])
>>> a.view(2, -1, 2)
tensor([[[1, 2],
         [3, 4],
         [5, 6],
         [7, 8]],

        [[1, 2],
         [3, 4],
         [5, 6],
         [7, 8]]])

因此,单个token的位置编码是如下图方式计算
image
但以上的R矩阵比较稀疏,计算时浪费大量算力,因此Llama中采用不同的方式计算

  • Llama源码中计算方法

( q 0 q 1 ⋮ q d / 2 − 1 q d / 2 q d / 2 + 1 ⋮ q d − 1 ) ⊗ ( cos ⁡ m θ 0 cos ⁡ m θ 2 cos ⁡ m θ 4 ⋮ cos ⁡ m θ d − 2 cos ⁡ m θ 0 cos ⁡ m θ 2 ⋮ cos ⁡ m θ d − 2 ) + ( − q d / 2 − q d / 2 + 1 ⋮ − q d − 1 q 1 q 2 ⋮ q d / 2 − 1 ) ⊗ ( sin ⁡ m θ 0 sin ⁡ m θ 2 sin ⁡ m θ 4 ⋮ sin ⁡ m θ d − 2 sin ⁡ m θ 0 sin ⁡ m θ 2 ⋮ sin ⁡ m θ d − 2 ) \begin{pmatrix} {q_0}\\{q_1}\\{\vdots}\\{q_{d/2-1}}\\{q_{d/2}}\\{q_{d/2+1}}\\{\vdots}\\{q_{d-1}} \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_0\\\cos m\theta_2\\\cos m\theta_4\\\vdots\\\cos m\theta_{d-2}\\\cos m\theta_0\\\cos m\theta_2\\\vdots\\\cos m\theta_{d-2} \end{pmatrix} + \begin{pmatrix} {-q_{d/2}}\\{-q_{d/2+1}}\\\vdots\\{-q_{d-1}}\\{q_{1}}\\{q_{2}}\\\vdots\\{q_{d/2-1}} \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_0\\\sin m\theta_2\\\sin m\theta_4\\\vdots\\\sin m\theta_{d-2}\\\sin m\theta_0\\\sin m\theta_2\\\vdots\\\sin m\theta_{d-2} \end{pmatrix} q0q1qd/21qd/2qd/2+1qd1 cosmθ0cosmθ2cosmθ4cosmθd2cosmθ0cosmθ2cosmθd2 + qd/2qd/2+1qd1q1q2qd/21 sinmθ0sinmθ2sinmθ4sinmθd2sinmθ0sinmθ2sinmθd2

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1795287.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue——监听器简单使用与注意事项

文章目录 前言编写简单demo注意事项 前言 监听器,在官网中称为侦听器,个人还是喜欢称之为监听器。官方文档如下: vue 官网 侦听器 编写简单demo 侦听器在项目中通常用于监听某个属性变量值的变化,并根据该变化做出一些处理操作。…

冯喜运:6.7今日黄金原油行情分析及独家操作策略

【黄金消息面分析】:周三(6月5日),金价回升逾1.2%,收盘报每盎司2,355.49美元,全面收复前一交易日的跌幅。周三当天前公布的美国民间就业数据弱于预期,增强了美联储将在今年晚些时候降息的预期&a…

AI大数据处理与分析实战--体育问卷分析

AI大数据处理与分析实战–体育问卷分析 前言:前一段时间接了一个需求,使用AI进行数据分析与处理,遂整理了一下大致过程和大致简要结果(更详细就不方便放了)。 文章目录 AI大数据处理与分析实战--体育问卷分析一、数据…

[Python]用Qt6和Pillow实现截图小工具

本文章主要讲述的内容是,使用python语言借助PyQt6和Pillow库进行简单截图工具的开发,含义一个简单的范围裁剪和软件界面。 主要解决的问题是,在高DPI显示屏下,坐标点的偏差导致QWidget显示图片不全、剪裁范围偏差问题。 适合有一点…

Spark MLlib 机器学习详解

目录 🍉引言 🍉Spark MLlib 简介 🍈 主要特点 🍈常见应用场景 🍉安装与配置 🍉数据处理与准备 🍈加载数据 🍈数据预处理 🍉分类模型 🍈逻辑回归 &a…

【Linux网络】传输层协议 - UDP

文章目录 一、传输层(运输层)运输层的特点复用和分用再谈端口号端口号范围划分认识知名端口号(Well-Know Port Number)两个问题① 一个进程是否可以绑定多个端口号?② 一个端口号是否可以被多个进程绑定? n…

Java Web学习笔记15——DOM对象

DOM: 概念:Document Object Model: 文档对象模型 将标记语言的各个组成部分封装为对应的对象: Document: 整个文档对象 Element:元素对象 Attribute: 属性对象 Text:文本对象 Comment&a…

logback删除日志文件和文件夹

​​​​​一,事由和源码 logback版本1.2.11 网上找了很多都是无法删除文件夹的,原先使用的TimeBasedRollingPolicy无法删除日志的文件夹,有很多空的日期文件夹,于是查看TimeBasedRollingPolicy源码发现有校验不删除文件夹&#x…

docker-compose部署 kafka 3.7 集群(3台服务器)并启用账号密码认证

文章目录 1. 规划2. 服务部署2.1 kafka-012.2 kafka-022.3 kafka-032.4 启动服务 3. 测试3.1 kafkamap搭建(测试工具)3.2 测试 1. 规划 服务IPkafka-0110.10.xxx.199kafka-0210.10.xxx.198kafka-0310.10.xxx.197kafkamp10.10.xxx.199 2. 服务部署 2.1…

MySQL报ERROR 2002 (HY000)解决

今天在连接客户服务器时MySQL的时候报: ERROR 2002 (HY000): Can’t connect to local MySQL server through socket ‘/tmp/mysql/mysql.sock’ (2) [rootXXX ~]# mysql -uroot -p Enter password: ERROR 2002 (HY000): Can’t connect to local MySQL server through socket…

香港服务器无法访问是什么情况?

香港服务器无法访问是什么情况?简单来说,这意味着香港服务器没有响应请求,客户端无法访问。此错误可能由于多种原因而发生,包括网络连接问题、服务器停机、防火墙限制和 DNS 错误。当发生服务器无法访问错误时,它会影响您网站的性…

【Linux】进程切换环境变量

目录 一.进程切换 1.进程特性 2.进程切换 1.进程切换的现象 2.如何实现 3.现实例子 2.环境变量 一.基本概念 二.常见环境变量 三.查询常见环境变量的方法 四.和环境变量相关的命令 五.环境变量表的组织方式 六.使用系统调用接口方式查询环境变量 1.getenv 2.反思 …

【自然语言处理】【Scaling Law】语言模型物理学 第3.3部分:知识容量Scaling Laws

语言模型物理学3.3:知识容量Scaling Laws 论文名称:Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws 论文地址:https://arxiv.org/pdf/2404.05405 相关博客 【自然语言处理】【Scaling Law】Observational Scaling …

RTA_OS基础功能讲解 2.8-Tick计数器

RTA_OS基础功能讲解 2.8-Tick计数器 文章目录 RTA_OS基础功能讲解 2.8-Tick计数器一、计数器简介二、计数器配置三、计数器驱动3.1 软件计数器驱动3.1.1 递增软件计数器3.1.2 静态计数器接口3.2 硬件计数器驱动3.2.1 Advancing硬件计数器3.2.2 回调函数四、在运行时访问计数器属…

Xcode 打包报错Command PhaseScriptExecution failed with a nonzero exit code

解决办法: 1、在Xcode项目中 Pods -> Targets Support Files -> Pods-项目名 -> Pods-项目名-frameworks 中(大约在第44行) 加上 -f 2、CocoaPods版本太旧了,可以尝试升级CocoaPods版本 使用sudo gem update cocoapods更新cocoapods,问题将在1.12.1版本已…

lua vm 二: 查看字节码、看懂字节码

本文讲一讲如何查看 lua 的字节码(bytecode),以及如何看懂字节码。 以下分析基于 lua-5.4.6,下载地址:https://lua.org/ftp/ 。 1. 查看字节码 1.1 方法一:使用 luac luac 是 lua 自带的编译程序&#x…

无线和移动网络

背景 两个重要的挑战 无线:通过无线链路通信移动:需要网络处理移动(不同变换所接入的网络)用户 无线网络中的组件 无线主机(无线并不总是意味着移动的)基站(base station 或者叫AP&#xff0…

芝麻IP好用吗?来测试了!

作为老牌代理IP服务厂商,芝麻IP和青果网络代理IP都做的不错,市场上几乎可以是有口皆碑了,上次测试了青果网络的代理IP,效果表现得还挺不错,和他们自己宣传的以及客户对他们的评价大差不差。 总的来说,他们家…

纷享销客安全体系:物理与环境安全

纷享销客的物理设备托管在经过严格准入制度授权的TIER3级别以上的专业数据中心,这些数据中心均通过了等保三级与IS027001安全认证,确保电力、制冷等基础设施提供相应级别的冗余,以增强IDC环境的安全性。 业务操作系统平台采用当前广泛使用的…

解决 iOS 端小程序「saveVideoToPhotosAlbum:fail invalid video」问题

场景复现: const url https://mobvoi-digitalhuman-video-public.weta365.com/1788148372310446080.mp4uni.downloadFile({url,success: (res) > {uni.saveVideoToPhotosAlbum({filePath: res.tempFilePath,success: (res) > {console.log("res > &…