Meta更低的训练成本取得更好的性能: 多token预测(Multi-Token Prediction)

news2024/12/29 8:53:57

Meta提出了一种透过多token预测(Multi-token Prediction)来训练更好、更快的大型语言模型的方法。这篇论文的重点如下:

训练语言模型同时预测多个未来的token,可以提高样本效率(sample efficiency)。
在推论阶段,使用多token预测可以达到最高3倍的加速。

在这里插入图片描述

论文的主要贡献包括:

  • 提出了一种简单的多token预测架构,在训练时间和内存使用上没有额外开销。
    实验证明,这种训练范式在大规模模型(最高达130亿参数)上是有效的,平均可以解决大约15%以上的编程问题
  • 多token预测使得自我推测解碼(self-speculative decoding)成为可能,在各种批次大小下将模型的推论速度提高了最多3倍。

https://arxiv.org/pdf/2404.19737

动机与目的

传统的语言模型通常使用下一个token预测(next-token prediction)的方式进行训练,即根据前面的token序列,预测下一个最可能出现的token。然而,这种训练方式可能导致模型过度关注局部的模式,忽略了长程的依赖关系。为了解决这个问题,本文提出了多token预测(multi-token prediction)的训练方法,同时预测未来的多个token,以提升语言模型的训练效率和性能。

在这里插入图片描述

方法原理

模型架构

语言模型使用一个共享的模型主体(shared model trunk),并在其上添加n个独立的输出头(output head),分别预测未来的n个token。
在训练时,模型在每个位置同时预测未来的n个token,使用n个独立的loss项。
为了减少GPU内存用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。
推论时,可以只用第一个输出头(也就是下一个token的预测),其余输出头可选择性地用于加速推论(称为self-speculative decoding)。

在这里插入图片描述

训练目标

在训练时,模型在每个位置同时预测未来的 n n n个token,使用 n n n个独立的cross-entropy loss项。假设输入的token序列为 x 1 , x 2 , . . . , x t , x_1, x_2, ..., x_t, x1,x2,...,xt,模型的训练目标可以表示为:

L n = − Σ t l o g P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t log P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) Ln=ΣtlogP(xt+1,...,xt+nx1,...,xt)

其中, P ( x t + 1 , . . . , x t + n ∣ x 1 , . . . , x t ) P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t) P(xt+1,...,xt+nx1,...,xt)表示在给定前 t t t个token的条件下,未来 n n n个token的联合概率分布。将这个联合概率分解为 n n n个条件概率的乘积,可以得到:

L n = − Σ t [ l o g P ( x t + 1 ∣ x 1 , . . . , x t ) + l o g P ( x t + 2 ∣ x 1 , . . . , x t ) + . . . + l o g P ( x t + n ∣ x 1 , . . . , x t ) L_n = - Σ_t [log P(x_{t+1} | x_1, ..., x_t) + log P(x_{t+2} | x_1, ..., x_t) + ... + log P(x_{t+n} | x_1, ..., x_t) Ln=Σt[logP(xt+1x1,...,xt)+logP(xt+2x1,...,xt)+...+logP(xt+nx1,...,xt))

每个条件概率 P ( x t + i ∣ x 1 , . . . , x t ) P(x_{t+i} | x_1, ..., x_t) P(xt+ix1,...,xt)由一个独立的输出头计算得到。

训练技巧

为了减少GPU内存的使用量,作者巧妙地调整了前向/反向传播的顺序。模型依序计算每个输出头的前向和反向传播,同时累积主体的梯度,避免同时储存所有n个庞大的logit向量。这种技巧使得多token预测模型的训练几乎不增加额外的计算和存储开销。

在这里插入图片描述

推论过程

在推论阶段,可以只使用第一个输出头(即下一个token的预测),其余输出头可选择性地用于加速推论。这种加速技术称为self-speculative decoding,通过并行计算多个输出头的预测结果,可以提高推论的效率。

实验结果

作者在多个编码和自然语言任务上评估了多token预测模型的性能,并与传统的下一个token预测模型进行了比较。

在这里插入图片描述

编码任务

在HumanEval和MBPP两个编码数据集上,多token预测模型显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。4个token的预测在综合表现上最佳,在HumanEval上pass@100提升了4.1%,在MBPP上pass@1提升了3.8%。此外,训练多个epoch时,多token预测的优势仍然存在。

自然语言任务

在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。在8个摘要数据集上,2个token的预测平均将ROUGE-L提升了0.51,4个token的预测平均提升了0.46。在GSM8K自然语言数学数据集上,2个token的预测模型显著优于基准模型。

字符级训练

在这里插入图片描述

为了验证多token预测有助于学习更长程的依赖关系,作者进行了字符级(byte-level)的训练实验。结果表明,8个字符的多token预测模型在HumanEval上pass@1的表现比下一个字符预测模型高出20%,在MBPP上高出67%。这说明多token预测能够捕捉更长距离的模式和依赖关系。

模型微调

使用预训练的多token预测模型进行微调,也能在下游任务上取得优于基准模型的成果。在CodeContests数据集上,4个token预训练的模型在pass@k上全面超过了下一个token预训练的模型。

在这里插入图片描述

在编码(coding)任务上,多token预测模型在HumanEval和MBPP数据集上的表现显著优于基准模型,尤其在大模型(如13B参数)上提升更加明显。
在自然语言任务上,多token预测也带来了改进,特别是在需要生成较长文本的摘要和自然语言数学任务。
多token预测有助于模型学习更长程的依赖关系。在字符级(byte-level)的训练中,8个字符的多token预测大幅优于下一个字符预测。
实验显示,4个token的预测在综合表现上最佳。此外,训练多个epoch时,多token预测的优势仍然存在。
使用训练好的多token预测模型进行微调(如在CodeContests数据集上),也能取得优于基准模型的成果。
额外的输出头可用于self-speculative decoding,在推论阶段提供最高3倍的加速。

在这里插入图片描述

结论与讨论

本文提出了一种简单而有效的语言模型训练方法——多token预测,通过同时预测未来的多个token,促进模型学习更长程的依赖关系。实验结果表明,这种方法在编码和自然语言任务上带来了显著的性能提升,尤其对大模型和较长文本的生成任务效果更佳。多token预测几乎不增加训练成本,却能提高训练和推论效率,值得进一步探索。

在这里插入图片描述

作者认为,这项工作为寻找更有效的语言模型训练方法开辟了新的方向。未来的研究可以探索以下几个方面:

  1. 在更大规模的数据集和模型上验证多token预测的有效性。
  2. 研究最优的token预测数量n,以及如何自适应地选择n。
  3. 设计更高效的多token预测架构,如使用单一的输出头来预测多个token。
  4. 将多token预测与其他辅助训练目标结合,如掩码语言建模(masked language modeling)。在这里插入图片描述
    在这里插入图片描述

多token预测是一种前景广阔的语言模型训练方法,有望帮助构建更强大、更连贯的语言模型,推动自然语言处理领域的发展。

以下是我对这项工作的一些想法:

Meta最近提出了一种简单而有效的语言模型训练方法—多token预测(Multi-Token Prediction,简称MTP)。传统的语言模型通常每次只预测一个token,而MTP则在每个时间步预测多个token,从而提高训练效率。
核心思想:

在每个时间步,模型预测接下来的n个token,而不是1个
将这n个token打包成一个单独的预测目标,用一个特殊的分隔符隔开
模型的输出是长度为n的token序列,用交叉熵损失函数优化

优点:

预测多个token,捕捉更长距离的依赖,学到更强的上下文表征
并行化程度高,加快训练速度,节省显存
实现简单,几乎不增加模型参数量
在下游任务上finetune,相比传统方法能取得更好的效果

实验结果表明,相比标准的next token prediction,MTP能以更低的训练成本取得更好的性能。比如在相同的计算预算下,MTP的WikiText-103困惑度比传统方法低15%以上。
总之,多token预测是一种简洁而强大的语言模型训练范式。通过预测多个token,它能学到更丰富的上下文信息。同时并行化程度高,训练高效。Meta的这项工作为语言模型的训练提供了新的思路。

多token预测利用了语言的长程依赖关系,通过同时预测多个未来的token,促使模型学习更全面、更连贯的表示。这种方法与人类语言学习的过程更为相似,因为我们在理解和生成语言时,也是基于对未来一段文本的预期,而不仅仅依赖于前一个词。

该方法在编程任务上取得了显著的性能提升,这可能是因为编程语言具有更强的结构性和逻辑性,多token预测更容易捕捉到其中的模式和依赖关系。在自然语言任务上的改进相对较小,可能是因为自然语言的不确定性和灵活性更高,单纯增加预测的token数量效果有限,需要更细致的建模方法。

多token预测在推论阶段带来的加速效果非常可观,这对于实际应用中的延迟敏感场景(如实时对话、同步翻译等)具有重要价值。不过,这种加速方法对模型性能的影响还需要进一步评估,确保生成质量不会显著下降。

论文中的实验主要集中在编程和自然语言文本上,未来可以考虑将多token预测应用于其他类型的序列数据,如时间序列、生物序列等,探索它在更广泛领域的有效性。

多token预测作为一种辅助的训练目标,与其他方法(如对比学习、知识蒸馏等)结合使用,可能会产生更好的协同效果。探索多种训练策略的组合,有望进一步提升语言模型的性能和泛化能力。

我认为这项工作为改进大型语言模型的训练和推理效率提供了一个简单而有效的思路,具有广阔的应用前景。未来可以在更大规模的数据集和模型上验证这种方法的有效性,并探索与其他技术结合的可能性,推动语言模型的进一步发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1655187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

并发问题系统学习(更新中)

进程、线程 进程:进程是代码在数据集合上的一次运行活动,是系统进行资源分配和调度的基本单位。可以理解为一个java应用。 线程:线程是进程的一个执行路径,一个进程中至少有一个线程,进程中的多个线程共享进程的资源。…

CleanMyMac X 4.15.3 版本发布

CleanMyMac X 4.15.3 版本发布,一款苹果 macOS 系统好用的伴侣软件,其包含 1.一键深度清理。2.系统垃圾专清。3.大/旧文件专清。4.系统提速。5.性能悬浮窗。6.恶意软件防护。7.隐私保护。8.软件卸载器。9.软件更新器等 9 大功能,为您的苹果电…

VSCode-vue3.0-安装与配置-export default简单例子

文章目录 1.下载VSCode2.修改语言为中文3.辅助插件列表4.vue3模板文件简单例子5.总结 1.下载VSCode 从官网下载VSCode,并按下一步安装成功。 2.修改语言为中文 点击确认修改,如下图所示: 或者打开命令面板:输入Configure Displ…

如何快速学习VCU电控开发

本课程基于实际项目案例和岗位需求技能制定教学大纲,以任务驱动方式引导学员,让学员快速掌握VCU开发知识。首先从VCU开发必备知识点和MATLAB/Simulink软件建模工具的使用入手,夯实学员基础。再通过策略设计、模型搭建和测试标定来指导学员完成…

关闭vscode保存自动格式化的功能

1 首先打开设置 搜索:editor.formatOnSave 取消勾选框 2 再打开 settings.json 搜索 editor 找到 settings.json 设置: "editor.formatOnSave": false

基于opencv的车辆统计

车辆统计) 一、项目背景二、整体流程三、常用滤波器的特点四、背景减除五、形态学开运算闭运算 六、项目完整代码七、参考资料 一、项目背景 检测并识别视频中来往车辆的数量 最终效果图: 二、整体流程 加载视频图像预处理(去噪、背景减除…

详解typora配置亚马逊云科技Amazon S3图床

欢迎免费试用亚马逊云科技产品:https://mic.anruicloud.com/url/1333 当前有很多不同的博客社区,不同的博客社区使用的编辑器也不尽相同,大概可以分为两种,一种是markdown格式,另外一种是富文本格式。例如华为云开发者…

【项目学习01_2024.05.08_Day06】

学习笔记 5 新增课程5.1 需求分析5.1.1 业务流程5.1.2 数据模型 5.2 接口定义5.3 接口开发5.3.1 保存课程基本信息5.3.2 保存营销信息 5.4 接口测试 5 新增课程 5.1 需求分析 5.1.1 业务流程 5.1.2 数据模型 5.2 接口定义 5.3 接口开发 根据需求分析,新增课程表…

Python中的类和对象的概念理解和创建方法1——基本概念的理解和具体程序实例

Python中的类和对象的概念理解和创建方法1——基本概念的理解和具体程序实例 目录 Python中的类和对象的概念理解和创建方法1——基本概念的理解和具体程序实例一、类和对象的概念二、类和对象的关系2.1 两者辩证关系2.2 两者内部的对应关系 三、类和对象的优势3.1 多态性3.2 封…

添加一个索引要投产,需要哪些步骤?

编程一生 致力于写大家都能看懂的、有深度的 技术文章 05/2024 01 开场白 亚马逊有个bar raiser文化。就是说新招来的人一定要超过之前入职人员的平均水平,宁缺毋滥。越来越多的公司在推行这种文化。在这种氛围下:“虽然我不懂,但是活儿是能出…

Spring自定义配置属性类

以一个minio的配置类为例 首先,由于minio模块被很多微服务需要,因此封装了一个starter,当背的微服务需要的时候就进行引入。 以下是starter模块的结构图 一、spring.factories文件 org.springframework.boot.autoconfigure.EnableAutoConf…

【管理篇】如何管理情绪?

目录标题 为什么要特别关注激动和愤怒两种情绪呢?管理自己的情绪大致的步骤三层脑结构爬行脑情绪脑视觉脑 大家说的情绪管理,基本上都是对于情绪激动、生气甚至是愤怒的管理;日常所说的情绪化,一般也是指某个人特别容易情绪激动&a…

Java | Leetcode Java题解之第78题子集

题目&#xff1a; 题解&#xff1a; class Solution {List<Integer> t new ArrayList<Integer>();List<List<Integer>> ans new ArrayList<List<Integer>>();public List<List<Integer>> subsets(int[] nums) {dfs(0, nums…

Ansible--Templates 模块 Tags模块 Roles模块

一 Templates 模块 ①Jinja是基于Python的模板引擎。Template类是Jinja的一个重要组件&#xff0c;可看作一个编译过的模 板文件&#xff0c;用来产生目标文本&#xff0c;传递Python的变量给模板去替换模板中的标记。 ②在配置文件中&#xff0c;会有一些数据&#xff08;如…

YOLOv8改进 | 独家创新篇 | 利用MobileNetV4的UIB模块二次创新C2f(全网独家首发)

一、本文介绍 本文给大家带来的改进机制是利用MobileNetV4的UIB模块二次创新C2f&#xff0c;其中UIB模块来自2024.5月发布的MobileNetV4网络&#xff0c;其是一种高度优化的神经网络架构&#xff0c;专为移动设备设计。它最新的改动总结主要有两点&#xff0c;采用了通用反向瓶…

rust打包编译为mac或者linux可执行文件,发送到别的电脑不能运行

如果使用rust项目编译为linux或者mac可执行文件&#xff0c;发送到别的电脑之后&#xff0c;不可以直接运行&#xff0c;而是显示一个空白文件&#xff0c;双击也没有反应&#xff0c;其实这是因为这个文件没有可执行权限导致的&#xff0c;添加可执行权限就可以了&#xff1a;…

沙盘Sandboxie v5.56.4

菜鸟高手裸奔工具沙盘Sandboxie是一款国外著名的系统安全工具&#xff0c;它可以让选定程序在安全的隔离环境下运行&#xff0c; 只要在此环境中运行的软件&#xff0c;浏览器或注册表信息等都可以完整的进行清空&#xff0c;不留一点痕迹。同时可以防御些 带有木马或者病毒的…

24证券从业资格报名照片要求✅如何上传?

✨24证券从业报名今天下午3点开始喽&#xff01; 话说&#xff0c;每次都有人证券报名照片不符合规格导致报名不通过&#xff0c;建议大家提前了解一下注意事项和要求&#xff01; 之前考过还需要上传照片吗&#xff1f; ✅老考生之前传过照片不用上传了。 ✅首次注册过但没有考…

EMAP的Root工程及其他工具

首先右击项目导航&#xff0c;新建EMAP系统配置 上方辅助工具功能&#xff1a; 1 2 3 4 5 6 7 8 9 10 查看重复数据模型:显示为放大镜标识&#xff0c;可以显示所有应用中相同…

mysql oceanbase数据库alter语句阻塞,解决方案

获取当前阻塞事件 select d.trx_started, a.thread_id, b.processlist_id, a.SQL_text from performance_schema.events_statements_current ajoin performance_schema.threads b on a.thread_id b.thread_idjoin information_schema.processlist c on b.processlist_id c.i…