聊聊拉长LLaMA的一些经验

news2024/12/25 20:24:06

Sequence Length是指LLM能够处理的文本的最大长度,越长,自然越有优势:

  1. 更强的记忆性。更多轮的历史对话被拼接到对话中,减少出现遗忘现象

  2. 长文本场景下体验更佳。比如文档问答、小说续写等

当今开源LLM中的当红炸子鸡——LLaMA,第一版上下文长度是2048,第二版长度是4096。相比之下ChatGPT、GPT4已经支持到16k,Claude甚至支持到了100k。足以见得将LLaMA拉长是如此的任重而道远。本文将会介绍三种在旋转位置编码(RoPE)基础上扩充上下文的高性价比方案,在文末会介绍我的实践经验。

线性插值法

Kaiokendev的博客[1]中提到了方法,和Meta的一篇工作[2]不谋而合,其思想主要是将目标长度压缩到原始长度。如下图所示,LLaMA-1预训练的长度为2048,如果我们想把它拉长到4096:

  • 方法一:推理时直接拉长到4096。这考虑位置编码的外推性(即在短文本上训练,长文本上推理的能力[2]),而RoPE的外推性则是相当一般[2]。由于训练时长度都是小于2048的,超过2048部分Attention分数会飙升,导致困惑度急剧上升。

  • 方法二:在原始模型基础上做长度为4096的继续训练。这里先岔开介绍另一款模型——MPT-30B的做法,根据官方博客[3]的介绍:

    As mentioned earlier, MPT-30B was trained with a long context window of 8k tokens (vs. 2k for LLaMa and Falcon) and can handle arbitrarily long context windows via ALiBi or with fine-tuning. To build 8k support into MPT-30B efficiently, we first pre-trained on 1T tokens using sequences that were 2k tokens long, and continued training for an additional 50B tokens using sequences that were 8k tokens long.

    MPT-30B采用ALiBi位置编码(外推性优于RoPE),在2k的长度进行1T token的训练,然后在8k长度上进行50B token的预训练——这是在外推性强于RoPE的ALiBi上的情况。LLaMA-1预训练的token数是1T以上,想要在长度为4096样本上效果不下降,那需要训练足够多的token数才行,这就需要较大的成本了。

  • 方法三:另一种思路则是将4096的位置编码通过线性插值法压缩到2048内,这样只需要在少量的长度为4096的数据上进行继续预训练,便可达到不错的效果。

来自论文[2]

来自论文[2]

代码实现

线性插值法的实现代码相当的简单,这需要在原始RoPE上进行微小的改动,即加上下图的scale参数。

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

效果

Meta的工作[2]中进行了充足实验和公式推导证明,如果想看具体的代码,建议看lmsys.org(Vicuna的出品方)的一篇工作[4]:How Long Can Open-Source LLMs Truly Promise on Context Length? 他们对比了商用模型、号称支持长文本的开源模型和“Vicuna+线性插值法”的效果,并给出了几个结论:

  1. 商用模型在长文本的效果上很能打!而那些号称支持长文本的开源模型,在长文本上则表现不佳。

  2. 随着文本长度的增加,越接近边界,Vicuna+线性插值法的效果降低越明显。这可能是因为训练数据存在短文本的情况。

来自[4]

来自[4]

来自[4]

来自[4]

写这篇文章的同时,ChatGLM团队更新了ChatGLM2-6B-32K,也是使用了插值法。同时推出了长文本的中英评测集LongBench,在这个评测集上ChatGLM2-6B-32K展示了强大的实力,但值得注意的是,该评测集的评测方式是使用ChatGLM2-6B来进行评估的。

NTK插值法

NTK插值法的提出于一篇Reddit帖子[5],它提出使用Neural Tangent Kernel (NTK)来解决这个问题。

if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of ~600x)

Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)

帖子中作者还用了时钟的例子来解释线性插值和NTK插值的异同:

RoPE behaves like a clock. Your 12 hours wall clock is basically a RoPE of dimension 3 with a base of 60. So for each second, the minute hand turns 1/60th of a minute, and for each minute, the hour hand turns 1/60th.

Now if you slowed down time by a factor of 4x, that is a linear RoPE scaling used in SuperHOT. Unfortunately now it is really really hard to distinguish each second, because now the seconds hand barely moves each second. So if someone gave you two different times, which is only different by a single second, you won't be able to distinguish them from afar (let's say the NNs have myopia because that's basically what NTK predicts)

Now NTK-Aware RoPE scaling does not slow down the seconds. One second is still one second, but it slows down minutes by a factor of let's say 1.5, and the hours by a factor of 2. This way you can fit 90 minutes in a hour, and fit 24 hours in half a day. So now you basically have a clock that can measure 129.6k seconds instead of 43.2k seconds.

Because you don't need a precise measurement of the hour hand when looking at the time, scaling the hours more compared to seconds is crucial. You don't want to lose the precision of the seconds hand, but you can afford to lose precision on the minutes hand and even more on the hours hand.

Then, it's just a matter of deriving the base change formula in order to obtain such a scaling. (where less precise dimensions are scaled more and more)

代码实现

NTK的实现则更加简单了,根据超参数alpha,对应修改base变量即可:

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

效果

在效果上,帖子中也给出了NTK插值法和线性插值法的PPL比较,可以看到,在二者都不做Finetune的情况下,NTK插值法具备更低的PPL。

来自[5]

来自[5]

动态插值法

动态插值法同样出自于一篇Reddit帖子[6],它的出发点很简单:

My idea was to use the exact position values for the first 2k context (after all, why mess with a good thing?) and then re-calculate the position vector for every new sequence length as the model generates token by token.

这种做法可以和先前的两种方法相结合,[7]中也给出了详细的代码实现。

效果

作者和前两种方法做了对比,展示了动态插值法在PPL下降上的优势。

实践经验

我在实践的过程中,评估效果主要使用longChat[4]中使用的评估方式,以下是一些takeaway tips,欢迎大家一起交流。

  1. 线性插值法具备完整的理论支持和大量的实验证明,在我的实践中,“线性插值法+Finetune”取得了最佳效果。

  2. NTK插值法的实验中,对比的是不做Finetune的情况,在我的实践中,“NTK插值+Finetune”效果会明显优于单独的“NTK插值”,但它的收敛速度会慢于“线性插值法+Finetune”。

  3. 动态插值法的实验同样是在不做Finetune的情况对比的,目前为止我并没有尝试过这种方法。在Reddit的评论区有人提出一个很好的问题:如果采取这种方法,逐token推理时,文本的长度是在变化的,则导致无法使用kv-cache,这会对性能产生很大的影响。

最后,拉长LLaMA的方案可以不从RoPE入手(如:LongLLaMA[8]),但“线性插值法+Finetune”无疑是一种性价比很高的方案,推荐大家尝试!

——2023.07.31

Reference

[1] Extending Context is Hard…but not Impossible†

[2] EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

[3] MPT-30B: Raising the bar for open-source foundation models

[4] How Long Can Open-Source LLMs Truly Promise on Context Length?

[5] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.

[6] Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

[7] GitHub - jquesnelle/scaled-rope

[8] Focused Transformer: Contrastive Training for Context Scaling

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/817258.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开放麒麟1.0发布一个月后,到底怎么样?另一款操作系统引发热议

具有里程碑意义 7月5日,国产首个开源桌面操作系统“开放麒麟1.0”正式发布。 标志着我国拥有了操作系统组件自主选型、操作系统独立构建的能力,填补了我国在这一领域的空白。 举国欢庆,算的上是里程碑意义了! 发布后用着如何&a…

Linux系统下U盘打不开: No application is registered as handling this file

简述 系统是之前就安装好使用的Ubuntu14.04,不过由于某些原因只安装到了机械硬盘中;最近新买了一块固态硬盘,所以打算把Ubuntu系统迁移到新的固态硬盘上; 当成功的迁移了系统之后发现其引导有点问题,导致多个系统启动不…

所有流的知识都有,IO流原理及流的分类

1、Java IO流原理 I/O是Input/Output的缩写, I/O技术是非常实用的技术,用于处理设备之间的数据传输。如读/写文件,网络通讯等。 Java程序中,对于数据的输入/输出操作以”流(stream)” 的方式进行。java.io包下提供了各种“流”类…

C++语法(27)--- 类型转换和C++线程库

C语法(26)--- 特殊类设计_哈里沃克的博客-CSDN博客https://blog.csdn.net/m0_63488627/article/details/131879800?spm1001.2014.3001.5501 目录 1.类型转换 1.C语言的转换模式 2.C四种类型转换 1.static_cast 2.reinterpret_cast 3.const_cast …

ALLEGRO之Logic

本文主要讲述ALLEGRO的Logic菜单。 (1)Net Logic:暂不清楚; (2)Net Schedule:暂不清楚; (3)AssignDifferential Pair:暂不清楚; &a…

OR-Tool 报INFEASIBLE

OR-Tool 使用Minimum Cost Flows报 There was an issue with the min cost flow input. Status: Status.INFEASIBLE 这是因为node的编号需要是连续的,比如下面这样不行 修改为连续的

【已解决】如果将MySQL数据库中的表生成PDM

数据库表PDM关系图 | 原创作者/编辑:凯哥Java | 分类:经验分享 有时候,我们需要MySQL数据库中的表生成对应的PDM文件,这里凯哥就讲讲第一种将MySQL数据库的表生成对应的PDM文件。 环境准备: MySQL数据库连接客户端&…

中文多模态医学大模型智能分析X光片,实现影像诊断,完成医生问诊多轮对话

项目设计集合(人工智能方向):助力新人快速实战掌握技能、自主完成项目设计升级,提升自身的硬实力(不仅限NLP、知识图谱、计算机视觉等领域):汇总有意义的项目设计集合,助力新人快速实…

费舍尔线性分辩分析(Fisher‘s Linear Discriminant Analysis, FLDA)

费舍尔线性分辩分析(Fisher’s Linear Discriminant Analysis, FLDA) 目录 费舍尔线性分辩分析(Fishers Linear Discriminant Analysis, FLDA)1. 问题描述2. 二分类情况3. 多分类情况4. 代码实现4.1 二分类情况4.2 多分类情况 5. 参考资料 1. 问题描述 为解决两个或多个类别的…

ROS-PyQt小案例

前言:目前还在学习ROS无人机框架中,,, 更多更新文章详见我的个人博客主页【前往】 ROS与PyQt5结合的小demo,用于学习如何设计一个界面,并与ROS中的Service和Topic结合,从而控制多个小乌龟的运动…

从零开始搭建Vue3框架(二):Vue-Router4.0使用与配置

前言 上篇文章我们创建了模板项目并成功运行,但是运行后的页面只是一个静态页面,并没有页面间跳转。 对于Vue这种单页应用来说,最要紧的就是控制整个系统的页面路由。因为我们使用Vue3的框架,所以这里使用Vue-Router4.0版本。 …

1992-2021年全国及31省对外开放度测算数据含原始数据和计算过程(无缺失)

1992-2021年全国及31省对外开放度测算数据含原始数据和计算过程(无缺失) 1、时间:1992-2021年 2、范围:全国及31省 3、指标:进出口总额、国内生产总值、年均汇率 4、计算方法:对外开放度进出口总额/GDP…

【Git系列】Git配置SSH免密登录

🐳Git配置SSH免密登录 🧊1.设置用户名和邮箱🧊2. 生成密钥🧊3.远程仓库配置密钥🧊2. 免密登录 在以上push操作过程中,我们第一次push时,是需要进行录入用户名和密码的,比较麻烦。而且…

【数据分析专栏之Python篇】四、pandas介绍

前言 在上一篇中我们安装和使用了Numpy。本期我们来学习使用 核心数据分析支持库 Pandas。 一、pandas概述 1.1 pandas 简介 Pandas 是 Python 的 核心数据分析支持库,提供了快速、灵活、明确的数据结构,旨在简单、直观地处理关系型、标记型数据。 …

Resnet与Pytorch花图像分类

1、介绍 1.1数据集介绍 flower_data├── train│ └── 1-102(102个文件夹)│ └── XXX.jpg(每个文件夹含若干张图像)├── valid│ └── 1-102(102个文件夹)└── ─── └── XXX.jp…

如何使用免费敏捷工具Leangoo领歌管理Sprint Backlog

什么是Sprint Backlog? Sprint Backlog是Scrum的主要工件之一。在Scrum中,团队按照迭代的方式工作,每个迭代称为一个Sprint。在Sprint开始之前,PO会准备好产品Backlog,准备好的产品Backlog应该是经过梳理、估算和优先…

ffmpeg安装

简介 FFmpeg是一个开源的音视频处理库,它提供了一系列的工具和API,可以用于处理音视频文件。你可以使用FFmpeg的命令行工具来执行各种音视频处理操作,比如转码、剪辑、合并等。FFmpeg的命令格式通常是:ffmpeg [全局选项] {[输入文…

章节5:SQL注入之WAF绕过

章节5:SQL注入之WAF绕过 5.1 SQL注入之WAF绕过上 WAF拦截原理:WAF从规则库中匹配敏感字符进行拦截。 5.2 SQL注入之WAF绕过下 (原理简单了解) 关键词大小写绕过 有的WAF因为规则设计的问题,只匹配纯大写或纯小写的…

B. Binary Cafe(二进制的妙用)

题目:Problem - B - Codeforces 总结: 对于该题最简单的方法为使用二进制的数表示状态 例如: 对于一个数7的二进制:111 它的每一位都可表示两种状态我们可以理解为取或者不取 对于7这个数字它可以表示一种状态即在三个位置都…

使用Roles模块搭建LNMP架构

使用Roles模块搭建LNMP架构 1.Ansible-playbook中部署Nginx角色2.Ansible-playbook中部署PHP角色3.Ansible-playbook中部署MySQL角色4.启动安装分布式LNMP 1.Ansible-playbook中部署Nginx角色 创建nginx角色所需要的工作目录; mkdir -p /etc/ansible/playbook/rol…