基于注意力的知识蒸馏Attention Transfer原理与代码解析

news2025/1/13 2:52:55

paper:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/AT.py

背景

一个流行的假设是存在非注意non-attentional和注意attentional感知过程,非注意感知有助于从整体观察一个场景并获取high-level信息,注意力感知过程会将我们导向并更关注某个局部。不同的观察者有不同的知识,不同的目标,因此有不同的注意力策略,从而以不同方式看待同一场景。本文研究的主题是一个教师网络能否通过向学生网络传递它的注意力信息(即它更关注哪些区域)来提升学生网络的性能。

本文的创新点

本文提出将注意力作为一种将知识从教师网络传递给学生网络的机制。对于一个给定的卷积神经网络,首先需要合适地定义注意力,作者将注意力看做是一组spatial maps,给定输入spatial map上编码了网络最关注的空间区域。本文还提出了同时使用基于激活的和基于梯度的空间注意力图,并通过实验表明了基于注意力的知识蒸馏的有效性,同时表明了基于激活的注意力传递比单纯基于激活的知识传递效果更好。

方法介绍

给定CNN某层的输出激活张量 \(A\in R^{C\times H\times W}\),输入到一个基于激活的映射函数 \(\mathcal{F}\) 得到得到一个spatial attention map

这里隐含的假设是一个隐含神经元激活值的绝对值可以用来表明该神经元对某个特定输入的重要程度,比如注意力图上某个位置的值越大说明网络越关注该位置。基于该假设,我们可以沿通道维度上计算这些值的统计数据来构建一个空间注意力图。如下图所示

本文主要考虑了以下三种统计方法

我们假设教师网络和学生网络之间的注意力传递发生在相同分辨率的注意力图上,当分辨率不一致时也可以通过插值来进行匹配。一个示例如下图所示,其中教师和学生网络都是残差网络,在每个stage的最后进行注意力知识的传递,即计算对应attention map之间的损失。

定义 \(S,T\) 和 \(\mathbf{W_{S}},\mathbf{W_{T}}\) 分别表示学生和教师网络以及对应的模型权重,\(\mathcal{L}\left ( \mathbf{W},x \right ) \) 是交叉熵损失,\(\mathcal{I}\) 是所有教师和学生网络对应注意力图的索引。完整损失函数如下

其中 \(Q_{S}^{j}=vec(F(A^{j}_{S}))\) 和 \(Q_{T}^{j}=vec(F(A^{j}_{T}))\) 别是学生网络和教师网络第 \(j\) 层的向量形式的注意力图,\(p\) 是范数类型本文默认 \(p=2\)。注意力的知识传递还可以和知识蒸馏结合使用,只需要在式(2)中加一项教师和学生软化标签分布之间的交叉熵损失项即可。

代码解析

其中函数at_loss从教师和学生网络中一一取出对应层的特征图f_sf_t,函数single_stage_at_loss计算对应单层之间的注意力损失。

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


def single_stage_at_loss(f_s, f_t, p):
    def _at(feat, p):
        # (64,64,32,32)->(64,64,32,32)->(64,32,32)->(64,1024)->(64,1024)
        return F.normalize(feat.pow(p).mean(1).reshape(feat.size(0), -1))  # 沿通道取mean

    s_H, t_H = f_s.shape[2], f_t.shape[2]
    if s_H > t_H:
        f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
    elif s_H < t_H:
        f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
    # (64,1024)-(64,1024)->(64,1024)->()
    return (_at(f_s, p) - _at(f_t, p)).pow(2).mean()


def at_loss(g_s, g_t, p):
    return sum([single_stage_at_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)])


class AT(Distiller):
    """
    Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
    src code: https://github.com/szagoruyko/attention-transfer
    """

    def __init__(self, student, teacher, cfg):
        super(AT, self).__init__(student, teacher)
        self.p = cfg.AT.P
        self.ce_loss_weight = cfg.AT.LOSS.CE_WEIGHT
        self.feat_loss_weight = cfg.AT.LOSS.FEAT_WEIGHT

    def forward_train(self, image, target, **kwargs):
        logits_student, feature_student = self.student(image)
        with torch.no_grad():
            _, feature_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        loss_feat = self.feat_loss_weight * at_loss(
            feature_student["feats"][1:], feature_teacher["feats"][1:], self.p
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_feat,
        }
        return logits_student, losses_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/354838.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringCloudAlibaba-Sentinel

一、介绍官网&#xff1a;https://github.com/alibaba/Sentinel/下载jar包,启动,访问http://localhost:8080/创建module添加如下依赖<!--SpringCloud ailibaba sentinel --><dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring…

内网渗透(四十)之横向移动篇-ms14-068传递获取域管横向移动

系列文章第一章节之基础知识篇 内网渗透(一)之基础知识-内网渗透介绍和概述 内网渗透(二)之基础知识-工作组介绍 内网渗透(三)之基础知识-域环境的介绍和优点 内网渗透(四)之基础知识-搭建域环境 内网渗透(五)之基础知识-Active Directory活动目录介绍和使用 内网渗透(六)之基…

[软件工程导论(第六版)]第2章 可行性研究(复习笔记)

文章目录2.1 可行性研究的任务2.2 可行性研究过程2.3 系统流程图2.4 数据流图概念2.5 数据字典2.6 成本/效益分析2.1 可行性研究的任务 可行性研究的目的 用最小的代价在尽可能短的时间内确定问题是否能够解决。 可行性研究的3个方面 &#xff08;1&#xff09;技术可行性&…

宝塔搭建实战人才求职管理系统adminm前端vue源码(三)

大家好啊&#xff0c;我是测评君&#xff0c;欢迎来到web测评。 上一期给大家分享骑士cms后台admin前端vue在本地运行打包、宝塔发布部署的方式&#xff0c;本期给大家分享&#xff0c;后台adminm移动端后台vue前端怎么在本地运行&#xff0c;打包&#xff0c;实现线上功能更新…

Ubuntu下使用Wine运行HBuilderX

安装完wine后&#xff0c;在HbuilderX的目录中打开终端&#xff0c;直接输入wine HBuilderX.exe命令&#xff0c;启动过程中会提示安装wine-mono组件&#xff0c;点击安装按钮下载安装该组件&#xff0c;该组件下载速度慢&#xff0c;需要等待特别长时间。   安装完毕后&…

金三银四软件测试工程师面试题(含答案)

前言&#xff1a;此文专门记载本人平时面试以及收藏的面试题目&#xff0c;如果有错误之处请及时指正&#xff0c;谢谢&#xff01; 1、python的数据类型有哪些 答&#xff1a;Python基本数据类型一般分为&#xff1a;数字、字符串、列表、元组、字典、集合这六种基本数据类…

pytorch配置—什么是CUDA,什么是CUDNN、在配置pytorch虚拟环境中遇到的问题、在安装gpu—pytorch中遇到的问题

1.什么是CUDA&#xff0c;什么是CUDNN &#xff08;1&#xff09;什么是CUDA CUDA(ComputeUnified Device Architecture)&#xff0c;是显卡厂商NVIDIA推出的运算平台。 CUDA是一种由NVIDIA推出的通用并行计算架构&#xff0c;该架构使GPU能够解决复杂的计算问题。 &#xff0…

RuoYi-Vue-Plus搭建(若依)

项目简介 1.RuoYi-Vue-Plus 是重写 RuoYi-Vue 针对 分布式集群 场景全方位升级(不兼容原框架)2.环境安装参考&#xff1a;https://blog.csdn.net/tongxin_tongmeng/article/details/128167926 JDK 11、MySQL 8、Redis 6.X、Maven 3.8.X、Nodejs > 12、Npm 8.X3.IDEA环境配置…

建造《流浪地球2》中要毁灭人类的超级量子计算机MOSS的核心量子技术是什么?

1.《流浪地球2》中的量子计算机 2023年中国最火的电影非《流浪地球2》莫属&#xff0c;在《流浪地球2》中有一个人工智能机器人MOSS &#xff0c;它的前身是“550W”超级量子计算机&#xff0c;“MOSS”是它给自己起的名字&#xff08;“550W”倒转180度就是“MOSS”&#xff…

力扣38.外观数列

文章目录力扣38.外观数列题目描述方法1&#xff1a;按规则生成&#xff08;顺序暴力法&#xff09;力扣38.外观数列 题目描述 给定一个正整数 n &#xff0c;输出外观数列的第 n 项。 「外观数列」是一个整数序列&#xff0c;从数字 1 开始&#xff0c;序列中的每一项都是对…

这才是计算机科学_人工智能

人工智能一、前言二、ML2.1 分类2.1.1 决策树2.2.2 支持向量机2.2.3 人工神经网络三、计算机视觉3.1 Prewitt算子3.2 Viola-Jones 人脸检测算法3.3 卷积神经网络四、自然语言处理4.1 知识图谱4.2 语音识别一、前言 之前讲了计算机从发展到现在的过程&#xff0c;计算机很适合做…

[软件工程导论(第六版)]第2章 可行性研究(课后习题详解)

文章目录1. 在软件开发的早期阶段为什么要进行可行性研究&#xff1f;应该从哪些方面研究目标系统的可行性&#xff1f;2. 为方便储户&#xff0c;某银行拟开发计算机储蓄系统。储户填写的存款单或取款单由业务员输入系统&#xff0c;如果是存款&#xff0c;系统记录存款人姓名…

ThreadLocal 内存泄漏问题

1. 认识ThreadLocal java中提高了threadlocal&#xff0c;为每个线程保存其独有的变量&#xff0c;threadlocal使用的一个小例子是&#xff1a; public class ThreadLocalTest {public static void main(String[] args) {ThreadLocal<String> threadIds new ThreadLoc…

FreeRTOS系统延时函数分析

FreeRTOS提供了两个系统延时函数&#xff0c;相对延时函数vTaskDelay()和绝对延时函数vTaskDelayUntil()。相对延时是指每次延时都是从任务执行函数vTaskDelay()开始&#xff0c;延时指定的时间结束&#xff0c;绝对延时是指每隔指定的时间&#xff0c;执行一次调用vTaskDealyU…

操作系统闲谈07——系统启动

操作系统闲谈07——系统启动 一、BIOS启动 BIOS程序不需要由谁加载&#xff0c;本身便固化在ROM只读存储器中。 开机的一瞬间 cs:ip 便被初始化为0xf000 : 0xfff0。开机的时候处于实模式&#xff0c;其等效地址为0xffff0&#xff0c;如上图所示此地址为BIOS的入口地址。 建立…

C/C++每日一练(20230218)

目录 1. 整数转罗马数字 2. 跳跃游戏 II 3. 买卖股票的最佳时机 IV 1. 整数转罗马数字 罗马数字包含以下七种字符&#xff1a; I&#xff0c; V&#xff0c; X&#xff0c; L&#xff0c;C&#xff0c;D 和 M。 字符 数值 I 1 V 5 X …

物联网中RocketMQ的使用

物联网中RocketMQ的使用 1. 背景 随着物联网行业的发展、智能设备数量越来越多&#xff0c;很多常见的智能设备都进入了千家万户&#xff1b;随着设备数量的增加&#xff0c;也对后台系统的性能提出新的挑战。 在日常中&#xff0c;存在一些特定的场景&#xff0c;属于高并发请…

Java数据类型、基本与引用数据类型区别、装箱与拆箱、a=a+b与a+=b区别

文章目录1.Java有哪些数据类型2.Java中引用数据类型有哪些&#xff0c;它们与基本数据类型有什么区别&#xff1f;3.Java中的自动装箱与拆箱4.为什么要有包装类型&#xff1f;5.aab与ab有什么区别吗?1.Java有哪些数据类型 8种基本数据类型&#xff1a; 6种数字类型(4个整数型…

java ssm志愿者信息服务平台springmvc

志愿者服务平台&#xff0c;采用ssm框架技术&#xff0c;java语言&#xff0c;jsp前端页面以及js&#xff0c;DIVCSS&#xff0c;html5等技术综合应用开发实现。志愿者服务平台&#xff0c;采用了前台后台的模式开发&#xff0c;前台用于志愿者在线项目培训&#xff0c;视频学习…

CUDA编程接口

编程接口 文章目录编程接口3.1利用NVCC编译3.1.1编译流程3.1.1.1 离线编译3.1.1.2 即时编译3.1.2 Binary 兼容性注意&#xff1a;仅桌面支持二进制兼容性。 Tegra 不支持它。 此外&#xff0c;不支持桌面和 Tegra 之间的二进制兼容性。3.1.3 PTX 兼容性3.1.4 应用程序兼容性3.1…