知识蒸馏(Knowledge Distillation)简述

news2024/11/25 23:53:28

知识蒸馏(Knowledge Distillation)简述

  • 结论

Reference:

  1. Distilling the Knowledge in a Neural Network
  2. 知识蒸馏(Knowledge Distillation)简述(一)

知识蒸馏被广泛用于模型压缩和迁移学习当中。开山之作应该是 Distilling the Knowledge in a Neural Network 。这篇文章中,作者的动机是找到一种方法,把多个模型的知识提炼给单个模型。

在大规模的机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管他们的要求非常不同:

  • 对于语音和目标识别等任务,训练必须从非常大的、高度荣冗余的数据集提取结构,但它不需要实时操作,因此可以使用大量的计算;
  • 然而,在部署到用户端时,对于延迟和性格上有着特别高的要求。

这里的蒸馏针对的是神经网络的知识。一般认为模型的参数保留了模型学到的知识,因此最常见的迁移学习的方式就是在一个大的数据集上先做预训练,然后使用预训练得到的参数在一个小的数据集上做微调(两个数据集往往领域不同或者任务不同)。例如先在 ImageNet 上做预训练,然后在 COCO 数据集上做检测。

在这篇论文中,作者认为可以将模型看成是黑盒,知识可以看成是输入到输出的映射关系。因此,我们可以先训练好一个 teacher 网络,然后将 teacher 的网络输出结果 q q q 作为 student 网络的目标,训练 student 网络,使得 student 网络的结果 p p p 接近 q q q

如果按照这里的想法,我们可以将损失函数写成:
L = C E ( y , p ) + α C E ( q , p ) L=CE(y,p)+\alpha CE(q,p) L=CE(y,p)+αCE(q,p)这里 CE 是 交叉熵(Cross Entropy) y y y 是真实标签的 onehot编码,即表示正确与否, q q q 是 teacher 网络的输出结果, p p p 是 student 网络的输出结果。

在模型学习区分大量类别时,正常的训练目标是最大化正确答案的平均对数概率,这是学习还有一个额外的作用,即训练模型会为所有错误答案分配概率,即使这些概率很小,其中一些也比其他的要大得多。错误答案一定程度表示模型倾向于如何去泛化信息。例如,一辆宝马的图像小概率会被误认为是一辆垃圾车,但这种错误也比将它误认为是胡萝卜的可能性大得多。训练模型优化性能的目的是更好的概括新数据,即,训练模型更好的泛化。但这需要一个正确的泛化方法,这些信息通常是无法获得的。使用 teacher 的原因在于,我们可以将这部分的知识,从一个大模型蒸馏(distill)(也可以理解成提炼)到一个小模型中,我们可以就训练小模型以与大模型相同的方式进行泛化。如果大模型泛化得很好,那么在相同训练集上,与大模型有相同泛化方式的小模型在测试数据上的表现通常比在用于训练集成的以正常方式训练的小模型要好得多。

将大模型的泛化能力转移到小模型的一个方法是使用大模型产生的概率作为小模型的“软目标”(soft target)。当大模型是几个简单模型的集合(bagging的思路),可以用他们各自预测分布的算术或几何均值作为软目标。当软目标具有高熵,它们在每个训练案例中提供的信息比硬目标多得多,且不同场景间的梯度有更小的方差,所以小模型往往可以用比原有方式更少的数据训练并且有着高不少的学习率。

综上,按照上面的说法直接使用 teacher 网络的 softmax 的输出结果 q q q 可能不大合适。一个大模型总是以非常高的置信度产生正确答案。例如,在 MINST 数据中,对于某个 2 2 2 的输入,对于 2 2 2 的预测概率会很高,而对于 2 2 2 类似的数字,例如 3 3 3 7 7 7 的 预测概率为 1 0 − 6 10^{-6} 106 1 0 − 9 10^{-9} 109。而这些信息是有价值的,它定义了丰富的相似性数据的结构(例如,它说哪些 2 2 2看起来像 3 3 3,哪些看起来像 7 7 7),而在上面的情况下,它几乎没有在传递阶段对交叉熵代价函数产生影响,因为它们的概率值接近 0 0 0

用于训练小模型的传输集可以完全由未标记的数据组成,或者可以使用原始训练集。文中测试使用原始训练集效果很好。

因此,这里就提出来一种一般解决方式,并将其称为“蒸馏”---------将最后的 softmax 温度升高,直到大模型产生一组合适的软目标。然后再训练小模型时使用相同的温度匹配这些软目标。

文中提出的蒸馏方式为 softmax-T,公式如下所示:
q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) q_i=\frac{\exp \left(z_i / T\right)}{\sum_j \exp \left(z_j / T\right)} qi=jexp(zj/T)exp(zi/T)这里 q i q_i qi 是 student 网络学习的对象(soft targets), z i z_i zi 是神经网络 softmax 前的输出 logit。如果将 T T T 1 1 1,这个公式就是 softmax,根据 logit 输出各类别的概率。如果 T T T 接近于 0 0 0,则最大值会越接近 1 1 1,其他值会接近 0 0 0,近似于 onehot编码。如果 T T T 越大,则输出的结果分布越平缓,相当于平滑的一个作用,起到保留相似信息的作用。如果 T T T 等于无穷,就是一个均匀分布。

最终文章根据上述的损失函数对网络进行训练:

  1. 在 MNIST 这个数据集上,先使用大的网络进行训练,测试集错误 67 67 67 个;小网络训练,测试集错误 146 146 146 个。加入 soft targets 到目标函数中,相当于正则项,测试集的错误降低到了 74 74 74 个。这证明了 teacher 网络确实把知识转移到了 student 网络,使结果变好了;
  2. 第二个实验是在 speech recognition 领域,使用不同的参数训练了 10 10 10 个 DNN,对这 10 10 10 个模型的预测结果求平均作为 emsemble 的结果,相比于单个模型有一定的提升。然后将这 10 10 10 个模型作为 teacher 网络,训练 student 网络,得到的 Distilled Single model 相比于直接的单个网络,也有一定的提升,结果见下表:
    在这里插入图片描述

结论

知识蒸馏,可以将一个网络的知识转移到另一个网络,两个网络可以是同构或者异构。做法是先训练一个 teacher 网络,然后使用这个 teacher 网络的输出和数据的真实标签取训练 student 网络。知识蒸馏可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;也可以将多个网络学习到的知识转移到一个网络中,使得单个网络的性能接近 emsemble 的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1101305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【STM32】--基础了解

一、STM32来历背景 1.从51到STM32 (1)单片机有很多种 (2)STM32内核使用ARM,但是ARM不一定是STM32 (3)ATM32是当前主流的32位高性能单片机 (4)STM32的特点:高…

数据发现工具

数据发现是从非结构化和结构化数据源中查找特定数据子集的过程,必须查明业务存储库中有哪些数据以及位置,数据发现与数据分类,这是根据其敏感性和脆弱性对不同类型的数据进行排序的过程,敏感数据发现和分类本身就是不同的过程&…

测试网线的仪器叫什么?

测试网线的仪器有哪些?测试网线的仪器叫什么?很多小伙伴对此有疑问。咱们逐一分析,并做出简单的讨论。 测试网线的仪器大概几类,从携带是否方便上来说,手持式和台式。从测试功能上来说,分为物理常量、电气…

利用在线培训系统提升员工技能,助力企业发展

近年来,随着互联网技术的发展,在线培训系统逐渐成为企业提升员工技能的利器。这种新型的培训方式打破了时间和空间的限制,为企业提供了更加灵活和高效的培训解决方案。下面,我们将详细介绍如何利用在线培训系统提升员工技能&#…

驱动数字化转型,Doris Summit Asia 2023 智慧金融与政企论坛精彩预告!

峰会官网已上线,最新议程请关注:doris-summit.org.cn 即刻报名 Doris Summit 是 Apache Doris 社区一年一度的技术盛会,由飞轮科技联合 Apache Doris 社区的众多开发者、企业用户和合作伙伴共同发起,专注于传播推广开源 OLAP 与…

一图看懂CodeArts Governance 三大特性,带你玩转开源治理服务

华为云开源治理服务CodeArts Governance是针对软件研发提供的一站式开源软件治理服务,凝聚华为在开源治理上的优秀实践经验,提供开源软件元数据及软件成分分析、恶意代码检测等能力,从合法合规、网络安全、供应安全等维度消减开源软件使用风险…

金媒人提问:为何还有男生觉得精致女人不顾家?

广东金媒人小编分析:大部分男生会觉得越精致漂亮的女生,不适合做老婆,认为找老婆还是朴实的女生好,起码是顾家、实在的。 为什么还会有这么多男生觉得?如果每天追求精致的生活和自己,很容易沉浸在化妆打扮上…

【分享Python代码】图片转化为素描画

哈喽,大家好,我是木易巷~ 代码生成效果图 原图: 生成图: 原图: 生成图: 准备工作 Python编程首先需要安装环境,下面是详细步骤: 会的小伙伴可自行跳过,代码在最后 1…

泊车功能专题介绍 ———— AVP系统定义应用场景

文章目录 介绍术语 系统定义系统架构系统类型 应用场景安全场景简介安全场景定义介绍安全场景外部环境 安全场景定义开启场景结束场景车位被占用搜索车位无空闲车位路口/出入口/跨层通道减速障碍物阻挡发生碰撞车辆离线光线变化天气变化环境变化常见障碍物类型 行人安全场景车辆…

USART使用

USART软件配置 具体步骤如下:(USART 相关库函数在 stm32f10x_usart.c 和 stm32f10x_usart.h 文件中) (1)使能串口时钟及 GPIO 端口时钟 前面说过 STM32F103C8T6 芯片具有 3 个串口,对应不同的引脚&#…

Vue3分支语法-登录注销

点击登录 点击注销登录 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><!-- vue.js --><script src"https://unpkg.com/vue3/dist/vue.global.js"><…

基于RuoYi-Flowable-Plus的若依ruoyi-nbcio支持自定义业务表单流程(四)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码&#xff1a; https://gitee.com/nbacheng/ruoyi-nbcio 演示地址&#xff1a;RuoYi-Nbcio后台管理系统 自定义业务表单里的流程历史需要单独设计&#xff0c;所以下面就这部分进行介绍。 1、后端部分&#xff…

基于SSM的流浪狗收容领养管理平台设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

汇川Easy521PLC与压力传感器485通讯实例

本例是汇川Easy521PLC与支持485通讯的压力传感器进行通讯的实例记录。对于初次使用汇川PLC的朋友,可能有借鉴的意义。 配置: 1、汇川Easy521PLC 2、美控压力变送器 3、汇川Autoshop编程软件 将压力变送器的485线与PLC本体的485端子一一连接: 485+:A+ 485-:B- 一般485的标…

数字孪生技术在智慧城市应用的推进建议

&#xff08;一&#xff09;坚持需求牵引&#xff0c;强场景重实效 必须始终坚持以人为本、场景导向、需求牵引&#xff0c;站在供给侧结构性改革的角度&#xff0c;突出以用促建&#xff0c;强调建用并重&#xff0c;真正发挥数字孪生城市应用建设的实效。从构建数字孪生创新…

手撕Vue-查找指令和模板

接着上一篇文章&#xff0c;我们已经实现了提取元素到内存的过程&#xff0c;接下来我们要实现的是查找指令和模板。 大致的思路是这样的&#xff1a; 遍历所有的节点需要判断当前遍历到的节点是一个元素还是一个文本如果是一个元素, 我们需要判断有没有v-model属性如果是一个文…

Linux——多线程,互斥与同步

目录 一.linux互斥 1.进程线程间的互斥相关背景概念 2.互斥量mutex 3.加锁互斥锁mutex 4.锁的底层原理 二.可重入VS线程安全 1.概念 2.常见的线程不安全的情况 3.常见的线程安全的情况 4.常见不可重入的情况 5..常见可重入的情况 6.可重入与线程安全联系 三.死锁 …

华为云应用中间件DCS系列—Redis实现(视频直播)消息弹幕

云服务、API、SDK&#xff0c;调试&#xff0c;查看&#xff0c;我都行 阅读短文您可以学习到&#xff1a;应用中间件系列之Redis实现&#xff08;视频直播&#xff09;消息弹幕 1 什么是DEVKIT 华为云开发者插件&#xff08;Huawei Cloud Toolkit&#xff09;&#xff…

为什么智能相机需要搭配镜头使用?

镜头作用是将光学图像聚焦在图像传感器的光敏面阵上。不同类型的工业镜头&#xff0c;成像质量也各不相同&#xff0c;成像质量也有差异&#xff0c;影响工业镜头的因素有哪些呢 图像中心与边缘的影响图像中心较边缘分辨率高&#xff1b;图像中心较边缘光场照度高&#xff1b;像…

springboot配置swagger

springboot配置swagger Swagger 是什么Swagger配置springboot代码展示总结 Swagger 是什么 Swagger 是一个用于构建、文档和调用 RESTful Web 服务的强大工具。它提供了以下几方面的好处&#xff1a; 自动生成 API 文档: Swagger 可以自动生成 API 文档&#xff0c;包括接口的…