高精度压缩Transformer,NNI剪枝一站式指南

news2024/11/18 14:51:22

无论在学术界还是产业界,今年人工智能大模型都是爆款话题。但面对这些动不动就数十亿级别参数的模型,使用传统方法微调,宛如水中捞月、海底捞针。作为微软亚洲研究院为科研人员和算法工程师量身定制的一站式 AutoML(自动机器学习)工具, NNI(Neural Network Intelligence)在过去的三年间不断迭代更新,加强了对各种分布式训练环境的支持,成为了最热门的 AutoML 开源项目之一。

近日,微软亚洲研究院对 NNI 进行了更新。在最新的版本中,NNI 集成了大量前沿的剪枝算法,如 TaylorFO Weight、Movement 等。基于现有的经典预训练模型,研究员们通过大量实验,发现了既能降低模型参数量和计算量,又能保持模型较高精度的剪枝步骤与算法组合,获得超越 SOTA 的模型剪枝效果。

今天我们就以 Transformer 系列的预训练模型和数据集 GLUE-MNLI 为例,为大家介绍一下 NNI 的 pruner 剪枝流程和使用的剪枝算法组合。

剪枝流程

在正式介绍剪枝流程前,我们需要先了解什么是 pruner,mask 和 SpeedUp。

  • pruner:使用具体的剪枝算法实例化的剪枝器。

  • mask:在剪枝过程中,pruner 会生成一个和目标子模块大小相同的 mask(全1)矩阵,并在 mask 矩阵中将目标子模块中需要剪掉的部分的对应位置置为0。最后通过将目标子模块和对应的 mask 矩阵相乘,即可得到模拟剪枝后的模型效果。

  • SpeedUp:从上述描述可以看出,在剪枝过程中,实际上只是将需要剪枝的部分用0进行了替换,因此使用 SpeedUp 模块是修剪上述目标子模块中需要剪掉的参数,而不是用0替代,从而实现真正意义上的减少参数量。

在使用 NNI Compression 模块中的 pruner 进行剪枝操作时,用户只需完成数据/模型等的准备、pruner 的构建,以及模型剪枝和再训练,即可为模型构建一个剪枝的 pipeline。

以 Transformer 系列的预训练模型为例,其剪枝流程共包含4步:首先准备数据/模型等,接着针对多头自注意力机制(Multi-head Attention)、嵌入层(embedding)和前馈神经网络(FFN)分别剪枝和再训练模型。

图1:Transformer 系列模型的剪枝流程示意图

1. 准备数据/模型等

在正式构建剪枝过程之前,用户需要加载预训练模型,对数据预处理并创建相应的 dataloader,同时设计相应的训练/评估函数,以用于后期对模型的训练和评估。其流程如图2所示,共包含5步:

图2:数据/模型准备过程的流程示意图

具体来说,首先需要从 Transformers 库中加载预训练模型,然后对数据 GLUE-MNLI 进行处理,并得到相应的 dataloader。随后,针对模型和数据集 GLUE-MNLI,构建相应的训练/评估函数。最后将模型在 GLUE-MNLI 数据集上进行微调。

完成以上步骤就相当于完成了数据/模型等的准备工作,可以得到预训练模型在 MNLI 数据集上微调后的模型。考虑到 Transformer 系列预训练模型的模型参数中的大头为嵌入层,且编码层/解码层中包含了多头自注意力机制和前馈神经网络。因此,在之后的步骤中需要分别对多头自注意力机制、嵌入层和前馈神经网络剪枝,并引入动态蒸馏机制对剪枝后的模型再训练。

2. 多头自注意力机制的剪枝和基于动态蒸馏机制的模型再训练

多头自注意力模块的剪枝和模型再训练分为3步,如图3所示:首先要构建 pruner,接着对多头自注意力模块进行剪枝,最后使用动态蒸馏机制再训练模型。

图3:多头自注意力机制的剪枝和再训练流程示意图

在进行剪枝前,用户需要选定一个剪枝算法并实例化相应的 pruner。所有的剪枝算法均需向模型中传入 config_list 参数,因为其定义了需要剪枝的运算名、运算类别及稀疏度等。具体到 Movement 剪枝算法,还需要设置其他的一些参数,如:evaluator 参数,用于训练感知的模型压缩过程;movement_mode 参数,共有“soft“和”hard“两种模式,若为”soft”,则难以精确地控制模型剪枝后的稀疏度,但是可以得到性能更好的模型。参数 regular_scale 用于控制剪枝的稀疏度,regular_scale 越大,模型剪枝后的稀疏度越高。更多其他参数可参阅https://nni.readthedocs.io/zh/stable/reference/compression/pruner.html#movement-pruner

接下来,要使用构造的剪枝算法实例 pruner 对多头自注意力模块进行剪枝。用户只需调用 pruner.compress() 即可执行对模型的剪枝过程,并得到剪枝后的模型和 attention_mask。其中 attention_mask 给出了需要剪枝的子模块的参数剪枝范围,0代表该位置被剪掉,1代表该位置被保留。

NNI 的 SpeedUp 模块可以将被 mask 住的参数和计算从模型中删除,具体的删除逻辑如图4所示,以 Query Linear 层的 weight(记作Q)为例,其维度为[768,768],那么 Q 的 weight 的 mask 矩阵维度也为[768, 768],将其记作 mask。首先将该 mask 矩阵的维度进行变换,第一维是多头数目8,其余的则是第二维,将变换后的 mask 矩阵记作 reshaped mask 矩阵。接着,对 reshaped mask 矩阵在第二维度上求和,并判断求和后的值是否为0,此时的 mask 矩阵维度变为[8],每个位置对应着一个多头。对于变换后的 mask 矩阵,若位置 i 的值为0,则代表在 Q 中的第 i 个多头需要被剪掉。在图中,位置0、3、7的值均为0,因此,在Q中的第0、3、7个多头需要被剪掉。最后,将[0,3,7]作为参数传入 prune_heads 函数中,对 Q 进行修剪。修剪后,Q 的维度为[576,768]。对 SpeedUp 更加全面的介绍可以参考发表于 OSDI 2022 的论文 SparTA。在即将发布的 NNI 3.0 中 SpeedUp 会对更多模型提供更加完善的支持。

图4:利用 prune_heads 函数修剪自注意力模块的过程示意图

在对多头自注意力模块剪枝后,以微调后的模型作为教师模型,以剪枝后的模型作为学生模型,然后借鉴 CoFi 中的动态蒸馏机制 [1] 对模型进行再训练,就可以得到新的模型。这里的动态蒸馏机制,是指教师模型的层和学生模型的层之间不是一个静态对应关系,每次蒸馏教师都可以选择从自身的高层动态蒸馏信息到学生模型低层中的一层里。

3. 嵌入层和前馈神经网络的剪枝,以及基于动态蒸馏机制的模型再训练

嵌入层和前馈神经网络的剪枝过程与多头自注意力模块的剪枝过程类似。此处使用 Taylor 剪枝算法 (https://nni.readthedocs.io/zh/stable/reference/compression/pruner.html#taylor-fo-weight-pruner ) 对嵌入层和前馈神经网络进行剪枝。同样地,研究员们定义了 config_list、evaluator 参数及 taylor_pruner_steps 参数。由于嵌入层的维度与后续模型中的维度具有相关性。因此,基于上述参数,在嵌入层的剪枝过程中研究员们将剪枝模式 mode 设置为了“dependency-aware”模式,并传入模型的输入 dummy_input,以帮助 pruner 捕捉和嵌入层维度具有依赖关系的子模型。

接下来,使用分别构造的 pruner 对前馈神经网络和嵌入层进行剪枝。和多头自注意力模块的剪枝不同的是,此处使用了迭代式剪枝法,即在模型基于动态蒸馏的再训练过程中,每2000步分别使用 pruner 对前馈神经网络和嵌入层剪枝一次,其中,前馈神经网络共剪枝19/24次,嵌入层共剪枝3次。每次剪枝后,使用 ModelSpeedUp 对前馈神经网络层进行剪枝,以实现真正意义上的修剪参数,而不是将需要修剪的参数用0替换。

实验结果

通过调整 regular_scale 参数的值和前馈神经网络的剪枝次数,研究员们得到了具有不同稀疏度和性能的模型。该过程使用了1张 A100 进行实验,并设置 batch_size 为32。

图5:实验结果

从上图实验结果可以看出:

  1. 随着 regular_scale 的增加,模型总的稀疏度有所增加。当 regular_scale 大于等于10时,模型总的稀疏度超过了69%,性能损失超过1%。

  1. 随着前馈神经网络剪枝次数的增加,模型总的稀疏度有所增加,同时模型的性能有所下降,且随着模型总稀疏度的增加,模型的性能下降程度逐渐增大。

  1. 对嵌入层剪枝3次,能够将模型的维度从768减小至561,在一定程度上提升了模型总的稀疏度。

实验结果与平台对比

进一步分析实验结果可以发现,使用 NNI 对 BERT 在 MNLI 数据集上剪枝后的性能好于 nn pruning 框架(图6(a)),且当模型总的稀疏度低于65%时,NNI 和 CoFi 对 BERT 在 MNLI 数据集上剪枝的性能差距较小,当模型总的稀疏度大于65%时,使用 NNI 对 BERT 在 MNLI 数据集上剪枝后的性能好于 CoFi。图6(b)和图6(c)分别展示了 NNI 在 T5 和 ViT 模型上的剪枝性能。从图中可以看出,当模型相应部分的稀疏度超过了75%后,模型性能下降约为3%,当模型相应部分的稀疏度低于50%时,模型性能下降较少。

(a)

(b)

(c)

图6:NNI 在经典预训练模型下的剪枝性能示意图

三个平台(Paper)的详细比较结果,如表1所示。可以看出,NNI 的 Compression 模块不仅具有完整的教程实例,同时还提供了 SpeedUp 模块,能够实现真正意义上的减少模型参数量,而非将需要修剪的参数置为0。

同时,NNI 支持 BERT、RoBerta、GPT、BART、T5、ViT 等主流模型,并提供了 Taylor、Movement、ADMM、Slim、AGP、Activation APoZ、Activation Mean 等16种前沿剪枝算法,能够更好地满足用户的需求,具有较强的通用性。

表1:各平台(Paper)功能对比总结

展望未来

在 NNI 3.0 版本中,微软亚洲研究院的研究员们还将引入蒸馏模块,更好地为用户提供集剪枝、蒸馏为一体的压缩工具,同时 SpeedUp 模块也将更全面地支持对 Transformer 的修剪。敬请期待!

关于最新版 NNI 的完整代码和 tutorial,请参见:

https://nni.readthedocs.io/zh/stable/tutorials/pruning_bert_glue.html

NNI 快速入门视频教程:https://space.bilibili.com/1649051673

参考论文:

[1] https://arxiv.org/pdf/2204.00408.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/192032.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vertical-align属性值区分

先简单看一下官方文档上的Vertical-align这些属性值的描述,虽然只有简单的描述,但其实描述的很清楚,但有时只看文字并不能很好的理解其含义。 下面结合代码图说明: 首先,我建造了一个背景颜色为绿色的div盒子&#xf…

9、app稳定性测试之monkey工具使用

简介 方法:利用Monkey工具,选择某些场景做持续反复操作,以衡量系统的稳定性 工具: monkey 友盟埋点 日志分析工具 系统监控工具GT 场景: * 随机测试 可以用monkey模拟 * 多个运行中app切换测试; * 各种事件打扰,如插拔数据线、电话打扰、收…

软件测试之冒烟测试须知

冒烟测试的介入时间? 开发编码完成,自测通过以后为最佳介入时间。 如果开发无自测直接提交,一般冒烟测试通过率会很低【除非你遇到的是大内高手】 什么需求需要做冒烟测试? 理论上,所有的需求均可以做冒烟测试。 冒烟测试需要做几轮? 一轮…

分享会上狂吹MySQL的4大索引结构,没想到大家的鉴赏能力如此的~~~~

索引(index)是帮助MySQL高效获取数据的数据结构(有序)。在数据之外,数据库系统还维护着满足 特定查找算法的数据结构,这些数据结构以某种方式引用(指向)数据, 这样就可以在这些数据结构 上实现高…

企业需要做哪些准备,来落地商业智能 BI 系统

随着新一代信息化、数字化技术的应用,引发了新一轮的科技革命,现代化社会和数字化的联系越来越紧密,数据也变成继土地、劳动力、资本、技术之后的第五大生产要素,这一切都表明世界已经找准未来方向,前沿科技也与落地并…

中国电子学会2022年09月份青少年软件编程Scratch图形化等级考试试卷四级真题(含答案)

2022-09 Scratch四级真题 分数:100 题数:29 测试时长:60min 一、单选题(共15题,共30分) 1.运行下列程序,说法正确的是?(D) A.列表中的数字全部小于11 B.列表的长度为10 C.变量…

项目管理:甘特图的作用是什么?

在我们工作和学习中,有一个提高工作效率,简单又实用的神器——甘特图。 甘特图以表格进度条,展示任务列表和时间表示出项目的持续时间及进度。并根据实际执行时间和工期对计划进行动态调整的进度控制方法。 甘特图将各个任务的完成情况在时间…

论Unity_InputSystem如何使用(三)

PlayerInput InputSystem提供专门用来处理玩家输入的组件,通过关联配置输入文件,可以不需要编写设备输入的相关逻辑,专注于编写输入触发后的逻辑。 如何添加 创建一个Cube,点击Add Component,搜索Player Input即可添…

【C语言 数据机构】时间复杂度与空间复杂度

文章目录时间复杂度空间复杂度时间复杂度 判断一个算法所编程序运行时间的多少,并不是将程序编写出来,通过在计算机上运行所消耗的时间来度量。原因很简单,一方面,解决一个问题的算法可能有很多种,一一实现的工作量无疑…

cocoapods安装失败到成功的记录贴

mac系统版本:10.15.5 (19F101) 最优解安装顺序:Xcode > HomeBrew > RVM > Ruby > CocoaPods 1. 安装方案1(百度常用法) 1.1 更新gems和换国产源: RubyGems 镜像 - Ruby Chinahttps://gems.ruby-china.co…

使用VBA获取电脑MAC地址

实例需求:如何使用VBA读取电脑的MAC地址,包含有线网卡和无线网卡。 这个需求看似有些无厘头,为嘛要用VBA来读取MAC地址,存在的就是合理的。例如使用MAC地址和其他硬件信息可以生成电脑的唯一识别号,用于软件注册和实现…

Vue Node

Vue配置代理服务器 一、运行后台服务 启动后台Node服务器,运行后台程序,学习资料node代码,服务5000开启 FeHelper - Awesome 二、Ajax请求 xhr 【不常用】Windows 内部 new XMLHttpRequest()xhr.open() xhr.send()内部公司封装xhr开源封装…

力扣 2325. 解密消息

题目 给你字符串 key 和 message ,分别表示一个加密密钥和一段加密消息。解密 message 的步骤如下: 使用 key 中 26 个英文小写字母第一次出现的顺序作为替换表中的字母 顺序 。 将替换表与普通英文字母表对齐,形成对照表。 按照对照表 替换…

OAuth2 01

目录 1.什么是OAuth 2.OAuth2中的角色 3.认证流程 4.生活中的OAuth2思维 5.令牌的特点 6.OAuth2的授权方式 6.1 OAuth2授权码 6.2 隐藏方式 6.3 密码方式 6.4 凭证方式 1.什么是OAuth2 1.OAuth2.0介绍 OAuth(Open Authorization)是一个关于授权&…

Android 抓包相关 SSL相关

https无法明文抓包 Android P版本开始强制App使用Https协议,否则访问崩溃如下所示错误: java.lang.ClassCastException: com.android.okhttp.internal.huc.HttpURLConnectionImpl cannot be cast to javax.net.ssl.HttpsURLConnection可参阅&#xff…

C 语言零基础入门教程(二十三)

C 可变参数 有时,您可能会碰到这样的情况,您希望函数带有可变数量的参数,而不是预定义数量的参数。C 语言为这种情况提供了一个解决方案,它允许您定义一个函数,能根据具体的需求接受可变数量的参数。下面的实例演示了…

Centos8中安装配置php

一、问题描述Centos8中我们在使用Apache部署配置网站的时候,发现Apache服务已经正常启动且网站也配置完成到Apache主目录中,但是访问时网站却不能正常运行【即:只能够以列表的方式列出所有网站的资源文件,而不是以网页的形式展现】…

关于荧光素76863-28-0,FITC-5-thiosemicarbazide,荧光素-5-氨基硫脲 相关知识分享

荧光素-5-氨基硫脲,Fluorescein-5-thiosemicarbazide,FITC-5-thiosemicarbazide荧光素-5-氨基硫脲是一种含胺的荧光探针,可用于标记糖和蛋白质羰基衍生物Product specifications:1.CAS No:76863-28-02.Molecular formu…

超越OCR的富文档内容解析神器LayoutParser

论文题目:《A unified toolkit for Deep Learning Based Document Image Analysis》 论文链接:https://arxiv.org/abs/2103.15348 论文官方网站:https://layout-parser.github.io/ 论文开源项目:https://github.com/Layout-Par…

Ubuntu 18.04安装配置MySQL数据库

文章目录1. 安装MySQL数据库2. 配置MySQL数据库3. 远程访问设置4. Navicat连接MySQL数据库1. 安装MySQL数据库 这里可以通过包管理工具apt安装MySQL数据库,在ubuntu18.04下mysql版本默认为5.7。 安装命令如下: sudo apt-get install mysql-server安装…