SimKD

news2024/11/17 9:45:01

又搬来一个简单高效的知识蒸馏技术哦~~直接复用教师分类器还能显著减小性能差距的~

在分类器的上一层通过特征对齐来训练学生模型,并直接复用教师分类器到学生模型中,再使用L2损失进行特征对齐。来自浙江大学的复用教师模型的方法哦~~ 浙大好厉害~~

论文名称: Knowledge Distillation with the Reused Teacher Classifier

论文地址:https://arxiv.org/pdf/2203.14001.pdf

给定一个参数量较大的教师模型,知识蒸馏 (KD) 的目标是帮助另一个参数量较少的学生模型获得与较大的教师模型相似的泛化能力。实现这一目标的一种直接方法是,给定相同的输入,尽量减少它们输出预测结果的差距。原始 KD 策略的一个不足之处是,教师模型和学生模型性能的差距依然很大。

之前有一些相关的知识蒸馏的方法,如[1][2][3][4][5][6]。这些方法利用了一些中间层特征的信息,同时也获益于精心设计的知识蒸馏的特征 (比如蒸馏注意力[7],蒸馏相关性[8][9],蒸馏教师模型和学生模型的互信息[10]等)。

这些知识蒸馏策略的确能够带来某些性能的提升,但是它们要么基于不那么鲁棒的超参,要么依赖于精心设计的蒸馏特征。

本文作者提出了一个简单的知识蒸馏技术,可以显著弥合教师和学生模型之间的性能差距,称为 SimKD,如下图1所示。作者认为,教师模型强大的预测能力不仅归功于更强的特征提取能力,最后的分类器 (Classifier) 也同样重要。基于这一点,作者在分类器的上一层通过特征对齐 (Feature Alignment) 来训练学生模型,并直接复用 (Reuse) 教师分类器到学生模型中。

图1:SimKD 简介

原始 KD 方法

深度神经网络模型可以看成是一个特征提取器 + 最后的分类层。特征提取器通常是由很多个非线性层组成,分类层一般是由一个 Fully Connected Layer 加上一个 softmax 激活函数构成。它们的参数通过反向传播算法更新。

 图2:原始 KD 方法

 Simple KD 方法

Simple KD 方法是基于特征蒸馏方法,特征蒸馏方法如下图3所示,特征蒸馏主要是收集和传输 teacher 和 student 模型的额外梯度信息,以更好地训练学生的特征 Encoder。然而,特征蒸馏方法很大程度上依赖于特征类型的选择,比如是蒸馏注意力特征还是隐藏层特征。同时,由于涉及到的特征类型较多,特征蒸馏还对超参数的选择比较敏感。结合以上两个缺点,特征蒸馏方法比较耗时,同时我们很难直接做出判断哪种模型适合什么类型的特征蒸馏。

图3:特征蒸馏技术

SimKD 是一种简单的知识蒸馏技术,如下图4所示,它一个关键组成部分是 "分类器复用" 操作,即我们直接借用预先训练好的教师分类器进行学生推理,而不是训练一个新的分类器。这样就不需要用标签信息来计算交叉熵损失,使得特征对齐损失成为产生梯度的唯一来源

作者认为,精心训练好的教师模型中包含的判别能力是非常重要的,但在很多 KD 方法中被很大程度上忽略了。作者是这么理解的:当一个模型被要求处理几个具有不同数据分布的任务,一个基本的做法是冻结或共享一些浅层作为跨不同任务的特征提取器,同时微调最后一层分类器以学习特定于任务的信息[11][12]。在这种单模型多任务的设置中,现有的研究一般认为:

  • task-invariant 的信息可以在不同模型之间共享,而 task-specific 的信息则需要独立识别,通常由最终的分类器进行识别。

推广到 KD 领域,不同能力的教师和学生模型在相同的数据集上进行训练,作者认为:

  • capability-invariant 的信息可以在教师和学生模型之间共享,而 capability-specific 的信息则学生模型很难独立地学好,通常这些信息在网络的深层,尤其是最后的分类器。

图4:Simple KD 方法 

通过这种简单的技术,KD 中的性能下降将得到极大的缓解。而且,来自预训练的教师模型的特征复用允许合并更多的层,不限于最终的分类器。通常情况下,重用的层数越多,学生的准确率越高,但是会增加额外的推理负担。

与其他 KD 方法的精度对比

数据集:CIFAR100,ImageNet。优化器:SGD 0.9 Momentum,CIFAR100 和 ImageNet 分别训练 240 和 120 Epochs。

对比的其他 KD 方法:FitNet,AT,SP, VID,CRD,SRRL,SemCKD

SimKD 的性能始终优于所有竞争对手,在某些情况下提高相当显著。例如,对于 "ResNet8x4 & ResNet-32x4" 的组合,SimKD 在 ImageNet 上的准确率提高了 3.66%。作者还发现,在 "ResNet-8x4 & WRN-40-2" 和 "ShuffleNetV2 & ResNet110x2" 组合的情况下,用 SimKD 训练的学生模型比教师模型的精度更高,这有点令人困惑,因为即使是特征对齐损失训练到了零,也只能保证它们的准确性完全相同。自蒸馏 (Self-distillation) 的一个可能解释是,式3损失函数可以帮助特征重建,或许可以帮助学生模型变得更稳健,从而获得更好的结果。

分类器复用操作分析

"分类器重用" 操作是本文取得成功的关键。为了更好地理解它的作用,作者用两种可选策略进行了几个实验来处理学生模型的 Encoder 和分类器:

(1) 联合训练:不再复用教师模型的分类器,而联合训练学生模型的 Encoder 和分类器

图8:联合训练实验结果

(2) 顺序训练:先使用3式的损失函数训练好学生的特征提取器,再冻结其参数,即冻结提取的特征,用常规训练过程训练随机初始化的学生分类器

以上做法与自监督训练中的 Linear Probing 做法一致。实验结果如下图9所示。可以发现,除了 "WRN-40-1 & WRN-40-2" 和 "ResNet-110/116 & ResNet-110x2",其他学生模型的测试精度出现了急剧下降。而几次调优初始学习速率只对性能产生了轻微的影响。图9的结果表明,即使提取的特征已经对齐,训练一个令人满意的学生分类器仍然是一个挑战。相比之下,直接重用预先训练好的教师分类器显得简单,而且性价比高。

图9:顺序训练实验结果

(3) 复用更多的层:以 ResNet 架构为例,除了复用最后的教师模型分类器之外,还复用最后一个 Building Block (SimKD+),和倒数第二个 Building Block (SimKD++)

实验结果如下图10所示。SimKD+ 和 SimKD++ 进一步提升了性能,但是复杂性也增加了。这些结果支持了 SimKD 的假设,即重用深层教师层有利于学生模型性能的提升,可能是因为其中包含了大多数特定能力的信息。另一种解释是,重用更深层的教师层将使浅层教师层的近似更容易实现,从而减少性能下降。在实践中,只重用最终的教师分类器可以很好地平衡性能和参数复杂性。

图10:复用更多的层实验结果

投影层分析

图11:投影层消融实验结果

总结

 whaosoft aiot http://143ai.com 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78189.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年五面蚂蚁、三面拼多多、字节跳动最终拿offer入职拼多多

文章有点长,请耐心看完,绝对有收获!不想听我BB直接进入面试分享: 准备过程 蚂蚁金服面试分享 拼多多面试分享 字节跳动面试分享 总结 说起来开始进行面试是年前倒数第二周,上午9点,我还在去公司的公交…

ERD Online 4.0.5 在线数据库建模、元数据管理(免费、私有部署)

4.0.5版本来袭❝ fix(erd): 增加数据库数据查询功能,支持多数据源切换查询,查看sql执行计划fix(erd): 数据查询功能,保留历史查询记录,格式化sql,多级树结构保存历史查询fix(erd): 依赖ERD加密手段,导出保留…

vdbench测试SSD快速入门

介绍 vdbench是一个I/O工作负载生成器,通常用于验证数据完整性和度量直接附加(或网络连接)存储性能。它可以运行在windows、linux环境,可用于测试文件系统或块设备基准性能。我们下面主要以块设备为介绍对象。 下载及安装 下载…

Linux 在过去几年发生的六种变化

随着时间的推移,Linux 桌面已经发生了变化,这种变化是逐渐发生的,因此这里汇总了过去十年中 Linux 桌面体验发生变化的一些具体方式。资深用户知道 Linux 桌面已经走过了漫长的道路。从前端应用程序设计到后端 Linux 组件,近年来发…

驱动无模块注入dll

文章目录实现效果三环无模块注入的方案反射型dll注入方式的改进零环无模块注入方案petoshellcode无模块注入流程实现代码Xenos注入方案研究IT_MMap注入IT_Thread注入IT_Apc注入火绒的注入思路实现效果 可以看到dll已经成功执行,但是在内存区域里面并没有我们的模块&…

《野球少年》:投捕搭档·棒球联盟

中文名 野球少年 原版名称 バッテリー 别 名 棒球伙伴、Battery 动画制作 ZERO-G 类 型 青春、运动、棒球 剧情简介 身为一名投手,原田巧是位拥有着拔群棒球才能的少年。在初中入学时移居的山间城镇新田市,巧与接住自己全力投球的捕手永仓豪相遇了。…

13 个你应该知道的 Webpack 优化技巧

Webpack 是目前前端开发最重要的构建工具。无论是自己的日常开发,还是准备面试,都应该掌握一些关于 Webpack 的优化技巧。 在这篇文章中,我将从三个方面分享一些我常用的技巧: 提高优化速度 压缩打包文件的大小 改善用户体验。…

[附源码]Python计算机毕业设计SSM基于框架的动漫设计(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

MarkDown语法浅析(基础语法)

本篇学习笔记简述MarkDown基础语法。掌握了“MarkDown基本语法简单HTML5标签”的综合运用,就可以把CSDN博文搞得美美哒✌ (本文获得CSDN质量评分【92】)【学习的细节是欢悦的历程】Python 官网:https://www.python.org/ Free:大咖免费“圣经…

SpringMVC笔记

文章目录一、SpringMVC简介1、什么是MVC2、什么是SpringMVC3、SpringMVC的特点二、HelloWorld1、开发环境2、创建maven工程a>添加web模块b>打包方式:warc>引入依赖3、配置web.xmla>默认配置方式b>扩展配置方式4、创建请求控制器5、创建springMVC的配…

Android开发中的服务发现技术

我们的日常开发中充满了InterfaceRegistry这种模式的代码,其中: Interface为定义的服务接口,可能是业务功能服务也可能是日志服务、数据解析服务、特定功能引擎等各种抽象层(abstract layer);Registry为特…

线性表→顺序表→链表 逐个击破

一. 线性表 1. 前言 线性表,全名为线性存储结构。使用线性表存储数据的方式可以这样理解,即 “ 把所有(一对一逻辑关系的)数据用一根线儿串起来,再存储到物理空间中 ”。这根线有两种串联形式,如下图,即顺序存储(集中…

【收藏级】MySQL基本操作的所有内容(常看常新)

文章目录前言一、ER模型二、数据类型三、字段命名规范四、数据库创建与管理4.1、创建数据库4.2、删除数据库4.3、列出数据库4.4、备份数据库4.5、还原数据库4.6、使用某个数据库五、数据表创建与管理5.1、创建表、结构5.2、查看表结构5.3、查看数据表5.4、复制表结构5.5、复制表…

m基于PSO粒子群算法的重采样算法仿真,对比随机重采样,多项式重采样,分层重采样,系统重采样,残差重采样,MSV重采样

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 重采样的主要方法有随机重采样,多项式重采样,分层重采样,系统重采样,残差重采样,MSV重采样等。 a.随机采样是一种利用分层统计思想设计出来的,将空间均匀划分,粒子打点后…

Lecture6:激活函数、权值初始化、数据预处理、批量归一化、超参数选择

目录 1.最小梯度下降(Mini-batch SGD) 2.激活函数 2.1 sigmoid 2.2 tanh 2.3 ReLU 2.4 Leaky ReLU 2.5 ELU 2.6 最大输出神经元 2.7 建议 3.数据预处理 4. 如何初始化网络的权值 5. 批量归一化 6.超参数的选择 1.最小梯度下降&#xf…

Flowable定时器与实时流程图

1. 定时器 1.1. 流程定义定时激活 在之前松哥给小伙伴们介绍流程定义的时候,流程都是定义好之后立马就激活了,其实在流程定义的这个过程中,我们还可以设置一个激活时间,也就是流程定义好之后,并不会立马激活&#xf…

Java一些面试题(简单向)

以下全部简单化回答(本人新手,很多都是直接百度粘贴收集得来的,如有不对请留下正确答案,谢谢) (问题来源https://www.bilibili.com/video/BV1XL4y1t7LL/?spm_id_from333.337.search-card.all.click&vd_source3cf72bb393b8cc11b96c6d4bfbcbd890) 1.重写 重载的区别 重写(ov…

dubbo3.0使用

dubbo3.0使用 介绍 官方网址:https://dubbo.apache.org/ 本文基于springCloud依赖的方式演示相关示例:https://github.com/alibaba/spring-cloud-alibaba/wiki/Dubbo-Spring-Cloud dubbo示例项目:https://github.com/apache/dubbo-sample…

9 内中断

内中断 任何一个通用的CPU,比如8086 ,都具备一种能力,可以在执行完当前正在执行的指令之后,检测到从CPU 外部发送过来的或内部产生的一种特殊信息,并且可以立即对所接收到的信息进行处理。这种特殊的信息,…

S7-200SMART高速脉冲输出的使用方法和示例

S7-200SMART高速脉冲输出的使用方法和示例 S7-200SMART PLC内部集成了高速脉冲发生器,不同的CPU型号,高速脉冲发生器的数量不同。 具体型号可参考下图: 注意:要输出高速脉冲的话,必须选择ST晶体管型号的PLC,SR继电器型的不支持。 S7-200SMART PLC能产生2种类型的高速脉冲…