论文解读:Masked Generative Distillation

news2024/11/25 4:42:00

文章汇总

话题

知识蒸馏

创新点

带掩盖的生成式蒸馏

方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),来帮助学生获得更好的表现

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:

其中:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

想改进的地方

摘要

知识蒸馏已成功地应用于各种任务中。目前的蒸馏算法通常通过模仿老师的输出来提高学生的表现。本文表明,教师还可以通过引导学生特征恢复来提高学生的代表性。从这个角度来看,我们提出了掩膜生成蒸馏(mask Generative Distillation, MGD),它很简单:我们掩膜学生特征的随机像素,并通过一个简单的块强制其生成教师的完整特征。MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。我们用大量的数据集对不同的模型进行了实验,结果表明所有的学生都取得了很好的进步。值得注意的是,我们将ResNet-18的ImageNet顶级1精度从69.90%提高到71.69%,将ResNet-50主干的RetinaNet从37.4提高到41.0 Boundingbox mAP,基于ResNet-50的SOLO从33.1提高到36.2 Mask mAP,以及基于ResNet-18的DeepLabV3从73.20提高到76.02 mIoU。我们的代码可在https://github.com/yzd-v/MGD上获得。

关键词:知识蒸馏,图像分类,目标检测,语义分割,实例分割

介绍

深度卷积神经网络(cnn)已广泛应用于各种计算机视觉任务中。一般来说,较大的模型具有较好的性能,但推理速度较低,难以在有限的源下部署。为了克服这一问题,知识蒸馏被提出[18]。按蒸馏类型可分为两种。第一种是专门为不同的任务而设计的,例如用于分类的基于logit的蒸馏[18,40]和用于检测的基于head的蒸馏[10,39]。第二种是基于特征的蒸馏[28,17,4]。由于各种网络之间只有头部或投影仪后的特征是不同的,从理论上讲,基于特征的蒸馏方法可以可用于各种任务。然而,为特定任务设计的蒸馏方法通常不适用于其他任务。例如,OFD[17]和KR[4]对探测器的改进有限。FKD[37]和FGD[35]是专门为探测器设计的,由于缺乏颈部,无法用于其他任务。

以往基于特征的提炼方法,由于教师的特征具有更强的表征能力,通常会让学生尽可能地模仿教师的输出。然而,我们认为没有必要直接模仿老师来提高学生特征的表征能力。用于蒸馏的特征通常是通过深度网络获得的高阶语义信息。特征像素已经在一定程度上包含了相邻像素的信息。因此,如果我们可以使用部分像素通过一个简单的块来还原教师的全部特征,那么这些使用的像素的代表性也可以得到提高。从这个角度出发,我们提出了一种简单有效的基于特征的蒸馏方法——掩膜生成蒸馏(mask Generative Distillation, MGD)。

如图2所示,我们首先对学生特征的随机像素进行遮罩,然后通过一个简单的块将遮罩后的特征生成教师的完整特征。由于每次迭代都使用随机像素,因此在整个训练过程中都会使用所有像素,这意味着特征将更加鲁棒,并且其表示能力将得到提高。在我们的方法中,老师只是引导学生还原特征,并不要求学生直接模仿

为了验证我们的假设,即在不直接模仿教师的情况下,掩蔽特征生成可以提高学生的特征表征能力,我们从学生和教师的颈部对特征注意力进行了可视化。如图1所示,学生和教师的特征有很大的不同。

与教师相比,学生特征的背景有更高的反应。教师的mAP也显著高于学生,为41.0比37.4。采用最先进的蒸馏方法FGD蒸馏后[35],迫使学生用心模仿老师的特征,学生的特征与老师的特征更加相似,mAP很大提高到40.7。而经过MGD培训后,学生与教师的特征仍有显著差异,但学生对背景的反应却大大降低。令我们惊讶的是,该学生的成绩超过了FGD,甚至达到了与老师相同的mAP。这也说明用MGD训练可以提高学生特征的表征能力。此外,我们还在图像分类和密集预测任务上做了大量的实验。结果表明,MGD对图像分类、目标检测、语义分割和实例分割等任务都有较大的改善。MGD还可以与其他基于logit或基于head的蒸馏方法相结合,以获得更大的性能收益。综上所述,本文的贡献有:

1. 我们提出了一种新的基于特征的知识提炼方法,使学生利用被掩盖的特征来生成教师的特征,而不是直接模仿教师的特征。

2. 本文提出了一种新的基于特征的蒸馏方法——掩膜生成蒸馏,该方法简单易用,只需要两个超参数。

3. 我们通过在不同数据集上的大量实验验证了我们的方法在各种模型上的有效性。对于图像分类和密集预测任务,学生在MGD的帮助下都取得了显著的进步。

相关工作

面向分类的知识蒸馏

知识蒸馏最早是由Hinton等人[18]提出的,其中学生受到来自教师最后一个线性层的标签和软标签的监督。

然而,除了logit之外,更多的蒸馏方法是基于特征映射的。FitNet[28]从中间层提取语义信息。AT[36]总结了跨渠道维度的价值,并将注意力知识转移给学生。OFD[17]提出了余量ReLU,并设计了一个测量蒸馏距离的新函数。CRD[30]利用对比学习将知识传递给学生。最近,KR[4]建立了一个审查机制,并利用多层次信息进行蒸馏。SRRL[33]将表示学习和分类解耦,利用老师的分类器来训练学生的倒数第二层特征。WSLD[40]从偏方差权衡的角度提出了加权的蒸馏软标签。

面向语义分割的知识蒸馏

Liu等人[23]提出了成对和整体蒸馏,在学生和教师的输出之间执行成对和高阶一致性。他等人[16]将教师网络的输出重新解释为一个重新表示的潜在域,并从教师网络中捕获长期依赖关系。CWD[29]最小化了概率图之间的Kullback-Leibler (KL)散度,该散度是通过对每个通道的激活图进行归一化计算得到的。

方法

对于不同的任务,模型的体系结构差别很大。此外,大多数蒸馏方法都是为特定任务而设计的。然而,基于特征的精馏可以同时应用于分类和密集预测。特征蒸馏的基本方法可表述为:

式中,F^TF^S分别表示教师和学生的特征,f_{align}是将学生的特征F^S与教师的特征F^T

对齐的适应层。C, H, W表示特征映射的形状。

这种方法有助于学生直接模仿老师的特征。然而,我们提出了掩蔽生成蒸馏(MGD),其目的是迫使学生产生教师的特征,而不是模仿它,给学生带来了分类和密集预测方面的显着改善。MGD的架构如图2所示,我们将在本节中专门介绍它。

带掩盖特征的生成

对于基于cnn的模型,更深层的特征具有更大的接受域和更好的原始输入图像表征。换句话说,特性的图像素在一定程度上已经包含了相邻像素的信息。

因此,我们可以使用部分像素来恢复完整的特征映射。我们的方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),这样可以帮助学生获得更好的表现。
我们用T^l \in R^{C \times H \times W},S^l \in R^{C \times H \times W}(l=1,...,L)表示分别为教师和学生的第l个特征图。首先我们设置第l个随机掩码来覆盖学生的第l个特征,可以表示为:

其中R_{i,j}^l为(0,1)中的随机数,i,j分别为特征图的横坐标和纵坐标。λ是表示掩码比的超参数。第
l个特征映射被第l个随机掩码覆盖。

然后我们使用相应的掩码覆盖学生的特征图,并尝试用左边的像素生成教师的特征图,可以表示为:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

根据该方法,我们设计了MGD的蒸馏损失L_{dis}:

其中L为蒸馏层数和,C、H、W为特征映射的形状。S和T分别表示学生和教师的特征。

总体损失

利用提出的MGD蒸馏损失L_{dis},我们用总损失训练所有模型如下:

其中L_{original}为所有任务中模型的原始损失,α为平衡损失的超参数。

MGD是一种简单有效的蒸馏方法,可方便地应用于各种任务。算法1总结了我们的方法的过程。

方法过程汇总

算法1:带掩盖的生成式蒸馏

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:L_{dis}(fea^S, fea^T)

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

主要实验

MGD是一种基于特征的蒸馏,可以很容易地应用于各种任务的不同模型。在本文中,我们对分类、目标检测、语义分割和实例分割等任务进行了实验。我们针对不同的任务使用不同的模型和数据集进行实验,所有模型都通过MGD获得了出色的改进。

分类

数据集

对于分类任务,我们在包含1000个对象类别的ImageNet[11]上评估我们的知识蒸馏方法。我们用120万张图片进行训练,用5万张图片进行测试,完成所有的分类实验。我们用准确性来评价模型。

实现细节

对于分类任务,我们计算来自主干的最后一个特征映射的蒸馏损失。有关消融的研究见5.5节。MGD使用超参数α来平衡方程6中的蒸馏损失。另一个超参数λ用于调整公式2中的屏蔽比。所有分类实验均采用超参数{α = 7 × 10^(−5),λ = 0.5}。我们使用SGD优化器训练所有模型100个epoch,其中动量为0.9,权重衰减为0.0001。我们初始化学习率为0.1,并每30次衰减一次。此设置基于8个gpu。实验采用基于Pytorch[26]的MMClassification[6]和MMRazor[7]进行。

分类结果

我们用两种常用的蒸馏设置进行实验,包括均相蒸馏和非均相蒸馏。

第一个蒸馏设置是从ResNet-34[15]到ResNet-18,另一个设置是从ResNet-50到MobileNet[19]。如表1所示,我们比较了各种知识蒸馏方法[18、36、17、25、30、4、40、33],包括基于特征的方法、基于逻辑的方法和结合的方法。使用我们的方法,学生ResNet-18和MobileNet的Top-1准确率分别提高了1.68和3.14。此外,如上所述,MGD只需要计算特征图上的蒸馏损失,并且可以与其他基于逻辑的图像分类方法相结合。因此,我们尝试在WSLD中加入基于logit的蒸馏损失[40]。这样,两位同学的Top-1准确率分别达到了71.80和72.59,分别提高了0.22和0.24。

表1。不同蒸馏方法在ImageNet数据集上的结果。T和S分别表示老师和学生。

目标检测和实例分割

数据集

我们在COCO2017数据集[22]上进行实验,该数据集包含80个对象类别。我们使用120k的训练图像进行训练,5k的val图像进行测试。用平均精度对模型的性能进行了评价。

表2。不同蒸馏方法在COCO上的目标检测结果。

实现细节

我们从颈部计算所有特征映射的蒸馏损失。我们采用超参数{α = 2 × 10^(−5),λ = 0.65}对所有的单阶段模型,{α = 5 × 10^(−7),λ = 0.45}对所有的两阶段模型。我们使用SGD优化器训练所有模型,其中动量为0.9,权重衰减为0.0001。除非特别说明,否则我们训练模型为24个epoch。我们使用继承策略[20,35],用教师的颈部和头部参数初始化学生,在头部结构相同的情况下训练学生。实验采用MMDetection进行[2]。

目标检测和实例分割结果

对于目标检测,我们在三种不同类型的检测器上进行了实验,包括两级检测器(Faster RCNN[27])、基于锚点的一级检测器(RetinaNet[21])和无锚点的一级检测器(RepPoints[34])。

我们将MGD与最近三种最先进的检测器蒸馏方法进行比较[37,29,35]。以分割为例,我们在SOLO[32]和Mask RCNN[14]两个模型上进行了实验。如表2和表3所示,我们的方法在两种目标检测和实例分割方面都优于其他最先进的方法。学生在MGD的帮助下获得了显著的AP改善,例如基于ResNet-50的retanet和SOLO在COCO数据集上分别获得了3.6个Boundingbox mAP和3.1个Mask mAP的改善。

表3。不同蒸馏方法对实例分割的结果。MS的意思是多尺度训练。这里的AP指掩模AP。

参考资料

论文下载(ECCV 2区 2022)

https://arxiv.org/pdf/2205.01529.pdf

📎Masked Generative Distillation.pdf

代码地址

GitHub - yzd-v/MGD: Masked Generative Distillation (ECCV 2022)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1451709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CTFshow web(文件上传158-161)

web158 知识点: auto_append_file 是 PHP 配置选项之一,在 PHP 脚本执行结束后自动追加执行指定的文件。 当 auto_append_file 配置被设置为一个文件路径时,PHP 将在执行完脚本文件的所有代码后,自动加载并执行指定的文件。 这…

【springboot+vue项目(十四)】基于Oauth2的SSO单点登录(一)整体流程介绍

场景:现在有一个前后端分离的系统,前端框架使用vue-element-template,后端框架使用springbootspringSecurityJWTRedis(登录部分)现在需要接入到已经存在的第三方基于oauth2.0的非标准接口统一认证系统。 温馨提示&…

【STM32 CubeMX】I2C查询方式

文章目录 前言一、CubeMX配置IIC二、查询方式的使用2.1 分析一种情况2.2 Master模式2.3 Mem模式 总结 前言 在STM32 CubeMX环境中,I2C(Inter-Integrated Circuit)通信协议的查询方式是一种简单而常见的通信方式。通过查询方式,微…

代码随想录 Leetcode45. 跳跃游戏 II

题目&#xff1a; 代码(首刷看解析 2024年2月15日&#xff09;&#xff1a; class Solution { public:int jump(vector<int>& nums) {if (nums.size() 1) return 0;int res 0;int curDistance 0;int nextDistance 0;for (int i 0; i < nums.size(); i) {nex…

6、内网安全-横向移动WmiSmbCrackMapExecProxyChainsImpacket

用途&#xff1a;个人学习笔记&#xff0c;有所借鉴&#xff0c;欢迎指正&#xff01; 前言&#xff1a; 在内网环境中&#xff0c;主机192.168.3.31有外网网卡能出网&#xff0c;在取得该主机权限后上线&#xff0c;搭建web应用构造后门下载地址&#xff0c;利用该主机执行相…

Windows系统VMware创建多个CentOS7虚拟机 NAT网络配置 ssh连接

主要目标: 1.创建3个虚拟机, centos7系统 2.虚拟机之间互相访问 3.物理机访问各虚拟机, 通过xshell建立ssh连接 4.物理机网络变化时,仍能访问 用途: NoSQL课程使用, 课前环境搭建,个人备忘 基本信息&#xff1a; 物理机&#xff1a; windows 11 操作系统 虚拟机软件&#xff…

前端工程化面试题 | 10.精选前端工程化高频面试题

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

OpenCV Mat实例详解 三

OpenCV Mat实例详解 一、二介绍了&#xff0c;OpenCV Mat类构造函数及其公共属性。下面继续介绍OpenCV Mat类公有静态成员函数 OpenCV Mat类公有静态成员函数&#xff08;Static Public Member Functions&#xff09; static CV_NODISCARD_STD Mat diag (const Mat &d)&…

CSP-201903-2-二十四点

CSP-201903-2-二十四点 一、中缀表达式转后缀表达式 中缀表达式是一种常见的数学表达式书写方式&#xff0c;其中操作符位于相关的操作数之间&#xff0c;如 A B。而后缀表达式&#xff08;逆波兰表示法&#xff09;则是一种没有括号&#xff0c;操作符跟随操作数之后的表示…

TIM输出比较 P2

D触发器&#xff1f; 一、输出比较 二、PWM 1、简介 2、结构 三、外部设备 1.舵机 2.直流电机 我的理解是xO1 xIN1 & PWMx; xO2 xIN2 & PWMx;引入PWMx可以更方便的控制特定的电路。 四、函数学习 /*****单独设置输出比较极性*****/ void TIM_OC1PolarityConfig(…

CSS篇--transform

CSS篇–transform 使用transform属性实现元素的位移、旋转、缩放等效果 位移 // 语法 transform:translate(水平移动距离&#xff0c;垂直移动距离) translate() 如果只给一个值&#xff0c;表示x轴方法移动距离 单独设置某个方向的移动距离&#xff1a;translateX() transla…

Rust 基本环境安装

rust 基本介绍请看上一篇文章&#xff1a;rust 介绍 rustup 介绍 rustup 是 Rust 语言的安装器和版本管理工具。通过 rustup&#xff0c;可以轻松地安装 Rust 编译器&#xff08;rustc&#xff09;、标准库和文档。它也允许你切换不同的 Rust 版本或目标平台&#xff0c;以及…

Compose 自定义 - 数据转UI的三阶段(组合、布局、绘制)

一、概念 Compose 通过三个阶段把数据转化为UI&#xff1a;组合&#xff08;要显示什么&#xff09;、布局&#xff08;要显示在哪里&#xff09;、绘制&#xff08;如何渲染&#xff09;。 组合阶段 Compisition 界面首次渲染时会将可组合函数转化为一个个布局节点 Layout Nod…

【打工日常】使用docker部署linux-command解析搜索工具

一、linux-command介绍 linux-command工具是一个非盈利性的工具&#xff0c;里面记录了550 个 Linux 命令&#xff0c;内容包含 Linux 命令手册、详解、学习&#xff0c;是值得收藏的 Linux 命令速查手册。内容来自网络和网友的补充。 二、本次实践介绍 1. 本次实践简介 本次…

Flume(二)【Flume 进阶使用】

前言 学数仓的时候发现 flume 落了一点&#xff0c;赶紧补齐。 1、Flume 事务 Source 在往 Channel 发送数据之前会开启一个 Put 事务&#xff1a; doPut&#xff1a;将批量数据写入临时缓冲区 putList&#xff08;当 source 中的数据达到 batchsize 或者 超过特定的时间就会…

qt-C++笔记之捕获鼠标滚轮事件并输出滚轮角度增量

qt-C笔记之捕获鼠标滚轮事件并输出滚轮角度增量 code review! 文章目录 qt-C笔记之捕获鼠标滚轮事件并输出滚轮角度增量1.运行2.main.cpp3.main.pro 1.运行 2.main.cpp #include <QApplication> #include <QWidget> #include <QWheelEvent> #include <…

.NET Core MongoDB数据仓储和工作单元模式封装

前言 上一章我们把系统所需要的MongoDB集合设计好了&#xff0c;这一章我们的主要任务是使用.NET Core应用程序连接MongoDB并且封装MongoDB数据仓储和工作单元模式&#xff0c;因为本章内容涵盖的有点多关于仓储和工作单元的使用就放到下一章节中讲解了。仓储模式&#xff08;R…

java的面向对象编程(oop)——认识枚举

前言 打好基础&#xff0c;daydayup! 枚举 1&#xff0c;认识枚举&#xff1a; 枚举是一种特殊类&#xff0c;用enum语句修饰。与普通类不同的是&#xff1a;枚举类的第一行只能写一些合法的标识符&#xff08;名称&#xff09;&#xff0c;多个名称用逗号隔开。这些标识符&a…

相机图像质量研究(16)常见问题总结:光学结构对成像的影响--IRCUT

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结&#xff1a;光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结&#xff1a;光学结构对成…

原型模式-Prototype Pattern

原文地址:https://jaune162.blog/design-pattern/prototype-pattern/ 引言 在Java中如果我们想要拷贝一个对象应该怎么做?第一种方法是使用 getter和setter方法一个字段一个字段设置。或者使用 BeanUtils.copyProperties() 方法。这种方式不仅能实现相同类型之间对象的拷贝,…