论文笔记--Distilling the Knowledge in a Neural Network

news2024/11/24 3:15:31

论文笔记--Distilling the Knowledge in a Neural Network

  • 1. 文章简介
  • 2. 文章概括
  • 3 文章重点技术
    • 3.1 Soft Target
    • 3.2 蒸馏Distillation
  • 4. 文章亮点
  • 5. 原文传送门

1. 文章简介

  • 标题:Distilling the Knowledge in a Neural Network
  • 作者:Hinton, Geoffrey, Oriol Vinyals, Jeff Dean
  • 日期:2015
  • 期刊:arxiv

2. 文章概括

  文章提出了一种将大模型压缩的新的思路:蒸馏distillation。通过蒸馏,可以将很大的模型压缩为轻量级的模型,从而提升推理阶段的速率。

3 文章重点技术

3.1 Soft Target

  随着模型的参数量越来越大,如何从训练好的大模型(教师模型)学习一个轻量级的小模型(学生模型)是一个重要课题。传统的hard-target训练直接学习大模型的预测结果,无法学习到不正确的类别之间的相对关系。比如给定一张宝马的照片,假设教师模型给出的预测结果为宝马,学生模型只从教师模型中学习到“宝马”这一个标签信息。事实上,教师模型还会给出其它类别的信息,比如将宝马预测为垃圾车为0.02,将宝马预测为胡萝卜的概率仅为0.0001,但学生模型没有学习到垃圾车和胡萝卜之间的区别。
  我们需要一种方法来使得学生学习到正确的标签,以及错误标签的相对关系。文章提出“soft-target",即通过学习教师模型的预测概率分布来训练小模型。

3.2 蒸馏Distillation

  对一个分类模型,假设教师模型的输出层给出的logits为 z i z_i zi,然后通过计算Softmax得到预测概率: q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) q_i = \frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)} qi=jexp(zj/T)exp(zi/T),其中相比于传统的SoftMax增加了 T T T表示温度,用于控制输出概率分布的平滑度。 T T T越大,不同类别之间的差异越不明显,即分布越平滑。可以参考 e x p exp exp的函数曲线来理解:给定 x 1 , x 2 x_1, x_2 x1,x2,由当 T T T越大时, x 1 / T , x 2 / T x_1/T, x_2/T x1/T,x2/T对应的导数越小(导数即为 exp ⁡ ( x ) \exp(x) exp(x),也可参考下图),从而差距越小,分布越平滑。当 T = 1 T=1 T=1时,即传统的Softmax。
exp
  我们希望学生模型满足:1) 模型可以学习到教师模型的预测概率,即soft targets; 2)学生模型可以预测真实的标签。从而我们可以考虑2个目标函数: L hard \mathcal{L}_{\text{hard}} Lhard L soft \mathcal{L}_{\text{soft}} Lsoft。首先我们记学生模型和教师模型的logits分别为 z i , v i z_i, v_i zi,vi,预测概率分别为 q i , p i q_i, p_i qi,pi,真实标签为labels,则

  • L hard = Cross Entropy ( labels , arg max ⁡ i ( exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) ) ) \mathcal{L}_{\text{hard}}=\text{Cross Entropy}\left(\text{labels}, \argmax_i \left(\frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)}\right)\right) Lhard=Cross Entropy(labels,argmaxi(jexp(zj/T)exp(zi/T)))
  • L soft = Cross Entropy ( p , q ) = Cross Entropy ( ( exp ⁡ ( v i / T ) ∑ j exp ⁡ ( v j / T ) ) , ( exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) ) ) \mathcal{L}_{\text{soft}}=\text{Cross Entropy}\left(p, q \right) =\text{Cross Entropy}\left((\frac {\exp (v_i/T)}{\sum_j \exp (v_j/T)}), (\frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)})\right) Lsoft=Cross Entropy(p,q)=Cross Entropy((jexp(vj/T)exp(vi/T)),(jexp(zj/T)exp(zi/T)))
    考虑 L soft \mathcal{L}_{\text{soft}} Lsoft的梯度 ∂ L soft ∂ z k = ∂ ( − ∑ i p i log ⁡ q i ) ∂ z k = − ∑ i p i q i ∂ q i ∂ z k = − p k q k ∂ q k ∂ z k − ∑ i ≠ k p i q i ∂ q i ∂ z k = − 1 T p k q k q k ( 1 − q k ) + ∑ i ≠ k p i q i exp ⁡ ( z i / T ) ( ∑ j exp ⁡ ( z j / T ) ) 2 1 T exp ⁡ ( z k / T ) = 1 T ( − p k ( 1 − q k ) + ∑ i ≠ k p i q i q i q k ) = 1 T ( − p k + ∑ i p i q k ) = 1 T ( q k − p k ) = 1 T ( exp ⁡ ( z k / T ) ∑ j exp ⁡ ( z j / T ) − exp ⁡ ( v k / T ) ∑ j exp ⁡ ( v j / T ) ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = \frac {\partial (-\sum_i p_i \log q_i)}{\partial z_k} = -\sum_i \frac {p_i}{q_i} \frac{\partial q_i}{\partial z_k} = -\frac {p_k}{q_k} \frac{\partial q_k}{\partial z_k}-\sum_{i\neq k} \frac {p_i}{q_i} \frac{\partial q_i}{\partial z_k} \\=-\frac 1T \frac {p_k}{q_k} q_k (1-q_k) +\sum_{i\neq k} \frac {p_i}{q_i} \frac {\exp (z_i/T)}{(\sum_j \exp (z_j/T))^2} \frac 1T \exp (z_k/T) \\= \frac 1T (-p_k (1-q_k) + \sum_{i\neq k} \frac {p_i}{q_i} q_i q_k )= \frac 1T (-p_k + \sum_i p_i q_k ) \\= \frac 1T (q_k - p_k) = \frac 1T \left(\frac {\exp (z_k/T)}{\sum_j \exp (z_j/T)} - \frac {\exp (v_k/T)}{\sum_j \exp (v_j/T)}\right) zkLsoft=zk(ipilogqi)=iqipizkqi=qkpkzkqki=kqipizkqi=T1qkpkqk(1qk)+i=kqipi(jexp(zj/T))2exp(zi/T)T1exp(zk/T)=T1(pk(1qk)+i=kqipiqiqk)=T1(pk+ipiqk)=T1(qkpk)=T1(jexp(zj/T)exp(zk/T)jexp(vj/T)exp(vk/T)),当 T T T相比于 z i , v i z_i, v_i zi,vi等logits量级比较高时,有 z i / T → 0 , v i / T → 0 z_i/T\to 0, v_i/T \to 0 zi/T0,vi/T0,从而由泰勒公式上式近似为 ∂ L soft ∂ z k ≈ 1 T ( 1 + z k / T N + ∑ j z j / T − 1 + v k / T N + ∑ j v j / T ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac 1T \left(\frac {1+z_k/T}{N + \sum_j z_j/T} - \frac {1+v_k/T}{N + \sum_j v_j/T}\right) zkLsoftT1(N+jzj/T1+zk/TN+jvj/T1+vk/T),假设logits都是零均值的,则有 ∂ L soft ∂ z k ≈ 1 N T 2 ( z k − v k ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac 1{NT^2} (z_k - v_k) zkLsoftNT21(zkvk)。从而当温度比较高时,我们的目标近似为最小化 1 2 ( z k − v k ) 2 \frac 12 (z_k - v_k)^2 21(zkvk)2(上式的原函数,不考虑常数项),即最小化logits的MSE函数。温度越低,我们越关注小于均值的logits。
      最终的损失函数为上述hard和soft损失的加权求和。

4. 文章亮点

  文章提出了基于soft-target的蒸馏方法,可以让学生模型学习到教师模型的预测概率分布,从而增强学生模型的泛化能力。实验表明,在MNIST和speech recognition数据上,基于soft target的学生模型可以提取到更多有用的信息,且可以有效防止过拟合的发生。

5. 原文传送门

Distilling the Knowledge in a Neural Network

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/784288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Macbook M1编译安装Java OpenCV

OpenCV-4.8.0编辑安装 查询编译依赖 brew info opencv确保所有需要模块都打上了✔,未打✔的需要使用brew进行安装 下载OpenCV源码 在此处下载OpenCV源代码,选择Source,点击此处下载opencv_contrib-4.8.0 或者使用如下命令,通…

gerrit 从安装到出坑

一般公司在做代码审核的时候选择codereview gerrit来处理代码的入库的问题。 它是通过提交的时候产生Change-Id: If4e0107f3bd7c5df9e2dc72ee4beb187b07151b9 来决定是不是入库,一般如果不是通过这个管理,那么就是我们通常的操作 git add . git comm…

【算法与数据结构】110、LeetCode平衡二叉树

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析:二叉树遍历一共有前中后遍历和层序遍历,这道题只有后序遍历适合,求深度是从上往…

【小黄碎碎念】如何解析和替换字符串中的 Markdown 文本?正则表达式与 flexmark-java 库

前言 本周,笔者将之前的基于 Servlet 的个人博客项目进行了迭代,基于 SpringBoot SpringMVC Mybatis Redis 进行实现。额外实现密码的明文加密处理(加盐算法)、修改博客、公共主页等功能,并将 session 存储到 Redis…

深度学习——批标准化Batch Normalization

什么是批标准化? 批标准化(Batch Normalization)是深度学习中常用的一种技术,旨在加速神经网络的训练过程并提高模型的收敛速度。 批标准化通过在神经网络的每一层中对输入数据进行标准化来实现。具体而言,对于每个输…

我在VScode学Python(Python函数,Python模块导入)

我的个人博客主页:如果’真能转义1️⃣说1️⃣的博客主页 (1)关于Python基本语法学习---->可以参考我的这篇博客《我在VScode学Python》 (2)pip是必须的在我们学习python这门语言的过程中Python ---->&a…

fl studio 20如何设置中文汉化汇总及flstudio21水果language选项中文设置方法

fl studio这是一个编曲软件,它有中文和英文两种语言供大家选择,对我们来说,中文版肯定更方便。fl studio如何设置中文?事实上,只需在设置中切换中文即可。 我们一起 fl studio 20如何设置中文一些方法 一、fl studio手…

Angular:动态依赖注入和静态依赖注入

问题描述: 自己写的服务依赖注入到组件时候是直接在构造器内初始化的。 直到看见代码中某大哥写的 private injector: Injector 动态依赖注入和静态依赖注入 在 Angular 中,使用构造函数注入的方式将服务注入到组件中是一种静态依赖注入的方式。这种方…

docker中搭建lnmp

目录 一:项目环境 1、主机ip需求 2、 任务需求 二:多级构建Dockerfile实验部署 lnmp 1、先部署一个有所有依赖包的镜像 2、搭建nginx 3、搭建mysql 4、搭建php 三:一级构建安装lnmp 1、构建自定义docker网络 2、构建nginx容器&#x…

办公室安全升级,如何保障人身财产安全?

视频监控,一种常见的安全措施,以监视和记录办公室内的活动。这项技术为企业提供了许多优势,包括保障员工和财产安全、帮助调查犯罪事件、提高业务管理效率以及应对突发事件。 因此,在合理范围内应用视频监控,将为企业提…

LINUX中的myaql(一)安装

目录 前言 一、概述 二、数据库类型 三、数据库模型 四、MYSQL的安装 (一)yum安装MYSQL (二)rpm安装MYSQL 五、MYSQL本地登录 rpm安装MYSQL本地登录 六、重置密码 总结 前言 MySQL是一种常用的开源关系型数据库管理系统&#xff…

泛微OA客戶管理融合呼叫中心系统功能

泛微OA,全程数字化运营平台。用户因业务管理需要,crm客户管理流程中,需要通过语音与客户沟通,对话务沟通过程实现流程和过程化管理。 泛微OA中,无需过多编程代码。通过呼叫中心接口快速开发,实现点击拨打&…

【计算机网络】第 3 课 - 计算机网络体系结构

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 ! 时止则止,时行则行。动静不失其时,其道光明。 目录 1、常见的计算机网络体系结构 2、计算机网络体系结构分层的必要性 2.1、物理层 2.2、数据链路层 2.3、网路层 2.4、运输层 2…

融合正余弦和柯西变异的麻雀搜索算法优化CNN-BiLSTM,实现多输入单输出预测,MATLAB代码...

上期作者推出的融合正余弦和柯西变异的麻雀优化算法,效果着实不错,今天就用它来优化一下CNN-BiLSTM。CNN-BiLSTM的流程:将训练集数据输入CNN模型中,通过CNN的卷积层和池化 层的构建,用来特征提取,再经过BiL…

(20)操纵杆或游戏手柄

文章目录 前言 20.1 你将需要什么 20.2 校准 20.3 用任务规划器进行设置 20.4 飞行前测试控制装置 20.5 测试失控保护 20.6 减少控制的滞后性 前言 本文解释了如何用操纵杆或游戏手柄控制你的飞行器,使用任务计划器向飞行器发送"RC Override"消息…

【C++基础(六)】类和对象(中) --构造,析构函数

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:C初阶之路⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 类和对象-中 1. 前言2. 构造函数3. 构造函数的特性4…

Opencv 细节补充

1.分辨率的解释 •像素:像素是分辨率的单位。像素是构成位图图像最基本的单元,每个像素都有自己的颜色。 •分辨率(解析度): a) 图像分辨率就是单位英寸内的像素点数。单位为PPI(Pixels Per Inch) b) PPI表示的是每英…

从零开始 Spring Cloud 7:Gateway

从零开始 Spring Cloud 7:Gateway 图源:laiketui.com Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关,它旨…

系统集成|第三章(笔记)

目录 第三章 系统集成专业技术3.1 信息系统建设3.1.1 信息系统3.1.2 信息系统集成 3.2 信息系统设计3.3 软件工程3.4 面向对象系统分析与设计3.5 软件架构3.5.1 软件的架构模式3.5.2 软件中间件 3.6 典型应用集成技术3.6.1 数据库与数据仓库技术3.6.2 Web Services 技术3.6.3 J…