Similarity-Preserving KD(ICCV 2019)原理与代码解析

news2024/9/24 1:15:32

paper:Similarity-Preserving Knowledge Distillation

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/SP.py

背景

本文的灵感来源于作者观察到在一个训练好的网络中,语义上相似的输入倾向于引起相似的激活模式。下图是CIFAR-10测试集在教师网络WideResNet-16-2的最后一个卷积层的每个通道的平均激活的可视化结果。横坐标是测试图片index,按类别进行了分组,例如1-1000张是类别1,1000-2000张是类别2。纵坐标是采样后的通道激活均值。从图中可以看出,来自同一类别的图像倾向于激活相似的通道。在教师网络中,不同图像之间的激活相似性包含了网络学习到的有用信息,因此作者本文研究了这些相似性是否可以为知识蒸馏提供监督信息。

本文的创新点

基于上述观察,作者假设如果两个输入在教师网络中产生了高度相似的激活,那么引导学生网络对于同样两个输入也产生相似的激活是有益的。相反如果两个输入在教师网络中产生了不同的激活,那么我们希望这些输入在学生网络中也产生不同的激活。因此,本文引入了保持相似性(similarity-preserving)的知识蒸馏,这是一种新的知识蒸馏形式,它使用教师网络中每个mini-batch里两两激活的相似性来引导学生网络的训练。

方法介绍

给定一个mini-batch的输入,教师网络 \(T\) 的某一层 \(l\) 的激活图activation map表示为 \(A^{(l)}_{T}\in \mathbf{R}^{b\times c\times h\times w}\),学生网络 \(S\) 对应层 \(l'\) 的输出表示为 \(A^{(l')}_{S}\in \mathbf{R}^{b'\times c'\times h'\times w'}\),这里教师网络和学生网络对应输出的通道、宽高都不一定要相等。为了引导学生网络学习教师网络学习到的激活相关性,我们定义了一个蒸馏损失,它惩罚 \(A^{(l)}_{T}\) 和 \(A^{(l')}_{S}\) L2标准化的外积(outer products)之间的差异

其中 \(Q^{(l)}_{T}\in \mathbf{R}^{b\times chw}\) 是 \(A^{(l)}_{T}\) reshape的结果,因此 \(\tilde{G} ^{(l)}_{T}\) 是一个 \(b\times b\) 的矩阵。\(\tilde{G} ^{(l)}_{T}\) 中的 \((i,j)\) 项编码了mini-batch中第 \(i\) 张图片和第 \(j\) 张图片在教师网络中的激活相似度。然后沿行进行L2-normalization得到 \(G ^{(l)}_{T}\),\([i,:]\) 表示矩阵中的第 \(i\) 行。同样定义学生网络的激活相似度矩阵

然后定义similarity-preserving的知识蒸馏的损失如下

其中 \(\mathcal{I}\) 表示教师网络和学生网络所有对应的层 \((l,l')\),\(\left \| \cdot \right \| _{F}\) 表示Frobenius范数。最后学生网络的完整损失函数如下

其中 \(\gamma\) 是权重超参。

下图是CIFAR-10测试集中几个batch的G矩阵的可视化结果,每一列表示一个相同的batch,每个batch中的图片都按类别进行了进行了分组,激活值取自网络的最后一个卷积层,颜色越亮表明相似度越高,图中方块状的亮的区域表明了网络最后一层的激活在同一类别中基本是相似的,而在不同的类别中是不同的。其中同一张图中方块大小不同是因为一个batch中各类别的图片数量不同。另外可以看出WideResNet-40-2中方块状的区域更明显亮度值更大表明了该网络提取数据集语义信息的能力更强。

实验结果

下图是三种不同的蒸馏方法在不同的教师和学生网络中的效果对比,可以看出本文提出的similarity-preserving在五种中的四种都取得了最优的效果。

代码解析

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def sp_loss(g_s, g_t):
    return sum([similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])


def similarity_loss(f_s, f_t):
    bsz = f_s.shape[0]  # 64
    f_s = f_s.view(bsz, -1)  # (64,16384)
    f_t = f_t.view(bsz, -1)  # (64,16384)

    G_s = torch.mm(f_s, torch.t(f_s))  # (64,64)
    G_s = torch.nn.functional.normalize(G_s)
    G_t = torch.mm(f_t, torch.t(f_t))  # (64,64)
    G_t = torch.nn.functional.normalize(G_t)

    G_diff = G_t - G_s
    loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)  # (64,64)*(64,64)->(4096,1)->(1)
    return loss


class SP(Distiller):
    """Similarity-Preserving Knowledge Distillation, ICCV2019"""

    def __init__(self, student, teacher, cfg):
        super(SP, self).__init__(student, teacher)
        self.ce_loss_weight = cfg.SP.LOSS.CE_WEIGHT
        self.feat_loss_weight = cfg.SP.LOSS.FEAT_WEIGHT

    def forward_train(self, image, target, **kwargs):
        logits_student, feature_student = self.student(image)
        with torch.no_grad():
            _, feature_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_feat = self.feat_loss_weight * sp_loss(
            [feature_student["feats"][-1]], [feature_teacher["feats"][-1]]  # (64,256,8,8),(64,256,8,8)
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_feat,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/356846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[python入门㊾] - python异常中的断言

目录 ❤ 断言的功能与语法 ❤ 常用断言 ❤ 常用的断言表达方式 ❤ 异常断言 ❤ 正则断言 ❤ 检查断言装饰器 ❤ 断言的功能与语法 Python assert(断言)用于判断一个表达式,在表达式条件为 False 的时候触发异常 断言可以在条件…

2023Python接口自动化测试实战教程,附视频实战讲解

这两天一直在找直接用python做接口自动化的方法,在网上也搜了一些博客参考,今天自己动手试了一下。 一、整体结构 上图是项目的目录结构,下面主要介绍下每个目录的作用。 Common:公共方法:主要放置公共的操作的类,比如数据库sql…

js中数字运算结果与预期不一致的问题和解决方案

本文主要是和大家聊聊关于js中经常出现数字运算结果与预期结果不一致的问题,与及解决该问题的的方案。 一、问题现象 如:0.1 0.2的预期结果是0.3,但是在js中得到的计算结果却是0.30000000000000004,如下图所示 如:0…

GEE学习笔记 六十九:【GEE之Python版教程三】Python基础编程一

环境配置完成后,那么可以开始正式讲解编程知识。之前我在文章中也讲过,GEE的python版接口它是依赖python语言的。目前很多小伙伴是刚开始学习GEE编程,之前或者没有编程基础,或者是没有学习过python。为了照顾这批小伙伴&#xff0…

15种NLP数据增强方法总结与对比

数据增强的方法 数据增强(Data Augmentation,简称DA),是指根据现有数据,合成新数据的一类方法。毕竟数据才是真正的效果天花板,有了更多数据后可以提升效果、增强模型泛化能力、提高鲁棒性等。然而由于NLP…

光伏VSG-基于虚拟同步发电机的光伏并网逆变器系统MATLAB仿真

采用MATLAB2021b仿真!!!仿真模型1光伏电池模块(采用MATLAB自带光伏模块)、MPPT控制模块、升压模块、VSG控制模块、电流滞环控制模块。2s时改变光照强度 !!!VSG输出有功功率、无功功率…

6.3 使用 Swagger 生成 Web API 文档

第6章 构建 RESTful 服务 6.1 RESTful 简介 6.2 构建 RESTful 应用接口 6.3 使用 Swagger 生成 Web API 文档 6.4 实战:实现 Web API 版本控制 6.3 使用 Swagger 生成 Web API 文档 高质量的 API 文档在系统开发的过程中非常重要。本节介绍什么是 Swagger&#xff…

15-基础加强-2-xml(约束)枚举注解

文章目录1.xml1.1概述【理解】(不用看)1.2标签的规则【应用】1.3语法规则【应用】1.4xml解析【应用】1.5DTD约束【理解】1.5.1 引入DTD约束的三种方法1.5.2 DTD语法(会阅读,然后根据约束来写)1.6 schema约束【理解】1.6.1 编写schema约束1.6.…

基于高频方波电压信号注入的永磁同步电机无传感器控制仿真及其原理介绍

基于方波信号注入的永磁同步电机无传感器控制仿真及其原理介绍 注入的高频方波信号为: 可以得到估计轴的高频响应电流为: 当向定子绕组注入高频电压信号时,所注入的高频信号频率远高于基波信号频率。因此,IPMSM 在a-β轴的电压模型可以表示为: 假定…

二叉树OJ(一)二叉树的最大深度 二叉搜索树与双向链表 对称的二叉树

二叉树的最大深度 二叉树中和为某一值的路径(一) 二叉搜索树与双向链表 对称的二叉树 二叉树的最大深度 描述 求给定二叉树的最大深度, 深度是指树的根节点到任一叶子节点路径上节点的数量。 最大深度是所有叶子节点的深度的最大值。 (注:…

Xcode Archives打包上传 / 导出ipa 发布至TestFlight

Xcode自带的Archives工具可以傻瓜式上传到App Store Connect分发这里以分发到TestFlight为例进行操作。 环境:Xcode 14 一:Archives打包 选择Xcode菜单栏的Product,Archives选项,需要等待编译完成,进入如下界面&…

【C语言】初识结构体

☃️内容专栏:【C语言】初阶部分 ☃️本文概括:继初识C语言,对C语言结构体初阶部分进行归纳与总结。 ☃️本文作者:花香碟自来_ ☃️发布时间:2023.2.19 一、结构体的声明 结构体(类型)是一些…

字符设备驱动基础(二)

目录 一、五种IO模型------读写外设数据的方式 二、阻塞与非阻塞 三、多路复用 3.1 应用层:三套接口select、poll、epoll 3.2 驱动层:实现poll函数 四、信号驱动 4.1 应用层:信号注册fcntl 4.2 驱动层:实现fasync函数 一、…

CSAPP学习笔记——虚拟内存(二)

案例研究 Intel Core i7 该处理底层的Haswell微体系结构允许64位的虚拟和物理地址空间,而现在的Core i7实现支持48位(256TB)虚拟地址空间和52位(4PB)物理地址空间,这对目前来说已经完全够用了。&#xff…

Liunx(狂神课堂笔记)

一.常用命令 1. cd 切换目录 cd ./* 当前目录cd /* 绝对路径cd .. 返回上一级目录cd ~ 回到当前目录pwd …

定点数的表示和运算

文章目录真值(有正负号)和机器数(0正1负)原码整数小数补码负数的补数正数的补数[y]~补~ > [-y]~补~反码小结移码移位运算加减法运算溢出判断真值(有正负号)和机器数(0正1负) 无符…

链表OJ(六)链表相加(一) 链表相加(二)

目录 链表相加(一) 链表相加(二) 描述 二与一相比多了俩次反转而已 链表相加(一) 描述 给定两个非空链表逆序存储的的非负整数,每个节点只存储一位数组。 请你把两个链表相加以下相同方法返回链表,保证两个数都不会以 0 开头。 【我的解法】长到…

实例五:MATLAB APP design-APP登录界面的设计

一、APP 界面设计展示 注:在账号和密码提示框输入相应的账号和密码后,点击登录按钮,即可跳转到程序中设计的工作界面。 二、APP设计界面运行结果展示

使用继承的虚函数表

​ 代码 #include <iostream> using namespace std;class Father { public:virtual void func1() { cout << "Father::func1" << endl; }virtual void func2() { cout << "Father::func2" << endl; }virtual void func3()…

一文彻底理解大小端和位域 BIGENDIAN LITTLEENDIAN

一文彻底理解大小端和位域 为什么有大小端 人们一直认为大道至简&#xff0c;就好像物理学上的世界追求使用一个理论来统一所有的现象。为什么cpu存在大小端之分&#xff0c;一言以蔽之&#xff0c;这两种模式各有各的优点&#xff0c;其各自的优点就是对方的缺点&#xff0c…