Relational KD(CVPR 2019)原理与代码解析

news2025/1/2 4:14:45

paper:Relational Knowledge Distillation

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/RKD.py

背景

本文从语言结构主义的角度来重新审视知识蒸馏,前者主要关注一个符号学系统中的结构关系。索续尔关于符号关系身份的概念是结构主义理论的核心:“在一种语言中,和在其他符号系统中一样,区分一个符号的依据是构成它的要素(what distinguishes a sign is what constitutes it)。从这个角度来看,一个符号的意义取决于它与系统内其他符号的关系,一个符号没有独立于上下文的绝对意义。

本文的创新点

本文的中心思想是:相比于学习到的单个特征表示,用学习到的特征表示之间的关系(relation)作为知识是更好的选择。单个数据比如一张图像,在一个表示系统中获得了与其它数据相关联的特征表示,因此主要的信息包含在数据嵌入空间中的一个结构中的。基于此,本文引入了一种新的知识蒸馏方法,称为关系知识蒸馏(Relational Knowledge Distillation, RKD),他传递的是输出之间的结构关系而不是单个输出本身,如下图所示

具体来说,提出了两种RKD损失:基于距离的二阶(distance-wise, second-order)和基于角度的三阶(angle-wise, third-order)蒸馏损失,RKD可以看作是传统KD的一种泛化,由于其和传统KD的互补性,也可以与其它方法结合使用来提高模型性能。

方法介绍

对于一个教师模型 \(T\) 和一个学生模型 \(S\),\(f_{T}\) 和 \(f_{S}\) 分别表示教师和学生模型的函数。通常这些模型是深度神经网络,原则上函数 \(f\) 可以用网络任何层的输出来定义(比如隐藏层或softmax层)。\(\mathcal{X}^{N}\) 表示一组数据中的 \(N\) 个不同样本的元组,例如 \(\mathcal{X}^2=\left \{ \left ( x_{i},x_{j}|i\ne j \right ) \right \} \),\(\mathcal{X}^3=\left \{ \left ( x_{i},x_{j},x_{k}|i\ne j\ne k \right ) \right \} \)。RKD旨在使用教师模型输出表示中数据实例之间的相互关系来传递结构性知识,和传统方法不同,它计算所有数据中每 \(n\) 个样本子集的关系势(relation potential)\(\psi\),并且通过势将信息从教师模型传递给学生模型。

定义 \(t_{i}=f_{T}(x_{i})\),\(s_{i}=f_{S}(x_{i})\),RKD的目标损失如下

其中 \((x_{1},x_{2},...,x_{n})\) 是 \(\mathcal{X}^{N}\) 中的一个 \(n\) 元素的子集,\(\psi\) 是一个关系势函数用来计算给定 \(n\) 个样本子集的关系势能(relational energy),\(l\) 是用来惩罚教师和学生模型之间差异的损失函数。当关系是一元时,式(4)的RKD就变成了IKD(即传统的Individual KD),下图是RKD和IKD的对比

关系势函数 \(\psi\) 在RKD中起着至关重要的作用,RKD的有效性和效率依赖于势函数的选择。例如高阶(high-order)势在捕获高阶结构方面更有效但计算复杂度也更高。本文作者提出了两种简单有效的势函数和对应RKD损失:distance-wise和angle-wise损失,它们分别样本之间的二元关系(pairwise)和三元关系(ternary)。

Distance-wise distillation loss

给定两个训练样本,distance-wise势函数 \(\psi_{D}\) 计算输出表示空间中两者之间的欧式距离

其中 \(\mu\) 是距离的标准化因子,为了关注其它样本对之间的相对距离,\(\mu\) 设置为mini-batch中来自 \(\mathcal{X}^{2}\) 的所有样本对的平均距离

利用在教师模型和学生模型中分别计算的distance-wise potentials,distance-wise蒸馏损失定义如下

其中 \(l_{\delta}\) 是Hube loss,定义如下

距离蒸馏损失通过惩罚样本输出表示空间之间的距离差来传递样本之间的关系,与传统的KD不同,它不是让学生模型去匹配教师模型的输出,而是让学生模型关注输出的距离结构。

Angle-wise distillation loss

给定三个样本,基于角度的势函数计算输出表示空间中三个样本之间的角度

利用分别在教师模型和学生模型中计算的angle-wise potentials,基于角度的蒸馏损失定义如下

其中 \(l_{\delta}\) 是Hube loss,基于角度的蒸馏损失通过惩罚角度差异来传递训练样本之间的关系。由于角度是比距离更高阶(high-order)的属性,它可能更有效地传递关系信息,在训练中给学生模型更大的灵活性。作者在实验中也发现基于角度的损失收敛更快性能更好。

Training with RKD

训练过程中多个蒸馏损失包括本文提出的RKD可以单独使用,也可以和任务特定的损失结合使用,例如分类任务中的交叉熵损失。因此完整的损失形式如下

其中 \(\mathcal{L}_{task}\) 是特定任务相关的损失,\(\mathcal{L}_{KD}\) 是蒸馏损失,\(\lambda_{KD}\) 是权重超参。

代码解析

其中输入特征feature_student["pooled_feat"]是网络最后全连接层的输入,也就是最后一层卷积层的输出进行全局平局池化再reshape成 (batch_size, -1)的结果。函数_pdist计算样本输出特征之间的欧式距离,res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)是将 \((a-b)^2\) 展开成 \(a^{2}+b^{2}-2ab\)。40-41行计算式(9)中的 \(\mathbf{e}^{ij}\) 和 \(\mathbf{e}^{kj}\),\(\left \langle \cdot \right \rangle \) 是点积操作,42行是计算 \(\left \langle \mathbf{e}^{ij},\mathbf{e}^{kj} \right \rangle \)。

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def _pdist(e, squared, eps):
    e_square = e.pow(2).sum(dim=1)  # (64,256)->(64)
    prod = e @ e.t()  # (64,64)
    res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
    # (64,1)+(1,64)->(64,64), -(64,64)->(64,64)

    if not squared:
        res = res.sqrt()

    res = res.clone()
    res[range(len(e)), range(len(e))] = 0
    return res


def rkd_loss(f_s, f_t, squared=False, eps=1e-12, distance_weight=25, angle_weight=50):
    stu = f_s.view(f_s.shape[0], -1)  # (64,256)->(64,256)
    tea = f_t.view(f_t.shape[0], -1)  # (64,256)->(64,256)

    # RKD distance loss
    with torch.no_grad():
        t_d = _pdist(tea, squared, eps)
        mean_td = t_d[t_d > 0].mean()
        t_d = t_d / mean_td

    d = _pdist(stu, squared, eps)
    mean_d = d[d > 0].mean()
    d = d / mean_d

    loss_d = F.smooth_l1_loss(d, t_d)

    # RKD Angle loss
    with torch.no_grad():
        td = tea.unsqueeze(0) - tea.unsqueeze(1)  # (1,64,256)-(64,1,256)->(64,64,256)
        norm_td = F.normalize(td, p=2, dim=2)  # (64,64,256)
        t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)  # (64,64,256),(64,256,64)->(64,64,64)->(262144)

    sd = stu.unsqueeze(0) - stu.unsqueeze(1)
    norm_sd = F.normalize(sd, p=2, dim=2)
    s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

    loss_a = F.smooth_l1_loss(s_angle, t_angle)

    loss = distance_weight * loss_d + angle_weight * loss_a
    return loss


class RKD(Distiller):
    """Relational Knowledge Disitllation, CVPR2019"""

    def __init__(self, student, teacher, cfg):
        super(RKD, self).__init__(student, teacher)
        self.distance_weight = cfg.RKD.DISTANCE_WEIGHT
        self.angle_weight = cfg.RKD.ANGLE_WEIGHT
        self.ce_loss_weight = cfg.RKD.LOSS.CE_WEIGHT
        self.feat_loss_weight = cfg.RKD.LOSS.FEAT_WEIGHT
        self.eps = cfg.RKD.PDIST.EPSILON
        self.squared = cfg.RKD.PDIST.SQUARED

    def forward_train(self, image, target, **kwargs):
        logits_student, feature_student = self.student(image)  # (64,3,32,32)
        with torch.no_grad():
            _, feature_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_rkd = self.feat_loss_weight * rkd_loss(
            feature_student["pooled_feat"],  # (64,256)
            feature_teacher["pooled_feat"],  # (64,256)
            self.squared,  # False
            self.eps,  # 1e-12
            self.distance_weight,  # 25
            self.angle_weight,  # 50
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_rkd,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/359816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux系列 linux 常用命令(笔记)

作者简介:一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭:低头赶路,敬事如仪 个人主页:网络豆的主页​​​​​​ 目录 前言 一.linux 常用命令(目录文和件基本操作) 1.命令的分类…

一文读懂功率放大器(功率放大器的特性是什么意思)

功率放大器是一种电子放大器,旨在增加给定输入信号的功率幅度。功率放大器一般要求得到一定的不失真或者较小失真的输出功率,在大信号状态下进行工作,主要是输出较大功率。功率放大器的特性介绍:1、增益功率放大器的增益主要是指放…

postman利用newman生成测试报告

1.安装nodejs nodejs下载地址:https://nodejs.org/zh-cn/配置环境变量:在path环境变量中增加nodejs的安装路径 安装完成后,在控制台执行node -v检查是否安装成功2.安装newman 以管理员身份打开cmd控制台,执行如下命令安装newma…

测试团队都在用哪些不错的测试用例管理平台?盘点6大主流测试管理系统

测试团队使用的主流测试用例管理平台:1.PingCode ;2.TestRail;3.Testlink;4.ZephyrJira;5.TestCenter;6.飞蛾。目前市面上的测试用例管理工具有很多,但由于针对的项目、领域、目标用户&#xff…

速看!!!一套能直接拿捏大厂面试官的软件测试面试宝典

3.5.1、说说你们是怎么做自动化测试的☆☆☆☆☆我们的自动化测试主要是web UI的自动化测试,主要用于冒烟测试和主要功能的回归测试或者主流浏览器的兼容性测试,作为手工测试的一种补充,提高测试效率,减少一些重复性的测试工作。1…

kubectl

目录 一、陈述式资源管理方法 二、基本信息查看 2.1 基本信息查看格式 2.2 查看master节点组件状态 2.3 查看命名空间 2.4 创建/查看命名空间 2.5 删除(重启)命名空间/pod 2.6 查看资源的详细信息 2.7 创建副本控制器来启动Pod 2.8 查看指定命…

Linux-0.11 文件系统buffer.c详解

Linux-0.11 文件系统buffer.c详解 buffer_init void buffer_init(long buffer_end)该函数的作用主要是初始化磁盘的高速缓冲区。 刚开始使用h指针指向了start_buffer的位置。 struct buffer_head * h start_buffer; void * b; int i;start_buffer定义为end的位置&#xff…

数据结构与算法基础(王卓)(11):栈的定义及其基础操作(顺序表和链表的初始化、求长度,是否为空,清空和销毁、出栈、压栈)

栈的定义: stack:一堆,一摞;堆;垛; 顺序栈和链栈的设计参考: 数据结构与算法基础(王卓)(7):小结:关于链表和线性表的定义及操作_宇 -Yu的博客-C…

【免费教程】 SWMM在城市水环境治理中的应用及案例分析

SWMMSWMM(storm water management model,暴雨洪水管理模型)是一个动态的降水-径流模拟模型,主要用于模拟城市某一单一降水事件或长期的水量和水质模拟。EPA(Environmental Protection Agency,环境保护署&am…

QTCreator 设置编码格式

显示文件编码格式 选择“工具>首选项>文本编辑器>显示>显示文件编码” 全局设置 选择“工具>首选项>文本编辑器>行为>文件编码” 将文件编码设置为utf-8,UTF-8 BOM 选择存在则保留,最后选择apply。 打开项目设置 选择“项目&…

解析HTTP/2如何提升网络速度

我们知道HTTP/1.1 为网络效率做了大量的优化,最核心的有如下三种方式: 增加了持久连接;浏览器为每个域名最多同时维护 6 个 TCP 持久连接;使用 CDN 的实现域名分片机制。 虽然 HTTP/1.1 采取了很多优化资源加载速度的策略&#x…

[学习笔记]SQL server完全备份指南

方式一,使用SQL Server Management Studio 准备工作 连接目标数据库服务器 在目标数据库上右键->属性,将数据库的恢复模式设置为“简单”,兼容级别设置为“SQL Server 2016(130)” [可选]将表中将无用的业务数据删除,以减…

Java EE|TCP/IP协议栈之传输层UDP协议详解

文章目录一、对UDP协议的感性认识简介主要特点二、UDP的报文结构协议端格式概览报文结构详解源端口目的端口16位UDP报文长度16位校验和参考一、对UDP协议的感性认识 简介 UDP,是User Datagram Protocol的简称,中文名是用户数据报协议,是OSI…

Leetcode力扣秋招刷题路-0081

从0开始的秋招刷题路,记录下所刷每道题的题解,帮助自己回顾总结 81. 搜索旋转排序数组 II 已知存在一个按非降序排列的整数数组 nums ,数组中的值不必互不相同。 在传递给函数之前,nums 在预先未知的某个下标 k(0 &…

公安局靶场建设规划设计

随着我国国家安全形势的变化,公安工作也面临着越来越严峻的挑战。为了提高公安干警的专业技能和反恐能力,建设一座现代化的靶场已成为公安局的迫切需求。本文将介绍公安局靶场建设的重要性,靶场的规划与设计以及建设过程中需要注意的事项。 一…

Pyspark基础入门4_RDD转换算子

Pyspark 注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hi…

Flex写法系列-Flex布局之基本语法

以前的传统布局,依赖盒装模型。即 display position float 属性。但是对于比较特殊的布局就不太容易实现,例如:垂直居中。下面主要介绍flex的基本语法,后续还有二期介绍Flex的写法。一、什么是Flex布局?Flex布局个人…

Vuex的创建和简单使用

Vuex 1.简介 1.1简介 1.框框里面才是Vuex state:状态数据action:处理异步mutations:处理同步,视图可以同步进行渲染1.2项目创建 1.vue create 名称 2.运行后 3.下载vuex。采用的是基于vue2的版本。 npm install vuex3 --save 4.vu…

Frequency Domain Model Augmentation for Adversarial Attack

原文:[2207.05382] Frequency Domain Model Augmentation for Adversarial Attack (arxiv.org)代码:https://github.com/yuyang-long/SSA.黑盒攻击替代模型与受攻击模型之间的差距通常较大,表现为攻击性能脆弱。基于同时攻击不同模型可以提高…

C++8:模拟实现list

目录 最基础的链表结构以及迭代器实现 链表节点结构 构造函数 push_back list的迭代器 增删查改功能实现 insert erase pop_front pop_back push_front clear 默认成员函数 析构函数 拷贝构造函数 赋值操作符重载 list的完善 const迭代器 赋值操作符重…