基于模型蒸馏的模型加速方案总结

news2024/11/25 14:46:00

1.简介

1.1目的

在过去的一段时间里,对基于模型蒸馏技术的模型加速方案的方法在多个数据集上进行了一系列的实验。所谓的模型蒸馏技术,简单的来说就是利用一个设计简单的小网络去学习一个设计比较复杂的大网络。特别的有,本次实验针对每一个复杂的(teacher model)大模型,都设计了多个简单的(student model)小模型去学习,并且针对不同的超参数组合,本文给出了多组对比实验。详细的实验结果以及相应实验现象的分析和总结将在下文给出。

1.2范围

本文档描述的代码修改以及实验方法都是基于caffe框架进行的,添加的新层有SoftmaxWithLossWithSoftTargetLabel和SoftmaxWithLossWithLabelSmooth。主要的加速策略是利用参数少的(student model)小模型去学习参数多的(teacher model)大模型,所有的实验都在mnist数据集、cifar10数据集、以及年龄和性别属性相关数据集进行的,其中需要调节的超参数包括温度超参数T,loss比例超参数LAMDA。

1.3定义、首字母缩写词和缩略语

序号

术语或缩略语

说明性定义

1

2

3

4

5

6

7

1.4 参考资料

《基于模型蒸馏技术的模型加速方案实验设计v2.pdf》

Distill the Knowledge in a Neural Network

2.实验的方法——模型蒸馏

        本文档中的实验是基于caffe框架进行的,修改了其中的源码,并添加了新的层,使得这个框架可以按照制定的模型加速方案进行运行和测试。

2.1 为什么需要模型蒸馏

一个很大的DNN往往训练出来的效果会比较好,并且多个DNN一起ensemble的话效果会更加的好,但是当用在实际的应用中的话,过于庞大的DNN ensemble在一起会增大计算量,从而影响应用。于是一个问题就被提出了:有没有一个方法,能使降低网络的规模,但是保持(一定程度上)精确度呢?

Hinton举了一个仿生学的例子,就是昆虫在幼生期的时候往往都是一样的,适于它们从环境中摄取能量和营养;然而当它们成长到成熟期,会基于不同的环境或者身份,变成另外一种形态以适应这种环境。那么对于DNN是不是存在类似的方法?在一开始training的过程中比较的庞杂但是后来当需要拿去deploy的时候,可以转换成一个更小的模型。他把这种方法叫做Knowledge Distillation(KD)。

2.2模型蒸馏的基本原理

这里的distillation方法其实主要用的就是通过一个performance非常好的大网络(有可能是ensemble的)来教一个小网络进行学习。这里我们可以把大网络叫为:teacher network,小网络叫为:student network。至于为什么是希望通过大网络来教小网络而不是直接利用ground truth label来学习,hinton也给了一个例子:比如说在MNIST数据集中,有两个数字“2”,但是写法是不一样的:一个可能写的比较像3(后面多出了一点头),一个写的比较像7(出的头特别的短)。在这样的情况下,ground truth label都是“2”,然而一个学习的很好的大网络会给label “3” 和 “7” 都有一定的概率值,如图1所示。通常叫这种信息为 “soft targets”;相对的,ground truth label 是一种 “hard target” 因为它是one-hot label。总的来说就是,通过大网络的“soft targets”,能得到更加多的信息来更好的训练小网络。

图1 hard target vs. soft target

论文中所提出的上述soft target实际上就是已经训练好的复杂模型的softmax层的输出概率,而其中所提出的“蒸馏”方法在softmax层中引入了一个”温度”参数T,如公式(1)所示:

qi=expzi/Tjexpzj/T

(1)

      其中zi 表示的是logit,即softmax层的输入;qi 表示经过softmax层计算后的每个类别的概率;T 表示的就是上述的温度参数,通常设置为1。不过通过上述温度参数的调整,softmax层的映射曲线更加平缓,因而实例的概率映射将更为集中,便使得目标更加地"soft"。并且有论文中还指出,当transfer set中的标签可得时,将soft target和实际标签的两个目标共同使用作为目标函数将使得其性能更加提高。在训练过程中,作者将迁移样本集中样例输入原复杂模型并通过上述蒸馏softmax得到soft target,并将其作为目标,并在迭代过程中更新温度,训练出细粒度的模型。

蒸馏”最简单的形式就是:以从复杂模型得到的“软目标”为目标(这时T比较大),用“转化”训练集训练小模型。训练小模型时T不变仍然较大,训练完之后T改为1。 

当“转化”训练集中部分或者所有数据都有标签时,这种方式可以通过一起训练模型使得模型得到正确的标签来大大提升效果。一种实现方法是用正确标签来修正“软目标”,但是论文中发现一种更好的方法是:对两个目标函数设置权重系数。第一个目标函数是“软目标”的交叉熵,这个交叉熵用开始的那个比较大的T来计算。第二个目标函数是正确标签的交叉熵,这个交叉熵用小模型softmax层的logits来计算且T等于1。论文中指出当第二个目标函数权重较低时可以得到最好的结果。整体的结构如图2所示:

图2 模型蒸馏的整体结构

2.3为什么使用soft target会有用

图3 soft target的用处

信息量:

hard target 包含的信息量(信息熵)很低,soft target包含的信息量大,拥有不同类之间关系的信息(比如同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target 那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。)

软化:

问题是像图3左侧的红色0.001这部分,在cross entropy的loss function中对于权重的更新贡献微乎其微,这样就起不到作用。把soft target软化(整体除以一个数值后再softmax),就可以达到右侧绿色的0.1这个数值,这样在后来权重的更新中就有一定的贡献了。

3. 模型蒸馏实验设计

3.1 蒸馏模型训练过程

实验步骤:

1.根据提出的目标问题设计一个或者多个复杂的网络结构(N1,N2,…,Nt)。

2.收集足够多的训练数据,按照常规CNN模型训练流程,训练好1中的一个或者多个复杂网络得到(M1,M2,…,Mt),记为原始网络。

3.收集简单模型训练数据,此处的训练数据可以是训练原始网络的有标签数据,也可以是额外的无标签数据。

4.修改原始模型(M1,M2,…,Mt)的softmax层中温度参数T为一个较大值如T=20,将3中收集到的样本输入到原始复杂模型中。每一个样本在每个原始模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前样本的判定结果。对于t个原始模型就可以得到t个概率向量。那么对这t个概率向量求取均值作为当前样本最后的概率输出向量,记为soft_target label,最后保存到文件中。

5.根据(N1,N2,…,Nt)重新创建一个精简的小网络N0,该网络最后有两个loss,一个是hard loss,即传统的softmaxloss,使用one shot label;另外一个是soft loss,即T>1的softmaxloss,使用我们第4步保存下来的soft target label。

6.设置精简的小网络N0的softmax层温度参数与原始复杂模型产生soft target label时所采用的温度一致,如T=20,按照常规模型训练精简的小网络得到模型M0。

7.训练完成之后,在实际应用中将精简的小模型中的softmax温度参数重置为1,即采用最原始的softmax,来走前向作为最后输出的小模型。

上述的训练过程可以用图4简单表示:

图4 模型蒸馏的训练过程

3.2 论文中的经验

论文中作者认为,由于soft target具有更高的熵,它能比hard target提供更加多的信息,因此可以使用较少的数据以及较大的学习率。将hard和soft的target通过加权平均来作为学生网络的目标函数,soft target所占的权重更大一些。 论文中作者同时还指出,T值取一个中间值时,效果更好,而soft target所分配的权重应该为T^2,hard target的权重为1。 这样训练得到的小模型也就具有与复杂模型近似的性能效果,但是复杂度和计算量却要小很多。 
对于distilling而言,复杂模型的作用事实上是为了提高label包含的信息量。通过这种方法,可以把模型压缩到一个非常小的规模。模型压缩对模型的准确率没有造成太大影响,而且还可以应付部分信息缺失的情况。

4.实验结果及其分析

针对上述方法修改完毕之后的caffe框架,分别在mnist数据集,cifar10数据集以及年龄和性别属性识别数据集上分别对不同的温度超参数T以及loss比例超参数lamda做了多组对比实验。其中T的取值为0.5,1,2,3,5,10,20,lamda的取值为0.01,0.1,0.3。具体的实验结果以及实验分析如下所示。

4.1 mnist

图5 mnist上模型蒸馏的实验结果

    从图5中可以看出,teacher model一共有四层参数层,即两个卷积层以及两个全连接层,尺寸为20_50_500_10,其中的数字表示的为caffe中prototxt中每一层的num_output的大小。其精度可以达到很高的精度0.9914;student model这里设计了三组对比实验,相应的尺寸分别为10_25_250_10、4_10_100_10、2_5_50_10。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的实验结果都要优于train_smallnet_from_scratch的实验结果,前者表示的是模型蒸馏的结果,后者表示从头训练小模型的结果,即图中的红色部分和small net的baseline进行对比。
  • 特别的有模型大小为10_25_250_10即尺寸为原始大模型的一半的时候,当lamda设置为0.3,T设置为3的时候,小模型经过对大模型的学习是可以达到大模型的精度的。
  • 第三对于当前的实验对于mnist数据集可以看出,最优值的基本上是在T超参数设置为3左右的时候出现的。

    综上所述,当大模型的精度很高的时候,模型蒸馏(知识提取)的效果可以达到很好,并且小模型经过学习是能够达到大模型的那种效果的。

4.2 cifar10

图6 cifar10数据集上模型蒸馏的实验结果

从图6中可以看出,teacher model一共有四层参数层,即两个卷积层以及两个全连接层,尺寸为32_32_64_10,其中的数字表示的为caffe中prototxt中每一层的num_output的大小。其精度不高只有0.7937;student model这里设计了三组对比实验,相应的尺寸分别为16_16_32_10、8_8_16_10、4_4_8_10。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的实验结果部分优于train_smallnet_from_scratch的实验结果。这个和mnist的实验结果有点差距。
  • 所有的train_smallnet_from_kd的实验结果都达不到最初的大模型的效果。
  • 对于当前的实验对于mnist数据集可以看出,最优值的基本上是在T超参数设置为1、2左右。

综上所述,当大模型的精度不高的时候,在特定的小模型尺寸、温度参数T以及lamda参数设置后的模型蒸馏(知识提取)也可以达到一定效果,但是最终达不到原始大模型的精度。

   

4.3 年龄和性别属性

图7 年龄和性别属性数据集上模型蒸馏的实验结果

从图7中可以看出,teacher model是一个具有12个卷积层CaffeNetConv网络,年龄和性别属性的精度都挺高分别为0.912161和0.98991。而student model分成了三种。第一种为 具有6个卷积层的CaffeNetConv网络,仅对年龄的属性进行模型蒸馏(知识提取);第二种还是具有12个卷积层的CaffeNetConv,但是其每层的num_output的大小减半,仅对年龄的属性进行模型蒸馏(知识提取);第三种的网络结构和第二种的网络结构一致,并且对年龄和性别两个属性同时进行模型蒸馏(知识提取)。其中12layers_half表示的是只对age进行模型蒸馏,12layers_half_both表示的是对age和gender同时进行模型蒸馏。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的结果都要优于train_smallnet_from_scratch的结果。这个和mnist的实验结论一样。
  • 6层小网络模型在经过多组对比实验中都不能达到原始12层大网络的效果,然而在保持原始深度12层,将参数量减半的小网络在经过模型蒸馏(知识提取)后,却可以达到甚至超过原始12层大网络的精度。比如,当lamda设置为0.3,T设置为3的时候的第二种小网络以及lamda设置为0.1,T设置为3的时候的第三种小网络。
  • 对于当前的实验对于年龄以及性别属性数据集可以看出,最优值的基本上是在T超参数设置为3左右。
  • 从第二种小网络的实验结果可以看出,随着相应T和lamda超参数的设定,会使得年龄属性的精度上升,但是相反会导致未进行模型蒸馏的另一个分类任务精度的下降。
  • 从第三种小网络的实验结果可以看出,将年龄和性别的分类任务都进行模型蒸馏,可以解决上一个问题,性别精度的都有所提高,但是有可能会使得年龄的精度有些许下降(很少)。

原本的年龄属性识别网络中就已经引入了label smooth的思想,这个和模型蒸馏(知识提取)的思想很类似,所以本身年龄属性识别模型蒸馏的效果可能会被弱化。

5.总结

   在经过三个数据集上对模型蒸馏(知识提取)的方法进行实验,都表明模型蒸馏方法的有效性。当原始模型精度很高的时候,模型蒸馏的效果往往都会很好,并且在特定的模型T和lamda超参数的组合下,小的student model可以很好的学习到大的teacher model,甚至会超过原始大的网络的精度。相反,当原始teacher model的精度就不高,如cifar10中的实验效果一样,模型蒸馏的效果要差些,可能达不到原始teacher model的精度,甚至差距还挺大。特别的在属性数据集上的实验中可以看出,同样参数量的两种student网络,保持深度缩小宽度的小网络要比缩小深度保持宽度的小网络模型蒸馏的效果会更好。特别的有当原始的teacher model是一个多分类任务的时候,我们的实验表明如果仅对一个任务进行模型蒸馏,会使得其他分类任务的精度下降,而对多个分类任务都进行模型蒸馏的话,虽然没有单独模型蒸馏的效果那么好,但是所有模型的精度都会上升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1870753.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机图形学入门18:阴影映射

1.前言 前面几篇关于光栅化的文章中介绍了如何计算物体表面的光照,但是着色并不会进行阴影的计算,阴影需要单独进行处理,目前最常用的阴影计算技术之一就是Shadow Mapping技术,也就是俗称的阴影映射技术。 2.阴影映射 Shadow Map…

压缩pdf文件大小的方法,如何压缩pdf格式的大小

pdf太大怎么压缩?当你需要通过电子邮件发送一个PDF文件,却发现文件太大无法成功发出时,这些情况下,我们都需要找到一种方法来压缩PDF文件,以便更便捷地进行分享和传输。PDF文件的大小通常与其中包含的图片、图形和文本…

【智能算法应用】麻雀搜索算法在物流配送中心选址的应用(无待选配送中心)

目录 1.算法原理2.数学模型3.结果展示4.参考文献5.代码获取 1.算法原理 【智能算法】麻雀搜索算法(SSA)原理及实现 2.数学模型 模型假设 待定物流配送中心的库存总能满足需求点的需求不考虑从工厂到待定物流配送中心的运输成本不考虑选定区域内待确定…

openinstall拥抱鸿蒙生态,SDK全面适配HarmonyOS NEXT

作为国内领先的App渠道统计与深度链接服务商,openinstall持续推动鸿蒙生态建设,近日正式发布openinstall HarmonyOS SDK,并成功入驻鸿蒙生态伙伴SDK专区,成为华为鸿蒙生态的合作伙伴,为鸿蒙应用开发者带来安全合规、高…

可用的搜索引擎

presearchhttps://presearch.com/yandexhttps://ya.ru

MySQL简介:开源数据库的基石(一)

目录 引言:数据库领域的革新者 一、MySQL的发展历程:从开源先锋到行业领袖 二、MySQL的核心特性:性能、安全与灵活性并重 三、MySQL的应用场景:从Web开发到企业级应用的全面覆盖 四、MySQL在开源数据库中的地位:开…

从爬取到分析:Faraday爬取Amazon音频后的数据处理

什么是Faraday? Faraday是一个简单、灵活的高级爬虫框架,支持多种编程语言。它提供了一套丰富的API,允许开发者快速构建复杂的爬虫应用。Faraday的主要特点包括: 模块化设计:易于扩展和自定义。多语言支持&#xff1…

如何寻找强势货币和弱势货币?

外汇交易的独特之处在于,它融合了两种货币的价值,其中一种货币的价值通过另一种货币来体现。举例来说,USDJPY外汇反映了美元与日元之间的价值关系,而EURUSD则代表了欧元与美元的价值对比。 通过开仓操作,我们预测一种…

ai assistant 是所有编程助手中最出色的一款 ?

ai assistant激活成功后,如图 ai assistant渠道:https://web.52shizhan.cn/activity/ai-assistant 在去年五月份的 Google I/O 2023 上,Google 为 Android Studio 推出了 Studio Bot 功能,使用了谷歌编码基础模型 Codey,Codey 是…

nacos漏洞汇总

1 nacos介绍 1.1 nacos是啥 Alibaba Nacos是阿里巴巴推出来的一个新开源项目,是一个更易于构建云原生应用的动态服务发现、配置管理和服务管理平台。致力于帮助发现、配置和管理微服务。Nacos提供了一组简单易用的特性集,可以快速实现动态服务发现、服…

【Java Web】Vite构建前端目录结构

目录 一、Vite概述 二、Vite构建Vue3工程化项目 三、ViteVue3项目目录结构 四、ViteVue3项目组件(SFC入门) 五、ViteVue3样式导入方式 六、ViteVue3响应式数据和setup语法糖 一、Vite概述 Vite是一种新型前端构建工具,能够显著提升前端开发体验;Vite结合…

PVE 8.2.2安装OpenWrt 23.05.3

1,下载官方openwrt 23.5.3镜像并解压 2,进入pve上传镜像 复制这段文字之后需要使用 创建虚拟机 删除磁盘 安装完毕后 shell 运行 qm importdisk 100 /var/lib/vz/template/iso/openwrt-23.05.3-x86-64-generic-ext4-combined-efi.img local-lvm 其中100是虚拟…

音频处理3_频率分贝

1. 频率分类 频率 (Frequency): 指声波振动的速度,通常以赫兹(Hz)为单位。频率决定了声音的高低,即音高。 高低频:高频声音听起来尖锐或高亢,而低频声音则沉闷或低沉。 听觉范围: 人类通常是20 Hz到20,000 Hz(20 kHz)。在这个范围内,声音的频率越高,听到的音调就越高…

基于ssh框架的个人博客源码

基于ssh的个人博客源码,页面清爽简洁,原先有部分bug,运行不了,现已修复 1.博客首页 (本地访问地址 :localhost:8080/Blog/index/index) 2.关于我 3.慢生活 4.留言板 5.我的相册 微信扫码下载源码

科林Linux7_网络爬虫

一、爬虫 网络资源的下载工具,工作与万维网环境,持续获取网页网站中的网络信息。可持续的数据采集机器人 1、搜索引擎技术使用爬虫 2、数据分析、数据挖掘领域,需要爬虫进行数据准备 3、数据批处理、采集,大量获取某些网站中的…

测试用例设计方法-判定表法

一、简介 测试用例设计的重要性和挑战 在软件开发过程中,测试用例的设计是确保软件质量的关键步骤之一。测试用例是描述如何验证软件功能是否按照预期工作的详细指导。它们不仅帮助发现潜在的缺陷和问题,还能确保软件在不同情况下的稳定性和可靠性。以下…

redis集群简单介绍及其搭建过程

Redis集群 1、哨兵模式 哨兵可以有多个,从服务器也可以有多个,从服务器也可以有多个,在Redis3.0以前的版本要实现集群一般是借助哨兵sentinel工具来监控master节点的状态,如果master节点异常,则会实现主从切换&#x…

Potato(土豆)一款轻量级的开源文本标注工具(二)

示例项目(模版) Potato 旨在提高数据标注的可复制性,并降低研究人员设置新标注任务的成本。因此,Potato 提供了一系列预定义的示例项目,并欢迎公众向项目中心贡献。如果您使用 Potato 进行了自己的标注工作&#xff0…

RK3568开发笔记(三):瑞芯微RK3588芯片介绍,入手开发板的核心板介绍

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/139905873 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

临时挂载字体文件工具

一、简介 1、FontLoader是一款专为字体管理和快速加载设计的工具,它能够在不占用系统资源的情况下,实现字体的临时加载和快速切换。用户可以将字体文件存放在系统之外的硬盘分区,并通过FontLoader直接从内存中加载这些字体,从而避…