知识蒸馏开山之作(部分解读)—Distilling the Knowledge in a Neural Network

news2024/12/27 1:33:03

1、蒸馏温度T

正常的模型学习到的就是在正确的类别上得到最大的概率,但是不正确的分类上也会得到一些概率尽管有时这些概率很小,但是在这些不正确的分类中,有一些分类的可能性仍然是其他类别的很多倍。但是对这些非正确类别的预测概率也能反应模型的泛化能力,例如,一辆宝马车的图片,只有很小的概率被误识别成垃圾车,但是被识别成垃圾车的概率还是比错误识别成胡萝卜的概率高很多倍。(例如一个车,猫,狗3分类的模型识别一张猫的图片,最后结果是:(cat,99%) ; (dog,0.95%)(car,0.05%)错误类别 dog 上的概率仍是错误类别 car 的概率的19倍 )

知识蒸馏

这里一个可行的办法是使用大模型生成的模型类别概率作为“soft targets”(使用蒸馏算法以后的概率,相对应的 head targets 就是正常的原始训练数据集)来训练小模型,由于 soft targets 包含了更多的信息熵,所以每个训练样本都提供给小模型更多的信息用来学习,这样小模型就只需要用更少的样本,及更高的学习率去训练了。

仍然是上面的错误分类概率的例子,在 MNIST 数据集上训练的一个大模型基本都能达到 99 % 以上的准确率,假如现在有一个数字 2 的图片输入到大模型中分类,在得到的结果是数字 3 的概率为 10e-6, 是数字 7 的概率为 10e-9,这就表示了相比于 7 ,3更接近于 2,这从侧面也可以表现数据之间的相关性,但是在迁移阶段,这样的概率在交叉熵损失函数(cross-entropy loss function)只有很小的影响,因为它们的概率都基本为0。 所以这里,本文提出了 “distillation” 的概念, 来软化上述的结果。

上面的公式就是蒸馏后的 softmax,其中 T 代表 temperature, 蒸馏的温度。那么 T 有什么作用呢?

假设现在有一个数组 x=[2,7,10] ,当T = 1,即为正常的 Softmax函数 输入上式中可得:

T = 1 ——>  y = [0.00032,0.04741,0.95227]

可以理解为上述的一个车,猫,狗3分类网络,输入一张猫的图片,预测为汽车的概率为0.00032, 预测为狗的概率为 0.04741, 预测为猫的概率为 0.95227。
下面再看一下改变 T 的值概率的输出:

 T = 5   ——>   y = [0.11532, 0.31348, 0.5712]   
 T = 10  ——>   y = [0.20516, 0.33825, 0.45659]  
 T = 20  ——>   y = [0.26484, 0.34006, 0.3951]

下面是在(-10,10)之间随机取多个点然后在 不同的 T 值下绘制的图像。

 可以看到当 T = 1 是就是常规的 Softmax,而升温T,对softmax进行蒸馏,函数的图像会变得越来越平滑,这也是文中提高的 soft targets 的 soft 一词来源吧。

假设你是每次都是进行负重登山,虽然过程很辛苦,但是当有一天你取下负重,正常的登山的时候,你就会变得非常轻松,可以比别人登得高登得远。我们知道对于一个复杂网络来说往往能够得到很好的分类效果,错误的概率比正确的概率会小很多很多,但是对于一个小网络来说它是无法学成这个效果的。我们为了去帮助小网络进行学习,就在小网络的softmax加一个T参数,加上这个T参数以后错误分类再经过softmax以后输出会变大,同样的正确分类会变小。这就人为的加大了训练的难度,一旦将T重新设置为1,分类结果会非常的接近于大网络的分类效果。

最后将小模型在 soft targets 上训练得到的交叉熵损失函数,加上在真实带标签数据(hard targets)上训练得到的交叉熵损失函数乘以 1/T^2 加在一起作为最后总的损失函数。这里hard targets 上面乘以一个系数是因为 soft targets 生成过程中蒸馏后的 softmax 求导会有一个 1/T^2 的系数,为了保持两个 Loss 所产生的影响接近一样(各 50%)。

训练过程

假设这里选取的 T = 10;

Teacher 模型:
( a ) Softmax(T=10)的输出,生成“Soft targets”

Student 模型:
( a ) 对 Softmax(T = 10)的输出与Teacher 模型的Softmax(T = 10)的输出求 Loss1
( b ) 对 Softmax(T = 1)的输出与原始label 求 Loss2
( c ) Loss = Loss1 + (1/T^2)Loss2

 使用soft target会增加信息量,熵高

发现:T参数越大,soft target的分布越均匀。因此,我们可以:

  1. 首先用较大的T值来训练模型,这时候复杂的神经网络能够产生更均匀分布(更容易让小网络学习)的soft target;
  2. 之后小规模的神经网络用相同的T值来学习由大规模神经网络产生的soft target,接近这个soft target从而学习到数据的结构分布特征;
  3. 最后在实际应用中,将T值恢复到1,让类别概率偏向正确类别。

在大数据集上训练专家模型

Training ensembles of specialists on very big datasets 

可以用无限大的数据集来使用教师网络训练学生网络

  1. 当数据集非常巨大以及模型非常复杂时,训练多个模型所需要的资源是难以想象的,因此作者提出了一种新的集成模型(ensemble)方法:
  • 一个generalist model:使用全部数据训练。
  • 多个specialist model(专家模型):对某些容易混淆的类别进行训练。
  1. specialist model的训练集中,一半是由训练集中包含某些特定类别的子集(special subset)组成,剩下一半是从剩余数据集中随机选取的。
  2. 这个ensemble的方法中,只有generalist model是使用完整数据集训练的,时间较长,而剩余的所有specialist model由于训练数据相对较少,且相互独立,可以并行训练,因此训练模型的总时间可以节约很多。
  3. specialist model由于只使用特定类别的数据进行训练,因此模型对别的类别的判断能力几乎为0,导致非常容易过拟合。
  • 解决办法:当 specialist model 通过 hard targets 训练完成后,再使用由 generalist model 生成的 soft targets 进行微调。这样做是因为 soft targets 保留了一些对于其他类别数据的信息,因此模型可以在原来基础上学到更多知识,有效避免了过拟合。
     

实现流程:

此部分很有意思,但是不知道具体细节,需要再去看论文。 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/919434.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ext JS 之Microloader(微加载器)

“Microloader”是 Sencha 数据驱动的 JavaScript 和 CSS 动态加载器的名称。 清单 app.json 用于应用的设置,Sencha Cmd 在构建的时候会读取这个文件。 Sencha Cmd 转换“app.json”的内容并将生成的清单传递给 Microloader 以在运行时使用。 最后,Ext JS 本身也会查阅运…

(WAF)Web应用程序防火墙介绍

(WAF)Web应用程序防火墙介绍 1. WAF概述 ​ Web应用程序防火墙(WAF)是一种关键的网络安全解决方案,用于保护Web应用程序免受各种网络攻击和威胁。随着互联网的不断发展,Web应用程序变得越来越复杂&#x…

Java代码审计12之JDNI注入以及rmi和Ldap的利用

文章目录 1、Jndi、Ldap、Rmi协议1.1、什么是ladp协议1.2、jndi协议1.3、rmi协议 2、jndi注入2.1、简介与jdk版本限制2.2、rmi协议的利用2.2.1、更换idea的执行jdk版本2.2.2、生成恶意class文件payload2.2.3、模拟测试低版本jdk2.2.4、模拟高版本测试 2.3、rmi攻击的疑问之两个…

关于ios Universal Links apple-app-site-association文件 Not Found的问题

1. 背景说明 1.1 Universal Links 是什么 Support Universal Links 里面有说到 Universal Links 是什么、注意点、以及如何配置的。简单来说就是 当您支持通用链接时,iOS 用户可以点击指向您网站的链接,并无缝重定向到您安装的应用程序 大白话就是说&am…

3D旅游情景实训教学展示

随着科技的不断发展,情景实训教学在教育领域中的应用越来越广泛。通过虚拟现实技术,3D视觉技术,计算机技术等为学生提供了一个身临其境的学习环境,让他们能够在模拟的场景中学习和实践,从而更好地理解和掌握知识。 3D虚…

机器学习基础之《分类算法(4)—案例:预测facebook签到位置》

一、背景 1、说明 2、数据集 row_id:签到行为的编码 x y:坐标系,人所在的位置 accuracy:定位的准确率 time:时间戳 place_id:预测用户将要签到的位置 3、数据集下载 https://www.kaggle.com/navoshta/gr…

微信小程序创建项目以及注意事项

1.申请账号并完善信息 2.下载安装开发工具 3.开发小程序 4.上传代码 5.提交审核 6.发布 创建项目 根据需求选择模板,也可以不选择模板 创建完毕之后 进入页面点击终端 然后新建终端 输入npm init 一直按回车即可 安装成功 出现package.json 如何使用组件&#x…

Spring Cache的介绍以及怎么使用(redis)

Spring Cache 文章目录 Spring Cache1、Spring Cache介绍2、Spring Cache常用注解2.1、EnableCaching注解2.2、CachePut注解2.3、CacheEvict注解2.4、Cacheable注解 3、Spring Cache使用方式--redis 1、Spring Cache介绍 Spring Cache是一个框架,实现了基于注解的缓…

【LeetCode】模拟实现FILE以及认识缓冲区

模拟实现FILE以及认识缓冲区 刷新缓冲逻辑图自定义实现如何强制刷新内核缓冲区例子 刷新缓冲逻辑图 自定义实现 mystdio.h #pragma once #include <stdio.h>#define NUM 1024 #define BUFF_NOME 0x1 #define BUFF_LINE 0x2 #define BUFF_ALL 0x4typedef struct _MY_FIL…

对《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》的改进

《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》使用的Activex DLL公共对象是需要先注册的。https://blog.csdn.net/weixin_45707491/article/details/132437502?spm1001.2014.3001.5501 Activex DLL事前注册&#xff0c;一次多用说起来也不是啥大问题&#x…

C语言小白急救 指针进阶讲解1

文章目录 指针一、 字符指针二、 指针数组三、数组指针1.数组的地址2.数组指针3.数组指针的应用 四、数组参数、指针参数1. 一维数组传参2.二维数组传参3.一级指针传参4.二级指针传参 五、函数指针1.函数的地址2.函数指针3.练习 指针 指针的概念&#xff1a; 1.指针就是个变量…

数据库(DQL,多表设计,事务,索引)

目录 查询数据库表中数据 where 条件列表 group by 分组查询 having 分组后条件列表 order by 排序字段列表 limit 分页参数 多表设计 一对多 多对多 一对一 多表查询 事物 索引 查询数据库表中数据 关键字&#xff1a;SELECT 中间有空格&#xff0c;加引…

H.265视频无插件流媒体播放器EasyPlayer.js播放webrtc断流重连的异常修复

H5无插件流媒体播放器EasyPlayer属于一款高效、精炼、稳定且免费的流媒体播放器&#xff0c;可支持多种流媒体协议播放&#xff0c;可支持H.264与H.265编码格式&#xff0c;性能稳定、播放流畅&#xff0c;能支持WebSocket-FLV、HTTP-FLV&#xff0c;HLS&#xff08;m3u8&#…

FxFactory 8 Pro Mac 苹果电脑版 fcpx/ae/motion视觉特效软件包

FxFactory pro for mac是应用在Mac上的fcpx/ae/pr视觉特效插件包&#xff0c;包含了成百上千的视觉效果&#xff0c;打包了很多插件&#xff0c;如调色插件&#xff0c;转场插件&#xff0c;视觉插件&#xff0c;特效插件&#xff0c;文字插件&#xff0c;音频插件&#xff0c;…

百望云华为云共建零售数字化新生态 聚焦数智新消费升级

零售业是一个充满活力和创新的行业&#xff0c;但也是当前面临很大新挑战和新机遇的行业。数智新消费时代&#xff0c;数字化转型已经成为零售企业必须面对的重要课题。 8 月 20 日-21日&#xff0c;以“云上创新 韧性增长”为主题的华为云数智新消费创新峰会2023在成都隆重召…

stm32之10.系统定时器

delay_s()延时秒 delay_ms()毫秒*1000 delay_us()微秒*1000000 微秒定时器代码 void delay_us(uint32_t n) { SysTick->CTRL 0; // Disable SysTick&#xff0c;关闭系统定时器 SysTick->LOAD SystemCoreClock/1000000*n-1; // 就是nus SysTick->LOAD Sys…

有趣的数学 数学建模入门二 一些理论基础

一、什么是数学建模? 现实世界中混乱的问题可以用数学来解决&#xff0c;从而产生一系列可能的解决方案来帮助指导决策。大多数人对数学建模的概念感到不舒服&#xff0c;因为它是如此开放。如此多的未知信息似乎令人望而却步。哪些因素最相关&#xff1f;但正是现实世界问题的…

C语言基础之——操作符(上)

本篇文章&#xff0c;我们将展开讲解C语言中的各种常用操作符&#xff0c;帮助大家更容易的解决一些运算类问题。 这里提醒一下小伙伴们&#xff0c;本章知识会大量涉及到二进制序列&#xff0c;不清楚二进制序列的小伙伴&#xff0c;可以去阅读我的另一篇文章《数据在内存中的…

Go【gin和gorm框架】实现紧急事件登记的接口

简单来说&#xff0c;就是接受前端微信小程序发来的数据保存到数据库&#xff0c;这是我写的第二个接口&#xff0c;相比前一个要稍微简单一些&#xff0c;而且因为前端页面也是我写的&#xff0c;参数类型自然是无缝对接_ 前端页面大概长这个样子 先用apifox模拟发送请求测试…

数据结构 day1

1>x.mind 2>间接定义结构体数组&#xff0c;进行4种方式的定义和初始化 3>定义结构体存储10辆车&#xff08;车的信息&#xff1a;品牌、单价、颜色&#xff09; 1.定义函数&#xff0c;实现循环输入 2.定义函数&#xff0c;实现排序 3.定义函数&#xff0c;计算红色车…