Transformer中高级位置编码的介绍和比较:Linear Rope、NTK、YaRN、CoPE

news2024/9/22 19:32:11

在处理诸如文本之类的序列时,排序信息显然是至关重要的。为了结合排序信息而不是将序列视为集合,对位置信息进行编码是至关重要的。位置编码通过为每个位置分配嵌入向量并将其添加到相应的标记表示来实现这一点。绝对和相对位置编码是最常见的两种位置编码方式,但是本文将要比较更高级的位置编码方法:

1、RoPE 位置编码及其变体

2、CoPE

旋转位置编码

旋转位置编码(Rotary Positional Encoding,RoPE)是一种在自然语言处理(NLP)中处理序列数据时使用的技术。它旨在通过旋转方式将位置信息编码到输入的表示中,使得模型能更好地理解序列中元素的位置关系。关键思想是通过将上下文表示与旋转矩阵相乘来编码相对位置。RoPE随相对距离的增加而衰减。

RoPE 的核心思想是通过在每个位置应用一个旋转矩阵到每个词元的嵌入上,从而将位置信息融入到词元的表示中。这种编码方式允许模型在处理序列数据时,能够更好地利用位置信息,提升语义理解和语言生成的质量。

我们简单的实现一下RoPE:

 defapply_rope(k, q, cis):
     # Idea suppose vector v = [x,y,x1,y1,...] # v.shape = dim
     # convert vetor into complex num # ie two vec one real, one imagery
     # [x,y,x1,y1,...] -> x+iy, x1+iy1
     # Multiplying by complex num == roatate vector
     # => (x + iy) * (cos + isin) -> x'+iy'
     # restack
     # x'+iy' -> [x',y',x1',y1'...]
     # you roated vector in chunks of two lfg!!!
     _, seq_len, _, _=q.shape
     freqs_cos, freqs_sin=cis
     freqs_cos, freqs_sin=freqs_cos[:seq_len], freqs_sin[:seq_len]
     #  rehsape a shape (...,n )-> (..., n//2,2)
     q_cis=q.float().reshape(
         q.shape[:-1] + (-1, 2)
     )  # (B,T,nhead,C) -> (B,T,nhead,Cc,2) # Cc = C//2
     k_cis=k.float().reshape(k.shape[:-1] + (-1, 2))  # (B,T,nhead,C) -> (B,T,nhead,Cc,2)
     xq_r, xq_i=q_cis.unbind(-1)  # (B,T,nhead,Cc,2) -> ((B,T,Cc), (B,T,Cc)) split into two tuple
     xk_r, xk_i=k_cis.unbind(-1)  # (B,T,nhead,Cc,2) -> ((B,T,Cc), (B,T,Cc))
     freqs_cos=reshape_for_broadcast(freqs_cos, xq_r)  # freqs.shape = (1,T,1,Cc)
     freqs_sin=reshape_for_broadcast(freqs_sin, xq_r)
     xq_out_r=xq_r*freqs_cos-xq_i*freqs_sin  # (ac-bd)   # shape =  # (B,T,nhead,Cc)
     xq_out_i=xq_r*freqs_sin+xq_i*freqs_cos  # (ad+bc) * i
     xk_out_r=xk_r*freqs_cos-xk_i*freqs_sin  # (ac-bd)
     xk_out_i=xk_r*freqs_sin+xk_i*freqs_cos  # (ad+bc) * i
     # now we stack r,i -> [r,i,r2,i2]
     xq_out=torch.stack([xq_out_r, xq_out_i], dim=-1)  # (B,T,nhead,Cc,2)
     xk_out=torch.stack([xk_out_r, xk_out_i], dim=-1)  # (B,T,nhead,Cc,2)
     # flatten last two dimensions
     xq_out=xq_out.flatten(3)  # (B,T,nhead,C)
     xk_out=xk_out.flatten(3)  # (B,T,nhead,C)
     returnxq_out.type_as(q), xk_out.type_as(q)

这是我们下面介绍的一些变体的基础,所以实现的比较简单。下面我们主要介绍一些变体:

基于旋转矩阵/旋转角度以及如何预先计算cos和sin频率,RoPE有三种变体。为了将模型的上下文长度扩展到预训练的极限之外,还会引入一些方法相关的函数。

线性旋转位置编码

在线性旋转位置编码中,通过引入以下方法相关函数g(m)和h(θ_d)来修改RoPE方程:

其中s为比例因子(扩展上下文长度与原始上下文长度之比),θ_d定义如下,b为底数(10000)

最后将波长(与频率成反比)描述为在维度d上嵌入RoPE以执行完整旋转(2π)所需的token长度。

实现如下:

 defprecompute_freqs_cis_linear(dim: int, end: int, theta: float=10000.0):
     freqs=1.0/ (theta** (torch.arange(0, dim, 2)[: (dim//2)].float() /dim))
     # [: (dim // 2)] for odd number truncation
     t=torch.arange(end, device=freqs.device)
     freqs=torch.outer(t, freqs).float()  # gives diffrent angle vector
     freqs_cos=torch.cos(freqs)  # real
     freqs_sin=torch.sin(freqs)  # imaginary
     
     returnfreqs_cos, freqs_sin

NTK

神经切线核(Neural Tangent Kernel,简称NTK)是一种在深度学习领域中被广泛研究的概念,它提供了一种框架来分析和理解神经网络训练过程中的动态行为。NTK是在无限宽度极限下的神经网络中定义的,即当网络的层宽度趋向于无限大时,网络的行为可以通过一个固定的核函数来描述。

NTK 核贡献在于将传统的神经网络训练过程与核方法联系起来。在无限宽度的假设下,神经网络在初始化后的行为可以被描述为一个线性模型,其权重通过NTK进行更新。这意味着,在这种情况下,神经网络的学习动态可以通过解析形式来精确计算,而这通常在有限宽度的网络中是不可能的。

NTK 感知插值解决了在插值RoPE嵌入时丢失高频信息的问题,通过减少对高频的缩放,增加对低频的缩放,这与将RoPE的每个维度均匀地缩放一个因子s不同,所以只需对θ的值执行基本变化即可完成,代码如下:

 defprecompute_freqs_cis_ntk(dim: int, end: int, theta: float=10000.0, alpha: int=16):
     theta=theta*alpha** (dim/ (dim-2))
     freqs=1.0/ (theta** (torch.arange(0, dim, 2)[: (dim//2)].float() /dim))\
     t=torch.arange(end, device=freqs.device)
     freqs=torch.outer(t, freqs).float()
     freqs_cos=torch.cos(freqs)  # real
     freqs_sin=torch.sin(freqs)  # imaginary
     returnfreqs_cos, freqs_sin

YaRN

YaRN(Yet another RoPE extensioN)是通过一种高效的计算方法来扩展模型的上下文窗口,比以前的方法减少10倍的令牌和2.5倍的训练步骤。它引入了一个ramp函数,并将该函数合并到方法依赖函数中,如下所示:


 defprecompute_freqs_cis_yarn(dim: int, original_max_position_embeddings: int, theta: float=10000.0, scale: int=16, beta_fast:int=32, beta_slow:int=1, mscale: float=0.707,  max_position_embeddings: int=2048):
     pos_freqs=theta** (torch.arange(0, dim, 2)[: (dim//2)].float() /dim)
     inv_freq_extrapolation=1.0/pos_freqs
     inv_freq_interpolation=1.0/ (scale*pos_freqs)
     low=max(math.floor(dim*math.log(original_max_position_embeddings/(beta_fast*2*math.pi)))/(2*math.log(theta)),0)
     high=min(math.ceil(dim*math.log(original_max_position_embeddings/(beta_slow*2*math.pi)))/(2*math.log(theta)),dim-1)
     linear_func= (torch.arange(dim//2, dtype=torch.float32) -low) / (high-low)
     ramp_func=torch.clamp(linear_func, 0, 1).float().to(device=pos_freqs.device)
     inv_freq_mask=1-ramp_func
     inv_freq=inv_freq_interpolation* (1-inv_freq_mask) +inv_freq_extrapolation*inv_freq_mask
     _mscale=float((0.1*math.log(scale) +1.0) *mscale)
     t=torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) # torch.Size([2048])
     freqs=torch.outer(t, inv_freq) # torch.Size([2048, 48]) 
     dtype=torch.get_default_dtype()
     freqs_cos=freqs.cos() *_mscale
     freqs_sin=freqs.sin() *_mscale
     
     returnfreqs_cos, freqs_sin

上下文位置编码(CoPE)

上下文位置编码(Contextual Positional Encoding,简称CoPE)是一种在处理序列数据时用于提高模型性能的技术。这种编码方法在自然语言处理(NLP)和其他需要处理时间序列数据的任务中尤其重要,因为它可以更好地捕获序列中元素的上下文关系。

传统的位置编码(如Transformer中使用的正弦位置编码)通常是静态的,即对于给定的位置,位置编码总是相同的,不考虑序列的具体内容。而上下文位置编码(CoPE)则试图根据序列中的实际内容动态调整位置编码,使编码反映出序列中每个元素的上下文环境。

门控机制

门控决定包含哪些令牌,以便使用它们的上下文向量来计算位置编码,并为每个查询键对计算一个门控值。:

值为1表示标记号在位置计数中被考虑,而值为0表示它被忽略。

计算位置嵌入

要计算位置嵌入,需要添加当前令牌和之前所有令牌之间的门值。每个位置可以表示给定序列中的一个记号/单词/句子号。

为了计算有限的位置,即如果门是稀疏激活的(当计算句子时),可以用更少的位置覆盖序列长度T的整个上下文,并将每个位置夹在最大可能的位置内。

因为添加了sigmoid输出[0,1],得到的每个第i个位置值都是[0,i]内的浮点数。所以位置是不可学习的,不能由嵌入层计算。

位置嵌入的插值

为了克服上述由于位置值浮动而导致的学习嵌入层的限制,会对序列中的每个整数位置分配一个可学习的位置嵌入e[p],第ij个元素的位置嵌入将是由上述计算的分数位置值加权的两个最接近的整数嵌入之间进行简单的插值。

最后通过添加关键向量中的位置嵌入来计算注意力。

CoPE的实现

为了节省内存和计算,q.e[p]矩阵会被预先计算,这样可以进一步进行插值,然后添加到上下文中。插值计算如下:

 classCoPE(nn.Module):
     def__init__(self, npos_max, head_dim):
         super().__init__()
         self.npos_max=npos_max
         self.pos_emb=nn.Parameter(torch.zeros(1, head_dim, npos_max))
     
     defforward(self, query, attn_logits):
         # Compute positions
         gates=torch.sigmoid(attn_logits)
         pos=gates.flip(-1).cumsum(dim=-1).flip(-1)
         pos=pos.clamp(max=self.npos_max-1)
         
         # Interpolate from integer positions
         pos_ceil=pos.ceil().long()
         pos_floor=pos.floor().long()
         
         logits_int=torch.matmul(query, self.pos_emb)
         logits_ceil=logits_int.gather(-1, pos_ceil)
         logits_floor=logits_int.gather(-1, pos_floor)
         
         w=pos-pos_floor
         returnlogits_ceil*w+logits_floor* (1-w)

给定查询矩阵和查询键乘积,CoPE类的前向传播可以返回内插的位置嵌入。下面就是要将它们添加到Attention类中的attn_mtx上下文中。

 classAttention(nn.Module):
     def__init__(self, model_args: MOEConfig):
         super().__init__()
         d_model=model_args.d_model
         self.num_heads=model_args.num_heads
         self.head_dim=model_args.d_model//model_args.num_heads
         self.num_kv_heads= (
             model_args.num_headsifmodel_args.num_kv_heads==0elsemodel_args.num_kv_heads
         )
         assertself.num_heads%self.num_kv_heads==0
         self.num_queries_per_kv=self.num_heads//self.num_kv_heads
         self.cope=CoPE(model_args.seq_len,self.head_dim)
         self.key=nn.Linear(d_model, self.head_dim*self.num_heads)
         self.query=nn.Linear(d_model, self.head_dim*self.num_kv_heads)
         self.value=nn.Linear(d_model, self.head_dim*self.num_kv_heads)
         self.proj=nn.Linear(d_model, d_model, model_args.bias)
         self.attn_dropout=nn.Dropout(model_args.dropout)
         self.res_dropout=nn.Dropout(model_args.dropout)
         self.flash_attn=hasattr(torch.nn.functional, "scaled_dot_product_attention")
     defforward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) ->torch.Tensor:
         batch, seq_len, d_model=x.shape
         k: torch.Tensor  
         q: torch.Tensor  
         v: torch.Tensor
         k=self.key(x)
         q=self.query(x)
         v=self.value(x)
         k=k.view(
             batch, seq_len, self.num_heads, self.head_dim
         )  # shape = (B, seq_len, num_heads, head_dim)
         q=q.view(batch, seq_len, self.num_heads, self.head_dim)
         v=v.view(batch, seq_len, self.num_heads, self.head_dim)
         q, k=apply_rope(q, k, freqs_cis)
         # Grouped Query Attention
         ifself.num_kv_heads!=self.num_heads:
             k=torch.repeat_interleave(k, self.num_queries_per_kv, dim=2)
             v=torch.repeat_interleave(v, self.num_queries_per_kv, dim=2)
         k=k.transpose(1, 2)  # shape = (B, num_heads, seq_len, head_dim)
         q=q.transpose(1, 2)
         v=v.transpose(1, 2)
         attn_mtx=torch.matmul(q, k.transpose(2, 3)) /math.sqrt(self.head_dim)
         attn_mtx=attn_mtx+mask[:, :, :seq_len, :seq_len]
         print("Before:", attn_mtx[0, 0, :3, :3])
         attn_mtx+=self.cope(q,attn_mtx)
         print("AFTER:", attn_mtx[0, 0, :3, :3])
         attn_mtx=F.softmax(attn_mtx.float(), dim=-1).type_as(k)
         attn_mtx=self.attn_dropout(attn_mtx)
         output=torch.matmul(attn_mtx, v)  # (batch, n_head, seq_len, head_dim)
         # restore time as batch dimension and concat heads
         output=output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
         # final projection into the residual stream
         output=self.proj(output)
         output=self.res_dropout(output)
         returnoutput

attn_mtx += self.cope(q, attn_mtx)是将cope嵌入添加到上下文的地方。

CoPE通过引入与序列内容相关的动态位置信息,使模型能更准确地理解和处理语言中的长距离依赖关系,例如在复杂的句子或文档中正确解释词义和句子结构。在处理多样化或特定领域的数据时,CoPE可以通过适应不同的文本特征和结构,提高模型的灵活性和泛化能力。在一些需要高度上下文感知的任务中,如机器翻译、文本摘要或对话系统,CoPE能够显著提升模型的性能。

总结

以下是本文介绍的一些方法的论文,供参考:

https://avoid.overfit.cn/post/91fd4283a7944bebabb6017f5ee285e9

作者:Zain ul Abideen

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930198.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

外贸行业汽车销售配件展示企业网站源码系统 带完整的源代码包以及搭建教程

系统概述 随着全球贸易的不断深化,外贸行业对于高效、专业的网站需求日益凸显。特别是对于汽车销售配件企业而言,一个功能全面、展示效果出色的网站源码系统,无疑是企业开拓海外市场、提升品牌形象的关键。本文将详细介绍一款专为外贸行业汽…

【Linux】文件管理常用命令【超详细】

文章目录 预防rm事故-血的教训😢1. 使用别名:2. 启用回收站:3. 只读文件系统: 一、文件管理1.1 touch-文件创建1.2 rm-文件删除1.3 mkdir-目录创建1.4 rmdir-目录删除1.5 pwd-显示当前目录1.6 cd-切换当前目录1.7 ls-列出文件和目…

鸿蒙语言基础类库:【@system.device (设备信息)】

设备信息 说明: 从API Version 6开始,该接口不再维护,推荐使用新接口[ohos.deviceInfo]进行设备信息查询。本模块首批接口从API version 3开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import dev…

AI绘画Stable Diffusion 零基础入门 —AI 绘画原理与工具介绍,万字解析AI绘画的使用教程

大家好,我是设计师阿威 想要入门 AI 绘画,首先需要了解它的原理是什么样的。 其实很早就已经有人基于深度学习模型展开了对图像生成的研究了,但在那时,生成的图像分辨率和内容都非常抽象。 直到近两年,AI 产出的图像…

[openwrt-21.02]mt7981开启mwan3功能ping出现unreachable 问题分析及解决方案

mwan3 提供以下功能和能力 基于数值权重分配的出站 WAN 流量负载均衡或使用多个 WAN 接口进行故障转移 使用重复测试监控每个 WAN 连接,如果第一个 WAN 接口失去连接,则可以自动将出站流量路由到另一个 WAN 接口 创建出站流量规则以自定义哪些出站连接应使用哪个 WAN 接口(…

白门楼 下 | 第13集 | 曹操口头禅:故戏之耳 | 逐鹿群雄 | 三国演义

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 📌这篇博客分享的是《三国演义》文学剧本第Ⅰ部分《群雄逐鹿》的第13集《白门楼 下》的经典语句和文学剧本全集台词 文章目录 1.经典语句2.文学剧本台词 …

防火墙---带宽管理

防火墙的带宽管理:是指对防火墙设备的带宽进行管理和控制,以确保网络流量的合理分配和优化网络性能 带宽管理:是指限制网络流量的速率或控制网络流量的优先级,以确保网络的性能和可用性 核心: 带宽限制:…

环形数组复习

普通储存数据 接收数据 先要有个 缓存区 通常先建立一个数组 来保存数据 缓存区内存 如何分配和释放 此时 一包数据为 5字节 缓冲区为 17字节 方法一:每次清空缓冲区,重头开始存放数据 第一次 存放在 字节1-5 然后分析读取这次数据 后 先清除B…

2024华为数通HCIP-datacom最新题库(变题更新⑥)

请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 请注意,华为HCIP-Datacom考试831已变题 近期打算考HCIP的朋友注意了,如果你准备去考试,还是用的之前的题库,切记暂缓。 1、…

《0基础》学习Python——第十三讲__面向对象

<类&#xff08;class&#xff09;> 一、面向对象概念 1、面向对象是一种编程思想和技术&#xff0c;它是一种将程序设计问题分解成对象的方式。每个对象都有自己的状态&#xff08;数据&#xff09;和行为&#xff08;方法&#xff09;&#xff0c;并且可以通过相互之间…

AMD software 将两个显示器合并为一个超宽显示器

最近玩游戏的时候&#xff0c;发现了一个骚操作。 可以将两个显示器&#xff08;更多个的自己去试&#xff0c;不知道&#xff09;组合为一个显示器&#xff0c;注意&#xff0c;这里说的不是将两个显示都连接电脑从而使用双屏显示器&#xff0c; 而是 将两个显示器组合为一个…

Logback格式简记

一、常见转换符 时间与日期 %d{pattern}&#xff1a;输出当前日期和时间。例如&#xff0c;%d{yyyy-MM-dd HH:mm:ss.SSS} 会输出 2024-07-11 15:34:55.123。 日志级别 %level 或 %p&#xff1a;输出日志级别&#xff0c;如 INFO, DEBUG, WARN, ERROR。 日志信息 %msg 或 …

【C++报错已解决】 “Undefined Reference“

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 前言 在编译程序时&#xff0c;遇到 “Undefined Reference” 报错总是令人头疼。这个错误提示通常意味着编译器找不到某个符号…

【Linux系统编程】shell命令以及运行原理 Linux权限

目录 一、shell命令以及运行原理 二、Linux权限的概念 2.1创建用户 2.2切换用户 2.3删除用户 三、Linux权限管理 3.1文件访问者的分类&#xff08;人&#xff09; 3.2文件类型和问权限&#xff08;事物属性&#xff09; 3.2.1文件类型 3.2.2基本权限代表的作用 3.…

泛微E-Cology WorkflowServiceXml SQL注入漏洞复现(QVD-2024-26136)

0x01 产品简介 泛微e-cology是一款由泛微网络科技开发的协同管理平台,支持人力资源、财务、行政等多功能管理和移动办公。 0x02 漏洞概述 2024年7月,泛微官方发布了新补丁,修复了一处SQL注入漏洞。经分析,攻击者无需认证即可利用该漏洞,建议受影响的客户尽快修复漏洞。…

mysql的主从复制(含位点复制和GTID复制)的代码实例

提示&#xff1a; master主库ip地址&#xff1a;192.168.137.2 从库s1的ip地址&#xff1a;192.168.137.11 从库s2的ip地址&#xff1a;192.168.137.22 主从复制的原理&#xff1a; MySQL主从复制是一个异步的复制过程&#xff0c;主要是通过二进制日志&#xff08;binary …

百度人脸识别Windows C++离线sdk C#接入

百度人脸识别Windows C离线sdk C#接入 目录 说明 设计背景 • 场景特点&#xff1a; • 客户特点&#xff1a; • 核心需求&#xff1a; SDK 包结构 效果 代码 说明 自己根据SDK封装了动态库&#xff0c;然后C#调用。 功能接口 设计背景 • 场景特点&#xff1a; -…

PTA - 接收n个关键字参数

接收n个以关键字形式传入的参数&#xff0c;按格式输出。 函数接口定义&#xff1a; def print_info (**keyargs) 提示&#xff1a;keyargs为可变参数&#xff0c;其可接受若干个关键字形式的实参值&#xff0c;并将接收到的值组装为一个字典。 裁判测试程序样例&#xff1…

Linux相关命令和安装软件

1.Linux命令 1.1 搜索文件或目录的命令 find 目录 -name "名称" 注意&#xff1a;名称可以使用通配符 *1.2 查看所有进程命令 ps -ef1.3 查看指定内容在文件中 grep "内容" 文件名1.4 管道符 | 1.5 查看端口号 netstat -tunlp | grep 端口号option说明…

框架设计MVVM

重点&#xff1a; 1.viewmodel 包含model 2.view包含viewmodel,通过驱动viewmodel去控制model的数据和业务逻辑 // Test.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 //#include <iostream> #include <vector>using namespace std;#p…