Decoupled Knowledge Distillation(CVPR 2022)原理与代码解析

news2025/1/13 7:33:26

paper:Decoupled Knowledge Distillation

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/DKD.py

背景

与基于响应logits-based的蒸馏方法相比,基于特征feature-based的蒸馏方法在各种任务上的表现更好,因此对基于响应的知识蒸馏的研究越来越少。然而,基于特征的方法的训练成本并不令人满意,因为在训练期间引入了额外的计算和存储占用(如网络模块和复杂的操作)来提取特征。基于响应的蒸馏所需的计算和存储都较小,但性能较差。直觉上来说,logit-based蒸馏方法应当达到与feature-based方法相当的性能,因为logits处与更深的层有更丰富的语义特征。作者猜测logit-based蒸馏的性能受到了未知原因的限制,导致表现不理想。

本文的创新点

本文作者深入研究了KD的作用机制,将分类预测拆分为两个层次:(1)对目标类和所有非目标类进行二分类预测。(2)对每个非目标类进行多分类预测。进而将原始的KD损失也拆分为两部分,一种是针对目标类的二分类蒸馏,另一种是针对非目标类的多分类蒸馏。并分别称为target classification knowledge distillation(TCKD)和non-target classification knowledge distillation(NCKD)。通过分别单独研究两部分对性能的影响,作者发现NCKD中包含了重要的知识,而原始KD对两部分耦合的方式抑制了NCKD的作用,也限制了平衡这两部分的灵活性。

为了解决这些问题,本文提出了一种新的logit蒸馏方法Decoupled Knowledge Distillation(DKD),将TCKD和NCKD进行解耦,使得它们之间的权重可调,从而解除了对NCKD的抑制,提升了蒸馏的性能。

方法介绍

Reformulating KD

Notions

对于一个属于第 \(t\) 类的样本,分类概率可以表示为 \(\mathbf{p}=[p_{1},p_{2},...,p_{t},...,p_{C}]\in \mathbb{R}^{1\times C}\),其中 \(p_{i}\) 是第 \(i\) 类的概率,\(C\) 是类别数。\(\mathbf{p}\) 中的每个元素都可以通过softmax函数得到

其中 \(z_{i}\) 表示第 \(i\) 类的logit。

为了将与目标类相关和无关的预测分开,定义 \(\mathbf{b}=[p_{t},p_{\setminus t}]\in \mathbb{R}^{1\times 2}\) 表示二分类概率,其中 \(p_{t}\) 表示目标类的概率,\(p_{\setminus t}\) 表示非目标类的概率(所有其它类的概率和),可按下式分别计算得到

同时定义 \(\hat{\mathbf{p}}=[\hat{p}_{1},...,\hat{p}_{t-1},\hat{p}_{t+1},...,\hat{p}_{C}]\in \mathbb{R}^{1\times (C-1)}\) 来单独建模非目标类别的概率(即不考虑第 \(t\) 类),其中每个元素按下式得到

Reformulation

\(\mathcal{T}\) 和 \(\mathcal{S}\) 分别表示教师和学生网络,根据上面定义的二分类概率 \(\mathbf{b}\) 和非目标类的多分类概率 \(\hat{\mathbf{p}}\),原始KD中的KL散度损失函数可以重写成下面的形式

根据式(1)和(2),我们有 \(\hat{p}_{i}=p_{i}/p_{\setminus t}\),式(3)可以重写成如下

然后式(4)又可以重写成如下

这里根据式(1)(2)(3)推导式(4)(5)的具体过程如下

由式(5)可以看出,KD loss可以看作两项的加权和,其中第一项表示教师和学生网络对目标类别预测概率之间的相似性,因此称之为Target Class Knowledge Distillation(TCKD)。第二项表示教师和学生网络对非目标类别预测概率之间的相似性,称为Non-Target Class Knowledge Distillation(NCKD)。因此式(5)可以重写成如下

显然,NCKD和 \(p_{t}^{\mathcal{T}}\) 是耦合的。

Effects of TCKD and NCKD

Performance gain of each part

作者在CIFAR-100数据集上分别研究了TCKD和NCKD的影响,结果如下表所示,可以看出,单独使用TCKD对学生模型的提升非常小甚至还会降低精度,而单独使用NCKD可以得到与完整KD相似甚至更高的精度,由此可以看出相比于TCKD,NCKD对学生网络精度的提升更加重要。

TCKD transfers the knowledge concerning the “difficulty” of training samples.

根据式(5)推测TCKD可能将关于样本“难度”的知识传递给了学生网络,例如,相比于 \(p_{t}^{\mathcal{T}}=0.75\) 的样本 \(p_{t}^{\mathcal{T}}=0.99\) 的样本对学生网络来说是更容易学习的样本。由于TCKD传递了样本的难度知识,推测当训练样本更难时TCKD的有效性就会彰显出来,因为CIFAR-100的数据比较简单,TCKD包含的难度知识也相对较少,因此作者通过三个角度进行实验,来验证观点:训练样本越难,TCKD提供的难度知识就越有用。

数据增强是一种增加训练样本难度很直接的方法,作者对CIFAR-100进行了AutoAugment增强,然后进行蒸馏的结果如下所示,可以看出进行数据增强后,TCKD对性能的提升更加明显。

噪声标签也会增加数据的训练难度,对数据添加噪声标签后结果如下所示,结果表明TCKD在噪声更大的训练数据上获得了更大的性能提升。

作者还考虑了更难的数据集比如ImageNet,在ImageNet上TCKD获得了0.32的性能提升。

通过上述实验,作者证明了TCKD在困难数据上的有效性,当在更困难的样本上进行蒸馏时,关于样本难度的知识更有用。

NCKD is the prominent reason why logit distillation works but is greatly suppressed.

从表(1)中可以看出单独使用NCKD时其性能和完整的KD相当甚至更好,这表明非目标类别的知识对logit蒸馏至关重要。但是从式(5)可以看出,NCKD和 \((1-p_{t}^{\mathcal{T}})\) 耦合,\(p_{t}^{\mathcal{T}}\) 表明教师对目标类别的置信度,因此置信度越高会导致NCKD的权重越小。作者认为教师模型对训练样本的置信度越高,它所能提供的知识应该越可靠越有价值,但实际上高置信度确抑制了损失的权重,因此作者将logit蒸馏性能不高的原因归结为原始的KD损失对NCKD的抑制。

作者设计了一个消融实验来验证预测准确即置信度高的样本确实比置信度低的样本包含更有用的知识。首先根据 \(p_{t}^{\mathcal{T}}\) 对训练样本进行排序,将其均分为两个子集,一个子集包含了 \(p_{t}^{\mathcal{T}}\) 前50%的样本,另一个子集包含 \(p_{t}^{\mathcal{T}}\) 后50%的样本。然后在每个子集上用NCKD训练学生网络来比较性能的增益。结果如下表所示,可以看出,对 \(p_{t}^{\mathcal{T}}\) 50%的样本使用NCKD获得了更好的性能,表明了预测准确的样本确实包含了更丰富的知识。

Decoupled Knowledge Distillation

针对上述问题,作者提出了解耦知识蒸馏Decoupled Knowledge Distillation(DKD),如下所示

具体来说,引入了超参 \(\alpha\) 和 \(\beta\) 分别作为TCKD和NCKD的权重。

实验结果

下表是采用不同的 \(\alpha\) 和 \(\beta\) 时学生网络的精度,表1中 \(\alpha\) 固定为1.0,表2中 \(\beta\) 固定为8.0。从结果可以看出解耦 \((1-p_{t}^{\mathcal{T}})\) 和NCKD可以带来显著的性能提升(73.64% vs. 74.79%),解耦TCKD和NCKD的权重获得了进一步的性能提升(74.79% vs. 76.32%)。第二个表表明TCKD是不可或缺的,同时当 \(\alpha\) 在1.0附近波动时,TCKD的提升比较稳定没有太大的波动。

下表是在CIFAR-100验证集上的结果,其中 \(\alpha\) 固定为1,对于不同的教师模型 \(\beta\) 值不同,具体后面会讲。

下面是在ImageNet上的结果

Guidance for tuning \(\beta\)

作者认为NCKD在知识传递中的重要性与教师网络的信心有关,教师网络越有信息,NCKD的重要性就越大,\(\beta\) 值就应该越大。如果目标类的logit值远大于所有非目标类,那么可以认为教师非常有信心,\(\beta\) 值也应该设置的更大。因此作者假定 \(\beta\) 值与目标类和所有非目标类之间的logit差有关。目标类的logit用 \(z_{t}\) 表示,其中 \(t\) 表示目标类别,\(z_{max}\) 表示所有非目标类的logit的最大值即 \(z_{max}=max(\left \{ z_{i}|i\ne t \right \} )\)。

作者选用ShuffleNet-v1作为学生网络,比较了选用不同的教师网络和不同的 \(\beta\) 值的精度,并且给出了所有训练样本上 \(z_{t}-z_{max}\) 的均值,结果如下

从结果可以看出最优的 \(\beta\) 值与 \(z_{t}-z_{max}\) 成正相关的关系。基于此,表6和表7中不同的教师网络对应的 \(\beta\) 值如下

代码解析

下面是官方实现,其中函数_get_gt_mask中tensor.scatter_()的用法具体见Torch.Tensor.scatter_( ) 用法解读_00000cj的博客-CSDN博客。在求nckd的输入pred_teacher_part2和log_pred_student_part2中都有一个- 1000.0 * gt_mask的操作,这里官方在issue里有解答https://github.com/megvii-research/mdistiller/issues/1,原本的应该是logits[1-gt_mask] / temperature计算所有非目标类别的softmax,因为这里index操作比较慢,因此改成logits/temperature - 1000 * gt_mask,gt_mask中非目标类别处全为0,因此相当于没减。目标类别的logit减去了1000,相当于softmax中分子和分母各加上 \(e^{-1000}\) 约等于0,等价于没加。

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    # (64,100),(64,100),(64),1,8,4
    gt_mask = _get_gt_mask(logits_student, target)  # (64,100),除了每个样本对应target索引处为True, 其它都为False
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    pred_student = cat_mask(pred_student, gt_mask, other_mask)  # (64,2), 第一列是目标类别的logit, 第二列是所有非目标类别的logit的和
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student = torch.log(pred_student)
    tckd_loss = (
        F.kl_div(log_pred_student, pred_teacher, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    # https://github.com/megvii-research/mdistiller/issues/1
    # e^{-1000}非常小约等于0,等价于把这一项去掉了
    pred_teacher_part2 = F.softmax(
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (
        F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    return alpha * tckd_loss + beta * nckd_loss


def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)  # (64,1)
    t2 = (t * mask2).sum(1, keepdims=True)  # (64,1)
    rt = torch.cat([t1, t2], dim=1)  # (64,2)
    return rt


class DKD(Distiller):
    """Decoupled Knowledge Distillation(CVPR 2022)"""

    def __init__(self, student, teacher, cfg):
        super(DKD, self).__init__(student, teacher)
        self.ce_loss_weight = cfg.DKD.CE_WEIGHT
        self.alpha = cfg.DKD.ALPHA
        self.beta = cfg.DKD.BETA
        self.temperature = cfg.DKD.T
        self.warmup = cfg.DKD.WARMUP

    def forward_train(self, image, target, **kwargs):
        logits_student, _ = self.student(image)
        with torch.no_grad():
            logits_teacher, _ = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_dkd = min(kwargs["epoch"] / self.warmup, 1.0) * dkd_loss(
            logits_student,
            logits_teacher,
            target,
            self.alpha,
            self.beta,
            self.temperature,
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_dkd,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/388018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教学典型案例】14.课程推送页面整理-增加定时功能

目录一:背景介绍1、代码可读性差,结构混乱2、逻辑边界不清晰,封装意识缺乏![在这里插入图片描述](https://img-blog.csdnimg.cn/bbfc5f04902541db993944ced6b62793.png)3、展示效果不美观二:案例问题分析以及解决过程1、代码可读性…

现代操作系统——Linux架构与学习

小白的疑惑 在我决定从事嵌入式(应用层)方面的工作时,我查询了大量资料该如何学习,几乎所有观点不约而同的都指向了学习好Linux,大部分工作都是在Linux环境下来进行工作的。于是我雄心勃勃的去下载Linux,可…

GEE开发之降雨(CHIRPS)数据获取和分析

GEE开发之降雨CHIRPS数据获取和分析1.数据介绍2.初识CHIRPS2.1 代码一2.2 代码二3.逐日数据分析和获取4.逐月数据分析和获取4.1 代码一4.2 代码二(简洁)5.逐年数据分析和获取5.1 代码一5.2 代码二(简洁)前言:主要获取和分析UCSB-CHG/CHIRPS/DAILY的日数据、月数据和…

一文带你入门,领略angular风采(上)!!!

话不多说,上代码!!! 一、脚手架创建项目 1.安装脚手架指令 npm install -g angular/cli 2.创建项目 ng new my-app(ng new 项目名) 3.功能选择 4.切换到创建好的项目上 cd my-app 5.安装依赖 npm install 6.运行项目 npm start或…

32 openEuler使用LVM管理硬盘-管理卷组

文章目录32 openEuler使用LVM管理硬盘-管理卷组32.1 创建卷组32.2 查看卷组32.3 修改卷组属性32.4 扩展卷组32.5 收缩卷组32.6 删除卷组32 openEuler使用LVM管理硬盘-管理卷组 32.1 创建卷组 可在root权限下通过vgcreate命令创建卷组。 vgcreate [option] vgname pvname ...…

曹云金郭德纲关系迎曙光,新剧《猎黑行动》被德云社弟子齐点赞

话说天下大势,分久必合,合久必分。这句话经过了历史的证明,如今依然感觉非常实用。 就拿郭德纲和曹云金来说,曾经后者是前者的得门生,两个人不但情同父子,曹云金还是郭德纲默认接班人。然而随着时间的流逝&…

数据库基本概念及常见的数据库简介

数据库基本概念 【1】数据库基本概念 (1)数据 所谓数据(Data)是指对客观事物进行描述并可以鉴别的符号,这些符号是可识别的、抽象的。它不仅仅指狭义上的数字,而是有多种表现形式:字母、文字…

设计模式-策略模式

前言 作为一名合格的前端开发工程师,全面的掌握面向对象的设计思想非常重要,而“设计模式”是众多软件开发人员经过相当长的一段时间的试验和错误总结出来的,代表了面向对象设计思想的最佳实践。正如《HeadFirst设计模式》中说的一句话&…

【Verilog】——模块,常量,变量

目录 1.模块 1.描述电路的逻辑功能 2. 门级描述 3.模块的模板​编辑 2.关键字 3.标识符 4.Verilog源代码的编写标准 5.数据类型 1.整数常量​ 2.参数传递的两种方法 3.变量 4.reg和wire的区别 5.沿触发和电平触发的区别​ 6.memory型变脸和reg型变量的区别​ 1.模块 1.描…

Mybatis一级缓存与二级缓存

一、MyBatis 缓存缓存就是内存中的数据,常常来自对数据库查询结果的保存。使用缓存,我们可以避免频繁与数据库进行交互,从而提高响应速度。MyBatis 也提供了对缓存的支持,分为一级缓存和二级缓存,来看下下面这张图&…

docker安装即docker连接mysql(window)

一 安装docker 1.什么是docker Docker容器与虚拟机类似,但二者在原理上不同。容器是将操作系统层虚拟化,虚拟机则是虚拟化硬件,因此容器更具有便携性、高效地利用服务器。 2.WSL2 WSL,即Windows Subsystem on Linux,中…

JavaScript高级 XHR - Fetch

1. 前端数据请求方式 早期的网页都是通过后端渲染来完成的:服务器端渲染(SSR,server side render) 客户端发出请求 -> 服务端接收请求并返回相应HTML文档 -> 页面刷新,客户端加载新的HTML文档 当用户点击页面中…

C++:哈希:闭散列哈希表

哈希的概念 哈希表就是通过哈希映射,让key值与存储位置建立关联。比如,一堆整型{3,5,7,8,2,4}在哈希表的存储位置如图所示: 插入数据的操作: 在插入数据的时候,计算数据相应的位置并进行插入。 查找数据的操作&…

从企业数字化发展的四个阶段,看数字化创新战略

《Edge: Value-Driven Digital Transformation》一书根据信息技术与企业业务发展的关系把企业的数字化分为了四个阶段: 技术与业务无关技术作为服务提供者开始合作科技引领差异化优势以技术为业务核心 下图展示了这四个阶段的特点: 通过了解和分析各个…

[ant-design-vue] tree 组件功能使用

[ant-design-vue] tree 组件功能使用描述环境信息相关代码参数说明描述 是希望展现一个树形的菜单,并且对应的菜单前有复选框功能,但是对比官网的例子,我们在使用的过程中涉及到对半选中情况的处理: 半选中状态: 选中…

NodeJS安装

一、简介Node.js是一个让JavaScript运行在服务端的开发平台,Node.js不是一种独立的语言,简单的说 Node.js 就是运行在服务端的 JavaScript。npm其实是Node.js的包管理工具(package manager),类似与 maven。二、安装步骤…

并发下的可见性、原子性、有序性还不懂?

CPU、内存、I/O速度大比拼CPU的读写速度是内存的100倍左右,而内存的读写速度又是I/O的10倍左右。根据"木桶理论",速度取决于最慢的I/O。为了解决速度不匹配的问题,通常在CPU和主内存间增加了缓存,内存和I/O之间增加了操…

C语言学习之路--操作符篇,从知识到实战

目录一、前言二、操作符分类三、算术操作符四、移位操作符1、左移操作符2、右移操作符五、位操作符拓展1、不能创建临时变量(第三个变量),实现两个数的交换。2、编写代码实现:求一个整数存储在内存中的二进制中1的个数。六、赋值操…

http客户端Feign

Feign替代RestTemplate RestTemplate方式调用存在的缺陷 String url"http://userservice/user/"order.getUserId();User user restTemplate.getForObject(url, User.class); 代码可读性差,变成体验不统一; 参数复杂的时候URL难以维护。 &l…

Gem5模拟器,一些运行的小tips(十一)

一些基础知识,下面提到的东西与前面的文章有一定的关系,感兴趣的小伙伴可以看一下: (21条消息) Gem5模拟器,全流程运行Chiplet-Gem5-SharedMemory-main(十)_好啊啊啊啊的博客-CSDN博客 Gem5模拟器&#xf…