Decoupled Knowledge Distillation解耦知识蒸馏

news2024/9/28 7:19:16

Decoupled Knowledge Distillation解耦知识蒸馏

现有的蒸馏方法主要是基于从中间层提取深层特征,而忽略了Logit蒸馏的重要性为了给logit蒸馏研究提供一个新的视角,我们将经典的KD损失重新表述为两部分,即目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。我们实证研究并证明了两部分的效果:TCKD转移了关于训练样本“难度”的知识而NCKD是logit蒸馏有效的突出原因。更重要的是,我们揭示了经典KD损失是一个耦合公式,它(1)抑制了NCKD的有效性,(2)限制了平衡这两个部分的灵活性。为了解决这些问题,我们提出了解耦知识蒸馏(DKD),使TCKD和NCKD更有效和灵活地发挥其作用。

介绍

在过去的几十年里,计算机视觉领域已经被深度神经网络(DNN)彻底改变,它成功地促进了各种真实场景的任务,如图像分类、目标检测和语义分割。然而,大的网络通常受益于大的模型容量,引入了高计算和存储成本。在广泛部署轻量级模型的工业应用中,这样的成本并不可取。在文献中,降低成本的一个潜在方向是知识蒸馏(KD)。KD代表了一系列专注于将知识从重模型(教师)——转移到轻模型(学生)的方法,这可以在不引入额外成本的情况下提高轻模型的性能。

KD的概念在[12]中首次提出,通过最小化教师和学生预测logit之间的KL-Divergence来转移知识(图1a)。

image-20240303132727112

自[28]以来,大部分的研究注意力都集中在从中间层的深层特征中提取知识。与基于logit的方法相比,特征蒸馏的性能在各种任务上是否表现出色,因此,对logit蒸馏的研究很少涉及。然而,基于特征方法的训练成本并不令人满意,因为在训练期间引入了额外的计算和存储使用(例如,网络模块和复杂的操作)来提取深度特征。

Logit蒸馏需要边际的计算和存储成本,但性能较差。直观的说,logit蒸馏应该达到与特征蒸馏相当的性能,因为logit比深度特征处于更高的语义层。假设logit蒸馏的潜力收到未知原因的限制,导致性能不理想。为了振兴基于Logit的方法,我们通过深入研究KD的机制开始这项工作。首先,我们将分类预测分为两个层次(1)对目标类和所有非目标类进行二值预测;(2)对每个非目标类进行多类预测。在此基础上,我们将经典KD损失[12]重新表述为两部分,如图1b所示。一种是针对目标类的二元logit蒸馏另一种是针对非目标类的多类别logit蒸馏。为了简化期间,我们将其分别命名为目标分类和知识蒸馏(TCKD)和非目标分类知识蒸馏(NCKD)。重新配方使我们能够独立地研究这两部分的效果。

TCKD通过二元logit蒸馏传递知识,这意味这只提供目标类的预测,而每个非目标类的具体预测是未知的。一个合理的假设是,TCKD传递了关于训练样本“难易度”的知识,即知识描述了识别每个训练样本的难易程度。为了验证这一点,我们从三个方面设计实验来提高训练数据的“难度”,即更强的增强、更嘈杂的标签和具有固有挑战性的数据集。

NCKD只考虑非目标logit之间的知识。有趣的是,我们通过经验证明,仅应用NCKD就可以获得与经典KD相当甚至更好的结果,这表明非目标logit中包含的知识至关重要,这可能是突出的“暗知识”。

更重要的是,我们的重新表述表明,经典KD损失是一个高度耦合的表述(如图1b所示),这可能是logit蒸馏潜力有限的原因。首先,NCKD损失项被一个与教师对目标类别的预测置信度负相关的系数加权。因此较大的预测分数将导致较小的权重。这种耦合显著抑制了NCKD对良好预测训练样本的影响。这种抑制并不可取,因为教师对训练样本越有信息,可提供的知识越可靠越有价值。其次,TCKD和NCKD的意义是耦合的,即不允许分别对TCKD和NCKD进行加权。这种限制是不可取的,因为TCKD和NCKD应该分开考虑,因为它们的贡献来自不同的方面。

为了解决这些问题,我们提出了一种灵活高效的logit蒸馏方法,称为解耦知识蒸馏(DKD,图1b)DKD将NCKD损失从与教师置信度负相关的系数中解耦,将其替换为恒定值,从而提高了对预测良好的样本的蒸馏效率。同时,对NCKD和TCKD也进行了解耦,通过调整各部分权重,可以分别考虑NCKD和TCKD的重要性。

总的来说,我们的贡献总结如下:

(1)将经典的logit蒸馏分为TCKD和NCKD,为Logit蒸馏的研究提供了新的思路。

(2)我们揭示了由其高耦合公式引起的经典KD损失的局限性。NCKD与教师信心的耦合抑制了知识转移的有效性。TCKD与NCKD的耦合限制了平衡两部分的灵活性。

(3)为了克服这些局限性,我们提供了一种有效的logit蒸馏方法DKD。

重新思考知识蒸馏

在本节中,我们深入探讨知识蒸馏的机制。我们将KD损失重新表述为两部分的加权和,一部分与目标类相关,另一部分与目标类无关。我们探讨了知识蒸馏框架中每个部分的作用,并揭示了经典KD的一些局限性。受此启发,我们进一步提出了一种新的logit蒸馏方法,在各种任务上取得了显著的性能。

回顾KD

Notation对于第t类的训练样本,分类概率可以表示为P=image-20240303150007961,其中pi是第i类的概率,C是类的个数。p中的每个元素都可以通过softmax函数得到:
p i = e x p ( z i ) ∑ j = 1 C e x p ( z j ) p_i = \frac{exp(z_i)}{\sum_{j=1}^Cexp(z_j)} pi=j=1Cexp(zj)exp(zi)
其中zi代表第i类的对数。

为了区分于目标类相关和不相关的预测,我们定义了以下符号。b = image-20240303150447331表示目标类(pt)和其他所有非目标类(p\t)的二值概率,其计算公式为:

image-20240303150539198

同时,我们声明image-20240303150715008独立建模非目标类之间的概率(即,不考虑第t类)。每个元素的计算方法为:image-20240303150736308

Reformulation 在第一部分中,我们尝试用二元概率b和非目标类之间的概率p来重新表述KD。T和S分别表示老师和学生。经典KD使用kl散度作为损失函数,也可以写成2:

image-20240303151314489

根据等式(1)和等式(2)我们有image-20240303151721273,所以我们可以把等式(3)改写为:

image-20240303151806455

等式(4)可以改写为:

image-20240303151918101

如公式(5)所示,KD损失被重新表述为两项的加权和。image-20240303152823063表示目标类别的教师和学生的二元概率之间的相似度。因此,我们将其命名为目标类知识蒸馏(TCKD)。同时,image-20240303153038652表示非目标类中教师和学生概率的相似度,称为非目标类知识蒸馏(NVKD)。式(5)可以改写为:

image-20240303153129634

显然,NCKD的重建与image-20240303153158532是耦合的。

上述重新表述启发了我们对TCKD和NCKD的个体效应进行研究,这将揭示经典耦合表述的局限性。

TCKD和NCKD的影响

各部件的性能增益。我们分别研究了TCKD和NCKD对CIFAR-100的影响。选择ResNet、WideResNet(WRN)和ShuffleNet作为训练模型,其中考虑了相同和不同的架构。实验结果如表1,对于每个师生对,我们报告了(1)学生基线,(2)经典KD(其中同时使用TCKD和NCKD),(3)单一TCKD和(4)单一NCKD的结果。每个损失的权重设置为1.0(包括默认的交叉熵损失)。其它实现细节与第4节相同。

image-20240303155626665

直观地说,TCKD集中于与目标类相关的知识,因为相应的损失函数只考虑二进制概率。相反,NCKD侧重于非目标类别的知识。我们注意到单独使用TCKD对学生来说可能没有帮助(例如在ShufflerNet-V1上增加0.02%和0.12%)甚至是有害的(例如,在WRN-16-2上下降2.3%,在ResNet8-4上下降3.87%)。然而,NCKD的蒸馏性能与经典KD相当,甚至更好(例如,在ResNet8/4上,1.76% vs 1.13%)。消融结果表明靶类相关知识不如非靶类知识重要,为了深度研究这一现象,我们提供如下进一步的分析。

TCKD传递了关于训练样本“难度”的知识

根据等式(5),TCKD通过二值分类任务传递“暗知识”,这可能与样本的“难度“有关。例如,与image-20240303155803478的训练样本相比,image-20240303155813762的训练样本可能”更容易“让学生学习。由于TCKD传达了训练样本的“难度”,我们假设当训练数据变得具有挑战性时,有效性将被解释。然而,CIFRA-100训练集很容易过拟合。因此,教师提供的“难度”知识并不是信息性的。在这一部分中,我们从三个角度进行实验验证:训练数据越难,TCKD提供的好处越多。

(1)应用强增强是增加训练数据难度的一种直接方法。我们在CIFAR-100上使用AutoAugment训练ResNet32×4模型作为教师,获得了81.29%的top-1验证精度。对于学生,我们训练带/不带TCKD的ResNet8、4和ShufflerNetv1模型。表2的结果表明,如果应用强增强,TCKD可以获得显著的性能增益。

image-20240303161609391

(2)噪声标签也会增加训练数据的难度。我们在CIFAR-100上以{0.1,0.2,0.3}对称噪声比训练ResNet32×4模型作为教师,ResNet8×4模型作为学生,如下[7,35]。如表3所示,结果表明TCKD在噪声较大的训练数据上取得了更多的绩效提升。

image-20240303161939762

(3)挑战性的数据集(例如,ImageNet也被考虑。表4显示,TCKD可以在ImageNet上带来+0.32%的性能增益。

image-20240303162009395

最后,我们通过实验各种策略来增加训练数据的难度(如强增强、噪声标签、困难任务),证明了TCKD的有效性。结果证明,在提取更具挑战性的训练数据时,有关训练样本“难度“的知识可能更有用。

NCKD是logit蒸馏工作的重要原因,但受到很大的抑制。有趣的是,我们在表1中注意到,当仅应用NCKD时,性能与经典KD相当甚至更好。结果表明,非目标类的知识对logit蒸馏至关重要,可以成为突出的“暗知识”。然而,通过回顾方程(5),我们注意到NCKD损失与image-20240303162635731相耦合。其中,image-20240303162731869代表教师对目标类别的置信度。因此,更有置信度的预测会导致更小的NCKD权重。我们假设教师对训练样本越有信心,它提供的知识就越可靠,越有价值。然而,这种自信的预测高度抑制了损失权重。我们假设这一事实会限制知识蒸馏的有效性,这首先是由于我们在等式(5)中对KD的重新表述而研究的。

我们设计了一个消融实验来验证预测良好的样本确实比其他样本更好地传递知识。首先,我们根据image-20240303163021532对训练样本进行排序,并将其平均分成两个子集。为了清晰起见,一个子集包括image-20240303163212148前50%的样本,而其余样本在另一个子集中。然后,我们在每个子集上使用NCKD训练学生网络,以比较性能增益(而交叉熵损失仍然在整个集合上)。表5显示,在前50%的样本上使用NCKD可以获得更好的性能,这表明预测良好的样本的知识比其他样本更丰富。然而,预测良好的样本的损失权重被教师的高置信度所抑制。

image-20240303163329872

解耦知识蒸馏

至此,我们将经典KD损失重新表述为两个独立部分的加权和,进一步验证了TCKD的有效性,揭示了NCKD的抑制作用。具体来说,TCKD传递了关于训练样本“难度”的知识。TCKD可以在更具挑战性的训练数据上获得更显著的改进。NCKD在非目标类之间进行知识转移。当权重image-20240303163557024较小时,知识转移受到抑制。

本能地,TCKD和NCKD都是必不可少的,至关重要的。然而,在经典KD公式中,TCKD和NCKD从以下几个方面耦合。

(1)首先,NCKD与image-20240303163724801耦合,这可以抑制预测良好的样本上的NCKD。由于表5的结果表明,预测良好的样本可以带来更多的性能增益,因此耦合形式可能会限制NCKD的有效性。

(2)另一方面,在经典KD框架下,NCKD与TCKD的权重是耦合的。不允许为了平衡重要性而改变每个词的权重。我们认为TCKD和NCKD应该考虑他们的贡献来自不同的方面而分离。

基于我们对KD的重新表述,我们提出了一种新的logit蒸馏方法——解耦知识蒸馏(DKD)。我们提出的DKD在解耦公式中独立考虑了TCKD和NCKD。具体来说,我们分别引入了两个超参数作为TCKD和NCKD的权重,DKD的损失函数为:

image-20240303164149479

在DKD中,image-20240303164236706会抑制NCKD的有效性,使用image-20240303164247291代替。此外,还允许调整两个超参数以平衡TCKD和NCKD的重要性。DKD通过解耦NCKD和TCKD,为logit蒸馏提供了高效、灵活的方法。算法1提供了DKD的伪代码。

image-20240303164406442

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1486047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaSec 基础之五大不安全组件

文章目录 不安全组件(框架)-Shiro&FastJson&Jackson&XStream&Log4jLog4jShiroJacksonFastJsonXStream 不安全组件(框架)-Shiro&FastJson&Jackson&XStream&Log4j Log4j Apache的一个开源项目,是一个基于Java的日志记录框架。 历史…

python学习笔记------元组

元组的定义 定义元组使用小括号,且使用逗号隔开各个数据,数据是不同的数据类型 定义元组字面量:(元素,元素,元素,......,元素) 例如:(1,"hello") 定义元组变量:变量名称(元素,元素,元素,......,元素)…

哈希表是什么?

一、哈希表是什么? 哈希表,也称为散列表,是一种根据关键码值(Key value)直接进行访问的数据结构。它通过把关键码值映射到表中一个位置来访问记录,从而加快查找速度。这个映射函数叫做散列函数&#xff08…

C#与VisionPro联合开发——单例模式

单例模式 单例模式是一种设计模式,用于确保类只有一个实例,并提供一个全局访问点来访问该实例。单例模式通常用于需要全局访问一个共享资源或状态的情况,以避免多个实例引入不必要的复杂性或资源浪费。 Form1 的代码展示 using System; usi…

初阶数据结构之---栈和队列(C语言)

引言 在顺序表和链表那篇博客中提到过,栈和队列也属于线性表 线性表: 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构。线性表在逻辑上是线性结构,也就是说是连…

c++之拷贝构造和赋值

如果一个构造函数中的第一个参数是类本身的引用,或者是其他的参数都有默认值,则该构造函数为拷贝构造函数。 那么什么是拷贝构造呢?利用同类对象构造一个新对象。 1,函数名和类必须同名。 2,没有返回值。 3&#x…

差分题练习(区间更新)

一、差分的特点和原理 对于一个数组a[],差分数组diff[]的定义是: 对差分数组做前缀和可以还原为原数组: 利用差分数组可以实现快速的区间修改,下面是将区间[l, r]都加上x的方法: diff[l] x; diff[r 1] - x;在修改完成后,需要做前缀和恢复…

4.关联式容器

关联式container STL中一些常见的容器: 序列式容器(Sequence Containers): vector(动态数组): 动态数组,支持随机访问和在尾部快速插入/删除。list(链表)&am…

奇舞周刊第521期:“一切非 Rust 项目均为非法”

奇舞推荐 ■ ■ ■ 拜登:“一切非 Rust 项目均为非法” 科技巨头要为Coding安全负责。这并不是拜登政府对内存安全语言的首次提倡。“程序员编写代码并非没有后果,他们的⼯作⽅式于国家利益而言至关重要。”白宫国家网络总监办公室(ONCD&…

Python3零基础教程之数学运算专题进阶

大家好,我是千与编程,今天已经进入我们Python3的零基础教程的第十节之数学运算专题进阶。上一次的数学运算中我们介绍了简单的基础四则运算,加减乘除运算。当涉及到数学运算的 Python 3 刷题使用时,进阶课程包含了许多重要的概念和技巧。下面是一个简单的教程,涵盖了一些常…

NOC2023软件创意编程(学而思赛道)python初中组决赛真题

目录 下载原文档打印做题: 软件创意编程 一、参赛范围 1.参赛组别:小学低年级组(1-3 年级)、小学高年级组(4-6 年级)、初中组。 2.参赛人数:1 人。 3.指导教师:1 人(可空缺)。 4.每人限参加 1 个赛项。 组别确定:以地方教育行政主管部门(教委、教育厅、教育局) 认…

嵌入式驱动学习第一周——linux的休眠与唤醒

前言 本文介绍进程的休眠与唤醒。 嵌入式驱动学习专栏将详细记录博主学习驱动的详细过程,未来预计四个月将高强度更新本专栏,喜欢的可以关注本博主并订阅本专栏,一起讨论一起学习。现在关注就是老粉啦! 行文目录 前言1. 阻塞和非阻…

Doris实战——美联物业数仓

目录 一、背景 1.1 企业背景 1.2 面临的问题 二、早期架构 三、新数仓架构 3.1 技术选型 3.2 运行架构 3.2.1 数据模型 纵向分域 横向分层 数据同步策略 3.2.2 数据同步策略 增量策略 全量策略 四、应用实践 4.1 业务模型 4.2 具体应用 五、实践经验 5.1 数据…

【Java EE】线程安全的集合类

目录 🌴多线程环境使用 ArrayList🎍多线程环境使⽤队列🍀多线程环境使⽤哈希表🌸 Hashtable🌸ConcurrentHashMap ⭕相关面试题🔥其他常⻅问题 原来的集合类, 大部分都不是线程安全的. Vector, Stack, HashT…

EndNote 21:文献整理与引用,一键轻松搞定 mac/win版

EndNote 21是一款功能强大的文献管理软件,专为学术研究者、学生和教师设计。它提供了全面的文献管理解决方案,帮助用户轻松整理、引用和分享学术文献。 EndNote 21软件获取 EndNote 21拥有直观的用户界面和强大的文献检索功能,用户可以轻松地…

昇腾ACL应用开发之硬件编解码dvpp

1.前言 在我们进行实际的应用开发时,都会随着对一款产品或者AI芯片的了解加深,大家都会想到有什么可以加速预处理啊或者后处理的手段?常见的不同厂家对于应用开发的时候,都会提供一个硬件解码和硬件编码的能力,这也是抛…

【C++干货基地】揭秘C++11常用特性:内联函数 | 范围for | auto自动识别 | nullptr指针空值

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引入 哈喽各位铁汁们好啊,我是博主鸽芷咕《C干货基地》是由我的襄阳家乡零食基地有感而发,不知道各位的…

基于springboot实现校园爱心捐赠互助管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现校园爱心捐赠互助管理系统演示 摘要 随着互联网及电子商务平台的飞速发展,利用在线平台实现的二手商品交易以及在线捐赠已经非常普遍,很多高校目前还存在贫困生需要通过爱心人士的捐助来完成学业,同时很多高校的大学生也希…

【C++】STL学习之旅——初识STL,认识string类

string类 1 STL 简介2 STL怎么学习3 STL缺陷4 string4.1 初识 string4.2 初步使用构造函数成员函数 5 小试牛刀Thanks♪(・ω・)ノ谢谢阅读!!!下一篇文章见!!! 1 STL 简介 …

PyCharm如何添加python库

1.使用pip命令在国内源下载需要的库 下面使用清华源,在cmd中输入如下命令就可以了 pip install i https://pypi.tuna.tsinghua.edu.cn/simple 包名版本号2.如果出现报错信息,Cannot unpack file…这种情况,比如下面这种 ERROR: Cannot unpa…