知识蒸馏(paper翻译)

news2024/10/7 12:27:11

paper:Distilling the Knowledge in a Neural Network

摘要:

提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。
不幸的是,使用整个模型集合进行预测非常麻烦,并且计算成本可能太高,无法部署到大量用户,尤其是在单个模型是大型神经网络的情况下。
Caruana 和他的合作者 [1] 已经证明,可以将集成中的知识压缩到单个模型中,该模型更容易部署,并且我们使用不同的压缩技术进一步开发了这种方法。
我们在 MNIST 上取得了一些令人惊讶的结果,并且表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进频繁使用的商业系统的声学模型。
我们还引入了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类别。 与专家的混合不同,这些专业模型可以快速并行地进行训练。

Introduction

许多昆虫都有幼虫形态和完全不同的成虫形态,幼虫形态经过优化,可以从环境中获取能量和营养,而成虫形态则可以满足不同的旅行和繁殖要求。

在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它并不需要实时操作,因此可以使用大量的计算。

然而,部署到大量用户对延迟和计算资源有更严格的要求。 与昆虫的类比表明,如果可以更轻松地从数据中提取结构,我们应该愿意训练非常繁琐的模型(后面称为大模型)。
大模型可能是单独训练的模型的集合,也可能是使用非常强大的正则化器(例如 dropout)训练的单个非常大的模型[9]。

一旦繁琐的模型训练出来,我们就可以使用不同类型的训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型。 Rich Caruana 及其合作者已经率先提出了该策略的一个版本 [1]。 在他们的重要论文中,他们令人信服地证明,通过大型模型集合获得的知识可以转移到单个小型模型中。

可能阻止对这种非常有前途的方法进行更多研究的一个概念障碍是,我们倾向于使用学习到的参数值来识别经过训练的模型中的知识,这使得我们很难看到如何改变模型的形式但保持相同的知识。

知识的一个更抽象的观点是,它是从输入向量到输出向量的学习映射,将其从任何特定的实例化中解放出来。
对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练后的模型将概率分配给所有不正确的答案,即使这些概率非常小,其中一些也比其他概率大得多。

错误答案的相对概率告诉我们很多关于大模型如何泛化的信息。 例如,BMW的图像可能只有很小的机会被误认为是垃圾车,但这种错误的可能性仍然比将其误认为是胡萝卜高很多倍。

一般认为,用于训练的目标函数应尽可能地反映用户的真实目标。尽管如此,模型通常被训练为优化训练数据上的性能,而真正的目标是要对新数据具有良好的泛化能力。
显然,更好的做法是训练模型以便它们能够很好地泛化,但这需要关于正确泛化方式的信息,而这些信息通常是不可用的。

然而,当我们将大模型的知识提炼到小模型时,可以训练小模型与大型模型相同的方式进行泛化。
如果大模型泛化得好,例如,因为它是多个不同模型大型集合的平均,那么训练小模型以相同方式泛化,在测试数据上通常会比按照常规方式在同一个训练集上训练的小模型表现更好,训练集就是训练大模型的集合的。

将大模型的泛化能力转移到小模型的一个明显方法是使用大模型产生的class probability作为训练小模型的“soft targets”。
对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集。 当大模型是简单模型的大型集合时,我们可以使用它们各自的预测分布的算术或几何平均值作为soft targets。

当soft targets具有高熵时,它们在每个训练case中提供的信息比hard targets多得多,并且训练case之间梯度的方差要小得多,因此小模型可以用更少的数据,更大的learning rate进行训练。

对于像MNIST这样的任务,大模型几乎总以很高的置信度得出正确答案,大量关于学习function的信息寄存在soft targets中非常小概率的比率里。例如,一个版本中,2可能以10-6的概率被认为是3,10-9的概率被认为是7,而另一个版本可能恰好相反。这是有用的信息,因为它定义了数据的丰富的类似结构(即它指出哪些2看起来像3,哪些看起来像7),但在transfer阶段它对交叉熵损失函数的影响非常小,因为这些概率接近于零。

Caruana及其合作者通过使用logits(最后的softmax层的input)而不是用由softmax产生的概率作为学习小模型的target 来避开这个问题,并且他们最小化大模型和小模型产生的logits之间的平方差。更通用的解决方案,称为“蒸馏”,是将最后的softmax层的温度提高,直到大模型产出一套合适的soft target。然后训练小模型时用相同的高温,以匹配这些soft targets。我们稍后将展示,匹配大模型的logits实际上是蒸馏的一个特殊case。
这里的 “温度” 在后面的公式中体现

用于训练小模型的转移集可以完全由未标记数据组成[1],或者我们可以使用原始训练集。我们发现使用原始训练集效果很好,尤其是如果我们在目标函数中增加一个小项,鼓励小模型预测真实的target, 并且匹配由大模型提供的soft target。

通常,小模型无法完全匹配soft target,而在正确答案的方向上犯错被证明是有帮助的。

蒸馏

softmax的input称为logits, 用 z i z_{i} zi表示,
softmax的output称为概率,用 q i q_{i} qi表示。
神经网络通常用一个softmax层把logits转为概率,通过把 z i z_{i} zi与其他概率作比较。
在这里插入图片描述
公式里面的T就是上面说的蒸馏的温度。T通常是1. 更高的T产生更加soft的概率分布。

如何设置温度T?

在最简单的蒸馏形式中,准备一个transfer set数据集,它的label是大模型通过调高T产生的soft target,训练蒸馏模型时也要用同样的T,训练完成后T=1.
通过在transfet set上训练蒸馏模型,知识就被转移到了蒸馏模型。

同时使用label和soft target

当所有或部分transfer set的正确label已知时,还可以通过训练蒸馏模型来生成正确的标签来显着改进该方法。
一种方法是使用正确的label来修改soft target,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。
第一个目标函数是与soft target的交叉熵,并且该交叉熵是让蒸馏模型和产生soft target的 大模型用相同的温度T(softmax中)。softmax 中与用于从繁琐模型生成软目标相同的高温来计算的。
第二个目标函数是和正确label的交叉熵。 这是在蒸馏模型的 softmax 中还是用完全相同的logits计算,但T= 1。
我们发现,通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。

由于soft target产生的梯度幅度 相当于缩放了 1/T2 ,因此在同时使用hard 和 soft targets时将其乘以 T 2 非常重要。 这确保了如果在元参数实验时用于蒸馏的温度T发生变化,hard和soft target的相对贡献保持大致不变。

Matching logits是蒸馏的一种特殊形式

PS: 前面introduction部分提到过,用softmax的input, 也叫logits, 代替softmax输出的概率作为学习小模型的target,来避开概率过小的问题,通过最小化大模型和小模型产生的logits之间的平方差。
现在说明这种方法为什么是蒸馏的一种形式。

transfer set中每个case都对蒸馏模型的每个logits z i z_{i} zi贡献出cross-entropy梯度 d C / d z i dC/dz_{i} dC/dzi.
如果大模型有logits v i v_{i} vi, 产生了soft target概率 p i p_{i} pi, 训练在温度T下完成.
那么梯度为:
在这里插入图片描述
如果温度比logits的幅度大,那么可以近似为:
在这里插入图片描述

假设每个transfer case的logits都是0均值的,即在这里插入图片描述
那么(3)可以简化为:

在这里插入图片描述

所以在温度T高时,如果logits对每个tranfer case都是0 均值,那么蒸馏等同于最小化 1 / 2 ( z i − v i ) 1/2(z_{i} - v_{i}) 1/2(zivi).
在T比较低时,蒸馏在matching logits上的attetion就少很多,因为它们比平均值负很多。
这是潜在的优势,因为这些logits几乎完全不受大模型的cost function的约束,因此它们可能非常noisy。
另一方面,非常负的logits可能会传达有关通过大模型获得的知识的有用信息。 这些影响中哪一个占主导地位是一个经验问题。 我们表明,当蒸馏模型太小而无法捕获大模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负logtis可能会有所帮助。

MNIST实验

为了了解蒸馏的效果如何,我们在所有 60,000 个训练案例上训练了一个大型神经网络,该神经网络具有两个隐藏层,每个隐藏层包含 1200 个校正线性隐藏单元。 该网络使用 dropout 和权重约束进行了强烈正则化,如 [5] 中所述。 Dropout 可以被视为训练共享权重的指数级大模型集合的一种方法。 此外,输入图像在任何方向上抖动最多两个像素。 该网络出现了 67 个测试错误,而具有两个隐藏层(由 800 个校正线性隐藏单元且无正则化)的较小网络出现了 146 个错误。 但是,如果仅通过添加在 20 ℃ 的温度下匹配大网络产生的软目标的附加任务来对较小的网络进行正则化,则它会出现 74 个测试错误。 这表明soft target可以将大量知识转移到蒸馏模型中,包括如何泛化从translated训练数据中学到的知识,即使转移集不包含任何translations。

当蒸馏网络的两个隐藏层中每个都有 300 个或更多units时,所有高于 8 的温度都会给出相当相似的结果。 但当这从根本上减少到每层 30 个units时,2.5 至 4 范围内的温度明显优于更高或更低的温度。

然后,我们尝试从传输集中省略数字 3 的所有示例。 所以从蒸馏模型的角度来看,3是一个它从未见过的神话数字。 尽管如此,蒸馏模型仅出现 206 个测试错误,其中 133 个位于测试集中的 1010 个三元组上。

大多数错误是由于3这个类别的学习bias太低而引起的。 如果此偏差增加 3.5(这会优化测试集的整体性能),则蒸馏模型会出现 109 个错误,其中 14 个错误位于 3 上。 因此,在正确的偏差下,尽管在训练期间从未见过 3,但蒸馏模型在测试 3 中的正确率达到 98.6%。 如果传输集仅包含训练集中的 7 和 8,则蒸馏模型的测试误差为 47.3%,但当 7 和 8 的偏差减少 7.6 以优化测试性能时,测试误差将降至 13.2%。

Discussion

我们已经证明,蒸馏对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型非常有效。 在 MNIST 上,即使用于训练蒸馏模型的传输集缺少一个或多个类的任何示例,蒸馏也能表现得非常好。 对于 Android 语音搜索所使用的深度声学模型版本,我们已经证明,通过训练深度神经网络集合所实现的几乎所有改进都可以被提炼为相同大小的单个神经网络, 部署起来要容易得多。
对于非常大的神经网络,甚至训练一个完整的集合也是不可行的,但是我们已经证明,经过很长时间训练的单个非常大的网络的performance 可以通过学习大量的专家网络来显着提高 ,每个专家网络都学会区分高度混乱的集群中的类别(通过大量专家网络进一步区分类别,是帮助的性质,并不是蒸馏)。 我们还没有证明我们可以将专家的知识蒸馏回单一的大网络中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1423508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

养老院|基于Springboot的养老院管理系统设计与实现(源码+数据库+文档)

养老院管理系统目录 目录 基于Springboot的养老院管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、老人信息管理 2、家属信息管理 3、公告类型管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选…

2023强网杯复现

强网先锋 SpeedUp 要求2的27次方的阶乘的逐位之和 在A244060 - OEIS 然后我们将4495662081进行sha256加密 就得到了flag flag{bbdee5c548fddfc76617c562952a3a3b03d423985c095521a8661d248fad3797} MISC easyfuzz 通过尝试输入字符串判断该程序对输入字符的验证规则为9…

计算机设计大赛 深度学习 python opencv 动物识别与检测

文章目录 0 前言1 深度学习实现动物识别与检测2 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 3 YOLOV53.1 网络架构图3.2 输入端3.3 基准网络3.4 Neck网络3.5 Head输出层 4 数据集准备4.1 数据标注简介4.2 数据保存…

万物简单AIoT 端云一体实战案例学习 之 智能小车

学物联网,来万物简单IoT物联网!! 下图是本案的3步导学,每个步骤中实现的功能请参考图中的说明。 1、简介 1.1、背景 市面上各种遥控的小车很多,小车的性能不同具备的能力也不一样,大概实现的逻辑就是通过遥控器控制小车的前进、后退、左转或者右转。遥控小车具备一定…

精通Python第14篇—Pyecharts神奇妙笔,绘制多彩词云世界

文章目录 安装Pyecharts基本的词云图绘制自定义词云图样式多种词云图合并高级词云图定制与交互1. 添加背景图片2. 添加交互效果 使用自定义字体和颜色从文本文件生成词云图总结: 在数据可视化领域,词云图是一种极具表现力和趣味性的图表,能够…

IDEA 取消参数名称提示、IDEA如何去掉变量类型提醒

一、IDEA 取消参数名称显示 取消显示形参名提示 例如这样的提示信息 二、解决方法 1、File—>Setting–>Editor—>Inlay Hints—>Java 去掉 Show Parameter hints for 前面的勾即可,然后Apply—>Ok 2、右键Disable Hints

java 图书管理系统 spring boot项目

java 图书管理系统ssm框架 spring boot项目 功能有管理员模块:图书管理,读者管理,借阅管理,登录,修改密码 读者端:可查看图书信息,借阅记录,登录,修改密码 技术&#…

离散数学5

集合的基本概念 集合间的关系 特殊集合 集合的运算 以上都是高一学过的内容。 有穷集的计数&#xff08;容斥定理&#xff09; 序偶与集合的笛卡尔积 二元关系及其表示法 二元关系的性质 前件<x,y>,<y,z>后件<x,z>通过前件能推出后件&#xff0c;只有前真…

【51单片机系列】应用设计——8路抢答器的设计

51单片机应用——8路抢答器设计 文章设计文件及代码&#xff1a;资源链接。 文章目录 要求&#xff1a;设计思路软件设计仿真结果 要求&#xff1a; &#xff08;1&#xff09; 按下”开始“按键后才开始抢答&#xff0c;且抢答允许指示灯亮&#xff1b; &#xff08;2&…

空间域:空间组学的耶路撒冷

文章目录 环境配置与数据SquidpySpaGCN将基因表达和组织学整合到一个图上基因表达数据质控与预处理SpaGCN的超参优化空间域 参考文献 空间组学不能没有空间域&#xff0c;就如同蛋白质不能没有结构域。 摘要&#xff1a; 空间域是反映细胞在基因表达方面的相似性以及空间邻近性…

vulnhub靶场之Matrix-Breakout 2 Morpheus

一.环境搭建 1.靶场描述 This is the second in the Matrix-Breakout series, subtitled Morpheus:1. It’s themed as a throwback to the first Matrix movie. You play Trinity, trying to investigate a computer on the Nebuchadnezzar that Cypher has locked everyone…

微信小程序如何实现实时显示输入内容

如下所示&#xff0c;在许多场景中需要实时显示用户输入&#xff0c;具体实现见下文。 .wxml <input type"text" placeholder"请输入{{item.value}}(必填)" style"width:80%;" bindinput"get_required_value" data-info"{{it…

HarmonyOS应用开发者基础认证考试答案

HarmonyOS应用开发者基础认证考试答案 一、判断题 1.Ability是系统调度应用的最小单元&#xff0c;是能够完成一个独立功能的组件。一个应用可以包含一个或多个Ability。 正确(True) 2.所有使用Component修饰的自定义组件都支持onPageShow&#xff0c;onBackPress和onPageHide…

linux -- 中断管理 -- softirq机制

softirq的起始 do_IRQ();--> irq_enter(); //HARDIRQ部分的开始 更新系统中的一些统计量 标识出HARDIRQ上下文--> generic_irq_handler(); --> irq_exit(); //softirq部分的起始irq_exit /** Exit an interrupt context. Process softirqs if needed and possibl…

MOS栅极驱动和运放所需注意的关键参数

FD6288Q_&#xff08;JSMSEMI(杰盛微)&#xff09;FD6288Q中文资料_价格_PDF手册-立创电子商城 (szlcsc.com) MOS栅极驱动芯片&#xff1a; 自举电路&#xff1a; 电容的两个重要参数&#xff1a; ESR&#xff08;等效串联电阻&#xff09;和ESL&#xff08;等效串联电感&…

基于javaEE的社区食堂管理-计算机毕业设计源码48691

摘 要 随着餐饮业强劲发展的趋势&#xff0c;企业对食堂的管理也更加严格。面对材料成本的提高&#xff0c;人才资源匮乏&#xff0c;租金成本提高等问题&#xff0c;企业如何改善食堂管理系统将成为挑战。 一个高效便捷的食堂管理系统&#xff0c;能为食堂管理者带来极大的便利…

【HarmonyOS应用开发】ArkUI 开发框架-进阶篇-管理组件状态(九)

管理组件状态 一、概述 在应用中&#xff0c;界面通常都是动态的。下图所示&#xff0c;在子目标列表中&#xff0c;当用户点击目标一&#xff0c;目标一会呈现展开状态&#xff0c;再次点击目标一&#xff0c;目标一呈现收起状态。界面会根据不同的状态展示不一样的效果。 Ar…

XUbuntu22.04之如何创建、切换多个工作区(二百零九)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

1.31总结

为什么和以前标题不一样了呢&#xff0c;是因为今天我感觉学到的东西太少了&#xff0c;很难按专题发&#xff0c;索性就直接写个总结水一篇好了 第一题&#xff1a;遍历问题 题解&#xff1a;真的纯思维题目&#xff0c;真的没啥&#xff0c;可说的&#xff0c;中序遍历取决于…

双目模组 - IMSEE SDK的配置实践:含Opencv的详细编译配置

IMSEE 的环境要求: CMake(3.0以上)(需要支持vs2019) Visual Studio 2019 opencv3.3.1 IMSEE-SDK 官网参考: Windows 源码安装 — IMSEE SDK 1.4.2 文档 (imsee-sdk-docs.readthedocs.io) 【案】按照IMSEE的建议进行安装: 1 Windows 安装: 1.1 环境准备: 1.1.1 CMake:in…