Lory: 推进大型语言模型训练的新篇章

news2024/12/23 18:41:03

人工智能咨询培训老师叶梓 转载标明出处

随着模型规模的增长,如何有效训练并利用这些模型成为了一个挑战。陈丹琦团队一项新的研究提出了一种创新的预训练方法——Lory,旨在解决大模型在混合专家(MoE)架构中的可微分性和计算效率问题。Lory通过专家合并技术,实现了模型的完全可微分,并通过因果段路由和基于相似性的数据批处理策略,显著提高了训练效率和模型性能。这项工作不仅推动了大模型的研究边界,也为未来更高效、更强大的AI系统的发展奠定了基础。

方法

Lory方法提出了一种可微分的MoE(Mixture of Experts)架构,专门为自回归语言模型设计。这一架构的核心在于专家合并技术,该技术允许模型在保持自回归特性的同时进行有效的训练。Figure 1展示了Lory的整体架构,其中包括两个关键技术:因果段路由策略和基于相似性的数据批处理方法。

因果段路由策略是Lory方法的首个关键技术。这一策略的核心思想是在处理语言模型时,不是对每个标记单独进行专家合并,而是在段级别上进行。这样做的好处是显著减少了合并操作的次数,从可能的L次减少到N次,其中N是将输入序列分割成的段的数量。这种策略不仅减少了计算量,而且通过仅使用前一段的信息来指导当前段的专家合并,有效地防止了信息泄露的问题。在Figure 1中,每个段的专家合并都是基于前一段的平均隐藏表示来计算的,这种设计确保了模型的自回归特性得以保持。

基于相似性的数据批处理方法是Lory方法的第二个关键技术。这一方法的目的是通过构建具有高相似性的段来提高专家在特定领域的专业化能力。在传统的数据批处理方法中,文档通常是随机连接的,这可能导致相邻段之间的标记来自完全不同的文档,从而影响了专家的专业化。Lory方法通过顺序连接相似的文档来构建训练实例,这种方法鼓励了相邻段之间的高相似性,使得专家能够更加专注于特定的领域或主题。在Figure 1基于相似性的数据批处理方法的示意图中,其中相似的文档被有意识地组合在一起,形成了具有高相似性的段,从而促进了专家的专业化。

通过结合这两个关键技术,Lory方法不仅提高了训练效率,还增强了模型在特定领域的专业化能力。

想要掌握如何将大模型的力量发挥到极致吗?2024年10月26日叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。

留言“参加”即可来叶老师的直播间互动,1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

实验

实验设置

模型:实验中训练了两种规模的解码器仅Transformer模型,分别含有0.3B和1.5B个活跃参数。对于每个FFN层,用具有相同架构的MoE层替换,其中E代表专家的数量,取值为8、16或32。模型配置和总参数计数详见附录D。所有模型使用4096-token的上下文窗口进行训练,并且采用因果段路由策略,每段长度设置为T=256。

训练细节:使用AdamW优化器进行训练,其中β1=0.9和β2=0.95,学习率设置为2e-4,并采用余弦学习率调度器。所有模型的批量大小为100万个token。采用ZeRO优化进行分布式训练。训练初期,使用与参数匹配的密集模型,并复制FFN层作为MoE模型的初始化。实验中,使用前5%的训练步骤作为预热来初始化MoE权重。此外,对学习率调度器的前5%训练步骤应用线性预热。训练过程中使用了多达64个A100 GPU。

训练数据集:从Commoncrawl数据集中随机抽取150亿个token作为训练数据。使用基于相似性的数据批处理方法构建所有训练实例。

评估数据集:在从arXiv、Books、Wikipedia、C4和Python代码(Github的Python子集)中抽取的保留评估数据集上,通过测量训练模型的困惑度来评估所有模型在语言建模任务上的性能。每个评估数据集包含1K个样本,每个样本由4096个token组成。此外,还在下游任务中评估了模型的性能,包括常识推理、阅读理解、闭卷问答和文本分类。

结果

训练效率和收敛性:Figure 2(左)展示了密集模型和具有不同模型大小的MoE模型的训练损失曲线。首先,可以发现在相同的训练token数量下,Lory模型明显比密集模型基线有更好的训练损失。对于0.3B和1.5B模型,具有32个专家的模型在训练token数量不到一半的情况下达到了相同的损失水平。这表明Lory方法在相同的训练计算量下实现了更好的性能。还观察到,使用更多的专家可以获得更多的改进。

语言建模:在语言建模任务上的评估表明,MoE模型一致性地优于密集基线,显著降低了所有领域的困惑度。例如,0.3B/32E模型在Books上的困惑度比0.3B密集模型提高了13.9%。值得注意的是,在与训练数据不同的测试领域(例如Python)中,改进最为显著,这表明了强烈的专家专业化,这在第5节中进一步探讨。

下游任务:Table 1展示了模型在下游任务上的性能。可以观察到在所有任务上都有显著的性能提升。例如,0.3B/32E模型在常识推理上平均提高了3.7%,在阅读理解上提高了3.3%,在闭卷问答上提高了1.5%,在文本分类上提高了11.1%。

消融研究和分析

在前缀路由策略中,仅根据第一个段来执行整个序列的专家合并,然后使用合并的FFN来处理序列的其余部分,而不再进行更新。Figure 3显示,仅使用前缀进行路由会导致性能大幅下降,与因果段路由相比,后者在每个段级别上提供更强大的训练信号,从而获得更好的性能。

为了研究基于相似性的数据批处理方法的重要性,比较了使用和不使用这种方法的MoE模型相对于密集模型的性能提升。Figure 4(左)展示了使用相似性批处理(sim batch)和随机批处理(rand batch)数据的密集(0.3B)和MoE模型(0.3B/8E)的训练损失。在两种设置中,MoE模型均优于密集模型。然而,使用相似性批处理时,损失改进(即密集模型和MoE模型之间的损失差异)更大,并且随着训练数据的增加,这种效果更加明显(Figure 4(右))。这些结果强烈支持了基于相似性批处理对于有效训练MoE模型的重要性。

将Lory方法与现有的Expert Choice (EC) MoE方法进行了比较,EC方法通过让每个专家根据路由权重选择top-k输入来确保训练期间的均衡负载。在推理期间,为了避免利用全局信息进行路由,每个标记被路由到top-k专家中。Figure 5展示了Lory方法与EC方法(包括段级别和标记级别路由)的训练损失曲线。Lory方法在相同的路由设置下显著优于段级别的EC,表明完全可微分的架构比使用相同路由策略的稀疏MoE更有效。与标记级别的EC模型相比,尽管Lory使用方法使用段级别路由且不需要任何高级训练技术,但Lory取得了有竞争力的结果。这些结果突出了Lory方法的显著潜力。

在分析中还探讨了专家利用率和专业化问题。利用率方面,尽管没有使用任何辅助损失来平衡负载,Lory方法仍然能够实现高专家利用率,防止MoE模型退化为密集模型。专业化方面,通过研究0.3B/8E模型在不同层的平均路由权重,发现即使在没有额外的领域级监督的情况下,训练出的MoE模型中也存在清晰的领域级专家专业化。例如,第11层的专家7专门处理arXiv领域的输入。此外,还观察到arXiv和Python代码的路由权重更为相似,而与Books和Wikipedia的相似度较低,这可能是因为LaTex代码和Python代码与自然语言的差异较大。

Figure 6 展示了0.3B/8E模型在不同领域(Books、arXiv、Python、Wikipedia)上的第0层、第11层和第23层的平均路由权重。观察到MoE模型中的专家学习到了领域级专业化,尤其是在中间和更高层次的层中。

随着未来研究的深入,这种方法有望在更广泛的应用场景中发挥重要作用,为人工智能领域带来新的突破。

https://arxiv.org/pdf/2405.03133

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2205970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开关打开输入框才能输入文字,否则为禁用状态

页面开关默认为关闭状态&#xff0c;输入框为禁用状态。 当点击开关&#xff0c;打开开关后&#xff0c;输入框禁用状态解除&#xff0c;才可以在输入框内输入。 html结构: <div class"page_top"><!-- 第一行 --><div class"top_first">…

使用three.js 实现一个 马赛克得 shader

使用three.js 实现一个 马赛克得 shader 源链接&#xff1a;https://threehub.cn/#/codeMirror?navigationThreeJS&classifyshader&idmosaicShader 国内站点预览&#xff1a;http://threehub.cn github地址: https://github.com/z2586300277/three-cesium-example…

HTML的介绍

HTML HTML是一种超文本标记语言,超文本是指,除了文本之外,还可能包含图片,音频,或者评注等的 文本形式,比文本强大,通过链接和交互方式来组织和呈现信息.标记语言是指,由标签构成的语言.HTML定义了多种不同的标签,用来表示不同的内容. 标签的介绍: 1.<h3> 三级 </h3&…

增强AI查询:使用Rewrite Retrieve Read框架优化RAG

增强AI查询&#xff1a;使用Rewrite Retrieve Read框架优化RAG 引言 在大规模语言模型&#xff08;LLM&#xff09;中&#xff0c;通过查询重写来提升检索增强生成&#xff08;RAG&#xff09;的性能是一个热门研究领域。本文将介绍如何使用rewrite_retrieve_read模板来优化R…

基于SpringBoot的图书推荐系统的设计与实现(论文+源码)_kaic

摘 要 网络信息技术的高速发展&#xff0c;使得高校图书馆的服务空间日益扩大&#xff0c;依据个人特点的针对性服务逐渐成为新服务模式的主导趋势。对于大多数用户而言&#xff0c;很难在大量的学术图书馆中快速找到他们想要的材料。另外&#xff0c;随着时代的不断发展&am…

Mysql的LSN是什么?

LSN的含义 ​ LSN全称为 Log Sequence Number&#xff0c;即日志序列号。它是一个不断递增的数字&#xff0c;用来标识事务日志中的每个操作或事件。LSN是一个64位的数字&#xff0c;每一个LSN值都是唯一的&#xff0c;并且随时间线性增加。 ​ 通过SHOW ENGINE INNODB STATUS;…

GADBench Revisiting and Benchmarking Supervised Graph Anomaly Detection

Neurips 23 推荐指数&#xff1a; #paper/⭐⭐⭐ 领域&#xff1a;图异常检测 胡言乱语&#xff1a; neurips 的benchmark模块的文章总能给人一些启发性的理解&#xff0c;这篇的insight真有意思。个人感兴趣的地方会加粗。此外&#xff0c;这篇文章和腾讯AIlab合作&#xff…

嵌入式基本知识

文章目录 调试接口仿真器MCU实际的调试接口 调试接口 调试接口用于对MCU进行编程和调试&#xff0c;这里的编程指将源代码编译后的.hex文件写入MCU闪存特定地址中&#xff0c;调试指MCU运行代码debug的过程。 不同的接口协议有不同的接口类型。SWD协议调试接口的引脚主要有&a…

卡码网C++基础课 |20. 排队取奶茶

目录 前言 一、题目描述 二、解题思路 1.队列 2.队列的操作 三、完整代码 总结 前言 仅个人记录所用 源自卡码网的C基础课 “这门C基础课 帮助 编程零基础学员快速学习刷算法题所需要的基础语法知识&#xff0c;学完之后&#xff0c;再来刷代码随想录&#xff0c;或者自己去…

CentOS 7.9 局域网配置指定同步时间服务器

在 CentOS 7.9 中&#xff0c;默认的时间同步工具是 chrony。以下是如何配置 NTP 服务器地址并使用 chrony 进行时间同步的步骤&#xff1a; 1. 安装 chrony&#xff08;通常已经预装可忽略&#xff09; 通过systemctl status chronyd检查是否已经安装启动 如果没网可以直接…

npm安装依赖报错npm ERR! Unexpected token ‘.

电脑是windows的&#xff0c;因为有多个项目做开发&#xff0c;每个项目需要的node版本不一样&#xff0c;所以使用了nvm做node管理。 电脑的nvm是1.1.7版本的。 新项目在安装依赖时突然报错如下&#xff1a; npm ERR! Unexpected token .在网上查了很多都说是nvm版本太低了&…

【MLP-Mixer】核心方法解读

abstract&#xff1a; 我们提出MLP-Mixer架构(或简称“Mixer”)&#xff0c;这是一个具有竞争力但在概念和技术上都很简单的替代方案&#xff0c;它不使用卷积或自关注。相反&#xff0c;Mixer的架构完全基于多层感知器(mlp)&#xff0c;这些感知器可以在空间位置或特征通道上…

渗透测试 之 域渗透手法【域内用户枚举】手法 Kerbrute msf pyKerbrute 工具使用详解

说明一下: 域内用户枚举工具使用说说&#xff1a; Kerbrute pyKerbrute MSF模块的使用 域内用户名枚举原理分析&#xff1a; 域内用户枚举攻击防御&#xff1a; 流量检测&#xff1a; 日志层面&#xff1a; 说明一下: 域环境或者内网环境下&#xff0c;可以在没有域环…

深入理解Transformer的笔记记录(精简版本)---- ELMO->GPT->BERT

1、ELMO word embedding无法区分多义词的不同语义,其本质上是个静态的方式,所谓静态指的是训练好之后每个单词的表达就固定住了,以后使用的时候,不论新句子上下文单词是什么,这个单词的Word Embedding不会跟着上下文场景的变化而改变 ELMO根据当前上下文对Word Embed…

有趣的python库:用 difflib 实现文本差异的可视化

一&#xff0c;介绍 difflib 模块是Python标准库的一部分&#xff0c;提供了一系列用于比较序列的类和函数&#xff0c;特别适用于文本比较任务。这个模块可以帮助用户发现两个文本文件或字符串序列之间的差异&#xff0c;并以多种格式展示这些差异&#xff0c;比如这样&#…

400行程序写一个实时操作系统RTOS(开篇)

笔者之前突发奇想&#xff0c;准备写一个极其微小的实时操作系统内核&#xff0c;在经过数天的努力后&#xff0c;这个RTOS诞生了。令读者比较意外的是&#xff0c;它的程序只有400行左右。但就是这短短的400行&#xff0c;完成了动态内存管理、多线程、优先级、低功耗管理、调…

深度学习--------------------------------使用注意力机制的seq2seq

目录 动机加入注意力Bahdanau注意力的架构 总结Bahdanau注意力代码带有注意力机制的解码器基本接口实现带有Bahdanau注意力的循环神经网络解码器测试Bahdanau注意力解码器该部分总代码 训练从零实现总代码简洁实现代码 将几个英语句子翻译成法语该部分总代码 将注意力权重序列进…

BUG修复(不断整理想起什么就整理什么)

声明&#xff1a;此篇博文是记录本人从开始学习计算机过程中遇到的各种类型的报错以解决办法,希望给同道中人提供一点绵薄的帮助&#xff0c;也欢迎大家在评论区讨论或私信我交流问题 共同进步&#xff01; 一、FPGA系列 1.Synthesis failed 错误&#xff1a;综合失败&#…

Python | Leetcode Python题解之第468题验证IP地址

题目&#xff1a; 题解&#xff1a; class Solution:def validIPAddress(self, queryIP: str) -> str:if queryIP.find(".") ! -1:# IPv4last -1for i in range(4):cur (len(queryIP) if i 3 else queryIP.find(".", last 1))if cur -1:return &q…

测试工作能干到退休!从会写一份成长型测试周报开始

测试周报则是反映团队工作进展和专业态度的一扇窗口。通过周报&#xff0c;我们不仅可以展示一周内的工作成果&#xff0c;更可以体现团队的工作心态——是积极进取、不断学习的成长型心态&#xff0c;还是仅仅满足于现状、缺乏动力的躺平型心态。本文将带您深入了解这两种不同…