LLM-Pruner: 剪枝+少量数据+少量训练 = 高效的LLM压缩

news2024/11/22 10:28:39

a488a185bfb99bd4180767338d9eefe2.gif

ebf59e31f229735aba23c2376c184f39.png

概要

大语言模型(LLMs, Large Language Models)在各种任务上展现出了惊人的能力,这些能力很大程度上来自于模型庞大的模型规模以及海量的训练语料。为了应对这些模型部署上存在的挑战,许多研究者开始关注大语言模型的轻量化问题。由于 LLM 模型庞大的参数量,我们希望能以最小的成本完成对模型的压缩,尽可能的减少压缩后训练的开销,实现高效的模型压缩。

因为,我们提出了一种基于自动化的结构化剪枝的方案,采用结构化剪枝可以尽最大可能保留模型已习得的知识,实现高效的大语言模型压缩。在实验中,我们仅使用了 5 万条训练语料以及单张 4090 24GB 显卡,就可以在 3 小时内完成 LLaMA,Vicuna 和 ChatGLM 等大语言模型的剪枝和训练。

fa1b1e22749d6f6cf17e995c6c7d8364.png

论文标题:

LLM-Pruner: On the Structural Pruning of Large Language Models

论文作者:

Xinyin Ma, Gongfan Fang, Xinchao Wang

单位:

新加坡国立大学

论文地址:

https://arxiv.org/abs/2305.11627

Github地址:

https://github.com/horseee/LLM-Pruner

1e92f2b6520152d55b0a124852a18ab5.png

大语言模型压缩面临的问题

首先,大语言模型的压缩与之前的语言模型(例如 BERT,RoBERTa 等)的压缩等有什么差异呢。这需要从模型/数据/任务三个角度来分析,

  • 模型规模:第一个主要差异来自 LLM 参数量规模远超之前的语言模型,这导致许多依赖重新训练的压缩方案,例如知识蒸馏,重新训练的开销较大。

  • 海量训练语料:许多 LLMs 经历了 1 万亿甚至更大规模的 tokens 上的训练,这导致许多依赖于原始数据或收集替代数据的方案变得尤其昂贵。

  • 任务无关的模型压缩:现有的压缩算法通常针对单一、特定的任务进行压缩,而 LLMs 是很优秀的多任务处理器,在压缩过程中我们不希望折损 LLM 的通用性和多功能性。

上述三种问题实际上对应了三个基本要求:

  • 降低训练规模:压缩算法需要尽可能少的依赖大规模的重新训练。

  • 减少数据依赖:压缩训练所需要的数据需要尽可能少

  • 保留模型原始能力:压缩模型过程中需要保留一定的原始模型能力。

由此,我们需要一种能够避免大规模重新训练、且能保持模型原有能力的压缩方法。现有的较为可行的两种方案是模型量化和结构化剪枝,他们的最主要特点都在于不需要完全从零开始训练。其中模型量化侧重于降低推理阶段的存储开销以及提升计算速度,而结构化剪枝则直接移除部分参数实现压缩,两种方案可以相互结合达到最优性能。本文主要介绍基于结构化剪枝的 LLM 压缩方法,通过保留 LLM 中的重要参数,降低压缩训练成本。

2618dee1ed4f4ddebdbef1fa0003b994.png

165e67e5c7f2d4c974eac8ff0b2a0109.png

方法

本文提出一种简洁高效的结构化剪枝方法,遵循了经典的“重要性估计-剪枝-微调”的策略,能够在有限资源下完成大语言模型的压缩:

  • 重要性估计:在 LLM 中存在着复杂的子结构,剪枝算法需要对网络的参数进行分组,找到可移除的最小单元(通常称为 Group),并对各个组进行重要性评估。不同于之前的研究,我们解决了依赖图自动的挖掘神经网络中存在依存关系的子结构,也因此我们的方法更易拓展到更多的大规模语言模型。

  • 剪枝:在得到各个分组的重要性后,我们将冗余的组整个移除,从而降低模型的参数量。

  • 微调:在剪枝后的模型上应用 LoRA 等高效微调策略,恢复模型性能。

3.1 参数组重要性分析

首先我们分析了 LLM 模型中参数之间的依赖关系,从而找到最小可移除的分组。这些依赖可以归纳为三类,下图展示了 LLaMA 模型中存在的层耦合关系,主要由 MLP 内部的耦合、MHA 内部的耦合(主要为 QKVO 四个映射层)与整个网络中的通道耦合。这些耦合导致虚线显示的参数、神经元需要被同时移除,从而确保剪枝后结构的正确性。

8f3f73253afc21b59608474005463ab5.png

为了自动化上述依赖分析,我们需要对神经元的依赖进行形式化建模。假设 和 是模型中的两个神经元, 和 表示 被指向和指向的所有神经元。结构之间的依赖性可以定义为:

4083bddcdc65cf743639e15bd462baba.png

其中 表示神经元的入度。注意这里的依赖关系并非是前向计算流向上的依赖关系,而是神经元结构上的依赖关系。这种依赖性是方向性的,我们因此可以得到另一种依赖性:

5567864462d489e1a4a6bc3a7989f57e.png

其中 表示神经元的出度。依赖性在网络剪枝中的重要性在于,如果当前神经元(例如 )仅依赖于另一个神经元(例如 ),且神经元 被剪枝,那么神经元 也必须被剪枝。此时如果不裁剪 ,那么网络结构中的维度就会不匹配。基于上述规则,我们可以实现一个自动识别参数依赖的程序,帮助我们剪枝各种 LLM 模型。同时,我们注意到 LLM 的最小可移除子结构通常包含了多个层,这就要求我们设计一种针对整个组的重要性估计策略。

3.2 组重要性估计

目前,我们已经对模型内部的耦合结构进行了建模。为了找到冗余分组,我们需要对将整个组作为一个最小单元进行重要性估计。鉴于对训练数据集的访问权限有限,我们探索使用可用的公共数据集或手动创建的样本作为替代资源。尽管这些数据集的领域可能与训练集不完全一致,但它们仍提供了评估结构组重要性的宝贵信息。

a)权重向量的一阶重要性估计:给定一个数据集 ,其中 N 是样本数量。在我们的实验中,我们设置 N 等于 10,也就是仅需 10 个样本。一个组(如前所述,被定义为一组耦合结构)包含一组互相耦合的参数 ,其中 M 是一个组中耦合结构的数量, 是组内第i层的权重。

在修剪时,我们的目标是移除对模型预测影响最小的结构,因此,参数的重要性可以通过损失函数的剪枝后扰动来判断。我们对损失函数进行泰勒展开并进行移项,可以得到一种简单的重要性估计指标 ,这类重要性评估策略需要我们尽可能精准的对损失函数进行建模,因此我们采用了二阶的泰勒展开:

c651c4514a2032d5fa297d5d0da7fa87.png

其中, 表示 Hessian 矩阵, 表示 Next-token Prediction Loss。通常来说,重要性的一阶项由于模型在训练数据集已完全收敛,此项通常为 0,即 。然而,由于我们选取的数据 并非来自原始数据集,重要性评估中我们通常可以得到非 0 的一阶项,即 (实验中发现我们的数据上的 Loss 高于充分收敛样本上的 Loss)。

重要性估计的第二项要求我们对 Hessian 矩阵 进行计算。然而,由于 Hessian 矩阵的计算复杂度过高,这对参数了巨大的 LLM 而言是不现实的。因此,我们不直接对参数计算 Hessian 矩阵,而是仅考虑估计 hessian 矩阵的对角线元素,这需要引入单个参数标量的重要性评估。

b)单个参数的二阶重要性估计:上述过程对整个权重向量 的进行了估计。实际上,我们可以在更细粒度上得到另一种重要性度量,其中 参数矩阵内的每个数值都被独立地执行重要性估计。

4d2032e3eb2417e3951ddf7592bbdae3.png

这里, 用于索引 中的第 k 个参数。我们使用 Fisher Information Matrix 来近似 Hessian 矩阵的对角线 ,重要性可以被定义为:

e35f893b48b97aaf2f151cca5147f5fd.png

3.3 组重要性聚合

在获得权重向量重要性 以及单个参数的重要性 后,我们还需要将其进行聚合,得到整个分组的重要性。本文提出四种重要性聚合的策略:

  • 求和 Sum: 或者 

  • 乘积 Prod: 或者 

  • 极值 Max: 或者 ;

  • 门控 Last-Only: 或者 

上述聚合方式实际上包含了不同的偏置,例如求和策略认为不同参数的贡献是独立且可叠加的,乘积策略则假设不同层的重要性会相互影响,最大值策略的偏置在于层的重要性由某一层主导。最后,门控策略则认为组内的最后一层主导了整个组的重要性,因为通过将该层设为 0 我们可以使得整个组不再参与网络预测。在评估每个组的重要性之后,我们根据预定义的修剪比例对每个组的重要性进行排序,并修剪重要性较低的组。

3.4 剪枝模型的低秩近似快速恢复

为了加速模型的压缩后训练,同时提高有限数据下的训练效率,我们需要降低优化的参数量。为了实现这一目标,我们将 LoRA 与剪枝模型相结合。模型中的每个可学习的权重矩阵,表示为 ,包含 LLM 中所有剪枝和未剪枝的线性投影。 的更新值 可以被分解为 ,其中 and 。前向计算现在表达为:

820000ecffe056d0d9570996efe01bb7.png

其中 b 是稠密层中的偏置。仅训练 P 和 Q 减少了整体训练复杂性,从而减少了对大规模训练数据的需求。此外,额外的参数 P 和 Q 可以被重新参数化为 ,这不会在最终压缩模型中造成额外的参数。

3.5 算法总结

上述过程给出了 LLM 模型的依赖分析、重要性评估以及后训练的完整方案,通过聚合一阶和二阶泰勒展开,我们可以得到更加鲁棒的重要性评估策略。对于大模型而言,重要性评估指标是尤其重要的,因为剪枝造成的性能损失越大,后训练恢复所需要的数据量、训练时间也就越多

5085565da76a0585bb338500f867ef4a.png

实验

本文对三种开源的 LLM 进行剪枝实验,包括 LLaMA-7B,Vicuna-7B 和 ChatGLM-6B,剪枝前后的模型参数量、MACs 和内存占用如表所示。

eb502ce2f9bcf9ee6f4ee46c1aaebf8a.png

4.1 LLaMA-7B剪枝后模型的Zero-shot能力验证

对于大模型而言,保留原始的多任务处理以及零样本能力是尤其重要的,我们完整评测了不同多种方案,包括权重剪枝、随机剪枝以及本文方法的剪枝效果。我们发现在剪枝 20% 参数的情况下,大模型依旧能保持一定的 zero-shot 能力,同时经过少量微调,zero-shot 性能可以快速提升,约达到基座模型的 94%。

a93b1a1bb9d4d0e9d9240923f67b5a96.png

4.2 Vicuna-7B剪枝后模型的Zero-shot能力验证

ed7758265e05b227050f2d887f28cb37.png

4.3 ChatGLM-6B剪枝后模型的Zero-shot能力验证

dcf0b2c232fbf2d35cb7fcf02dda4a5a.png

4.4 剪枝后模型的生成结果

79962223c5b509c42621058e183e7745.png

更多分析结果请参考论文。

59d36c0360f3ad134a8511b23f3ffbe5.png

总结

在本文中,我们提出了 LLM-Pruner,一种用于大型语言模型的结构化剪枝方法。LLM-Pruner 旨在以任务无关的方式压缩庞大的语言模型,同时尽量减少对原始训练语料库的依赖,并保留 LLM 的语言能力。LLM-Pruner 通过迭代地检查模型中的每个神经元作为识别依赖组的触发器,从而构建 LLM 的依赖图。随后,LLM-Pruner 使用参数级和权重级估计来评估这些组的重要性。

最后,我们利用 LoRA 对被剪枝模型进行快速恢复和调整。我们使用多个 zero-shot 数据集评估了 LLM-Pruner 在三个不同模型(LLaMA,Vicuna 和 ChatGLM)上的有效性。我们的实验结果表明,LLM-Pruner 成功地剪枝了模型,在保留 zero-shot 能力的同时减轻了计算负担。

欢迎试用我们的项目:

https://github.com/horseee/LLM-Pruner

更多阅读

f54db103c29dc53746a4db8be60df07f.png

8a35b0a2cc0d5c94b2fdadc1244a4163.png

333a1861624b6c1b6381f5616a716a87.png

ed60ad066ef98ea21e60b8925b7e2f1e.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

e4d033600eb2b139c17b440a8be1b0c7.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

e3d99d47acc61ab35de0ea09d6cb8d2a.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/618648.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为认证 | HCIE-存储 V3.0 即将发布!

华为认证HCIE-Storage V3.0(中文版)预计将于2023年6月30日正式对外发布。为了帮助您做好学习、培训和考试计划,现进行预发布通知,请您关注。 01 发布概述 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构&…

5周年更新 | OpenVINO™  2023.0,让AI部署和加速更容易

时光匆匆,岁月荏苒,OpenVINO™迎来了5岁生日。5岁,对于OpenVINO™来说还是个很年轻的年纪,一如正在茁壮成长的少年,每天都迸发着无穷的生命力。 在这5年里,OpenVINO™密切关注市场需求,着眼未来…

JavaScript拖动元素在一个范围内移动

基于 jQuery移动范围由 div 搭建(div 模仿表格),卡片的移动不允许超出该范围移动卡片会有一个淡蓝色卡片的标记出将要放置的位置有禁止放置标记的位置,不允许卡片放置(会放到前一个可放置的位置)卡片放置会覆盖单元格中的文字卡片…

TSS半导体放电管八大属性总结

​之前在写关于GDT放电管与TSS放电管之间的差异时,其实有谈到TSS(固体放电管)它拥有的一些特性,今天优恩小编还是想重复一下,希望更多小伙伴能够记住。 TSS,有人叫它固体放电管、也有人叫它半导体放电管&am…

智能交通车路协同系统的应用场景和发展趋势

随着城市化进程的加速和汽车保有量的增加,城市交通拥堵、交通事故等问题日益突出。为了解决这些问题,智能交通车路协同系统应运而生。智能交通车路协同系统是一种基于车载终端、路侧设备和交通管理中心等多个组成部分构成的智能交通系统,可以…

io之netty

写在前面 netty当前是网络io框架的事实标准,基于nio实现,框架的作者是韩国一位姓李的朋友,开始我们这位行李的韩国朋友开发一个io框架mina,但后来其离职,mina也就和其没有关系了,所以后来其改进了mina的不…

Maxcompute数据上云一致性比对

我写过很多如何去对数、如何批量对数的技术文档,最近项目遇到这个问题,我才发现在官方博客上还没有发布过这个课题的文章。这就像灯下黑,太长用到的知识点,反而没有意识到其重要性。 注:这里对数的场景就是指在阿里云…

docker 装机/卸载 Mysql

1、首先,需要安装Docker。可以使用以下命令安装: > yum install docker 2、安装完成后,启动Docker服务: > systemctl start docker3、CentOS7环境下的Docker使用 docker快速部署mysql数据库并初始化 docker快速部署mysq…

Power BI API调用注意事项 (By Power Automate)

注:本文最初发布于https://d-bi.gitee.io和medium, 2023年6月迁移至CSDN 前述 本站关于实现Power BI REST API的博文已有许多,包括: Power BI REST API有多强大?PBI开发者必读Power BI REST API实战教程:PowerQuery为…

基于SSM的便利店系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 基于SSH的便利店系统是…

Java中方法的重载与重写

文章目录 前言方法重载方法重写 前言 提示:这里可以添加本文要记录的大概内容: 方法的重载与重写容易混,所以单独拿出来比较 提示:以下是本篇文章正文内容,下面案例可供参考 方法重载 在同一个类中,允…

springcloud-alibaba (06)RocketMQ控制台安装与启动

RocketMQ控制台 ✨让你的消息传输更高效✨ 如果你是一名开发者,或者是对消息传输有需求的企业用户,那么你肯定不陌生于 RocketMQ,它是一个高可用、高可靠、高性能、分布式消息中间件。但是有时候,在 Windows 上安装和启动 Rocke…

生产环境可用的 Seata-go 1.2.0 来啦!!!

文|刘月财(GitHub ID:luky116) 360 服务端开发专家 Seata-go 项目负责人 本文 2752 字 阅读 7 分钟 发布概览 Seata-go 1.2.0 版本支持 XA 模式。XA 协议是由 X/Open 组织提出的分布式事务处理规范,其优点是对业务代码无侵入。当前…

小巧长续航的主动降噪耳机,更轻更好用,QCY ArcBuds上手

我平时听歌、玩游戏的时候喜欢戴上一副蓝牙耳机,这种耳机选择很多,这几年进步还很快,市面上有很多价格合理、音质出色的选择。我目前用的是一款QCY ArcBuds,这款耳机支持主动降噪,户外使用体验不错,而且它做…

Dockerfile实现LNMP

systemctl stop firewalld systemctl disable firewalld setenforce 0 docker network create --subnet172.18.0.0/16 --opt "com.docker.network.bridge.name""docker1" mynetwork #部署nginx(容器IP 为 172.18.0.10) mkdir /o…

一文讲解 基于C++手写Rpc项目

目录 github 预备知识 集群和分布式 单机聊天服务器 集群聊天服务器 分布式聊天服务器 从集群式 到 分布式聊天服务器 看来只有好处 ,但代价是什么? rpc 的 通信原理 remote procedure call 分布式通信 手写的rpc部分 protobuf>json 好处? 介绍protobuf protob…

RabbitMQ - 死信队列

RabbitMQ - 死信队列 死信的概念死信的来源死信实战死信之TTl死信之最大长度死信之消息被拒 死信的概念 先从概念解释上搞清楚这个定义,死信,顾名思义就是无法被消费的消息,字面意思可以这样理 解,一般来说,producer …

【进程间通信:管道】

目录 1 进程间通信介绍 1.1 进程间通信目的 1.2 进程间通信发展 1.3 进程间通信分类 2 管道 2.1 什么是管道 2.2 匿名管道 2.2.1 匿名管道的使用 2.2.2 使用匿名管道创建进程池 2.3 管道读写规则 2.4 匿名管道特点 2.5 命名管道 2.5.1 概念 2.5.2 使用 1 进程间通…

Learning C++ No.28 【C++11语法实战】

引言: 北京时间:2023/6/5/9:25,今天8点45分起床,一种怎么都睡不够的感觉,特别是周末,但是如果按照我以前的睡觉时间来看,妥妥的是多睡了好久好久,并且昨天也睡了一天,哈…

C#,码海拾贝(32)——计算“实对称三对角阵的全部特征值与特征向量的”之C#源代码

using System; namespace Zhou.CSharp.Algorithm { /// <summary> /// 矩阵类 /// 作者&#xff1a;周长发 /// 改进&#xff1a;深度混淆 /// https://blog.csdn.net/beijinghorn /// </summary> public partial class Matrix {…