Channel Distillation: Channel-Wise Attention for Knowledge Distillation 原理与代码解析

news2025/1/11 10:54:29

paper:Channel Distillation: Channel-Wise Attention for Knowledge Distillation

official implementation:https://github.com/zhouzaida/channel-distillation 

存在的问题

  1. 教师模型传递的知识不够好,学生模型无法准确地从教师模型学习到最重要的信息。
  2. 教师模型的预测并不完全是正确的,在训练过程中,如果完全参考教师模型的输出,教师模型的错误预测会对学生模型产生不利的影响。
  3. 教师模型和学生模型结构不同,如果总是让教师模型来监督学生模型,就会使学生模型无法找到自己的优化空间。 

本文的创新点

为了解决上述三个问题,本文提出了一些优化方法

  1. 提出了一种新的知识蒸馏方法,通过通道蒸馏(Channel-Wise Distillation, CD)将知识传递给学生,以便学生可以更高效的提取特征。
  2. 为了避免教师模型的错误输出对学生的负面影响,提出了引导知识蒸馏(Guided Knowledge Distillation, GKD),只使用教师模型的正确输出来对学生模型进行监督。
  3. 最后,提出了Early Decay Teacher(EDT)在训练过程中逐步减小蒸馏损失的权重,确保学生模型可以找到自己的优化空间。

方法介绍

Channel Distillation 

本文受到SENet的启发,在SENet中,channel-wise attention使模型能够学习每个通道的权重,然后将权重与原始的通道相乘,那些重要通道的特征被增强,不重要的通道特征被减弱,从而使特征的提取更具方向性,网络的预测能力更好。每个通道的权重计算如下

其中 \(w_{c}\) 是 \(c^{th}\) 通道的权重,\(H,W\) 是特征图的空间维度,\(u_{c}(i,j)\) 是激活值。

feature map的每个通道对应一种视觉模式visual pattern,但每个通道视觉模式的重要性是不同的。因为教师模型的性能优于学生模型,作者认为教师模型学习到的视觉模式更加准确,因此希望学生去学习教师模型的视觉模式。具体而言,使用全局平均池化GAP来计算每个通道的重要性,它代表了每个通道的注意力信息attention information。然后将每个通道的注意力信息作为知识,传递给学生模型。通常教师模型和学生模型的层数是不同的,作者只在特征图分辨率降低的层进行通道蒸馏,如果通道数不匹配,采用1x1卷积将学生模型的通道数增加到与教师模型相同。CD loss的定义如下

其中 \(CD(s,t)\) 表示教师和学生模型之间的CD损失,\(w_{ij}\) 表示 \(i^{th}\) 样本 \(j^{th}\) 通道的权重,\(c\) 表示通道数。 

Guided Knowledge Distillation

本文提出的GKD是在KD的基础上设计的,KD的思想是计算教师模型和学生模型的预测分布,通过逐渐减小它们之间的差异,使得学生的输出分布逐渐接近教师。KD的计算公式如下

其中 \(p\) 是由logit \(a\) 计算的概率分布,\(T\) 是温度。\(n\) 是batch size。\(KD(s,t)\) 是 \(p^{i}_{s}\) 和 \(p^{i}_{t}\) 之间的KL散度的平均值。 

尽管教师模型的预测结果更加准确,但仍然有一些错误的预测。当教师预测错误时,将错误知识传递给学生,这会降低学生的表现。因此作者在KD的基础上进行改进得到了GKD,具体而言只对教师模型预测正确的样本的KD损失进行反向传播而忽略预测错误的样本的KD损失,GDK的定义如下

其中 \(I\) 是一个indicator function,当教师模型的输出等于label时 \(I(p^{i}_{t},y)\) 等于1,否则为0。例如,假设一个batch种有 \(n\) 个样本,教师模型只预测对了其中 \(n_{1}\) 个,GKD就只计算这 \(n_{1}\) 个样本的KD损失。

Early Decay Teacher

On the Efficacy of Knowledge Distillation这篇文章中提出蒸馏的影响并不总是积极的,在训练的早期,蒸馏可以帮助学生的训练,但在训练的后期会抑制学生的学习,因此在合适的时间停止教师模型的监督有助于学生模型的学习。实验结果表明,在某个epoch,交叉熵损失反而会开始上升,最好在这个节点停止教师模型的监督,但在实际应用中直接停止教师模型的监督比较困难,因此作者提出了一种相对较缓和的做法,随着学习率的降低逐步降低蒸馏损失的权重,定义如下

其中 \(\alpha\) 是蒸馏损失的初始权重,\(\lambda\) 是一个常量系数,\(n_{e}\) 表示完整训练过程中的 \(n^{th}\) epoch,\(n\) 是一个经验值,表示减小损失权重的epoch数量。

我们只减小CD损失的权重,对于GKD损失,因为它只传递正确的知识,因此整个训练过程中都不减小它的权重。

完整的损失函数如下

 

实验结果

表1、2分别是在ImageNet和CIFAR 100数据集上与原始KD的结果对比,可以看出,本文提出的三点创新CD、GKD、EDT都对学生模型的精度提高有帮助,当将三者结合起来时精度最高。

 

表3是与其它蒸馏方法的精度对比,可以看出CD+GKD+EDT取得了最优的性能。

 

代码解析

CD损失代码如下

import torch
import torch.nn as nn


class CDLoss(nn.Module):
    """Channel Distillation Loss"""

    def __init__(self):
        super().__init__()

    def forward(self, stu_features: list, tea_features: list):
        loss = 0.
        for s, t in zip(stu_features, tea_features):
            s = s.mean(dim=(2, 3), keepdim=False)
            t = t.mean(dim=(2, 3), keepdim=False)
            loss += torch.mean(torch.pow(s - t, 2))
        return loss

GKD损失代码如下 

class KDLossv2(nn.Module):
    """Guided Knowledge Distillation Loss"""

    def __init__(self, T):
        super().__init__()
        self.t = T

    def forward(self, stu_pred, tea_pred, label):
        s = F.log_softmax(stu_pred / self.t, dim=1)
        t = F.softmax(tea_pred / self.t, dim=1)
        t_argmax = torch.argmax(t, dim=1)
        mask = torch.eq(label, t_argmax).float()
        count = (mask[mask == 1]).size(0)
        mask = mask.unsqueeze(-1)
        correct_s = s.mul(mask)
        correct_t = t.mul(mask)
        correct_t[correct_t == 0.0] = 1.0

        loss = F.kl_div(correct_s, correct_t, reduction='sum') * (self.t**2) / count
        return loss

EDT代码如下

def adjust_loss_alpha(alpha, epoch, factor=0.9, loss_type="ce_family", loss_rate_decay="lrdv1", dataset_type="imagenet"):
    """Early Decay Teacher"""

    if dataset_type == "imagenet":
        if loss_rate_decay == "lrdv1":
            return alpha * (factor ** (epoch // 30))
        else:  # lrdv2
            if "ce" in loss_type or "kd" in loss_type:
                return 0 if epoch <= 30 else alpha * (factor ** (epoch // 30))
            else:
                return alpha * (factor ** (epoch // 30))
    else:  # cifar
        if loss_rate_decay == "lrdv1":
            return alpha
        else:  # lrdv2
            if epoch >= 160:
                exponent = 2
            elif epoch >= 60:
                exponent = 1
            else:
                exponent = 0
            if "ce" in loss_type or "kd" in loss_type:
                return 0 if epoch <= 60 else alpha * (factor**exponent)
            else:
                return alpha * (factor**exponent)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/475740.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java基础语法总复习思维导图 + 重难点+面试题

前言 小亭子正在努力的学习编程&#xff0c;接下来将开启javaEE的学习~~ 分享的文章都是学习的笔记和感悟&#xff0c;如有不妥之处希望大佬们批评指正~~ 同时如果本文对你有帮助的话&#xff0c;烦请点赞关注支持一波, 感激不尽~~ 【需要可修改的思维导图可以私信我&#xff0…

Packet Tracer - 静态路由故障排除

Packet Tracer - 静态路由故障排除 地址分配表 设备 接口 IPv4 地址 子网掩码 默认网关 R1 G0/0 172.31.1.1 255.255.255.128 不适用 S0/0/0 172.31.1.194 255.255.255.252 不适用 R2 G0/0 172.31.0.1 255.255.255.0 不适用 S0/0/0 172.31.1.193 255.255…

Windows server 2012 R2系统怎么安装IIS管理器?

Windows server 2012 R2系统怎么安装IIS管理器&#xff1f;今天飞飞和你分享。服务器大本营&#xff0c;技术文章内容集合站发车啦&#xff01; 首先我们用电脑自带的远程连接桌面工具进入服务器&#xff0c;在任务栏左下角有个服务器管理器&#xff0c;单击打开 打开后在右上…

【致敬未来的攻城狮计划】— 连续打卡第十六天:FSP固件库系统定时器(滴答定时器SysTick)每2秒LED闪烁一次

系列文章目录 1.连续打卡第一天&#xff1a;提前对CPK_RA2E1是瑞萨RA系列开发板的初体验&#xff0c;了解一下 2.开发环境的选择和调试&#xff08;从零开始&#xff0c;加油&#xff09; 3.欲速则不达&#xff0c;今天是对RA2E1 基础知识的补充学习。 4.e2 studio 使用教程 5.…

跟着杰哥学强化学习:多臂老虎机问题

多臂老虎机问题 现在有3台外观一模一样的老虎机,每个老虎机的赔率是不同的,摇动一次需要1块钱,现在给你100块钱,如何获取最大的收益。 如果我们知道了每个老虎的赔率,那么只要选择收益最高的那个老虎机就可以了,但现在问题是并不知道每个老虎机的收益。为了简单,我们假…

linux 安装rocketmq

首先准备虚拟机一台 下载linux 64位 jdk1.8(自行百度资源) 下载 | RocketMQ (apache.org) cd /usr/local/ #这是我的本机的现有目录 mkdir rocketmq mkdir jdk 1.借助linux客户端工具,上传刚下载好的jdk安装包到java文件夹 2.借助linux客户端工具,上传刚下载好的二进制安装包…

「C/C++」C/C++软件跨平台思维

博客主页&#xff1a;何曾参静谧的博客 文章专栏&#xff1a;「C/C」C/C学习 目录 相关术语一、编写可移植的代码&#xff1a;二、使用跨平台的C库和框架&#xff1a;三、进行兼容性测试&#xff1a;四、用户界面设计&#xff1a; 相关术语 跨平台思维&#xff1a;是指在软件开…

D-Link DSL-2888A 远程命令执行漏洞(CVE-2020-24581/24579)

漏洞描述 D-Link DSL-2888A AU_2.31_V1.1.47ae55之前版本存在安全漏洞&#xff0c;该漏洞源于包含一个execute cmd.cgi特性(不能通过web用户界面访问)&#xff0c;该特性允许经过身份验证的用户执行操作系统命令。 在该版本固件中同时存在着一个不安全认证漏洞&#xff08;CVE…

【软考网络管理员】2023年软考网管初级常见知识考点(2)- 数据通信技术

【写在前面】也是趁着五一假期前再写几篇分享类的文章给大家&#xff0c;希望看到我文章能给软考网络管理员备考的您带来一些帮助&#xff0c;5月27号也是全国计算机软件考试统一时间&#xff0c;也就不用去各个地方找资料和代码了。紧接着我就把我整理的一些资料分享给大家哈&…

04 KVM虚拟化网络概述

文章目录 04 KVM虚拟化网络概述4.1 Linux Bridge4.2 Open vSwitch 04 KVM虚拟化网络概述 为了使虚拟机可以与外部进行网络通信&#xff0c;需要为虚拟机配置网络环境。KVM虚拟化支持Linux Bridge、Open vSwitch网桥等多种类型的网桥。如图1所示&#xff0c;数据传输路径为“虚…

InstructGPT 论文阅读笔记

目录 简介 数据集 详细实现 实验结果 参考资料 简介 InstructGPT 模型是在论文《Training language models to follow instructions with human feedback》被提出的&#xff0c;OpenAI在2022年1月发布了这篇文章。 论文摘要翻译&#x…

AttributeError: ‘Document‘ object has no attribute ‘pageCount‘ PyMuPDF库

这可能是由于PyMuPDF库更新导致的&#xff0c;里面的一些函数名发生了变化 1. AttributeError: Document object has no attribute pageCount 将 pageCount改为 page_count 2. AttributeError: Matrix object has no attribute preRotate 将preRotate改为prerotate 3.Attribut…

关于FFMPEG中的filter滤镜的简单介绍

滤镜的作用主要是对原始的音视频数据进行处理以实现各种各样的效果。比如叠加水印&#xff0c;翻转缩放视频等。 下图表示的正常转码流程&#xff0c;滤镜在解码和编码中间&#xff0c;虚线表示可有可无。 使用命令查看ffmpeg支持的滤镜 ffmpeg -filters 查看某个滤镜的详细参…

k210点亮LED灯

开发板上自带的3个led灯接线如图。 点亮led灯主要使用两个模块&#xff0c;如下&#xff1a; fm.register(pin,function,forceFalse) 【pin】芯片外部 IO 【function】芯片功能 【force】True 则强制注册&#xff0c;清除之前的注册记录 例&#xff1a;fm.register(12, fm.f…

真题详解(有向图)-软件设计(六十二)

真题详解&#xff08;极限编程&#xff09;-软件设计&#xff08;六十一)https://blog.csdn.net/ke1ying/article/details/130435971 CMM指软件成熟度模型&#xff0c;一般1级成熟度最低&#xff0c;5级成熟度最高&#xff0c;采用更高级的CMM模型可以提高软件质量。 初始&am…

RepVGG学习笔记

RepVGG 0 前言1 结构重参数化1.1 结构重参数化第一步&#xff08;将 C o n v 2 D Conv2D Conv2D算子和 B N BN BN算子融合以及将只有 B N BN BN的分支转换成一个 C o n v 2 D Conv2D Conv2D算子&#xff09;1.2 结构重参数化第二步&#xff08;多分支的 3 3 3\times3 33卷积融…

安全运营 ldap监控域控信息

0x00 背景 公司有多个主域&#xff0c;子域&#xff0c;有的子域因为境外数据安全的问题无法把日志传输到境内。那么如何在没有日志的情况下监控子域或者互信域的组织单元(OU)信息呢。 由于访问互信域要在域控上进行&#xff0c;本文根据最小权限原则监控普通用户也可以访问的…

Packet Tracer - 配置和验证小型网络

Packet Tracer - 配置和验证小型网络 地址分配表 设备 接口 IP 地址 子网掩码 默认网关 RTA G0/0 10.10.10.1 255.255.255.0 不适用 G0/1 10.10.20.1 255.255.255.0 不适用 SW1 VLAN1 10.10.10.2 255.255.255.0 10.10.10.1 SW2 VLAN1 10.10.20.2 255.25…

【C++】set和map的使用

对于STL容器来说&#xff0c;有很多相似的功能&#xff0c;所以这里主要将与之前不同的功能说清楚 文章目录 1.对于set与set的简单理解2. setinsert迭代器遍历countmultisetinsertfindcount 3. mapinsert与迭代器的使用统计水果次数 operator []operator[]的实现理解对整体的拆…

Nginx:常见的面试题和答案

1. 什么是Nginx&#xff1f; 答&#xff1a;Nginx是一款高性能的Web服务器和反向代理服务器,用于HTTP、HTTPS、SMTP、POP3和IMAP协议&#xff0c;同时用于处理高并发的请求&#xff0c;提供快速、可靠的服务。 2. Nginx的优点是什么&#xff1f; Nginx的优点包括&#xff1a…