大模型上下文长度的超强扩展:从LongLora到LongQLora

news2024/11/26 21:28:03

前言

本文一开始是《七月论文审稿GPT第2版:从Meta Nougat、GPT4审稿到Mistral、LongLora Llama》中4.3节的内容,但考虑到

  • 一方面,LongLora的实用性较高
  • 二方面,为了把LongLora和LongQLora更好的写清楚,而不至于受篇幅之限制
  • 三方面,独立成文可以有更好的排版,而更好的排版可以有更高的可读性(哪怕一个小小的换行都能提高可读性,更何况独立成文带来的可读性的提高)

故把这部分的内容抽取出来独立成本文

第一部分 LongLora

具体而言,LongLora是港中文和MIT的研究者通过此篇论文《LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models》于23年9月底提出的(这是其GitHub),其显著特点有三

  1. longlora的作者团队认为:尽管在推理过程中需要密集的全局注意力,但通过稀疏局部注意力(sparse local attention)也可以高效地完成模型的微调,比如他们提出的移位稀疏注意力(shifted sparse attention,简称S2-Attn)可有效地实现上下文扩展且显著节省计算资源,具有与使用vanilla注意力(vanilla attention)进行微调相似的性能

    简言之,用sparse local attention替换掉dense global attention,类似检索,不需要把所有的东西都拿过来,把相似度高的,匹配度高的一部分context拿来就可以了
  2. 他们发现,LoRA加到embedding matrix以及normalization的子网络上的时候,效果更好
    啥意思?这点在于常规操作是lora一般加到query, key, value等部分上,而这里是加到embedding matrix上,以及normaliztion上了
  3. LongLoRA在保留原始架构的同时扩展了模型的上下文,并且与大多数现有技术(如Flash-Attention2)兼容
    此外,还进一步发布了使用LongLoRA技术的长指令遵循数据集LongAlpaca,以进行监督微调(we further conduct supervised fine-tuning with LongLoRA and our long instruction-following LongAlpaca dataset)

1.1 LoRA在长文本上的不足

通过本博客内的多篇文章可知,原始transformer的计算复杂度虽序列长度的二次方成正比,这一点一直导致模型的长下文长度不好扩展(比如把长度从2048扩展到8192,复杂度得上升4x4 = 16倍),对于该问题 很多研究者或团队做了各种改进与探索

  • 比如Flash-Attention、Flash-Attention2(详见此文《通透理解FlashAttention与FlashAttention2:让大模型上下文长度突破32K的技术之一)
  • 再比如Position Interpolation (详见此文《大模型上下文扩展之YaRN解析:从直接外推ALiBi、位置插值、NTK-aware插值、YaRN》的2.3节) spent 32 A100 GPUs to extend LLaMA models from 2k to 8k context,当然了,这种资源开销即便是七月项目团队也不一定舍得耗(其实,我司项目团队一直在“低成本 高效果”的方向上探索,过程中积攒了这方面的很多经验),更别说一般个人了

如何降低资源开销呢?一种直接的方法是通过LoRA对预训练的LLM进行微调

  • 对于预训练的权重矩阵W∈Rd×k,它通过低秩分解W +∆W = W + BA进行更新,其中B∈Rd×r和A∈Rr×k。秩r≪min(d, k),在训练过程中,W被冻结,没有梯度更新,而A和B是可训练的(关于LoRA的更多说明,详见此文《LLM高效参数微调方法:从Prefix Tuning、Prompt Tuning、P-Tuning V1/V2到LoRA、QLoRA(含对模型量化的解释)》的第4部分)

    For a pre-trained weight matrix W ∈ R d×k , it is updated with a low-rank decomposition W + ∆W = W + BA, where B ∈ R d×r and A ∈ R r×k .

    The rank r ≪ min(d, k). During training, W is frozen with no gradient updates, while A and B are trainable. This is the reason why LoRA training is much more efficient than full fine-tuning.

  • 在Transformer结构中,LoRA只关注权重(Wq、Wk、Wv、Wo),而冻结所有其他层,包括MLP层和归一化层

    In the Transformer structure, LoRA only adapts the attention weights (Wq, Wk, Wv, Wo) and freezes all other layers, including MLP and normalization layers

LoRA利用低秩矩阵对自注意块中的线性投影层进行修改,从而减少了可训练参数的数量(LoRA modifies the linear projection layers in self-attention blocks by utilizing low-rank matrices, which are generally efficient and reduce the number of trainable parameters)

  1. 然而,单纯的低秩自适应会导致长上下文扩展的困惑度(perplexityin,简称PPL)很高,如下表所示,且即便将秩增加到一个更高的值,例如rank = 256,也并不能缓解这个问题
    那咋办呢?让embedding层和Norm层也添加LoRA训练之后,困惑度PPL可以显著降低

  2. 在效率方面,无论是否采用LoRA,计算成本都会随着上下文规模的扩大而急剧增加,这主要是由于标准的自注意机制所导致的(Vaswani et al., 2017)。如下图所示,即便使用LoRA,当上下文窗口扩展时,Llama2模型的训练时间也会大大增加

    为此,他们提出shifted sparse attention(S2-Attn)以替代标准自注意力机制

1.2  shifted sparse attention(S2-Attn)

1.2.1 S2-Attn的原理解释

如下图所示

  1. 将上下文长度分成几个组,并在每个组中单独计算注意力。在半注意力头中,将token按半组大小进行移位,这保证了相邻组之间的信息流动(In half attention heads, we shift the tokens by half group size, which ensures the information flow between neighboring groups)
  2. 例如,使用组大小为2048的S2-Attn来近似总共8192个上下文长度训练,这与Swin Transformer具有高度的相似(详见此文《AI绘画能力的起源:从VAE、扩散模型DDPM、DETR到ViT/Swin transformer》的第五部分)

上面的描述还是不够形象具体,那到底怎么理解这个S2-Attn呢?如下图所示(值得一提的是,这个图是论文v2版的,和论文v1版稍有细微差别,当然 不影响本质)

  1. 首先,它将沿头部维度的特征分成两大块(即it splits features along the head dimension into two chunks,比如8行4列,8行相当于8个token,4列可以认为是有4个头,然后竖着一切为二)

    相当于[L, H, D], L=token num=8, H=head num=4, D=dimension of expression=1(可暂且认为是1了,毕竟一个方块,算是长度为1的一个向量)
    执行完操作之后是:[L, H, D] -> [L, H/2, D] and [L, H/2, D],即被竖着切成了左右两个part
  2. 其次,其中一个块中的标记被移动组大小的一半(tokens in one of the chunks are shifted by half of the group size)
    如上图step 2的shift所示,shift the 2^{nd} part by half group,相当于
    \rightarrow  第2个part的第8个token的后一半表示(也即原始inputs第8个token的后两个heads)移动到第2个part的第1行
    \rightarrow  而第2个part中原来的「第1-7个token的后一半表示」整体往下移动一行
  3. 第三,将token分组并重塑为批量维度,注意力只在每个组内计算,信息通过移位在不同组之间流动。虽然移位可能会引入潜在的信息泄漏,但这可以通过对注意力掩码进行微调来避免
    Third, we split tokens into groups and reshape them into batch dimensions. Attention only computes in each group in ours while the information flows between groups via shifting. Potential information leakage might be introduced by shifting, while this is easy to prevent via a small modification on the attention mask.

    相当于把两个part连起来后,然后横着切三刀切成了4个group,每个group有8个小方块
    第一个group相当于包含:第一part的前两行,和第二part中更新之后的前两行
    然后计算该group内的注意力,类似于做了“cross-over”,正因为只是计算group内部的几个tokens之间的attention,所以称之为short attention

为方便大家更快的理解,特再补充两点

  1. 为形象起见,举个例子,假定这8个单词是i am learning Machine Learning by julyedu online,然后上述过程可用下表表示
    i 前一半(表示)i 后一半(表示)i 前一半online 后一半:line
    am 前一半am 后一半am 前一半i 后一半
    learning 前一半learning 后一半learning 前一半am 后一半
    Machine 前一半Machine 后一半Machine 前一半learning 后一半
    Learning 前一半Learning 后一半Learning 前一半Machine 后一半
    by 前一半by 后一半by 前一半Learning 后一半
    julyedu 前一半julyedu 后一半julyedu 前一半by 后一半
    online 前一半online 后一半online 前一半:onjulyedu 后一半
  2. 针对上面那个S2-Attn示意图
    该图的左边部分 上文已经解释的很清楚了,那右侧的两个图呢?
    咋一看,比较抽象,其实仔细琢磨之后,右侧的两个图描述的注意力范围,pattern2相对于pattern1的注意力窗口是“移位”了的

    具体到某个token来观察会清楚一点,除了“pattern1中q1”和“pattern2中q1”的注意力范围是一致 都是k1之外
    pattern1中q2的注意力范围是[k1,k2],pattern2中q2的注意力范围变成了仅[k2];
    pattern1中q3的注意力范围仅是[k3],pattern2中q3的注意力范围变成了[k2,k3];
    pattern1中q4的注意力范围是[k3,k4],pattern2中q4的注意力范围变成了仅[k4];
    pattern1中q5的注意力范围是仅[k5],pattern2中q5的注意力范围变成了[k4,k5];
    ...
    两个pattern从最开始的token注意力范围就是错位的,所以后续token注意力范围就一直是错开的,这样错开的形式使得两个pattern聚合起来就可以让组外信息有机会产生交互

1.2.2 S2-Attn的伪代码表示

如下图所示

  1. 第一步,B=batch size, N=sequence length, 3=q,k,v,H=head num,D=每个head的表示维度
    例如:qkv=[1, 4, 3, 4, 1]
    即batch size=1,一共一个序列;4=4个tokens,3=q,k,v,4=head num,1=dim of a head
    1head2head3head4head1head2head34
    21
    32
    43
  2. qkv.chunk(2, 3),得到的是一个tuple,包括两个张量,[1, 4, 3, 2, 1]左边的part,以及[1, 4, 3, 2, 1]是右边的part
    qkv.chunk(2, 3)[0],即左边的包括两个heads的part
    qkv.chunk(2,3)[1], 即右边的包括两个heads的part,这里是对其shift 1个token了
  3. 接下来,按照group分别计算group内的tokens的注意力
  4. 最后,复原

1.2.3 LongAlpaca-13B

在llama 13B上应用longlora技术,便是LongAlpaca-13B

第二部分 LongQLora

// 待更

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1354411.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[数据结构 C++] AVL树的模拟实现

文章目录 1、AVL树1.1 AVL树的概念 2、AVL树节点的定义3、AVL树的插入和旋转3.1 左单旋左旋代码实现 3.2 右单旋右旋代码实现 3.3 右左双旋右左双旋的代码实现 3.4 左右双旋左右双旋的代码实现 3.5 insert接口实现 4、判断是否为AVL树判断AVL树的代码实现 5、AVL树的性能 问题引…

Maven Resources Compiler: Maven project configuration required for module

Maven Resources Compiler: Maven project configuration required for module ‘cc-pdf’ isn’t available. Compilation of Maven projects is supported only if external build is started from an IDE. 报错原因是,我在git建立一个新空仓库,然后把…

C++线程池的原理(画图)及简单实现+例子(加深理解)

1.为什么线程池会出现,解决什么问题? C线程池(ThreadPool)的出现主要是为了解决以下几个问题: 1.性能:创建和销毁线程都是相对昂贵的操作,特别是在高并发场景下,频繁地创建和销毁线…

Linux下误删除后的恢复操作测试之extundelete工具使用

一、工具介绍 extundelete命令的功能可用于系统删除文件的恢复。在使用前,需要先将要恢复的分区卸载,以防数据被意外覆盖。 语法格式:extundelete [参数] 文件或目录名 常用参数: --after 只恢复指定时间后被删除的文件 --bef…

Linux学习(9)——RAID与服务器的常见故障

目录 一、服务器常见故障 1、系统不停重启进入不了系统 2、卡在开机界面右下角有fA B2 H8 3、系统安装不上 4、如何进入服务器的bios 5、一般进入阵列卡的快捷键 6.网络不通 7.硬盘不识别 二、RAID相关知识 1、RAID的概念 2、RAID功能实现 3、RAID实现的方式 三、…

机器学习笔记 - 偏最小二乘回归 (PLSR)

一、偏最小二乘回归:简介 PLS 方法构成了一个非常大的方法族。虽然回归方法可能是最流行的 PLS 技术,但它绝不是唯一的一种。即使在 PLSR 中,也有多种不同的算法可以获得解决方案。PLS 回归主要由斯堪的纳维亚化学计量学家 Svante Wold 和 Harald Martens 在 20 世纪 80 年代…

海外服务器2核2G/4G/8G和4核8G配置16M公网带宽优惠价格表

腾讯云海外服务器租用优惠价格表,2核2G10M带宽、2核4G12M、2核8G14M、4核8G16M配置可选,可以选择Linux操作系统或Linux系统,相比较Linux服务器价格要更优惠一些,腾讯云服务器网txyfwq.com分享腾讯云国外服务器租用配置报价&#x…

ByteTrack算法流程的简单示例

ByteTrack ByteTrack算法是将t帧检测出来的检测框集合 D t {\mathcal{D}_{t}} Dt​ 和t-1帧预测轨迹集合 T ~ t − 1 {\tilde{T}_{t-1}} T~t−1​ 进行匹配关联得到t帧的轨迹集合 T t {T_{t}} Tt​。 首先使用检测器检测t帧的图像得到检测框集合 D t {\mathcal{D}_{t}} …

手机技巧:分享10个vivo手机实用小技巧技巧,值得收藏

目录 1. 快速切换应用 2、智能助手Jovi 3. 轻按唤醒屏幕 4. 快速启动相机 5. 分屏功能 6. 手势操作 7. 一键清理 8.忘记密码 9.玩游戏耗电快 10.手机丢失后该怎么办 1. 快速切换应用 向右或向左滑动底部的虚拟按键即可。 2、智能助手Jovi vivo手机自带智能助手Jovi…

【Java EE初阶八】多线程案例(计时器模型)

1. java标准库的计时器 1.1 关于计时器 计时器类似闹钟,有定时的功能,其主要是到时间就会执行某一操作,即可以指定时间,去执行某一逻辑(某一代码)。 1.2 计时器的简单介绍 在java标准库中,提供…

CMake入门教程【核心篇】添加应用程序(add_executable)

😈「CSDN主页」:传送门 😈「Bilibil首页」:传送门 😈「本文的内容」:CMake入门教程 😈「动动你的小手」:点赞👍收藏⭐️评论📝 文章目录 1. 概述2. 使用方法2…

【计算机视觉】常用图像数据集

图像数据集 模型需要好的数据才能训练出结果,本文总结了机器学习图像方面常用数据集。 MNIST 机器学习入门的标准数据集(Hello World!),10个类别,0-9 手写数字。包含了60,000 张 28x28 的二值训练图像,10…

计算机网络(2)

计算机网络(2) 小程一言专栏链接: [link](http://t.csdnimg.cn/ZUTXU) 计算机网络和因特网(2)分组交换网中的时延、丢包和吞吐量时延丢包吞吐量总结 协议层次及其服务模型模型类型OSI模型分析TCP/IP模型分析 追溯历史 小程一言 我…

Graphics Control

Graphics Control提供了一个易于使用的图形设置管理解决方案,帮助您加快开发。它附带了一个常用设置库,如分辨率、垂直同步、全屏模式、光晕、颗粒、环境光遮挡等。我们的可自定义设置面板UI预制件为您提供了一个可用的UI面板,支持完整的游戏手柄和键盘输入。图形控制还附带…

【前沿技术杂谈:ChatGPT】ChatGPT——热潮背后的反思

【前沿技术杂谈:ChatGPT】ChatGPT——热潮背后的反思 缘起:无中生有,涅槃重生人工智能技术人工智能的发展史无中生有内容自动生成技术的发展代表企业OpenAI-GPT系列技术的发展历程ChatGPT新特点 热潮:万众瞩目,群雄逐鹿…

Unity | Shader基础知识番外(向量数学知识速成)

目录 一、向量定义 二、计算向量 三、向量的加法(连续行走) 四、向量的长度 五、单位向量 六、向量的点积 1 计算 2 作用 七、向量的叉乘 1 承上启下 2 叉乘结论 3 叉乘的计算(这里看不懂就百度叉乘计算) 八、欢迎收…

Vue3地图选点组件

Vue3地图选点组件 <template><div style"width: 100%; height: 500px"><div class"search-container"><el-autocompletev-model"suggestionKeyWord"class"search-container__input"clearable:fetch-suggestion…

【已解决】You have an error in your SQL syntax

报错讯息 java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near ‘desc,target_url,sort,status,create_by,modify_by,created,last_update_time FROM…

图像分割 分水岭法 watershed

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 本文的C#版本请访问&#xff1a;图像分割 分水岭法 watershed&#xff08;C#&#xff09;-CSDN博客 Watershed算法是一种图像处理算…

SSM的校园二手交易平台----计算机毕业设计

项目介绍 本次设计的是一个校园二手交易平台&#xff08;C2C&#xff09;&#xff0c;C2C指个人与个人之间的电子商务&#xff0c;买家可以查看所有卖家发布的商品&#xff0c;并且根据分类进行商品过滤&#xff0c;也可以根据站内搜索引擎进行商品的查询&#xff0c;并且与卖…