英伟达玩转剪枝、蒸馏:把Llama 3.1 8B参数减半,性能同尺寸更强

news2024/9/21 2:44:01

小模型崛起了。

上个月,Meta 发布了 Llama 3.1 系列模型,其中包括 Meta 迄今为止最大的 405B 模型,以及两个较小的模型,参数量分别为 700 亿和 80 亿。

Llama 3.1 被认为是引领了开源新时代。然而,新一代的模型虽然性能强大,但部署时仍需要大量计算资源。

因此,业界出现了另一种趋势,即开发小型语言模型 (SLM),这种模型在许多语言任务中表现足够出色,部署起来也非常便宜。

最近,英伟达研究表明,结构化权重剪枝与知识蒸馏相结合,可以从初始较大的模型中逐步获得较小的语言模型。

图片

图灵奖得主、Meta 首席 AI 科学家 Yann LeCun 也点赞转帖了该研究。

经过剪枝和蒸馏,英伟达研究团队将 Llama 3.1 8B 提炼为 Llama-3.1-Minitron 4B 开源了出来。这是英伟达在 Llama 3.1 开源系列中的第一个作品。

Llama-3.1-Minitron 4B 的表现优于类似大小的最先进的开源模型,包括 Minitron 4B、Phi-2 2.7B、Gemma2 2.6B 和 Qwen2-1.5B。

图片

这项研究的相关论文早在上个月已经放出了。

图片

  • 论文链接:https://www.arxiv.org/pdf/2407.14679

  • 论文标题:Compact Language Models via Pruning and Knowledge Distillation

剪枝和蒸馏

剪枝使模型变得更小、更精简,可以通过删除层(深度剪枝)或删除神经元和注意力头以及嵌入通道(宽度剪枝)来实现。剪枝通常伴随着一定程度的再训练,以恢复准确率。

模型蒸馏是一种将知识从大型复杂模型(通常称为教师模型)迁移到较小、较简单的学生模型的技术。目标是创建一个更高效的模型,该模型保留了原始较大模型的大部分预测能力,同时运行速度更快且资源消耗更少。

蒸馏方式主要包括两种:SDG 微调与经典知识蒸馏,这两种蒸馏方式互补。本文主要关注经典知识蒸馏方法。

英伟达采用将剪枝与经典知识蒸馏相结合的方式来构造大模型,下图展示了单个模型的剪枝和蒸馏过程(上)以及模型剪枝和蒸馏的链条(下)。具体过程如下:

1. 英伟达从 15B 模型开始,评估每个组件(层、神经元、头和嵌入通道)的重要性,然后对模型进行排序和剪枝,使其达到目标大小:8B 模型。

2. 接着使用模型蒸馏进行了轻度再训练,原始模型作为老师,剪枝后的模型作为学生。

3. 训练结束后,以小模型(8B)为起点,剪枝和蒸馏为更小的 4B 模型。

图片

从 15B 模型进行剪枝与蒸馏的过程。

需要注意的点是,在对模型剪枝之前,需要先了解模型的哪部分是重要的。英伟达提出了一种基于激活的纯重要性评估策略,该策略可以同时计算所有相关维度(深度、神经元、头和嵌入通道)的信息,使用一个包含 1024 个样本的小型校准数据集,并且只需要前向传播。这种方法相比依赖梯度信息并需要反向传播的策略更加简单且具有成本效益。 

在剪枝过程中,你可以针对给定轴或轴组合在剪枝和重要性估计之间进行迭代交替。实证研究显示,使用单次重要性估计就足够了,迭代估计不会带来额外的好处。

利用经典知识蒸馏进行重新训练

下图 2 展示了蒸馏过程,其中 N 层学生模型(剪枝后的模型)是从 M 层教师模型中(原始未剪枝模型)蒸馏而来。学生模型通过最小化嵌入输出损失、logit 损失以及映射到学生块 S 和教师块 T 的 Transformer 编码器特定损失组合来学习。

图片

图 2:蒸馏训练损失。

剪枝和蒸馏最佳实践

英伟达基于紧凑语言模型中剪枝和知识蒸馏的广泛消融研究,将自己的学习成果总结为以下几种结构化压缩最佳实践。

一是调整大小。

  • 要训练一组 LLM,首先训练最大的一个,然后迭代地剪枝和蒸馏以获得较小的 LLM。

  • 如果使用多阶段训练策略来训练最大的模型,最好剪枝并对训练最后阶段获得的模型进行重新训练。

  • 对最接近目标大小的可用源模型进行剪枝。

二是剪枝。

  • 优先考虑宽度剪枝而不是深度剪枝,这对于 15B 参数规模以下的模型效果很好。

  • 使用单样本(single-shot)重要性估计,因为迭代重要性估计没有任何好处。

三是重新训练。

  • 仅使用蒸馏损失进行重新训练,而不是常规训练。

  • 当深度明显减少时,使用 logit、中间状态和嵌入蒸馏。

  • 当深度没有明显减少时,使用 logit-only 蒸馏。

Llama-3.1-Minitron:将最佳实践付诸应用

Meta 最近推出了功能强大的 Llama 3.1 开源模型系列,在许多基准测试中可与闭源模型相媲美。Llama 3.1 的参数范围从巨大的 405B 到 70B、8B。

凭借 Nemotron 蒸馏的经验,英伟达着手将 Llama 3.1 8B 模型蒸馏为更小、更高效的 4B 模型,采取以下措施:

  • 教师微调

  • Depth-only 剪枝

  • Width-only 剪枝

  • 准确率基准

  • 性能基准

教师微调

为了纠正模型训练所基于的原始数据集的分布偏差,英伟达首先在他们的数据集上(94B token)对未剪枝的 8B 模型进行了微调。实验表明,如果不纠正分布偏差,教师模型在蒸馏时会为数据集提供次优指导。

Depth-only 剪枝

为了从 8B 降到 4B,英伟达剪枝了 16 层(50%)。他们首先通过从模型中删除每个层或连续子层组来评估它们的重要性,并观察下游任务中 LM 损失的增加或准确率的降低。

下图 5 显示了删除 1、2、8 或 16 层后验证集上的 LM 损失值。例如,第 16 层的红色图表示如果删除前 16 层,则出现 LM 损失。第 17 层表示如果保留第一层并删除第 2 至第 17 层,也出现 LM 损失。英伟达观察到:开始和结束的层是最重要的。

图片

图 5:depth-only 剪枝中层的重要性。

然而,英伟达观察到,这种 LM 损失不一定与下游性能直接相关。

下图 6 显示了每个剪枝模型的 Winogrande 准确率,它表明最好删除第 16 到第 31 层,其中第 31 层是倒数第二层,剪枝模型的 5-shot 准确率明显高于随机准确率 (0.5)。英伟达采纳了这一见解,删除了第 16 到第 31 层。

图片

图 6:当删除 16 层时,在 Winogrande 任务上的准确率。

Width-only 剪枝

英伟达沿宽度轴剪枝了嵌入(隐藏)和 MLP 中间维,以压缩 Llama 3.1 8B。具体来说,他们使用前面描述的基于激活的策略来计算每个注意头、嵌入通道和 MLP 隐藏维度的重要性分数。

在重要性估计之后,英伟达选择

  • 将 MLP 中间维从 14336 剪枝到 9216。

  • 将隐藏大小从 4096 剪枝到 3072。

  • 重新训练注意头数量和层数。

值得一提的是,在单样本剪枝之后,宽度剪枝的 LM 损失高于深度剪枝。然而,经过短暂的重新训练后,趋势发生了逆转。

准确率基准

英伟达使用以下参数对模型进行蒸馏

  • 峰值学习率 = 1e-4

  • 最小学习率 = 1e-5

  • 40 步线性预热

  • 余弦衰减计划

  • 全局批量大小 = 1152

下表 1 显示了 Llama-3.1-Minitron 4B 模型变体(宽度剪枝和深度剪枝)与原始 Llama 3.1 8B 模型、其他类似大小的模型在跨多个领域的基准测试中的性能比较。总体而言,英伟达再次证实了宽度剪枝策略相较于遵循最佳实践的深度剪枝的有效性。

图片

表 1:Minitron 4B base 模型相较于类似规模 base 模型的准确率比较。

为了验证蒸馏后的模型是否可以成为强大的指令模型,英伟达使用 NeMo-Aligner 对 Llama-3.1-Minitron 4B 模型进行了微调。

他们使用了 Nemotron-4 340B 的训练数据,在 IFEval、MT-Bench、ChatRAG-Bench 和 Berkeley Function Calling Leaderboard (BFCL) 上进行了评估,以测试指令遵循、角色扮演、RAG 和函数调用功能。最后确认 Llama-3.1-Minitron 4B 模型可以成为可靠的指令模型,其表现优于其他基线 SLM。

图片

表 2:对齐 Minitron 4B base 模型与类似规模的对齐模型的准确率比较。

性能基准

英伟达利用 NVIDIA TensorRT-LLM(一种用于优化 LLM 推理的开源工具包)优化了 Llama 3.1 8B 和 Llama-3.1-Minitron 4B 模型。

下两张图显示了不同模型在不同用例下以 FP8 和 FP16 精度每秒的吞吐量请求,表示为 8B 模型的 batch size 为 32 的输入序列长度 / 输出序列长度 (ISL/OSL) 组合以及 4B 模型的 batch size 为 64 的输入序列长度 / 输出序列长度 (ISL/OSL) 组合,这要归功于在一块英伟达 H100 80GB GPU 上,较小的权重允许较大的 batch size。

Llama-3.1-Minitron-4B-Depth-Base 变体是最快的,平均吞吐量约为 Llama 3.1 8B 的 2.7 倍,而 Llama-3.1-Minitron-4B-Width-Base 变体的平均吞吐量约为 Llama 3.1 8B 的 1.8 倍。与 BF16 相比,在 FP8 中部署还可使这三种型号的性能提高约 1.3 倍。

图片

图片

图 8:组合:Llama 3.1 8B 为 BS=32,Llama-3.1-Minitron 4B 型号为 BS=64。1x H100 80GB GPU。

结论

剪枝和经典知识提炼是一种非常经济高效的方法,可以逐步获得更小尺寸的 LLM,与在所有领域从头开始训练相比,可实现更高的准确性。与合成数据式微调或从头开始预训练相比,这是一种更有效且数据效率更高的方法。

Llama-3.1-Minitron 4B 是英伟达首次尝试使用最先进的开源 Llama 3.1 系列完成的探索。要在 NVIDIA NeMo 中使用 Llama-3.1 的 SDG 微调,可参阅 GitHub 上的 /sdg-law-title-generation 部分。

有关更多信息,请参阅以下资源:

点击访问我的技术博客https://ai.weoknow.comicon-default.png?t=N7T8https://ai.weoknow.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2047710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(javaweb)SpringBootWeb案例(毕业设计)案例--文件上传

1.简介 前端程序和服务端程序 对于前端 html文件放在static目录下 location---文件提交的位置 右键--copy value -------------c盘目录下 2.本地上传--文件存储 1. 2. 使用uuid:保证文件名是唯一的 此时 并没有文件的拓展名--所以需要---写后缀 用字符串截取 此时图…

Java、python、php版的宠物美容预约服务系统的设计与实现 (源码、调试、LW、开题、PPT)

💕💕作者:计算机源码社 💕💕个人简介:本人 八年开发经验,擅长Java、Python、PHP、.NET、Node.js、Android、微信小程序、爬虫、大数据、机器学习等,大家有这一块的问题可以一起交流&…

【报告】从 YCombinator 支持的 400 家(2023年和2024年) AI 初创公司看AI行业

这份报告对 YC 2023 年和 2024 年队列中的 417 家人工智能公司进行了广泛的分析。对于那些不知道的人来说,YCombinator是一个领先的初创企业加速器,提供种子资金、指导和资源,以帮助早期初创企业取得成功,YCombinator (YC)在发现和…

SOMEIP_ETS_044: echoUTF16DYNAMIC_with_odd_number_after_termination

测试目的: 验证设备(DUT)是否能够正确处理一个在终止符之后多出一个字节的echoUTF16DYNAMIC字符串,并且能够去除这个多余的字节。 描述 本测试用例旨在检查DUT在接收到一个不符合UTF16DYNAMIC字符串规范(即在终止符…

【Python机器学习】利用PCA来简化数据——PCA

PCA(主成分分析)的优缺点: 优点:降低数据的复杂性,识别最重要的多个特征; 缺点:不一定需要,且可能损失有用信息; 适用数据类型:数值型数据。 移动坐标轴 如下…

【研发日记】嵌入式处理器技能解锁(四)——TI C2000 DSP的Memory

文章目录 前言 背景介绍 Memory映射 RAM ROM 外设Register Memory分配 应用实例 总结 参考资料 前言 见《【研发日记】嵌入式处理器技能解锁(一)——多任务异步执行调度的三种方法》 见《【研发日记】嵌入式处理器技能解锁(二)——TI C2000 DSP的SCI(串口)通信》 见《…

在线excel/csv转json数据

具体请访问:在线Csv/Excel(xls/xlsx)转Json格式工具

编程语言进化史

编程语言多到你想象不到。 图片来自: 程序设计语言概念 发展历史 自从1946年冯诺依曼原理被提出,计算机数据和指令是通过二进制形式以及后来的汇编语言(二进制助记符),但依然没有改变容易出错的本质。1951年Rutishauser提出的用编译程序实现高级语言的思…

开放平台: 签名密钥、回调地址、ip白名单管理。

文章目录 引言I 渠道信息管理(签名密钥)表设计渠道信息管理服务商API配置导出II 签名校验兼容图片上传接口验签规则方案2III 工具类开放平台字典服务接口txt文件的下载see also引言 需求: 提供给下游的开放平台,需要对接口做签名密钥、回调地址、ip白名单管理。 涉及的功…

JS实现一键点击按钮复制文本

JS实现一键点击按钮复制文本 背景描述JS代码实现 背景描述 现在有这样一个需求,想要在页面实现点击按钮,一键复制指定列表字段内容的操作,就像这样的效果 复制成功之后的内容在Notepad 粘贴可以看到 正式列表中链接地址字段的内容&#xf…

【学习笔记】Day 15

一、进度概述 1、《地震勘探原理》第八、九章 二、详情 对于第八章,主要讨论地震资料岩性解释的基本方法,对于利用地震信息进行储层的物性预测于解释、储层的含油性分析与解释、地震地层学解释、层序地层学解释、地球物理资料综合解释等内容。 第五、六…

【图解秒杀系列】秒杀技术点——静态化

【图解秒杀系列】秒杀技术点——静态化 什么是静态化、静态化的作用如何实现静态化FreeMarker、Thymleaf处理流程问题 OpenResty Lualua_shared_dict & lua-resty-template处理流程具体操作 什么是静态化、静态化的作用 静态化就是指通过某种静态化技术,将原本…

【动态规划、dp】P1091 [NOIP2004 提高组] 合唱队形 题解

题意 n n n 位同学站成一排,音乐老师要请其中的 n − k n−k n−k 位同学出列,使得剩下的 k k k 位同学排成合唱队形。 合唱队形是指这样的一种队形:设 k k k 位同学从左到右依次编号为 1 , 2 , … , k 1,2, …,k 1,2,…,k,他…

Qt-创建第一个Qt项目(3)

目录 新建项目 设置路径 选择构建工具 父类的选择 各个父类的介绍 国际化相关的选项 选择SDK Summary选择 项目初见 新建项目 这一点和在VS里面是一样的,我们首先都得创建一个项目出来 进去之后就是选择项目模板了 我们使用默认的就行了,左边…

顺丰科技25届秋季校园招聘常见问题答疑及校招网申测评笔试题型分析SHL题库Verify测评

Q:顺丰科技2025届校园招聘面向对象是? A:2025届应届毕业生,毕业时间段为2024年10月1日至2025年9月30日(不满足以上毕业时间的同学可以关注顺丰科技社会招聘或实习生招聘)。 Q:我可以投递几个岗…

涉密载体管控系统DW-S402|实现载体管控新模式

涉密载体管控系统DW-S402是用于对各种涉密载体进行有效管理的智能柜(智能管理系统),基于物联网技术实现对载体的智能化、规范化、标准化管理,广泛应用于保密、机要单位以及企事业单位等有载体保管需求的行业。 载体管控软件对涉密…

可用性检查和短缺部件检查

可用性检查 可用性检查有两种类型: “库存管理”中库存类型的可用性检查(静态可用性检查)从“物料需求计划”的角度检查可用库存(动态可用性检查) 库存类型的可用性检查(静态可用性检查) 此项…

全面解析ETL:数据仓库架构中的关键处理过程

目录 一、数据仓库架构中的ETL 二、数据抽取 (1)逻辑抽取 (2)物理抽取 (3)变化数据捕获 三、数据转换 四、数据装载 (1)提高装载效率 (2)处理装载失败 五、ET…

MacOS 下运行 GPT-SoVITS

系统环境: # 安装 ffmpeg brew install ffmpeg # 查看版本 ffmpeg -version # 拉取项目代码 git clone --depth1 https://github.com/RVC-Boss/GPT-SoVITS cd GPT-SoVITS # 安装好 Miniconda 之后,先创建一个虚拟环境: conda create -n GPT…

关于Qt的系统总结

查看详情http://100bcw.com/qt6.htm 编译环境与开发流程 开发QT有两种IDE可以使用,一种是使用 VS + Qt 的插件,另一种就是使用QtCreator工具。前一种是微软的工具,用的都比较多容易上手,缺点是信号槽的支持不太好,需要手写,不能自动生成,另外可能有中文编码的问题。后一…