DeiT详解:知识蒸馏的Transformer

news2024/11/23 9:40:07

DeiT详解:知识蒸馏的Transformer

  • 0. 引言
  • 1. ViT
  • 2. DeiT
    • 2.1 知识蒸馏
      • 2.1.1 提出背景
      • 2.1.2 理论原理
    • 2.2 DeiT模型
  • 3. 总结

0. 引言

针对 ViT 需求数据量大、运算速度慢的问题,Facebook 与索邦大学 Matthieu Cord 教授合作发表 Training data-efficient image transformers(DeiT) & distillation through attentionDeiT知识蒸馏的策略与 ViT 相结合,性能与最先进的卷积神经网络(CNN)可以抗衡。

论文名称:Training data-efficient image transformers & distillation through attention
论文地址:https://arxiv.org/abs/2012.12877
代码地址:https://github.com/facebookresearch/deit

1. ViT

提到 DeiT ,就不提不提及 ViT 。这里对 ViT 进行简要介绍来帮助大家初步了解 ViT
ViT 模型将 Transformer 模型应用在了 CV 领域,并取得了突出的成果。在这里插入图片描述
在标准的 Transformer 中,模型仅能处理 1D 数据。为了处理 2D 图像,作者首先将图片数据 X ∈ R H × W × C X\in R^{H\times W \times C} XRH×W×C 按照 patch_size 进行切分并进行一维展平,得到数据 X ∈ R N × ( P 2 × C ) X\in R^{N\times (P^2\times C)} XRN×(P2×C) 。其中, P P P 表示 patch_size N N N 表示图片被切分为多少块,即 N = H × W P 2 N=\frac{H\times W}{P^2} N=P2H×W 。然后,这批数据经过线性变换后与原始图像的位置编码进行合并(并在首部添加类别编码 class embedding)。随后,合并后的数据输入到Transformer Encoder模块。最后经过MLP模型得到输出的类别(MLP模型包含两个具有GELU非线性的层)。总结为公式:
z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; ⋅ ⋅ ⋅ ; x p N E ] + E p o s ;          E ∈ R ( P 2 ⋅ C ) × D ; E p o s ∈ R ( N + 1 ) × D ( 1 )   z ℓ ′ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 ;                               ℓ = 1... L ( 2 )          z ℓ = M L P ( L N ( z ℓ ′ ) ) + z ℓ ′ ;                              ℓ = 1... L ( 3 )             y = L N ( z L 0 )                                                            ( 4 ) z_0 = [x_{class}; x^1_pE; x^2_pE; · · · ; x^N_pE] + E_{pos}; \ \ \ \ \ \ \ \ E\in R^{(P^2·C)×D}; E_pos \in R^{(N+1)×D} (1) \\\ z^′_ℓ = MSA(LN(z_{ℓ−1})) + z_{ℓ−1}; \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ ℓ= 1 ... L (2) \ \ \ \ \ \ \\\ z_ℓ = MLP(LN(z^′_ℓ )) + z^′_ℓ ; \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ ℓ= 1 ... L (3) \\\ \ \ \ \ \ \ \ \ \ \ y = LN(z^0_L) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4) z0=[xclass;xp1E;xp2E;⋅⋅⋅;xpNE]+Epos;        ER(P2C)×D;EposR(N+1)×D(1) z=MSA(LN(z1))+z1;                             =1...L(2)       z=MLP(LN(z))+z;                            =1...L(3)           y=LN(zL0)                                                          (4)
其中, x c l a s s x_{class} xclass 就是上文说的class embedding,即 x c l a s s = z 0 0 x_{class}=z^0_0 xclass=z00 ,其在Transformer编码器输出 ( z L 0 ) (z^0_L) (zL0) 的状态作为图像表示 y y y D D D 表示线性映射维度; L L L 表示Transformer输出维度。

注意:这里class embedding的作用是什么呢?
根本原因Transformer输入为一系列的patch embedding,输出也是同样长的序列patch feature,但是最后要总结为一个类别的判断。而class embedding作为一个Transformer的判断,在训练过程中不断汇总被分割图片的特性,进而得到一个最终分类结果。
具体而言:训练的时候,class tokenembedding随机初始化并与pos embedding相加。在训练过程中,随着网络的训练不断更新,它能够编码整个数据集的统计特性;同时,该token对所有其他token上的信息做汇聚(全局特征聚合),并且由于它本身不基于图像内容,因此可以避免对sequence中某个特定token的偏向性;最后,对该token使用固定的位置编码能够避免输出受到位置编码的干扰

2. DeiT

为了方便理解 DeiT 模型,首先介绍一下知识蒸馏的概念。

2.1 知识蒸馏

知识蒸馏整体性而言就是当模型训练完成后,可以将教师网络学习到的信息压缩到学生网络中,从而达到降低模型规模的目的。

2.1.1 提出背景

虽然在一般情况下,我们不会去区分训练和部署使用的模型,但是训练和部署之间存在着一定的不一致性。在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去,常见的瓶颈如下:

  • 推理速度和性能慢
  • 对部署资源要求高(内存,显存等)

在部署时,对延迟以及计算资源都有着严格的限制。因此,模型压缩(在保证性能的前提下减少模型的参数量)成为了一个重要的问题,而“模型蒸馏”属于模型压缩的一种方法。

2.1.2 理论原理

知识蒸馏使用的是 Teacher—Student 模型,其中 Teacher 是“知识”的输出者,Student 是“知识”的接受者。知识蒸馏的过程分为2个阶段:

  1. 原始模型训练:训练 “Teacher模型”,简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入 X X X, 其都能输出 Y Y Y,其中 Y Y Y经过softmax的映射,输出值对应相应类别的概率值。
  2. 精简模型训练: 训练"Student模型",简称为Net-S,它是参数量较小模型结构相对简单的单模型。同样的,对于输入 X X X,其都能输出 Y Y Y Y Y Y 经过softmax映射后同样能输出对应相应类别的概率值。

2.2 DeiT模型

DeiT 模型中,首先需要一个强力的图像分类模型作为teacher model。然后,引入了一个 Distillation Token,然后在 self-attention layers 中跟 class tokenpatch tokenTransformer 结构中不断学习。Class token的目标是跟真实的label一致,而Distillation Token是要跟teacher model预测的label一致。蒸馏过程如下图所示。
在这里插入图片描述
在蒸馏过程中,不同的蒸馏方案会得到不同的结果。DeiT 模型主要的蒸馏方案包括以下两种:

  1. 软蒸馏(Soft distillation):使教师模型的softmax学生模型的softmax之间的Kullback-Leibler分歧最小化
    L g l o b a l = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) L_{global} =(1−λ)L_{CE} (ψ(Z_s ),y)+λτ^2 KL(ψ(Z_s /τ),ψ(Z_t /τ)) Lglobal=(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))其中, Z s Z_s Zs Z t Z_t Zt 分别是 student modelteacher model 的对数, τ τ τ 表示蒸馏温度, λ λ λ 表示 K L KL KLKullback-Leibler散度损失)损失和交叉熵( L C E L_{CE} LCE )的系数 , y y y 表示真实值标签, ψ ψ ψ 表示 Softmax函数。
  2. 硬蒸馏(Hard-label distillation):
    L g l o b a l h a r d D i s t i l l = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 L C E ( ψ ( Z s ) , y t ) L_{global}^{hardDistill} = \frac{1}{2} L_{CE}(ψ(Z_s),y)+\frac{1}{2}L_{CE}(ψ(Z_s),y_t) LglobalhardDistill=21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt)值得注意的是,Hard Label 也可以通过标签平滑技术 (Label smoothing) 转换成Soft Label,其中真值对应的标签被认为具有 1- esilon 的概率,剩余的 esilon 由剩余的类别共享。

3. 总结

DeiT 模型(8600万参数)仅用一台 GPU 服务器在 53 hours train,20 hours finetune,仅使用 ImageNet 就达到了 84.2 top-1 准确性,而无需使用任何外部数据进行训练,性能与最先进的卷积神经网络(CNN)可以抗衡。其核心是提出了针对 ViT 的教师-学生蒸馏训练策略,并提出了 token-based distillation 方法,使得 Transformer 在视觉领域训练得又快又好。
如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/599924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工控机设备安全-系统加固分析

工控设备安全现状 工业控制系统是支撑国民经济的重要设施,是工业领域的神经中枢。现在工业控制系统已经广泛应用于电力、通信、化工、交通、航天等工业领域,支撑起国计民生的关键基础设施。 随着传统的工业转型,数字化、网络化和智能化的工…

倾斜摄影三维模型数据的高程偏差修正的几何纠正技术方法探讨

倾斜摄影三维模型数据的高程偏差修正的几何纠正技术方法探讨 倾斜摄影是一种先进的数字摄影技术,可以生成高分辨率、高精度的三维模型数据。然而,在倾斜摄影中,由于相机的倾斜角度和地形的高程差异,可能会出现高程偏差问题。为了…

Java性能权威指南-总结4

Java性能权威指南-总结4 Java性能调优工具箱操作系统的工具和分析CPU运行队列磁盘使用率网络使用率 Java监控工具基本的VM信息 Java性能调优工具箱 操作系统的工具和分析 CPU运行队列 快速小结 检查应用性能时,首先应该审查CPU时间。优化代码的目的是提升而不是…

树莓派初体验:开机啦

感谢大佬的赞助,这玩意是真的贵哇,呜呜呜呜呜呜,根本买不起 一、烧录系统 需要:SD卡(推荐16G)、读卡器(推荐高速读卡器) 进入官网:https://www.raspberrypi.com/softwa…

《MYSQL必知必会》读书笔记2

哈夫曼树的学习: http://t.csdn.cn/XJhUI 创建计算字段 字段:基本上与列的意思相同(数据库列一般称为列,而字段通常用于计算字段连接上) 拼接字段 拼接:将值联结到一起构成单个值 把两个结拼接起来&a…

【2023最叼教程】Appium自动化环境搭建保姆级教程

APP自动化测试运行环境比较复杂,稍微不注意安装就会失败。我见过不少朋友,装了1个星期,Appium 的运行环境还没有搭好的。 搭建环境本身不是一个有难度的工作,但是 Appium 安装过程中确实存在不少隐藏的比较深的坑,如果…

编程(38)----------计算机的部分原理

本篇主要总结一些计算机的理论部分. 计算机在发展历程中,无论是最早的巨无霸机器,还是现在小到可以拿在手中的掌机.只要其本质上是计算机,在最基础的结构上,都是以冯诺依曼体系所构建的. 冯诺依曼体系大致将计算机分为几个最重要的部分:输入,输出,中央处理器,存储设备.也就是…

Meta Quest 3发布:超越虚拟现实全新境界

2023年6月2日凌晨,全球领先的虚拟现实技术公司Meta隆重推出了Meta Quest 3无线头戴式显示器。这款全新设计的头戴设备从内到外焕然一新,为用户提供了全方位的体验。 借助全新一代骁龙芯片,Meta Quest 3拥有比Quest 2更高两倍的GPU处理能力&am…

VanillaNet详解:极简的网络模型

VanillaNet详解:极简的网络模型 0. 引言1. 网络结构2. 如何提高简单网络的非线性2.1 深度训练策略2.2 基于级数启发的激活函数3. 实验4. 代码解析总结 0. 引言 深度学习模型架构越复杂越好吗? 自过去的几十年里,人工神经网络取得了显著的进…

chatgpt赋能python:Python在硬件开发中的作用

Python在硬件开发中的作用 随着物联网的快速发展,越来越多的硬件设备需要与互联网连接。Python在硬件开发过程中扮演着重要的角色。 Python的优势 作为一种高级编程语言,Python有以下几个优势: 简单易学:Python的语法简洁清晰…

chatgpt赋能python:Python做网页可以直接访问吗?

Python做网页可以直接访问吗? Python作为一门功能强大的编程语言,近年来在Web开发中也越来越受欢迎。很多人或企业都采用Python来开发网站和网页,那么问题来了,Python做的网页能否直接被搜索引擎访问和索引呢? Pytho…

MySQL5-事务隔离级别和锁机制

❤️ 个人主页:程序员句号 🚀 支持水滴:点赞👍 收藏⭐ 留言💬关注 🌸 订阅专栏:MySQL性能调优 原创博文、基础知识点讲解、有一定指导意义的中高级实践文章。 认真或有趣的技术分享。 该专栏陆…

【数据结构】数据结构与算法基础 课程笔记 第七章 查找

🚀Write In Front🚀 📝个人主页:令夏二十三 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:【数据结构】 💬总结:希望你看完之后,…

Emacs之解决gtags -i --single-update占用率100%卡死问题(一百零六)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

python --自动化测试UiAutomator2

安装adb 安装adb后使用命令 adb devices 出现下图即可; 安装python依赖(uiautomator2,weditor) pip install uiautomator22.16.23 weditor0.6.8 -i https://pypi.doubanio.com/simple# 在手机上安装 atx-agent 应用 # 安装apk服务到手机上 python -m uiautomator2 init脚本…

基于 Docker 部署 Mysql8.0.27_单机_主从复制

文章目录 单机部署集群部署master 部署slave 部署错误记录 单机部署 通过 dockerhub 或 docker search 查找镜像。拉取 mysql 镜像。 docker pull mysql:8.0.27创建挂载目录,并赋予权限。 mkdir -p /var/docker_data/mysql/data mkdir -p /var/docker_data/mysql/co…

一些关于c++的琐碎知识点

目录 bool强转 const构成重载:const修饰*p 移动构造 new int (10)所做的四件事 this指针---为什么函数里面需要this指针? .和->的区别 new创建对象 仿函数 new和malloc的区别 c系统自动给出的函数有 delete和delete[ ]区别何在 检查有没有析构函数 e…

六一,用前端做个小游戏回味童年

#【六一】让代码创造童话,共建快乐世界# 文章目录 📋前言🎯简简单单的弹球游戏🎯代码实现📝最后 📋前言 六一儿童节。这是属于孩子们的节日,也是属于我们大人的节日(过期儿童&…

chatgpt赋能python:**Python免费编辑器:提高开发效率和便捷性**

Python 免费编辑器:提高开发效率和便捷性 Python 编程语言已经成为了越来越多开发者的首选。这是因为 Python 语言非常直观易懂,同时也拥有庞大的第三方开源库,方便开发人员快速实现项目功能。Python 编程之所以如此受欢迎,除了这…

Java基础编程

Java入门 1. JDK的安装目录介绍 目录名称说明bin该路径下存放了JDK的各种工具命令。javac和java就放在这个目录。conf该路径下存放了JDK的相关配置文件。include该路径下存放了一些平台特定的头文件。jmods该路径下存放了JDK的各种模块。legal该路径下存放了JDK各模块的授权文…