Neuron Selectivity Transfer 原理与代码解析

news2025/1/20 18:31:25

paper:Like What You Like: Knowledge Distill via Neuron Selectivity Transfer

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/NST.py

本文的创新点

本文探索了一种新型的知识 - 神经元的选择性知识,并将其传递给学生模型。这个模型背后的直觉相当简单:每个神经元本质上从原始输入提取与特定任务相关的某种模式,因此,如果一个神经元在某些区域或样本中被激活,这意味着这些区域或样本共享一些与该任务相关的特性。这种聚类知识对学生模型非常有价值,因为它为教室模型的最终预测结果提供了一种解释。因此,作者提出对齐教师模型和学生模型神经元选择模式的分布。

背景

Notions

假定教师模型和学生模型都是卷积神经网络,并将教师模型表示为 \(T\),学生模型表示为 \(S\)。CNN中某一层的输出特征图表示为 \(\mathbf{F}\in \mathbb{R}^{C\times HW}\),\(\mathbf{F}\) 的每一行即每个通道的特征图表示为 \(\mathbf{f}^{k\cdot}\in \mathbb{R}^{HW}\),\(\mathbf{F}\) 的每一列即每个空间位置沿所有通道的激活表示为 \(\mathbf{f}^{\cdot k}\in \mathbb{R}^{C}\)。\(\mathbf{F}_{T}\) 和 \(\mathbf{F}_{S}\) 分别表示教师模型和学生模型中某一层的特征图,不失一般性,我们假设 \(\mathbf{F}_{T}\) 和 \(\mathbf{F}_{S}\) 的大小相等,如果不相等则可以通过插值使它们相等。

Maximum Mean Discrepancy

最大平均差异(Maximum Mean Discrepancy,MMD)可以看作是一种概率分布间的距离度量,基于从它们采样的样本。假设我们有两组分别从分布 \(p\) 和 \(q\) 中采样的样本 \(\mathcal{X}=\left \{ x^{i} \right \}^{N}_{i=1} \) 和 \(\mathcal{Y}=\left \{ y^{j} \right \}^{M}_{j=1} \),那么 \(p\) 和 \(q\) 之间的MMD距离的平方如下

其中 \(\phi \left ( \cdot \right ) \) 是一个显示映射函数,通过进一步扩展并应用核技巧(kernel trick),式(1)可以表示为

其中 \(k(\cdot,\cdot)\) 是一个核函数,它将样本向量投射到一个更高维或是无限维的特征空间中。

最小化MMD等价于最小化 \(p\) 和 \(q\) 之间的距离。

方法介绍

Motivation

下面是两张叠加了热力图(heat map)的图片,其中热力图是根据VGG16 Conv5_3中的某个神经元得到的。从图中很容易看出这两个神经元具有很强的选择性:左图的神经元对猴子的脸部非常敏感,右侧的神经元对字符非常敏感。这种激活实际上意味着神经元的选择性,即什么样的输入可以触发这些神经元。换句话说,一个神经元高激活的区域可能共享一些与任务相关的相似特性,尽管这些特性可能对于人类没有非常直观的解释。

为了捕获这些相似特性,在学生模型中也应该有神经元模仿这些激活模式。因此本文提出了一种新的知识类型:神经元选择性(neuron selectivities)或者叫做共激活(co-activations),并将其传递给学生模型。

Formulation

每个通道的特征图 \(\mathbf{f}^{k\cdot}\) 示一个特定神经元的selectivity知识,我们定义Neuron Selectivity Transfer,NST损失如下

其中 \(\mathcal{H}\) 是交叉熵损失,\(\mathbb{y}_{true}\) 是ground truth标签,\(p_{S}\) 是学生模型的输出概率。

MMD损失可以扩展如下

其中用 \(l_{2}\) 标准化后的 \(\frac{\mathbf{f}^{k\cdot} }{\left \|\mathbf{f}^{k\cdot} \right \|_{2} } \) 替代了 \(\mathbf{f}^{k\cdot}\),这是为了使每个样本具有相同的尺度。最小化MMD损失就等价于将神经元的选择性知识传递给学生模型

Choice of Kernels

本文选用以下三种核函数

对于多项式核,本文设置 \(d=2,c=0\)。对于高斯核,\(\sigma ^{2}\) 设置为对应距离的平方。

代码解析

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def nst_loss(g_s, g_t):
    return sum([single_stage_nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])


def single_stage_nst_loss(f_s, f_t):
    s_H, t_H = f_s.shape[2], f_t.shape[2]
    if s_H > t_H:
        f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
    elif s_H < t_H:
        f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
    f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)  # (64,64,32,32)->(64,64,1024)
    f_s = F.normalize(f_s, dim=2)
    f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
    f_t = F.normalize(f_t, dim=2)

    return (
        poly_kernel(f_t, f_t).mean().detach()
        + poly_kernel(f_s, f_s).mean()
        - 2 * poly_kernel(f_s, f_t).mean()
    )


def poly_kernel(a, b):
    a = a.unsqueeze(1)  # (64,64,1024)->(64,1,64,1024)
    b = b.unsqueeze(2)  # (64,64,1024)->(64,64,1,1024)
    res = (a * b).sum(-1).pow(2)  # (64,64,64,1024)->(64,64,64)
    return res


class NST(Distiller):
    """
    Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
    """

    def __init__(self, student, teacher, cfg):
        super(NST, self).__init__(student, teacher)
        self.ce_loss_weight = cfg.NST.LOSS.CE_WEIGHT
        self.feat_loss_weight = cfg.NST.LOSS.FEAT_WEIGHT

    def forward_train(self, image, target, **kwargs):
        logits_student, feature_student = self.student(image)  # (64,3,32,32)
        with torch.no_grad():
            _, feature_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_feat = self.feat_loss_weight * nst_loss(
            feature_student["feats"][1:], feature_teacher["feats"][1:]
            # [torch.Size([64, 64, 32, 32]), torch.Size([64, 128, 16, 16]), torch.Size([64, 256, 8, 8])]
            # [torch.Size([64, 64, 32, 32]), torch.Size([64, 128, 16, 16]), torch.Size([64, 256, 8, 8])]
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_feat,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/364333.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

整型在内存中的存储(详细剖析大小端)——“C”

各位CSDN的uu们你们好呀&#xff0c;今天小雅兰的内容是整型在内存中的存储噢&#xff0c;现在&#xff0c;就让我们进入整型在内存中的存储的世界吧 数据类型详细介绍 整型在内存中的存储&#xff1a;原码、反码、补码 大小端字节序介绍及判断 数据类型介绍 前面我们已经学…

【扫盲】数字货币科普对于完全不了解啥叫比特币的小伙伴需要的聊天谈资

很多人并不清楚&#xff0c;我们时常听说的比特币&#xff0c;以太坊币&#xff0c;等等这些东西到底是一场骗局还是一场货币革命&#xff1f; 下面就围绕这数字货币的历史以及一些应用场景开始分析这个问题。 一、 开端 一切从2008年中本聪&#xff08;Satoshi Nakamoto&…

shiro反序列化漏洞与无依赖CB链分析

环境搭建 git clone https://github.com/apache/shiro cd shiro git checkout shiro-root-1.2.4将 shiro/samples/web/pom.xml 中的jstl依赖改为1.2: <dependency><groupId>javax.servlet</groupId><artifactId>jstl</artifactId><version&g…

【数据结构与算法】3.(单向、无向、带权)图,广度、深度优先搜索,贪心算法

文章目录1.图简介2.图的存储方式2.1.邻接矩阵存储方法2.2.邻接表存储方法3.有向、无向图和查询算法3.1.数据结构3.2.广度优先算法BFS3.3.深度优先算法DFS3.3.1.DFS查询单条路径3.3.2.DFS查询所有路径4.带权图和贪心算法4.1.贪心算法4.2.基于带权无向图使用贪心算法查询最优路径…

混合精度训练,FP16加速训练,降低内存消耗

计算机中的浮点数表示&#xff0c;按照IEEE754可以分为三种&#xff0c;分别是半精度浮点数、单精度浮点数和双精度浮点数。三种格式的浮点数因占用的存储位数不同&#xff0c;能够表示的数据精度也不同。 Signed bit用于控制浮点数的正负&#xff0c;0表示正数&#xff0c;1表…

MAC地址IP地址 端口

网络结构&#xff1a; 服务器-客户机&#xff08;C/S&#xff09;Client-Server结构&#xff0c;如QQ,LOL都拥有客户端 优点&#xff1a;响应速度快&#xff0c;形式多样&#xff0c;安全新较高缺点&#xff1a;安装软件和维护&#xff0c;不能跨平台LINUX/windows/MAC浏览器-…

C语言——柔性数组

目录0. 前言1. 思维导图2. 柔性数组的特点3. 柔性数组的使用4. 柔性数组的优势5. 结语0. 前言 柔性数组是在C99标准时引入&#xff1a; 结构中的最后一个元素允许是未知大小的数组&#xff0c;这就叫柔性数组成员。 代码示例&#xff1a; typedef struct flexible_arr {int a…

leetcode 1011. Capacity To Ship Packages Within D Days(D天内运送包裹的容量)

数组的每个元素代表每个货物的重量&#xff0c;注意这个货物是有先后顺序的&#xff0c;先来的要先运输&#xff0c;所以不能改变这些元素的顺序。 要days天内把这些货物全部运输出去&#xff0c;问所需船的最小载重量。 思路&#xff1a; 数组内数字顺序不能变&#xff0c;就…

【Storm】【一】简介

介绍 1.1 简介 Storm 是 Apache 旗下免费开源的分布式实时计算框架。Storm可以轻松、可靠地处理无限数据流&#xff0c;对实时分析、在线机器学习、连续计算、分布式RPC&#xff0c;ETL等提供高效、可靠的支持。 1.2 什么是分布式计算 分布式计算&#xff0c;将一个任务分解…

云解析专家解密《狂飙》剧中京海市人民政府域名访问真相

这段时间&#xff0c;最火的国产剧当属《狂飙》无疑。尽管不久前迎来了大结局&#xff0c;但这部剧的热度依然不减&#xff0c;成为大家茶余饭后热议的话题。 出于对这部剧的喜爱&#xff0c;小编开启了二刷模式&#xff0c;在重温剧情的同时&#xff0c;对于其中的一些细节也…

Windows 10注册表损坏怎么办?

注册表是一个复杂的数据库&#xff0c;如果不进行维护&#xff0c;它就会填充损坏的和孤立的注册表项。尤其是对Windows进行升级时&#xff0c;损坏或丢失的注册表项也会不断累积&#xff0c;从而影响您的系统性能。如果您的Windows 10系统正在经历注册表损坏的问题&#xff0c…

SpringBoot(powernode)

SpringBoot&#xff08;powernode&#xff09; 目录SpringBoot&#xff08;powernode&#xff09;一、第一个SpringBoot项目二、打jar包启动测试三、个性化banner四、常用注解4.1回顾spring和java的注解4.1.1 spring标注类的注解&#xff1a;4.1.2 spring标注方法的注解&#x…

linux的文件权限介绍

文件权限 在linux终端输入 ls -lh 出现下面界面 介绍 基本信息 其中的开头代表着文件类型和权限 而 root 和kali 则分别代表用户名和用户组名用户名顾名思义就是这个文件属于哪一个用户用户组是说自己在写好一个文件后&#xff0c;这个文件是属于该用户所有&#xff0c;…

剑指 Offer 63. 股票的最大利润

剑指 Offer 63. 股票的最大利润 难度&#xff1a;middle\color{orange}{middle}middle 题目描述 假设把某股票的价格按照时间先后顺序存储在数组中&#xff0c;请问买卖该股票一次可能获得的最大利润是多少&#xff1f; 示例 1: 输入: [7,1,5,3,6,4] 输出: 5 解释: 在第 2 …

尚硅谷《Redis7》(小白篇)

尚硅谷《Redis7 》&#xff08;小白篇&#xff09; 02 redis 是什么 官方网站&#xff1a; https://redis.io/ 作者 Git Hub https://github.com/antirez 03 04 05 能做什么 06 去哪下 Download https://redis.io/download/ redis中文文档 https://www.redis.com.cn/docu…

详解量子计算:相位反冲与相位反转

前言 本文需要对量子计算有一定的了解。需要的请翻阅我的量子专栏&#xff0c;这里不再涉及基础知识的科普。 量子相位反冲是什么&#xff1f; 相位反转&#xff08;phase kickback&#xff09;是量子计算中的一种现象&#xff0c;通常在量子算法中使用&#xff0c;例如量子…

亲测实现PopupWindow显示FlowLayout流式布局带固定文本/按钮(位置可改)

实现&#xff1a;动态绘制并带固定文本/按钮&#xff0c;固定文本/按钮固定在最后一行的右边且垂直居中&#xff0c;若最后一行放不下&#xff0c;则新开一行放到新行的右边且垂直居中&#xff08;新行的行高跟前面的一样&#xff09;&#xff0c;可单选、多选、重置。 注&…

SQL零基础入门学习(六)

SQL零基础入门学习&#xff08;六&#xff09; SQL零基础入门学习&#xff08;五&#xff09; SQL 通配符 通配符可用于替代字符串中的任何其他字符。 SQL 通配符用于搜索表中的数据。 在 SQL 中&#xff0c;可使用以下通配符&#xff1a; 演示数据库 在本教程中&#xff…

robotframework自动化测试环境搭建

环境说明 win10 python版本&#xff1a;3.8.3rc1 安装清单 安装配置 selenium安装 首先检查pip命令是否安装&#xff1a; C:\Users\name>pipUsage:pip <command> [options]Commands:install Install packages.download Do…

掌握这10个测试方法,软件测试已登堂入室

当然还有很多测试方法&#xff0c;这些要根据实际不同应用场景而变化&#xff0c;这里就以键盘为例子进行测试方法的讲解。 1.需求测试 需求这种大家都知道这种主要是就是甲方或者项目经理写的&#xff0c;或者某些人需要什么我们就给什么&#xff0c;一般来讲一个东西给到…