『论文精读』Data-efficient image Transformers(DeiT)论文解读

news2024/9/23 17:20:49
『论文精读』Data-efficient image Transformers(DeiT)论文解读

文章目录

  • 一. DeiT简介
  • 二. 知识蒸馏(knowledge distillation)
  • 三. better hyperparameter
  • 四. data augmentation
  • 五. label smoothing
  • 参考文献

  • 论文下载链接:https://arxiv.org/pdf/2012.12877.pdf
  • 论文代码链接:https://github.com/facebookresearch/deit
  • 关于VIT论文的解读可以关注我之前的文章:『论文精读』Vision Transformer(VIT)论文解读

一. DeiT简介

  • 现有的基于Transformer的分类模型ViT需要在海量数据上(JFT-300M,3亿张图片)进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。

在这里插入图片描述

  • DeiT的模型和VIT的模型几乎是相同的,可以理解为本质上是在训一个VIT。
  • better hyperparameter:指的是模型初始化、learning-rate等设置。
  • data augmentation:在只有120万张图片的Imagenet,使用数据增广模拟更多数据。
  • Distillation:知识蒸馏。
  • 三部分的作用分别为:保证模型更好的收敛、可以使用小的数据训练、进一步提升性能。还有一些其他的方式,如:warmup、label smoothing、droppath等。

在这里插入图片描述

  • Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。

在这里插入图片描述

  • 文章贡献如下:
  • 仅使用Transformer,不引入Conv的情况下也能达到SOTA效果。
  • 提出了基于token蒸馏的策略,这种针对transformer的蒸馏方法可以超越原始的蒸馏方法。
  • Deit发现使用Convnet作为教师网络能够比使用Transformer架构取得更好的效果。

在这里插入图片描述

二. 知识蒸馏(knowledge distillation)

  • Knowledge Distillation(KD)最初被Hinton提出,与Label smoothing动机类似,但是KD生成soft label的方式是通过教师网络得到的。
  • KD可以视为将教师网络学到的信息压缩到学生网络中。还有一些工作“Circumventing outlier of autoaugment with knowledge distillation”则将KD视为数据增强方法的一种。
  • KD能够以soft的方式将归纳偏置传递给学生模型,Deit中使用Conv-Based架构作为教师网络,将局部性的假设通过蒸馏方式引入Transformer中,取得了不错的效果。
  • 简单来说就是用teacher模型去训练student模型,通常teacher模型更大而且已经训练好了,student模型是我们当前需要训练的模型。在这个过程中,teacher模型是不训练的。
  • 当teacher模型和student模型拿到相同的图片时,都进行各自的前向,这时teacher模型就拿到了具有分类信息的feature,在进行softmax之前先除以一个参数 τ \tau τ,叫做temperature(蒸馏温度),然后softmax得到soft labels(区别于one-hot形式的hard-label)。
  • student模型也是除以同一个 τ \tau τ,然后softmax得到一个soft-prediction,我们希望student模型的soft-prediction和teacher模型的soft labels尽量接近,使用KLDivLoss进行两者之间的差距度量,计算一个对应的损失teacher loss
  • 在训练的时候,我们是可以拿的到训练图片的真实的ground truth(hard label)的,可以看到上面图中student模型下面一路,就是预测结果和真是标签之间计算交叉熵crossentropy。
  • 链接:损失函数|交叉熵损失函数
  • 然后两路计算的损失:KLDivLoss和CELoss,按照一个加权关系计算得到一个总损失total loss,反向修改参数的时候这个teacher模型是不做训练的,只依据total loss训练student模型。
  • 还可以使用硬蒸馏,对比上面的结构图,哪种更好没有定论。

2.1. KLDivloss

  • 这里可以参考下我之前的文章:〖ML笔记〗信息量、信息熵、交叉熵、KL散度(相对熵)、JS散度以及逻辑损失+面试知识点!
  • KL divergence(KL散度又叫相对熵): 它表示用分布 q ( x ) q(x) q(x) 模拟真实分布 p ( x ) p(x) p(x) 所需要的额外信息。同时也叫KL距离,就是是两个随机分布间距离的度量。
  • 取值范围: [ 0 , + ∞ ] [0, +\infty ] [0,+]当两个分布接近相同的时候KL散度取值为0,当两个分布差异越来越大的时候KL散度值就会越来越大。
    D K L ( p ∣ q ) = H ( p , q ) ⏟ 交叉熵 − H ( p ) ⏟ 信息熵 = − ∑ i = 1 n p ( x i ) log ⁡ q ( x i ) + ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) = ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) q ( x i ) (1) \begin{aligned} {D}_{K L}({p} | {q})&=\underbrace{H(p, q)}_{\text {交叉熵}}-\underbrace{H(p)}_{\text {信息熵}}\\&=-\sum_{i=1}^{n}{p}(x_i) \log {q}(x_i)+\sum_{i=1}^{n} {p}(x_i) \log {p}(x_i) \\ &=\sum_{i=1}^{n} {p}(x_i) \log \frac{{p}(x_i)}{{q}(x_i)}\tag{1} \end{aligned} DKL(pq)=交叉熵 H(p,q)信息熵 H(p)=i=1np(xi)logq(xi)+i=1np(xi)logp(xi)=i=1np(xi)logq(xi)p(xi)(1) 注意: 直观来说,由于 p ( x ) p(x) p(x) 是已知的分布(真实分布), H ( p ) H(p) H(p) 是个常数,交叉熵和KL散度之间相差一个这样的常数(信息熵)
  • 当两个分布完全一致时候,KL散度就等于0。KLDivloss定义和使用方式为:

2.2. 蒸馏温度 τ \tau τ

  • 蒸馏温度 τ \tau τ 的作用,回想之前VIT中在self-attention里面计算 q , k \mathbf {q,k} q,k间的加权因子的时候,计算完了要scale(除以 k k k 的维度),然后再做softmax,然后用它们对 v \mathbf v v 加权相加得到对应的表示向量。
  • 如果是[1.0,20.0,400.0]直接做softamx,那结果是[0.0,0.0,1.0],可见结果完全借鉴第三个引子。而先进行处理(比如除以1000)后变为[0.001,0.02,0.4]时,在做softamx结果为[0.28,0.29,0.42]结果总综合考虑了三部分,这显然是更合理的结果。实际中,看我是更希望结果偏向于更大的值,还是偏向于综合考虑来决定是否使用softmax前输入的预处理。

2.3. distillation in transformer

这一节主要弄清楚,如何在transformer中进行蒸馏操作。

在这里插入图片描述

  • 先说一下,在这DeiT篇论文出来的时候,teacher model使用的是Regnet(一个CNN)
  • 在VIT中时使用class tokens去做分类的,相当于是一个额外的patch,这个patch去学习和别的patch之间的关系,然后连classifier,计算CELoss。在DeiT中为了做蒸馏,又额外加一个distill token,这个distill token也是去学和其他tokens之间的关系,然后连接teacher model计算KLDivLoss,那CELoss和KLDivLoss共同加权组合成一个新的loss取指导student model训练(知识蒸馏中teacher model不训练)。
  • 在预测阶段,class token和distill token分别产生一个结果,然后将其加权(分别0.5),再加在一起,得到最终的结果做预测。

L global  = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) (2) \mathcal{L}_{\text {global }}=(1-\lambda) \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{\mathrm{s}}\right), y\right)+\lambda \tau^2 \mathrm{KL}\left(\psi\left(Z_{\mathrm{s}} / \tau\right), \psi\left(Z_{\mathrm{t}} / \tau\right)\right)\tag{2} Lglobal =(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))(2)

L global  hardill  = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 L C E ( ψ ( Z s ) , y t ) (3) \mathcal{L}_{\text {global }}^{\text {hardill }}=\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y\right)+\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y_{\mathrm{t}}\right)\tag{3} Lglobal hardill =21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt)(3)

三. better hyperparameter

  • DeiT中第二个优化点在于better hyperparameter,也就是更好的参数配置,看看其都包含哪些部分。

在这里插入图片描述

  • 参数初始化方式:truncated normal distribution(截断标准分布)
  • learning-rate:CNN中的结论:当batch size越大的时候,learning rate设置的越大。
  • learning rate decay:cosine,在warm-up阶段lr先线性升上去,然后通过余弦方式lr降下来

四. data augmentation

在这里插入图片描述

  • mixup之后的图片的label不再是单一的label,而是soft-label,比如[cat,dog]=[0.5,0.5]
  • cutmix之后的图片label是按所占据的比例给的,比如[cat,dog]=[0.3,0.7]

在这里插入图片描述

  • randomaug其实是由autoaug来的,autoaug是选取了25中增强策略,每种策略中有两个操作,这两种操作都要被执行。每次为一张图随机从25中策略中选取一种,将这两种操作对该图执行。至于这25中策略是怎么组成的,每种里面的操作的概率是如何确立的,这些是由搜索算法的实现的,总之认为这么搭配有效就行了。对于randomaug,相当于对于autoaug的简化,它是13种增强策略,然后从中一次选取6种策略依次对图片进行操作,完成增强操作。
  • model EMA(Exponential Moving Average)指数滑动平均,使得模型权重更新与一段时间内的历史取值有关。 m t m_{t} mt 是当前的模型权重, m t − 1 m_{t-1} mt1 是上一轮模型权重, θ t \theta_{t} θt为模型当前权重的值,举一个例子:

在这里插入图片描述
在这里插入图片描述

  • 三种更新参数方式的更新参数结果曲线:

在这里插入图片描述

  • 实际使用的时候,设置上面例子中的 β \beta β 值例如为0.99996,保证模型的参数值不会乱动。

五. label smoothing

  • label smoothing:原本hard-label变成soft-label,设置参数,给其余非标签平均一些label概率。
     Label  one hot  = [ 1 , 0 , 0 , 0 , 0 , 0 ]  Label  smoothing  = [ 0.9 , 0.02 , 0.02 , 0.02 , 0.02 , 0.02 ] , α = 0.1 \begin{aligned} & \text { Label }_{\text {one hot }}=[1,0,0,0,0,0] \\ & \text { Label }_{\text {smoothing }}=[0.9,0.02,0.02,0.02,0.02,0.02], \alpha=0.1 \end{aligned}  Label one hot =[1,0,0,0,0,0] Label smoothing =[0.9,0.02,0.02,0.02,0.02,0.02],α=0.1

在这里插入图片描述
在这里插入图片描述

  • 上图来自论文:When Does Label Smoothing Help

参考文献

  • 以上内容主要参考自大神Transformer学习(四)—DeiT
  • DeiT | Training data-efficient image transformers & distillation through attention
  • DeiT:使用Attention蒸馏Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/890733.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

随机微分方程

应用随机过程|第7章 随机微分方程 见知乎:https://zhuanlan.zhihu.com/p/348366892?utm_sourceqq&utm_mediumsocial&utm_oi1315073218793488384

2023年4大收银系统软件排名(真实测评)

现在满大街的各种服装店、便利店、百货店、母婴店...... 每天都要处理大量的订单。 使用传统的人工开单记账,效率低下、客户体验差、而且容易出错,需要耗费很多时间来回对账; 聪明的零售店老板都已经开始使用收银系统软件,通过手…

线程同步条件变量

为何要线程同步 在线程互斥中外面解决了多线程访问共享资源所会造成的问题。 这篇文章主要是解决当多线程互斥后引发的新的问题:线程饥饿的问题。 什么是线程饥饿?互斥导致了多线程对临界区访问只能改变为串行,这样访问临界资源的代码只能…

手把手教你Element Plus前端导出Excel表格

目录 需求背景: 项目环境: 最终效果: 具体实现: 1、下载第三方依赖包: pnpm下载命令: npm下载命令: 2、查看是否下载成功: 3、引入需要使用的页面js中 4、编写导出表格函数…

【0基础入门Python笔记】一、python 之基础语法、基础数据类型、复合数据类型及基本操作

一、python 之基础语法、基础数据类型、复合数据类型及基本操作 基础语法规则基础数据类型数字类型(Numbers)字符串类型(String)布尔类型(Boolean) 复合数据类型List(列表)Tuple&…

【傅里叶级数与傅里叶变换】数学推导——2、[Part2:T = 2 π的周期函数的傅里叶级数展开] 及 [Part3:周期为2L的函数展开]

文章内容来自DR_CAN关于傅里叶变换的视频,本篇文章提供了一些基础知识点,比如三角函数常用的导数、三角函数换算公式等。 文章全部链接: 基础知识点 Part1:三角函数系的正交性 Part2:T2π的周期函数的傅里叶级数展开 P…

顺序程序设计

#include <iostream> #include <stdio.h>int main() {float f, c;f 64.0f;c (5.0 / 9) * (f - 32);printf("f%f \n \\c%f\n", f, c);return 0; }

poste邮件服务器搭建

关于poste poste是一款开源邮件服务软件&#xff0c;可以很方便的搭建&#xff1a;SMTP IMAP POP3 反垃圾邮件 防病毒 Web 管理 Web 电子邮件&#xff0c;支持以下特性。 SPF、DKIM、DMARC、SRS 的原生实现&#xff0c;带有简单的向导用于检测木马、病毒、恶意软件的防…

【Hyper-V】Windows11 家庭版怎么启用虚拟机Hyper-V

在电脑Windows11系统上启用虚拟机Hyper-V&#xff0c;打开 启用和关闭WIndows功能&#xff0c;找到其中一项Hyper-V&#xff0c;对于家庭版的系统用户来说&#xff0c;这个选项是没有的&#xff0c;接下来讲一讲怎么开启。 安装Hyper-V 创建一个文件名为Hyper-v.bat&#xff…

人工智能原理(9)

目录 一、人工神经元模型 1、概念 2、分类 二、感知器的结构 三、反向传播网络 四、自组织映射神经网络 五、离散HOPFIELD网络 1、离散Hopfield网络结构 2、离散Hopfield网络的稳定性 3、离散Hopfield网络学习算法 六、脉冲耦合神经网络 一、人工神经元模型 1、概念…

Centos7安装Docker及配置加速器地址

一、安装docker #1.yum 包更新到最新 yum update #2.安装需要的软件包&#xff0c;yum-util 提供yum-config-manager功能&#xff0c;另外两个是devicemapper驱动依赖的 yum install -y yum-utils device-mapper-persistent-data lvm2 #3.设置yum源 yum-config-manager --add…

【推荐】深入浅出学习Spring框架【上】

​ 目录 1.spring简介 1.1含义 1.2优点 2&#xff0c;Spring之IOC详解 2.1&#xff0c;控制反转是什么 2.2&#xff0c;控制反转案例 2.1.3案例前台测试 3、IoC的三种注入方式 3.1 构造方法注入 3.2 setter方法注入 3.3 接口注入&#xff08;自动装配&#xff09;用的最多&…

考研408 | 【操作系统】操作系统的概述

操作系统的概念和功能 导图 操作系统的功能和目标 1.作为系统资源的管理者 2.向上层提供方便易用的服务 3.作为最接近硬件的层次 操作系统的特征 导图 并发 并发VS并行 共享 并发和共享的关系 虚拟 异步 操作系统的发展和分类 导图 1.手工操作 2.批处理阶段--单道批处理系统…

JavaWeb 速通 EL 和 JSTL

目录 一、EL表达式 1.快速入门 : 1.1 基本介绍 1.2 入门案例 2.常用输出形式 : 2.1 创建JavaBean类 2.2 创建JSP文件 3.empty运算符 : 3.1 介绍 3.2 实例 4.EL对象 : 4.1 EL11个内置对象 4.2 域对象演示 4.3获取HTTP信息 二、JSTL标签库 1.基本介绍 : 2.core核心库常用标…

java 一个注解实现限流

在Java中&#xff0c;可以使用自定义注解来实现限流功能。 java import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; Retention(RetentionPolicy.RUNTIME…

【数据结构】 ArrayList简介与实战

文章目录 什么是ArrayListArrayList相关说明 ArrayList使用ArrayList的构造无参构造指定顺序表初始容量利用其他 Collection 构建 ArrayListArrayList常见操作获取list有效元素个数获取和设置index位置上的元素在list的index位置插入指定元素删除指定元素删除list中index位置上…

redis-集群(基础了解)

前言 为什么要做集群&#xff1f;解决什么问题&#xff1f; 1、避免单点故障&#xff0c;实现高可用&#xff1b;就需要数据沉余&#xff0c;通过沉余副本也是slave。 三种集群区别&#xff1f; 1、主从复制 复制策略 --> 全量复制 第一次连接到master&#xff0c;mast…

Mysql整理一 基础知识/常见面试题

一、基础概念 1. 索引 之前的文章已经写过了&#xff0c;比较细 数据库索引含义,类别,用法,创建方式_表结构加树形id和索引是为什么_马丁•路德•王的博客-CSDN博客 简单概括就是在表的某个列或者多个列或者联合表的时候加个索引&#xff0c;类似图书馆书本的索引编号&…

git环境超详细配置说明

一&#xff0c;简介 在git工具安装完成之后&#xff0c;需要设置一下常用的配置&#xff0c;如邮箱&#xff0c;缩写&#xff0c;以及git commit模板等等。本文就来详细介绍些各个配置如何操作&#xff0c;供参考。 二&#xff0c;配置步骤 2.1 查看当前git的配置 git conf…

pyqt5 第一个程序

import sys from PyQt5.QtWidgets import QApplication, QWidgetif __name__ __main__:# 创建 QApplication 实例app QApplication(sys.argv)# 创建一个主窗口w QWidget()# 设置大小w.resize(400, 200)# 设置窗口标题w.setWindowTitle(第一个程序)# 显示窗口w.show()# 固定写…