ChatGLM-6B 主要代码分析 RotaryEmbedding

news2024/9/20 22:46:17

ChatGLM-6B 主要代码分析 RotaryEmbedding

flyfish
在这里插入图片描述

图片链接地址

传统的 Transformer 位置编码(Positional Encoding)被称为绝对位置编码 ,而 Rotary Embedding 被称为相对位置编码 ,主要是因为它们编码位置信息的方式不同,进而影响模型对序列中元素之间位置关系的理解。

1. 传统 Transformer 位置编码:绝对位置编码

在传统的 Transformer 模型中,位置编码使用正弦和余弦函数将每个位置 t t t 映射到一个固定的向量: P E ( t , 2 i ) = sin ⁡ ( t 1000 0 2 i / d ) PE(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i)=sin(100002i/dt)

P E ( t , 2 i + 1 ) = cos ⁡ ( t 1000 0 2 i / d ) PE(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i+1)=cos(100002i/dt)
其中, t t t 是序列中的位置索引, i i i 是维度索引, d d d 是嵌入维度。

特点:
  • 固定位置编码 :每个位置 t t t 的编码是固定的,无论它出现在序列的哪个部分,其编码都是由位置 t t t 唯一确定的。

  • 不变性 :这种编码方式不会随着序列的变化而变化,意味着同一位置的编码在每次出现时都是相同的。

绝对性:
  • 绝对位置感知 :由于位置编码与序列中的具体位置 t t t 紧密关联,模型在训练时会将这些编码与特定的序列模式联系起来。这种方式能够让模型感知到序列中每个元素的绝对位置,但对元素之间的相对位置(如相对距离)缺乏直接的建模能力。

  • 难以处理相对位置信息 :在绝对位置编码下,如果需要感知两个元素之间的相对距离或关系,模型必须通过训练学习到这些关系,而不是通过位置编码直接得到。

2. Rotary Embedding:相对位置编码

Rotary Embedding 的核心思想是通过旋转操作,将位置信息嵌入到序列的每个元素中,从而使模型能够自然地感知到序列中元素之间的相对位置关系。

工作原理:
  1. 旋转矩阵 :Rotary Embedding 将位置信息与特征向量通过旋转矩阵结合。假设 x 1 x_1 x1 x 2 x_2 x2 是在位置 t t t t + 1 t+1 t+1 的特征向量,那么旋转操作后的位置编码变换为: R ( θ ) ⋅ x = [ cos ⁡ ( θ ) − sin ⁡ ( θ ) sin ⁡ ( θ ) cos ⁡ ( θ ) ] ⋅ [ x 1 x 2 ] R(\theta) \cdot x = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix} \cdot \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} R(θ)x=[cos(θ)sin(θ)sin(θ)cos(θ)][x1x2]
    其中 θ \theta θ 是根据位置计算得到的旋转角度。

  2. 相对位置感知 :当两个位置 t t t t + 1 t+1 t+1 的特征向量进行旋转变换时,模型可以通过旋转角度的差异自然感知到这两个位置之间的相对关系,而无需依赖绝对位置编码。

相对性:
  • 相对位置感知 :Rotary Embedding 通过旋转矩阵直接捕捉相邻元素之间的相对位置信息。例如,元素 x 1 x_1 x1 x 2 x_2 x2 在相邻位置 t t t t + 1 t+1 t+1 之间的相对关系可以通过旋转角度的差异直接表达。

  • 位置编码灵活性 :由于旋转矩阵使得位置编码可以灵活变化,因此模型能够更自然地处理不同长度的序列和不同的相对位置关系。

3. 绝对 vs. 相对位置编码

  • 绝对位置编码 (传统 Transformer):编码固定,适合处理具体位置相关的任务,但难以直接处理相对位置关系。

  • 相对位置编码 (Rotary Embedding):编码与序列中的相对位置变化相关,更加灵活,适合处理长序列和需要相对位置信息的任务。

Rotary Embedding 与传统位置编码的比较

特点传统位置编码 (Positional Encoding)Rotary Embedding
编码方式正弦和余弦函数的绝对位置编码旋转矩阵的相对位置编码
位置关系只能表示绝对位置更好地表示相对位置
对长序列的处理长序列时可能失效能够有效处理长序列
模型适应性需要在训练期间观察到所有可能位置更具扩展性,适应超长序列
应用场景适用于大多数任务尤其适用于需要处理长序列和复杂依赖关系的任务
import torch
class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)
    
    


# 初始化 RotaryEmbedding 模块
dim = 64  # 嵌入维度
rotary_emb = RotaryEmbedding(dim=dim)

# 模拟输入张量
batch_size = 2
seq_len = 10
embedding_dim = dim
x = torch.randn(batch_size, seq_len, embedding_dim)

# 调用 forward 方法
cos, sin = rotary_emb(x)

# 输出 cos 和 sin 的形状
print("Cosine Embedding Shape:", cos.shape)
print("Sine Embedding Shape:", sin.shape)

输出

Cosine Embedding Shape: torch.Size([10, 1, 64])
Sine Embedding Shape: torch.Size([10, 1, 64])

Rotary Embedding 的设计思想是将位置编码嵌入到一个旋转的向量空间中,从而为序列建模提供更强的相对位置感知能力。

1. 三角函数基础

三角函数 cossin 描述了一个角度在单位圆上的投影,定义如下: cos ⁡ ( θ ) = 邻边 斜边 , sin ⁡ ( θ ) = 对边 斜边 \cos(\theta) = \frac{\text{邻边}}{\text{斜边}}, \quad \sin(\theta) = \frac{\text{对边}}{\text{斜边}} cos(θ)=斜边邻边,sin(θ)=斜边对边
这些函数具有周期性,对于任何角度 θ \theta θ,都有以下性质: cos ⁡ ( θ + 2 π ) = cos ⁡ ( θ ) , sin ⁡ ( θ + 2 π ) = sin ⁡ ( θ ) \cos(\theta + 2\pi) = \cos(\theta), \quad \sin(\theta + 2\pi) = \sin(\theta) cos(θ+2π)=cos(θ),sin(θ+2π)=sin(θ)

2. 位置编码(Positional Encoding)

在传统的 Transformer 模型中,位置编码通过 sincos 函数来表示输入序列中的位置信息。对于一个给定的位置 t t t,对应的编码可以表示为: P E ( t , 2 i ) = sin ⁡ ( t 1000 0 2 i / d ) PE(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i)=sin(100002i/dt)

P E ( t , 2 i + 1 ) = cos ⁡ ( t 1000 0 2 i / d ) PE(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i+1)=cos(100002i/dt)
其中, t t t 是序列中的位置, i i i 是维度索引, d d d 是嵌入维度。这个编码方式保证了不同维度具有不同的频率,以便模型能够感知到位置的不同。

3. 旋转嵌入(Rotary Embedding)

Rotary Embedding 是一种改进的相对位置编码方法,其核心思想是将位置信息通过旋转矩阵嵌入到序列中的每个特征向量中。它通过以下步骤实现:

1). 逆频率生成
首先,生成一个逆频率向量 inv_freq inv_freq j = 1 base 2 j d \text{inv\_freq}_j = \frac{1}{\text{base}^{\frac{2j}{d}}} inv_freqj=based2j1
其中 base 通常取 10000,j 是维度索引,d 是嵌入维度。

2). 频率矩阵生成
接下来,计算频率矩阵 freqs,将逆频率与时间步长(即序列位置)相乘: freqs i , j = t i × inv_freq j \text{freqs}_{i,j} = t_i \times \text{inv\_freq}_j freqsi,j=ti×inv_freqj
其中 t i t_i ti 是序列位置。

3). 三角函数编码
频率矩阵的每个元素通过 cossin 进行编码,并合并为一个编码矩阵: emb = [ cos ⁡ ( freqs ) , sin ⁡ ( freqs ) ] \text{emb} = [\cos(\text{freqs}), \sin(\text{freqs})] emb=[cos(freqs),sin(freqs)]

4). 旋转变换
在旋转嵌入中,编码后的 cossin 矢量与输入向量进行旋转变换。给定一个输入向量 x x x 及其旋转矩阵 R ( θ ) R(\theta) R(θ) R ( θ ) ⋅ x = [ cos ⁡ ( θ ) − sin ⁡ ( θ ) sin ⁡ ( θ ) cos ⁡ ( θ ) ] ⋅ [ x 1 x 2 ] R(\theta) \cdot x = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix} \cdot \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} R(θ)x=[cos(θ)sin(θ)sin(θ)cos(θ)][x1x2]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2051966.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python创建项目环境及项目打包

目录 创建项目环境conda创建环境常用命令创建项目虚拟环境创建虚拟环境激活虚拟环境安装第三方库 pyinstaller 打包常用参数组合 嵌入式打包下载嵌入式版本的python配置环境无参调用可完善 nuitka打包 创建项目环境 conda创建环境常用命令 conda create -n py310 python3.10.…

《学会 SpringBoot · 依赖管理机制》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

IntelliJ IDEA 集成 ShardingSphere-JDBC 访问分库分表

背景 众所周知,IntelliJ IDEA 是 Java 领域常用的开发工具之一,IDEA Ultimate(旗舰版)或其他例如 DataGrip 等 Intellij 平台的工具都集成了对数据库的访问能力。 但是,对于做了分库分表的项目,直接使用 …

微信支付代理商-自助提交资料源码之结算信息页面—微信支付商机版

一、支付代理上自助提交资料 一般在都在小程序完成提交 在网页中异常提示alert 但是很多小程序禁用了这个函数 并且不好看 那么久自定义一个组件每次直接调用 二、提示技术代码 function 未来之窗_VOS_通用提醒(msg){var 未来之窗内容message<cyberdiv style"font…

选择排序(直接选择排序与堆排序的比较)

选择排序 选择排序时间复杂度 1. 直接选择排序思考⾮常好理解&#xff0c;但是效率不是很好。实际中很少使用&#xff0c;思路是先进行遍历找到元最小的元素&#xff0c;然后与第一个进行交换 2. 时间复杂度&#xff1a;O&#xff08;&#xff09; 3. 空间复杂度&#…

gmapping算法核心部分

processScan函数 参考&#xff1a;https://blog.csdn.net/CV_Autobot/article/details/131058981 drawFromMotion:根据运动模型更新粒子位姿 scanMatch:进行扫描匹配 resample:重采样 逐步分解并详细解释代码 1. 获取当前扫描的相对位姿 OrientedPoint relPose reading.…

舜宇光学科技社招校招入职测评:商业推理测验真题汇总、答题要求、高分技巧

舜宇光学科技&#xff08;集团&#xff09;有限公司&#xff0c;成立于1984年&#xff0c;是全球领先的综合光学零件及产品制造商。2007年在香港联交所主板上市&#xff0c;股票代码2382.HK。公司专注于光学产品的设计、研发、生产及销售&#xff0c;产品广泛应用于手机、汽车、…

BEM架构

视频 总结&#xff1a; BEM架构&#xff1a;一个命名类的规范而已&#xff0c;说白了就是如何给类起名字使用sass的目的&#xff1a;在<style>中模块化的使用类名&#xff0c;同时减少代码数量 1、 BEM架构 &#xff08;通义灵码查询结果&#xff09; BEM (Block Ele…

【hot100篇-python刷题记录】【和为 K 的子数组】

R5-子串篇 目录 思路&#xff1a; 优化&#xff1a; tip: 代码&#xff1a; 结果&#xff1a; ps: 思路&#xff1a; 滑动&#xff0c;应该可以使用滑动窗口来解题。 貌似前缀和也可以&#xff0c;left&#xff0c;right两个指针&#xff0c;right的前缀和-left的前缀…

【学习笔记】printf中%m的含义

【学习笔记】printf中%m的含义 在有些代码中会看到如下的写法&#xff1a; printf("%m\n");printf中使用了%m来打印输出&#xff0c;那么%m又是什么意思呢&#xff1f; 其实%m 并不是在所有的 printf 实现中都通用或标准化的选项&#xff0c;而是在某些特定的编程语…

vue的markdown编辑器插件比对

vue的markdown编辑器插件比对 文章说明md-editor-v3的使用及效果展示vditor的使用及效果展示 文章说明 文章比对 md-editor-v3、vditor 这两个插件的使用及效果体验 md-editor-v3的使用及效果展示 安装 npm install md-editor-v3使用 <script setup> import {reactive} f…

图神经网络(Graph Neural Networks)是什么?

图神经网络&#xff08;Graph Neural Networks&#xff09;是什么&#xff1f; 引言 在数据科学和机器学习的广阔领域中&#xff0c;图结构数据以其独特的复杂性和丰富性成为了一个重要的研究方向。从社交网络中的用户关系&#xff0c;到生物信息学中的蛋白质交互网络&#x…

跨进程通信使用 Zenoh中间件 进行高效数据传输的测试和分析

文章目录 1. 引言2. Zenoh C 使用指南2.1 安装 Zenoh C 库2.2 编写基本的 Zenoh C 程序订阅示例发布示例 2.3 编译和运行程序 3. Zenoh 与 ROS2 集成3.1 安装 Zenoh3.2 安装 ROS2 的 Zenoh RMW 实现3.3 设置 RMW 实现为 Zenoh3.4 验证配置 4. 编写基于 Zenoh 的 ROS2 应用程序4…

Linux系统编程 --- 多线程

线程&#xff1a;是进程内的一个执行分支&#xff0c;线程的执行粒度&#xff0c;要比进程要细。 一、线程的概念 1、Linux中线程该如何理解 地址空间就是进程的资源窗口。 在一个程序里的一个执行路线就叫做线程&#xff08;thread&#xff09;。更准确的定义是&#xff1…

浏览器遇到的问题

下载的时候遇到&#xff0c;需要授权&#xff0c;无法下下载 将隐私里面的全部关掉

虚幻5|AI巡逻宠物伴随及定点巡逻—初步篇

一.建立AI基本三件套 1.建立AI基本三件套 二.使用AI的基本设置 1.打开我们想要用的AI宠物的蓝图&#xff0c;选中自我Actor,右侧细节处找到AI&#xff0c;选中对应的AI控制器 三.打开AI控制器 写如下 四&#xff0c;AI行为树 1.新建一个任务&#xff0c;命名含巡逻二字即可…

BigInteger与BigDecimal

BigInteger BigInteger构造方法 public BigInteger(int num, Random rnd) 获取随机大整数&#xff0c;范围&#xff1a;[0 ~ 2的num次方-1] public BigInteger(String val) 获取指定的大整数 public BigInteger(String val, int radix) 获取指定进制的大整数 构造方法小结…

Power Query抓取多页数据导入到Excel

原文链接 举例网站&#xff1a;http://vip.stock.finance.sina.com.cn/q/go.php/vLHBData/kind/ggtj/index.phtml?last5&p1 操作步骤 &#xff08;版本为&#xff1a;Excel2010&#xff09;&#xff1a; Step-01&#xff1a;单击【Power Query】-【从Web】&#xff0c;…

Java之文件操作和IO

目录 File类 属性 构造方法 方法 文件内容的读写 InputStream OutputStream File类 属性 修饰符及类型属性说明static StringpathSeparator依赖于系统的路径分隔符&#xff0c;String类型的表示static charpathSeparator依赖于系统的路径分隔符&#xff0c;char类型的…

ps磨皮滤镜插件Imagenomic Portraiture 4.5 Build 4501中文版

PS磨皮神器更新为Portraiture 中文汉化版&#xff08;支持PS 2024&#xff09; 。Portraiture 4.5 Build 4501中文绿色破解版是一款非常强大的适用于Photoshop&#xff0c;Lightroom&#xff0c;Aperture的人物磨皮&#xff08;人物润色&#xff09;插件。Portraiture插件被经常…