TinyViT: 一种高效的蒸馏方法

news2024/11/24 16:57:34

目录

    • 背景
    • 方法大意
      • 快速预训练蒸馏(Fast Pretraining Distillation, FPD)
        • 如何实现快速
        • 三个细节深入理解FPD
      • 模型架构
      • 训练trick
        • 预训练参数配置(Imagenet21k-pretraining)
        • finetuning 参数配置(Imagenet-1k)
    • 消融实验
      • **Q: 数据是否越多越好?**
      • **Q: logitK的数量是否越多越好?**
      • **Q: distill pretrain + finetune架构真的有效吗?**
      • Q: 预训练的Tearcher 模型对student的影响大吗?
    • 扩展思考
        • What are the underlying factors limiting small models to fit large data?
        • Why can distillation improve the performance of small models on large datasets?
    • 参考文献

背景

最近,视觉Transformer(ViT)由于其出色的模型能力而在计算机视觉领域引起了极大的关注。然而,大多数流行的ViT模型存在大量参数的问题,这限制了它们在资源有限的设备上的应用。

方法大意

在这里插入图片描述

快速预训练蒸馏(Fast Pretraining Distillation, FPD)

如何实现快速

常规的pretrain with distillation 非常慢并且成本高。因为teacher网络的每次推理都占用了大量的计算资源(GPU),并且也需要推理时间。为了解决efficient和costly的问题,作者提出了一个fast pretraining distillation 框架。这个框架是如何做的呢?

他在teacher的预训练阶段存储两个信息:一者输入图片的增强 A A A,二者tercher对该图片的预测概率向量 y ^ = T ( A ) \hat{y}=T(A) y^=T(A),记做 ( A , T ( A ) ) (A, T(A)) (A,T(A))。由于数据增强中有随机数,这会导致即使采用同样的增强参数,所获得得增强图片也不一致。因此 ( A , T ( A ) ) (A, T(A)) (A,T(A))需要在不同的迭代位置都保存。

在训练阶段,学生网络会读取teacher网络对同一图片的增强参数,对图片进行增强,优化目标为:
L = C E ( y ^ , S ( A ( x ) ) ) \mathcal{L} = CE(\hat{y}, S(\mathcal{A}(x))) L=CE(y^,S(A(x)))
其中 A ( x ) \mathcal{A}(x) A(x)是增强后的图片, S ( A ( x ) ) S(\mathcal{A}(x)) S(A(x))是学生模型的预测概率分布, y ^ \hat{y} y^是teacher预测的概率分布。 C E CE CE为交叉熵损失。可见这个框架是label-free的,学生网络的训练不依赖标签。因此用该方法可利用大量互联网无标注的图片。

三个细节深入理解FPD

  • 蒸馏阶段没有用GT的标签信息

作者发现,distillation with GT会导致性能下降。作者认为主要的原因可能是imagenet21k的有些标签间的类间差异很小,例如椅子和家具,马和动物,因此基于one-hot的GT标签不能很好的表征物体的类别信息。

  • y ^ \hat{y} y^进行了稀疏编码节约存储空间

对于imgnet21k来说总计有21841个标签,每个向量有21841维,非常大。作者的处理方式是,只存储向量中topk的元素的数值和位置,这大大降低了存储内存。在训练阶段,其它位置基于label smoothing的方式进行补充。

  • 优化数据增强的编码方式

比如一次的数据增强中包含,crop的坐标,旋转的角度等,每一次迭代中对同一图片的增强可能都不一样。直接存储是memory-inefficient的。作者采用了一种编码函数来解决这个问题。比如数据增强参数为d, 为编码的参数。训练过程对该参数进行解码 d = ϵ ′ ( d 0 ) d = \epsilon'(d_0) d=ϵ(d0)

模型架构

作者采用一种渐进式模型压缩方法(processive model contraction approach)实现从一个大模型中剪枝成小模型[1,2]。收缩因子有6个:embedding的维度、每个stage中block的个数、最后3个stage 的window size、MbConv block的通道扩展率、transformer中MLP的通道扩展率、多头attention,每个头的维度。

模型架构简要描述:

  • 类似swin-transformer同样有4个stage,每个stage都会下采样
  • patch embedding 采用了两个kernle为3补偿为2的卷积。
  • stage1 采用MBConv[3],剩下三个stage都是transformer with window attention.
  • 各个stage都用了残差连接。
  • 激活函数都用GELU。
  • 卷积的采用BN,线性层采用LN[4]
    在这里插入图片描述

训练trick

预训练参数配置(Imagenet21k-pretraining)

epoch90
optimizerAdamW(weight-decay 0,01)
lr0.002, cosine scheduler
Warm-up5-epoch
Batch-size4096
Gradient-clipMax-norm of 5
Stochastic depth ratio0 for TinyViT-5/11M, 0,1 for TinyViT 21M
Data-augRandom resize, crop, horizontal-flip, color jitter, random erasing, RandomAugment, Mixup, CutMix

finetuning 参数配置(Imagenet-1k)

epoch30
optimizerAdamW(weight-decay 10^-8)
lr0.0005, for each layer is decayed by the rate 0.8 form output to input
Warm-up5-epoch, cosine learning rate
bnfrozon
Batch-size1024
Gradient-clipMax-norm of 5
Stochastic depth ratio0 for TinyViT-5/11M, 0,1 for TinyViT 21M
Data-augRandom resize, crop, horizontal-flip, color jitter, random erasing, RandomAugment, Mixup, CutMix

消融实验

Q: 数据是否越多越好?

A: 模型的性能随着数据量的增加而呈现加速度不断降低的增大,同样的数据量,最终的性能受限于模型的大小。

Q: logitK的数量是否越多越好?

A: 保存的logitK的的数量不是越多越好,因为teacher模型的logit也可能有部分噪声,选取topk的策略不仅可以降低存储成本,也能起到一定的降噪作用。(作者在imagenet1k取得是10, imagenet21k取的是100)

Q: distill pretrain + finetune架构真的有效吗?

A: 从实验来看是有效的,不同的数据规模、不同的基础模型均能得到一定的提升。因此distill pretrain + fintuning可以作为一种较为通用的范式。

Q: 预训练的Tearcher 模型对student的影响大吗?

A: 更好的teacher模型能训练得到更强student模型,但好的teacher模型往往很大,会带来较大的时间消耗。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

扩展思考

What are the underlying factors limiting small models to fit large data?

主要是由于数据集中的Hardsample导致的

标签错误、由于一张图片中有多个目标导致相似的图片有不同的标签。Imagenet21k大约有10%的困难样本。对于小模型来说,拟合这些困难样本较为吃力以至于训练准确率比起大模型低得多。

作者提出两个方法解决这个问题:1. 采用大规模数据集训练的预训练模型(Florence)在imagenet21k微调,找出哪些大模型在top5都识别错误的样本(这个操作移除了2M个图片)。2. 以大模型作为teacher,采用文中提出的蒸馏方法在imagenet-21k训练小模型。

上述两个方法的收益:1. 方法一能够提升0.7%的性能. 2. 方法2能提升1.7%的性能。

Why can distillation improve the performance of small models on large datasets?

作者认为核心原因是teacher模型能够将类别间的关系注入给学生模型。对于常规的分类任务,一张图片只对应一个类别,但忽视了类别与类别之间联系,而论文提出的distillation是根据概率向量进行优化,概率向量反映了该图片在各个类别上的分布。

参考文献

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql ——基本约束以及语法 以及 Dbeaver基本使用

1. 规约 说到约束,就不得不想到命名规范,跟java一样,mysql也有一套自己的命名要求 库名尽量与业务名称一致,比如这是一个办公系统,你可以命名 将数据库命名为office, 多个单词组成全小写 例如:officeoa 表…

《Linux操作系统编程》第一章 操作系统引论:了解操作系统的发展、特征、功能以及操作系统结构

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

如何查看Navicat已保存数据库连接的密码?

此教程的背景:数据库密码忘记了,但是在Navicat连接过且目前能连接上的状态! 1.导出数据库连接 connections.ncx 文件 选择你要导出密码的数据库连接,切记要勾上导出密码 2.使用文本编辑工具打开导出的connections.ncx 文件 找到…

Android SDK安全加固问题与分析

作者 | 百度APP技术平台 导读 在移动互联网快速发展的背景下,保护Android应用程序的安全性和知识产权变得尤为重要。为了防止恶意攻击和未授权访问,通常采用对dex文件进行代码加固来保护应用程序。随着Android加固技术经过动态加载、不落地加载、指令抽取…

SSM整合 配置文件

<?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://ma…

宇宙是一位高位截瘫的病人

【一点小爱好&#xff0c;喜欢了解一些天文和物理】 宇宙中最快的速度——光速。 在真空中可以达到每秒30万千米&#xff0c;这个速度是什么概念呢&#xff1f;光一秒钟就可以绕地球7.5圈&#xff0c;一秒钟就可以从地球到月球。 但这个速度还是太慢了。太阳发出的光要整整走…

vue中设置花样字体

首先在assets中新建一个文件夹 font 然后再在字体网中选择想要的字体下载放入font文件夹中 字体网&#xff1a;字体_中文字体 | 英文字体 | 书法字体 免费下载 - 爱给网 Fonts2u.com 然后再在style文件夹中创建一个 fontStyle.scss文件 再在main.js文件中注册就可以全局使用…

市场监管总局:7月1日起加大合同范本应用,契约锁助力规范签

近日&#xff0c;国家市场监管总局令第77号公布&#xff0c;并于7月1日起正式施行。总局结合近年来合同行政监管的新形势、新情况、新问题修订出台了《合同行政监督管理办法》&#xff08;以下简称“办法”&#xff09;&#xff0c;明确&#xff1a;加强合同行政监管执法&#…

selenium入门超详细教程——网页自动化操作

selenium入门超详细教程——网页自动化操作 使用 Selenium 通过 Python 自动发布 Facebook 帖子 Selenium基础 — Selenium中的expected_conditions模块&#xff08;一&#xff09; Upload image on Facebook Marketplace with selenium (python)

SpringBoot整合OSS存储

Spring Boot整合OSS存储 一、OSS存储介绍二、准备工作二、添加依赖配置OSS连接信息properties文件yml文件 创建OSS客户端实现文件上传实现文件下载控制器实现文件上传和下载接口 三、Demo 本文介绍如何在Spring Boot应用程序中整合OSS&#xff08;对象存储服务&#xff09;来实…

树莓派使用非树莓派官方的IMX219和IMX477 摄像头配置

问题&#xff1a; sudo libcamera-hello -t 0 ERROR: the system appears to be configured for the legacy camera stack解决办法&#xff1a; 树莓派4B 查询系统型号&#xff1a; cat /etc/os-release 结果&#xff1a; RETTY_NAME"Debian GNU/Linux 11 (bullseye)…

docker部署rabbitmq

拉取镜像 我部署的是3.8版本的 docker pull rabbitmq:3.8 启动容器 docker run -d --hostname my-rabbit --name rabbitmq --restart always -e RABBITMQ_DEFAULT_USERadmin -e RABBITMQ_DEFAULT_PASSadmin -p 15672:15672 -p 5672:5672 --privilegedtrue rabbitmq:3.8 启…

(一)Qt 将某控件、图案绘制在最前面的方法,通过QGraphicsScene模块实现

系列文章目录 通过Qt实现手势识别控制软件操作相关系列技术方案 &#xff08;一&#xff09;Qt 将某控件、图案绘制在最前面的方法&#xff0c;通过QGraphicsScene模块实现 &#xff08;二&#xff09;Qt QGraphicsScene模块实现圆点绘制在所有窗体的最前方&#xff0c;实现圆…

深度学习100例 | 第37天:表情识别(K同学啊原创出品)

&#x1f3e1; 我的环境&#xff1a; 语言环境&#xff1a;Python3.10.11编译器&#xff1a;Jupyter Notebook深度学习框架&#xff1a;TensorFlow2.4.1显卡&#xff08;GPU&#xff09;&#xff1a;NVIDIA GeForce RTX 4070 &#x1f942; 相关教程&#xff1a; 编译器教程&…

「2024」预备研究生mem-比与比例(下)

一、比与比例&#xff08;下&#xff09; 好方法&#xff1a; 不错 二、课后题 三、每日一练

10 月发布,Ubuntu 23.10 已升级到 Linux Kernel 6.3 内核

导读Canonical 于近日宣布&#xff0c;代号为 Mantic Minotaur 的 Ubuntu 23.10 发行版本已升级基于 Linux Kernel 6.3 内核。 Canonical宣布&#xff0c;代号为 Mantic Minotaur 的 Ubuntu 23.10 发行版本已升级基于 Linux Kernel 6.3 内核。 Ubuntu 23.10 于今年 4 月下旬进入…

6、Redis事务、管道、发布订阅(了解)

1、Redis事务 是什么&#xff1f; 可以一次执行多个命令&#xff0c;本质是一组命令的集合。一个事务中的所有命令都会序列化&#xff0c;按顺序地串行化执行而不会被其它命令插入&#xff0c;不许加塞 一个队列中&#xff0c;一次性、顺序性、排他性的执行一系列命令 Redis…

【C51】基于51单片机无线遥控门铃电路的设计与实现

摘 要 20世纪以来&#xff0c;科技发展步入了信息时代&#xff0c;科技发展的目的就是为了服务人民&#xff0c;让我们可以拥有更好的生活。居住环境和质量也愈加重要&#xff0c;智能家居就是一次革新&#xff0c;给生活方面带来了巨大的改善&#xff0c;本课题研究的无线遥控…

C#核心知识回顾——4.object中的方法、String、StringBuilder

1.object中的方法 object中的静态方法&#xff1a; 静态方法Equals判断两个对象是否相等&#xff1a; 最终的判断权&#xff0c;交给左侧对象的Equals方法&#xff0c; 不管值类型引用类型都会按照左侧对象Equals方法的规则来进行比较 静态方法Reference Equals&#xf…

分布式操作系统期末复习(辽宁大学王龙主讲)

目录 一、题目 1.1 简答题 1.2 综合题 二、题目答案 2.1 简答题目答案 2.2 综合题目答案 三、期末题型分值分布 3.2 题型和分值 一、题目 1.1 简答题 1什么是中间件 22.1&#xff08;22年期末考试第一题&#xff09; 2 什么是名称解析 3 描述一下客户和服务器之间使…