Channel-wise Knowledge Distillation for Dense Prediction(ICCV 2021)原理与代码解析

news2024/12/23 1:07:46

paper:Channel-wise Knowledge Distillation for Dense Prediction

official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill 

摘要 

之前大多数用于密集预测dense prediction任务的蒸馏方法在空间域spatial domain中将教师和学生网络的激活图进行对齐,通过标准化normalize每个空间位置的激活值,并减小point-wise and/or pair-wise之间的差异来实现知识的传递。

和之前的方法不同,本文提出对每个通道的激活图进行标准化从而得到一个soft probabality map,通过减小两个网络channel-wise概率图之间的KL散度,使蒸馏过程更关注每个通道最显著的区域,这对密集预测任务非常有价值。

背景

密集预测任务是像素级的预测问题,比图像级的分类问题更具挑战性。以往的研究发现,直接将分类中的蒸馏方法应用到语义分割中得到的效果令人无法满意。严格地对齐教师和学生网络之间poit-wise分类得分或特征图可能会施加过于严格的约束,只能得到次优解。

最近的一些研究主要关注于加强不同空间位置之间的相关性。如图2(a)所示,每个空间位置上的激活值被标准化然后通过聚合不同空间位置的子集来得到一些特定任务的关系,比如pair-wise relations和inter-class relations。这些方法在捕获空间结构信息和提高学生网络的性能方面可能比point-wise对齐效果更好。然而,激活图中的每个空间位置对知识转移的贡献相同,这可能会从教师网络带来冗余的信息。 

本文提出了一种新的通道级channel-wise知识蒸馏方法,通过对每个通道中的激活图进行标准化,用于密集预测任务,如图2(b)所示。然后减小教师和学生网络标准化后的通道激活图之间的KL散度。一个channel-wise distribution的例子如图2(c)所示,可以看出每个通道的激活图倾向于编码特定场景类别的显著区域。对于每个通道,引导学生网络更加关注于模拟激活值大的区域,从而在密集预测任务中实现更准确的定位。比如在目标检测任务中,让学生网络更关注于学习前景区域的激活。

本文的贡献

  • 与现有的spatial蒸馏方法不同,本文提出了一种新的通道蒸馏范式用于密集预测任务,方法简单有效。
  • 在语义分割和目标检测方面,本文提出的通道级蒸馏方法显著优于最先进的KD方法。
  • 我们在语义分割和目标检测任务上的四个基准数据集上进行了一致的改进,证明了我们的方法是通用的。鉴于它的简单性和有效性,我们相信我们的方法可以作为一个strong baseline KD方法用于密集预测任务。

方法介绍

为了更好的利用每个通道中的知识,作者提出softly对齐教师网络和学生网络对应通道的激活。为此,首先将一个通道的激活转换为概率分布,这样就可以使用一个概率距离度量例如KL散度来衡量差异。如图2(c)所示,不同通道的激活倾向于编码输入图像中某个特定类别场景的显著saliency区域。此外,一个训练好的语义分割教师模型在每个通道显示出了清晰的类别特定掩码的激活图,这是符合预期的,如图1右侧所示。因此作者提出了一种新的通道蒸馏范式来指导学生从一个训练有素的教师那里学习知识。

首先定义教师网络和学生网络分别为 \(T\) 和 \(S\),\(T\) 和 \(S\) 的激活分别表示为 \(y^{T}\) 和 \(y^{S}\),通道蒸馏损失的一般表示形式如下

其中 \(\phi(\cdot)\) 用于将激活值转换为概率分布,如下

其中 \(c=1,2,...,C\) 表示通道,\(i\) 示一个通道的spatial location的索引,\(\mathcal{T}\) 是温度超参。如果我们使用更大的 \(\mathcal{T}\),概率分布会变得更softer,意味着每个通道中关注的spatial region更加wider。通过使用softmax归一化,消除了大网络和小网络之间尺度的差异。如果教师网络和学生网络的通道不匹配,则使用1x1卷积上采样小网络的通道数使两者相等。\(\varphi(\cdot)\) 评估教师模型和学生模型通道分布之间的差异,具体使用KL散度

KL散度是一种非对称度量。从式4可以看出,如果 \(\phi(y^{T}_{c,i})\) 很大,\(\phi(y^{S}_{c,i})\) 应该像 \(\phi(y^{T}_{c,i})\) 一样大从而减小KL散度。相反,如果 \(\phi(y^{T}_{c,i})\) 非常小,KL散度对减小 \(\phi(y^{S}_{c,i})\) 的关注相对较少。通过教师网络的监督,学生网络倾向于在前景显著区域产生和教师网络相似的激活分布,而教师网络背景区域中的激活对学生网络的影响较小。作者认为KL这种不对称性有助于密集预测任务中蒸馏的学习。

和分类任务中的channel-wise蒸馏的区别

在Channel Distillation: Channel-Wise Attention for Knowledge Distillation这篇文章中,作者也提出了使用通道蒸馏的方式,但主要应用于分类任务。受SENet的启发,通过全局平局池化将一个通道的特征图转换为一个标量,然后应用KL散度衡量教师网络和学生网络对应通道的标量之间的差异。

而本文主要考虑的是密集预测任务,GAP或许对image-level的分类任务有帮助,但所有空间位置的权重相同,丢失了空间信息,不适用于密集预测任务。本文通过softmax标准化的方式,考虑到了不同空间位置重要性的不同,保留了空间位置信息,因此更适用于密集预测任务。

实验结果

表2是本文提出的channel蒸馏和其它spatial蒸馏在Cityscapes数据集上复杂度和验证集上的mIoU的对比,可以看出通道蒸馏的精度最高,且复杂度较小

表5是在Cityscapes数据集上,不同的学生模型用不同的蒸馏方法的精度对比,可以看出,对于不同结构的学生网络,本文提出的通道蒸馏的效果都要好于其它蒸馏方法。

 

表6是在目标检测任务上与其它蒸馏方法的对比,可以看出,在两阶段、单阶段、anchor-free不同结构的目标检测模型中,本文提出的通道蒸馏的效果都是最好的。

代码解析

官方实现如下,其中归一化采用channel_norm损失采用KL损失是论文中给出的方法,官方实现中还给出了其它的归一化方法和损失函数的选择。 

import torch.nn as nn


class ChannelNorm(nn.Module):
    def __init__(self):
        super(ChannelNorm, self).__init__()

    def forward(self, featmap):
        n, c, h, w = featmap.shape
        featmap = featmap.reshape((n, c, -1))
        featmap = featmap.softmax(dim=-1)
        return featmap


class CriterionCWD(nn.Module):

    def __init__(self, norm_type='none', divergence='mse', temperature=1.0):

        super(CriterionCWD, self).__init__()

        # define normalize function
        if norm_type == 'channel':
            self.normalize = ChannelNorm()
        elif norm_type == 'spatial':
            self.normalize = nn.Softmax(dim=1)
        elif norm_type == 'channel_mean':
            self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)
        else:
            self.normalize = None
        self.norm_type = norm_type

        self.temperature = 1.0

        # define loss function
        if divergence == 'mse':
            self.criterion = nn.MSELoss(reduction='sum')
        elif divergence == 'kl':
            self.criterion = nn.KLDivLoss(reduction='sum')
            self.temperature = temperature
        self.divergence = divergence

    def forward(self, preds_S, preds_T):

        n, c, h, w = preds_S.shape
        # import pdb;pdb.set_trace()
        if self.normalize is not None:
            norm_s = self.normalize(preds_S / self.temperature)
            norm_t = self.normalize(preds_T.detach() / self.temperature)
        else:
            norm_s = preds_S[0]
            norm_t = preds_T[0].detach()

        if self.divergence == 'kl':
            norm_s = norm_s.log()
        loss = self.criterion(norm_s, norm_t)

        # item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)]
        # import pdb;pdb.set_trace()
        if self.norm_type == 'channel' or self.norm_type == 'channel_mean':
            loss /= n * c
            # loss /= n * h * w
        else:
            loss /= n * h * w

        return loss * (self.temperature ** 2)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/477649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(求正数数组的最小不可组成和,养兔子)笔试强训

博主简介:想进大厂的打工人博主主页:xyk:所属专栏: JavaEE初阶 目录 文章目录 一、选择题1 二、[编程题]养兔子 三、[编程题]求正数数组的最小不可组成和 一、选择题1 reflection是如何工作的__牛客网 (nowcoder.com) 考虑下面这个简单的例子&…

大数据Doris(八):Broker部署和集群启停脚本

文章目录 Broker部署和集群启停脚本 一、Broker部署 1、准备Broker 安装包 2、启动 Broker

PyQt6剑指未来-日期和时间

前言 时间和日期是软件开发中非常重要的概念。在PyQt6中,时间和日期模块提供了处理日期、时间和日期时间的类和函数,以及管理时区和夏令时的特性。这些模块提供了可靠和易于使用的工具,使得在PyQt6中处理和呈现时间和日期的操作变得轻松起来…

Java中Lambda表达式(初学到精通)

目录 一、Lambda表达式是什么?什么场景下使用Lambda? 1.Lambda 表达式是什么 2.函数式接口是什么 第二章、怎么用Lambda 1.必须有一个函数式接口 2.省略规则 3.Lambda经常用来和匿名内部类比较 第三章、具体使用举例() 1.案…

跳跃游戏类题目 总结篇

一.跳跃游戏类题目简单介绍 跳跃游戏是一种典型的算法题目,经常是给定一数组arr,从数组的某一位置i出发,根据一定的跳跃规则,比如从i位置能跳arr[i]步,或者小于arr[i]步,或者固定步数,直到到达某…

C++ 链表概述

背景 当需要存储大量数据并需要对其进行操作时,常常需要使用到链表这种数据结构。它可以用来存储一系列的元素并支持插入、删除、遍历等操作。 概念 一般来说,链表是由若干个节点组成的,每个节点包含了两个部分的内容:存储的数…

【嵌入式环境下linux内核及驱动学习笔记-(6-内核 I/O)-阻塞与非阻塞】

目录 1、阻塞与非阻塞1.1 以对recvfrom函数的调用及执行过程来说明阻塞的操作。1.2 以对recvfrom函数的不断轮询调用为例,说明非阻塞时进程的行为。1.3 简单介绍内核链表及等待队列1.4 等待队列1.4.1 定义等待队列头部(wait_queue_head_t)1.4…

vue动态添加多组数据添加正则限制

如图新增多条数据,如果删除其中一条正则校验失败的数据,提示不会随之删除,若想提示删除并不清空数据 delete (item, index) {this.applicationForm.reserveInfo.forEach((v, i) > {if (i index) {this.$refs.formValidate.fields.forEac…

UFT——操作模块

示例一 创建一个可重复利用的登录测试更改Action的名称。使用本地数据表。创建一个主调用测试。建立测试迭代。处理缺失的Action。 分析:就是创建一个只有登录的测试起名为login,然后在创建一个主测试起名字比如main,在main中,调用…

微信小程序定义模板

微信小程序提供模板(template)功能,把一些可以共用的,复用的代码在模板中定义为代码片段,然后在不同的地方调用,可以实现一次编写,多次引用的效果。 首先我们看一下官网是如何操作的 一般的情…

笔记:对多维torch进行任意维度的多“行”操作

如何取出多维torch指定维度的指定“行” 从二维torch开始新建torch取出某一行取出某一列一次性取出多行取出连续的多行取出不连续的多行 一次取出多列取出连续的多列取出不连续的多列 考虑三维torch取出三维torch的任意两行(means 在dim0上操作)取出连续…

( 字符串) 9. 回文数 ——【Leetcode每日一题】

❓9. 回文数 难度:简单 给你一个整数 x ,如果 x 是一个回文整数,返回 true ;否则,返回 false 。 回文数是指正序(从左向右)和倒序(从右向左)读都是一样的整数。 例如…

Git的安装与使用+Git在IDEA中的使用

文章目录 一、Git概述1、版本控制器的方式2、Git的工作流程图 二、Git的安装与常用命令1、Git环境安装2、Git环境基本配置3、获取本地仓库4、基础操作指令 三、分支1、常用指令2、解决合并冲突 四、Git远程仓库1、创建远程仓库2、远程操作仓库3、冲突处理 四、IDEA中使用Git1、…

数据结构——二叉树

二叉树 1 二叉树的种类 1.1 满二叉树 节点数量为 2^k - 1 (k是树的深度,底层的叶子节点都是满的) 1.2 完全二叉树 完全二叉树是指除了下面一层外,其余层的节点都是满的; 且最下面一层的叶子节点是从左到右连续的。 下面这个…

pci总线协议学习笔记——PCI总线基本概念

1、pci总线概述 (1)PCI,外设组件互连标准(Peripheral Component Interconnection),是一种由英特尔(Intel)公司1991年推出的用于定义局部总线的标准; (2)最早提出的PCI总线工作在33MHz频率之下,传输带宽达到133MB/s(33M…

【LeetCode】236. 二叉树的最近公共祖先

1.问题 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大(一个节点也可以是…

1992-2022年31省GDP、第一产业增加值、第二产业增加值 第三产业增加值

1992-2022年31省GDP、第一产业增加值、第二产业增加值 第三产业增加值 1、时间:1992-2022年 2、范围:包括31省 3、指标:省GDP、省第一产业增加值、省第二产业增加值、省第三产业增加值 4、缺失情况说明:无缺失 5、来源&#…

【python知识】__init__.py的来龙去脉

一、说明 我们常见__init__.py文件,但说不清楚它的用途,在本文,我将首先把它的来龙去脉说清楚,然后告诉大家,如何编制python工程,培养全局的编程格局。 二、包-模块-函数结构 在Python工程里,当…

playwright连接已有浏览器操作

文章目录 playwright连接已有浏览器操作前置准备打开本地已有缓存的Chrome(理解)指定端口打开浏览器连接指定端口已启动浏览器(推荐) playwright连接已有浏览器操作 前置准备 pip install playwright # 安装playwright的python…