通过知识蒸馏提升大模型训练效率

news2024/9/20 18:44:55

 人工智能咨询培训老师叶梓 转载标明出处

随着模型规模的不断扩大,如GPT-4这样的模型拥有约1.7万亿参数,其预训练所需的巨大能源和计算资源引发了对可持续发展AI解决方案的迫切需求。麦吉尔大学的研究团队介绍了一种创新的方法来解决与LLMs预训练相关的效率问题,即通过知识蒸馏实现跨架构的知识转移。研究团队提出了一种名为Hyena的机制,该机制通过替代变换器模型中的注意力头,提供了一种成本效益更高的替代传统预训练的方法。与传统的压缩方法不同,该技术不仅提高了推理速度,而且在准确性和效率方面都超越了预训练。

方法

Hyena算子是本文的核心创新之一,由Poli等人在2023年提出。它旨在作为次线性(subquadratic)替代方案,以替换变换器中的注意力(attention)操作。与H3等其他状态空间模型不同,Hyena直接对滤波器进行参数化,这相当于线性时不变(LTI)系统的脉冲响应。

具体来说,Hyena算子首先对时间索引应用位置嵌入,其中df​是嵌入维度。然后,通过前馈神经网络(FFN):,其中dm​是模型的维度,并将结果乘以一个窗口函数以获得滤波器h[n]。数学表达式为:

Hyena算子​使用这样的滤波器ℎh来聚合长上下文窗口的上下文,并通过对乘法门控机制引入非线性。首先通过投影操作P(x,θ)获得三个投影q,k,v,该操作由参数θ控制。投影操作包括一个线性投影​,然后是一个短的深度卷积,使用短滤波器​进行局部信息交换。然后使用逐元素乘法,接着是卷积和第二个逐元素乘法来计算Hyena算子的输出:其中∗表示卷积操作,⊙表示逐元素乘法。注意,通过使用不同数量的投影,可以进一步泛化该算子。

在进行实验时,研究团队选择了70M参数版本的GPT-NeoX模型,这是一个仅解码器的变换器模型,其架构与GPT-3非常相似,但存在一些关键差异:

  • 传统GPT模型中的位置嵌入被旋转位置嵌入(RoPE)所替代,它使用旋转矩阵对token的位置信息进行编码。
  • 通常在传统GPT模型中串行发现的注意力和前馈层在GPT-NeoX中为了效率而并行计算。
  • 所有的前馈层都是密集的,与GPT-3中密集和稀疏层的交替不同。

值得注意的是,GPT-NeoX的架构与GPT-J非常相似。图1展示了模型架构的详细图示,其中包括:

  • A) GPT NEO X层架构:70M GPT NEO X中的6层堆叠注意力和多层感知机(MLPs)。
  • B) 使用Hyena算子替换注意力头的Hyena-Distilled NEO GPT X层架构,用于蒸馏任务。
  • C) 来自Vaswani等人(2017)的注意力算子的视觉表示。
  • D) 来自Poli等人(2023)的Hyena算子的视觉表示。

本文的目标是将注意力机制替换为Hyena机制。由于Hyena算子已经保留了其输入token的位置信息,因此Hyena版本的模型不包括旋转位置嵌入。研究使用了Biderman等人在2023年实现的Pythia模型,并在开源的Pile数据集上进行了训练。

研究采用了逐步知识转移(Progressive Knowledge Transfer)的方法来逐步训练学生模型。对于每一层,首先在教师模型上对一个token数据集X进行推理,以获得一个蒸馏数据集,其中x是token索引序列,​是教师模型在第i层的输出。然后,最小化均方误差损失,使用​——学生模型在第i层的输出,一次训练一层。对于最后一层,可以通过在文本数据上进行无监督训练来额外微调模型:

所有语言建模实验都使用了OpenWebText数据集。通过从OpenWebText中随机抽取200万个示例来获得一个标记化的预训练数据集,每个预训练示例的上下文长度为1024。数据集被分为训练集和验证集,其中0.1%被保留用于验证。对于蒸馏实验,从训练集中采样了4000万个token来获得用于训练每层的蒸馏数据集。

所有实验都使用了与70M教师模型相同的6层GPTNeoX风格架构。研究者首先基于Pythia和Hyena模型的超参数,从头开始对模型进行预训练,使用了10亿个token。研究者定义预训练为从随机初始化的模型开始,在文本数据上进行无监督学习的过程。同样,研究者定义无监督微调(CE-tinune)为从模型检查点开始,在文本数据上进行无监督学习的过程。在预训练阶段,研究者实现了一个线性预热,跨越300个训练步骤,然后使用余弦衰减在2000次迭代中降低学习率。这种衰减持续到达到最大学习率的10%,此时学习率保持不变。类似地,在蒸馏过程中,研究者在总训练步骤的2.5%上实施线性预热,然后在整个步骤集上衰减,直到达到最大学习率的10%。研究者尝试只进行蒸馏(MSE)以及微调(CE-tinune)。所有实验都设计在RTX 3090上运行5小时。

在Pythia模型的解码器层上进行渐进式知识转移的图示

结果与分析

困惑度(Perplexity)作为衡量语言模型性能的关键指标,用于评估模型对真实数据分布的预测准确性。研究者使用了OpenWebText和WikiText数据集来计算所有模型的困惑度得分。他们采用了与预训练数据集相同的验证集来计算得分,并且所有模型的困惑度得分都是在1024个token的上下文长度下获得的。

表1展示了四种不同模型的困惑度得分:

  • PYTHIA-70M (TEACHER): 教师模型,使用传统的注意力机制,其在WikiText和OpenWebText上的困惑度得分分别为51.4和35.3。
  • PRE-TRAINED: 直接预训练的Hyena模型,得分较高,分别为230和64.9。
  • MSE: 使用均方误差(MSE)损失进行蒸馏后的Hyena学生模型,得分有所下降,分别为155.8和63.5。
  • CE FINE-TUNE: 在蒸馏后进行交叉熵(CE)微调的Hyena学生模型,其困惑度得分进一步降低,分别为121.2和49.6。

这些结果表明,经过蒸馏和微调的学生模型在语言建模任务上的性能有了显著提升,尤其是在OpenWebText数据集上,其困惑度得分接近教师模型。

研究者进一步在三个模型上应用了一系列自然语言任务,以评估它们在不同任务上的表现:

  1. 使用Hyena替代注意力机制的GPT模型。
  2. 使用传统注意力机制的Pythia 70M教师模型。
  3. 使用Hyena并通过联合知识转移(JKT)进行蒸馏的Pythia 70M学生模型。

他们使用了语言模型评估工具(lm eval)对这三个模型在多个不同的自然语言任务上进行了基准测试。测试结果如表2所示,所有结果都是在32位浮点精度下测量的,以确保可重复性并最小化由于低精度引起的机器误差。

表2中列出了不同任务的准确率(ACC)和标准偏差,包括ARC挑战、ARC简单、LOGIQA、PIQA、SCIQ、WINOGRANDE和WSC任务。从表中可以看出,使用Hyena的学生模型在某些任务上的表现略低于教师模型,但在Arc挑战和WSC任务上,学生模型的表现则略高于或显著高于其他两个模型。

表1的实验结果表明,在相同的GPU小时预算内,逐步知识转移与传统的预训练方法相比,在模型性能上具有优势。本方法在没有额外无监督学习的情况下取得了更好的性能,这表明了逐步知识转移策略的效率。

另外研究结果揭示了蒸馏作为无监督学习前的一个初始化步骤的潜力。这种方法在与传统预训练和纯知识转移相同的训练成本下提供了提高的性能。这表明知识蒸馏方法不仅提供了改进的初始性能,而且还允许在不增加额外训练费用的情况下进行额外的优化。

对结果的进一步检查强调了知识蒸馏对模型泛化的重大影响。的确,使用蒸馏在WikiText困惑度得分上的提高强调了本方法在增强模型用教师模型的知识对未见数据进行外推的能力方面的有效性。这为知识蒸馏在机器学习场景中的更广泛适用性和鲁棒性提供了宝贵的见解,特别是与传统的预训练策略相比。

表2表明,使用Hyena预训练的GPT模型通常具有与使用Hyena的Pythia 70M模型相似但略低的准确率。这些结果表明,使用Hyena的LLM通常能够像基于注意力的LLM模型一样表现良好,尽管基于Hyena的模型通常具有略低的测量性能。学生Pythia 70M JKT模型通常比预训练的基于注意力的Pythia 70M模型表现略差,尽管模型性能通常在相似的范围内,除了Sciq任务,学生模型的准确率明显低于GPT Hyena和教师模型。然而,在Arc挑战和Wsc任务中,Pythia 70M学生模型略微优于并显著优于其他两个模型。

结果表明,学生Hyena模型上的联合知识转移通常保留了其教师模型的语言能力,并且学生Hyena模型在某些情况下可以优于其教师模型。因为Hyena在直接比较时比注意力更有计算效率,并且因为联合知识转移可能比传统预训练更有计算效率,结果表明Hyena学生模型上的联合知识转移提供了一种计算效率高的替代传统基于注意力的LLMs预训练的方法。

论文链接:https://arxiv.org/abs/2401.17574

项目链接:

  • Pythia:本文中使用的模型实现之一。
  • The Pile:本文中用于训练的数据集之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132290.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL5.7基于mysqldump、xtrbackup、innobackupex工具进行全量备份/恢复、增量备份/恢复

mysql全量备份脚本 文章目录 前言一、数据库备份分类二、为什么需要备份?三、备份工具示例1.逻辑备份工具1.1.使用场景1.2.备份操作示例1.3.恢复操作示例 2.物理备份工具2.1.xtrbackup介绍2.2.使用场景2.3.安装percona-xtrabackup2.4.xtrbackup备份原理2.5.percona-…

西门子PLC读取时间相差8小时

当前时间与PLC读取到的时间相差8小时,如下图所示 原因:指令问题 模块时间总是存储在 CPU 时钟中,而不带因子“本地时区”或“夏令时”。之后,CPU 时钟将基于模块时间计算 CPU 时钟的本地时间。 解决办法:将读取时间指…

leetcode hot100_part01_哈希

1.两数之和 遍历数组,map中存在target - nums[i]就返回结果,不存在就把当前元素存入map; 49.字母异位词分组 分组,怎么分,用hashMap, key为每一组的标识,value为每一组包含的字符串(属于同一组的&#xff…

【笔记】第一节. 引言

• 轨道用钢的加工过程 • 钢轨结构及其标准 • 轨道结构特点 • 钢轨的商业化及其发展趋势 轨道用钢的加工过程 钢轨形式及其标准 钢轨的基本结构 轨头、轨腰、轨底。 钢轨的技术标准 • 铁道行业标准《TB/T2344-2003:43~75 kg/m 热轧钢轨订货技术…

腾讯云Ubuntu系统安装宝塔,配置Java环境,运行spring boot项目

致谢 本次学习宝塔部署spring boot项目,参考如下资料 https://www.cnblogs.com/daen/p/15997872.html 系统安装宝塔 直接用的腾讯云云服务器面板上的登录,你可以换成 xshell 进入宝塔官网: https://www.bt.cn/new/download.html 我们采…

【Android】Handler用法及原理解析

文章目录 用处基本用法用法一:使用sendMessage和handleMessage方法用法二:使用post方法 法一工作原理Handler的sendMessageMessage成员变量 MessageQueueLooper主线程自动初始化子线程手动创建**prepare**loop Handler的dispatchMessage 法二工作原理Han…

机器学习中最常见的50个问题(进阶篇)

机器学习中最常见的50个问题 进阶篇 1.解释SVM的工作原理。 SVM,全称支持向量机(Support Vector Machine),是一种有监督学习算法,主要用于解决数据挖掘或模式识别领域中的数据分类问题。 SVM的工作原理是建立一个最…

TypeScript 扩展

扩展 ?:可选参数 可选链事实上并不是TypeScript独有的特性,它是ES11(ES2020)中增加的特性 可选链使用可选链操作符 ? 作用是当对象的属性不存在时,会短路,直接返回undefined,如果存在,那么…

小程序开发设计-小程序简介①

1.小程序与普通网页开发的区别: 1.运行环境不同: 网页运行在浏览器环境中。 小程序运行在微信环境中。 2.API不同: 由于运行环境不同,所以小程序中,无法调用DOM和BOM的API。但是,小程序中可以调用微信环境提…

摊牌了!一文教会你轻松上手豆包MarsCode 编程助手!

豆包MarsCode 编程助手是豆包旗下的 AI 编程助手,提供以智能代码补全为代表的 AI 功能。豆包MarsCode 编程助手支持主流的编程语言和 IDE,在开发过程中提供单行代码或整个函数的编写建议。此外,它还支持代码解释、单测生成和问题修复等功能&a…

收藏!6个PPT素材模板网站,快速做出好看的PPT

找PPT模板一定要收藏好这6个网站,能让你快速做出好看的PPT,重点十可以免费下载,赶紧收藏! 1、菜鸟图库 ppt模板免费下载|ppt背景图片 - 菜鸟图库 菜鸟图库网有非常丰富的免费素材,像设计类、办公类、自媒体类等素材都…

时序必读论文05|PatchTST : 时序数据Patch已成趋势【ICLR 2023】

书接上回,我们在之前的文章已经分析了直接把transformer应用到时间序列预测问题的不足,其中我们总结了4个不足:分别是: 注意力机制的计算复杂度高,为 O(N^2),并且计算得出的权重仅有少部分有用;…

【TCP三次握手+四次挥手(个人理解版本)】

TCP协议介绍 TCP(传输控制协议)是一种面向连接的、可靠的、基于字节流的传输层通信协议(它是全双工工作模式)。以下是对它的具体介绍: 基本概念 定义:TCP是Transmission Control Protocol的缩写&#xff…

PHP无缝对接预订无忧场馆预订系统小程序源码

无缝对接,预订无忧 —— 场馆预订系统,让每一次活动都完美启航! 一、告别繁琐流程,预订从未如此简单 你是否曾经为了预订一个合适的场馆而焦头烂额?繁琐的咨询、确认、支付流程,让人心力交瘁。但现在&…

如何利用Java进行快速的足球大小球及亚盘数据处理与分析

在当今信息爆炸的时代,大量的数据产生和积累,对于企业和个人来说,如何高效地处理和分析这些数据成为了一项重要的任务。Java作为一门强大的编程语言,提供了丰富的工具和库,可以帮助我们快速进行数据处理与分析。下面将…

vue3中实现拖拽排序(vue-draggable-next的使用)

1.安装插件 npm i vue-draggable-next 2.引入使用 <template> <vue-draggable-next v-model"list" tag"div" handle".warn-card" group"warngroup" ghost-class"ghost"class"mb10 warn-card-box" ani…

【mysql】逻辑运算符

逻辑运算符 逻辑运算符主要是为了判断表达式的真假,返回结果也是1,0,null OR 这里面或就是两个条件或的关系,比如我要department_id等于10和等于20的情况就可以使用或. SELECT last_name,salary,department_id FROM employees WHERE department_id10 OR department_id20 …

Unreal游戏初始化流程

前言 本文主要是总结Unreal在游戏启动时的初始化流程&#xff0c;包括讨论PIE和Standalone的区别&#xff0c;避免把一些初始化逻辑放在不合适的位置&#xff0c;比如我希望在所有Actor BeginPlay后执行某个逻辑&#xff0c;那我如果把它放在Subsystem的initialize中显然就会搞…

Golang使用ReverseProxy实现反向代理

目录 1.源码结构体 2.官方单机示例 3.使用示例 4.简单的http服务&#xff08;用于测试&#xff09; 1.源码结构体 type ReverseProxy struct {// Rewrite 必须是一个函数&#xff0c;用于将请求修改为要使用 Transport 发送的新请求。然后&#xff0c;其响应将原封不动地…

打造古风炫酷个人网页:用HTML和CSS3传递笔墨韵味

需要用到的背景大家可以自己找喜欢的风格!!! 当然俺把俺用的背景放到文章最后了哦&#xff01;&#xff01;&#xff01;&#xff01;&#xff01; 感谢关注和支持 长期更新哦~~~ 1. 简洁的页面布局&#xff1a;保持优雅和对称 在古风设计中&#xff0c;布局的对称性非常重要…