Deit:知识蒸馏与vit的结合 学习笔记(附代码)

news2024/11/15 17:42:49

 论文地址:https://arxiv.org/abs/2012.12877

代码地址:GitHub - facebookresearch/deit: Official DeiT repository

1.是什么?

DeiT(Data-efficient Image Transformer)是一种用于图像分类任务的神经网络模型,它基于Transformer架构。这个模型的主要目标是在参数较少的情况下实现高效的图像分类。相比于传统的卷积神经网络(CNN),DeiT采用了Transformer的注意力机制,使其能够更好地捕捉图像中的全局关系。
以下是DeiT模型的一些关键特点和组成部分:

  1. 1.Transformer 架构: DeiT采用了Transformer的架构,这是一种自注意力机制的模型。这种架构在自然语言处理任务中取得了显著的成功,DeiT将其成功地应用于图像分类领域。
  2. 2.小模型参数: 为了提高数据效率,DeiT设计为具有相对较少的参数。这使得模型在训练和推理时需要更少的计算资源。
  3. 3.Knowledge Distillation: DeiT使用知识蒸馏(Knowledge Distillation)的方法进行训练。这意味着它通过从一个大型预训练模型中传递知识来训练,而不是从头开始训练。这有助于在资源受限的情况下实现更好的性能。
  4. 4.Patch Embedding: 与传统的卷积层不同,DeiT使用了补丁嵌入(Patch Embedding)来将图像分割成小块,然后对这些块进行变换。
  5. 5.Positional Embeddings: 由于Transformer不涉及卷积层,它需要一种处理输入序列的方式。在DeiT中,位置嵌入(Positional Embeddings)用于为模型提供输入中元素的相对位置信息。

总体而言,DeiT是一个旨在通过Transformer的优势实现图像分类的轻量级模型,适用于数据受限的情况。通过知识蒸馏和小模型参数,它在参数较少的情况下达到了令人满意的性能。

2.为什么?

Transformer的输入是一个序列(Sequence),ViT 所采用的思路是把图像分块(patches),然后把每一块视为一个向量(vector),所有的向量并在一起就成为了一个序列(Sequence),ViT 使用的数据集包括了一个巨大的包含了 300 million images的 JFT-300,这个数据集是私有的,即外部研究者无法复现实验。而且在ViT的实验中作者明确地提到:

意思是当不使用 JFT-300 大数据集时,效果不如CNN模型。也就反映出Transformer结构若想取得理想的性能和泛化能力就需要这样大的数据集。DeiT 作者通过所提出的蒸馏的训练方案,只在 Imagenet 上进行训练,就产生了一个有竞争力的无卷积 Transformer。

3.怎么样?

在 DeiT 模型中,首先需要一个强力的图像分类模型作为teacher model。然后,引入了一个 Distillation Token,然后在 self-attention layers 中跟 class token,patch token 在 Transformer 结构中不断学习。Class token的目标是跟真实的label一致,而Distillation Token是要跟teacher model预测的label一致。蒸馏过程如下图所示。

3.1知识蒸馏

知识蒸馏(Knowledge Distillation)是一种模型训练的技术,旨在通过传递一个大型教师模型的知识来训练一个小型学生模型。这个方法的目标是使得学生模型能够获得与教师模型相似的性能,同时减少学生模型的复杂性和计算成本。
以下是知识蒸馏的关键思想和步骤:

  1. 1.教师模型: 首先,有一个在任务上表现良好的大型教师模型。这个模型通常拥有更多的参数和计算能力,以便更好地捕捉任务的复杂性和结构。
  2. 2.软目标(Soft Targets): 在传统的监督学习中,模型通常以硬标签(one-hot编码的标签)作为目标进行训练。而在知识蒸馏中,使用了软目标,这是由教师模型输出的概率分布。这样的软目标包含了关于样本的更丰富信息,使得学生模型可以学到更多的任务相关知识。
  3. 3.温度参数: 软目标的概率分布可以通过温度参数进行调节。较高的温度使概率分布更平滑,有助于学生模型更好地学到教师模型的知识。
  4. 4.学生模型: 有了教师模型和软目标,接下来就是训练学生模型。学生模型通常是一个比教师模型简化的小型模型,可以在资源受限的环境中更轻松地进行推理。
  5. 5.蒸馏损失: 为了引导学生模型学习教师模型的知识,引入了蒸馏损失。这个损失函数用于比较学生模型的输出概率分布和教师模型的输出概率分布,促使学生模型模仿教师模型的行为。

知识蒸馏的优势在于,通过传递教师模型的知识,可以在小型模型上实现接近教师模型性能的效果。这对于移动设备、嵌入式系统或其他计算资源受限的环境中的部署非常有用。

具体方法:

第一步是训练Net-T;第二步是在高温 T 下,蒸馏 Net-T 的知识到 Net-S。

训练 Net-T 的过程很简单,而高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到:

Deit 中使用 Conv-Based 架构作为教师网络,以 soft 的方式将归纳偏置传递给学生模型,将局部性的假设通过蒸馏方式引入 Transformer 中,取得了不错的效果。

3.2Distillation Token

Distillation Token 和 ViT 中的 class token 一起加入 Transformer 中,和class token 一样通过 self-attention 与其它的 embedding 一起计算,并且在最后一层之后由网络输出。

而 Distillation Token 对应的这个输出的目标函数就是蒸馏损失。Distillation Token 允许模型从教师网络的输出中学习,就像在常规的蒸馏中一样,同时也作为一种对class token的补充。

3.3代码实现

class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2

参考:ViT、Deit这类视觉transformer是如何处理变长序列输入的?

DeiT:使用Attention蒸馏Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1380767.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试|教你使用Python绘制正多边形

简介 绘制正多边形是Python图形编程的基本任务之一。在本文中,我将为你提供一个使用Python绘制正多边形的详细教程,并提供一个示例代码。我们将使用Python的Turtle库来进行绘制。 步骤1:导入Turtle库 我们需要先安装好Python环境&#xff…

PyTorch Tutorial

本文作为博客“Transformer - Attention is all you need 论文阅读”的补充内容,阅读的内容来自于 https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#recommended-preparation 建议的准备流程。 Deep Learning with PyTorch: …

Linux第21步_取消鼠标中键的复制粘贴功能

在ubuntu18.04操作系统中,选中文本后,若按下鼠标中键,就可以执行复制粘贴,相当于 CtrlshiftC 后又按了 CtrlshiftV。在Linux系统中,基本上都是这么配置的。在windows系统中,我们习惯用Ctrl-C复制&#xff0…

POSTGRESQL中ETL、fdw的平行替换

POSTGRESQL中ETL、fdw的平行替换 01、简介 “ 在我前两次的文章中,说到postgresql对于python的支持,其实很多功能也就可以封装进入的postgresql数据库中去。比如fdw、etl等,本文将以此为叙述点,进行演示展示” 在postgresql数据…

好用的便签有哪些?windows便签工具在哪打开?

每当我8点准时上班,在等待电脑开机的过程,我都会习惯性地思考整理今天要晚上的任务,列出所要完成的待办事项。随着每一项任务的清晰呈现,我的心情也逐渐明朗起来。当然了,这个时候,我迫切需要一款好用的便签…

大数据赋能电竞出海企业发展

近几年电竞行业发展迅速,我国单2022年新增近4万家电竞相关企业,竞争十分激烈。中国电竞市场规模在全球占比19%左右,海外有巨大的增量市场,特别是东南亚、中南亚和拉丁美洲是电竞市场增长最快的地区,在2020至2025年期间…

【微信小程序独立开发2】授权登录 上

前言:这一节设想完成的功能为进入小程序后请求授权信息,用户授权登录后,弹出宠物登记页面,并根据宠物类型播放背景音乐 小程序昵称头像在之前的版本获取规则为触发后弹出用户授权弹窗,授权后可直接获取用户头像和昵称&…

DCP文件传输的重要性与应用

在数字时代,文件传输已成为商业运作中不可或缺的一环。随着企业越来越多地采用云基础设施和服务,有效地在云和团队之间传输大文件和数据集变得至关重要。在这一背景下,数据复制协议(DCP)文件传输应运而生,引…

Web实战丨基于django+html+css+js的电子商务网站

文章目录 写在前面实验目标需求分析实验内容安装依赖库1.登陆界面2.注册界面3.电子商城界面4.其他界面 运行结果写在后面 写在前面 本期内容:基于DjangoHTMLCSSJS的电子商务网站 实验环境: vscode或pycharmpython(3.11.4)django 代码下载地址&#x…

【分布式技术】监控平台zabbix介绍与部署

目录 一、为什么要做监控? 二、zabbix是什么? 三、zabbix有哪些组件? ​编辑Zabbix 6.0 功能组件: ●Zabbix Server ●数据库 ●Web 界面 ●Zabbix Agent ●Zabbix Proxy ●Java Gateway 四、zabbix的工作原理&#xf…

GNSS差分码偏差(DCB)原理学习与数据下载地址

一、DCB原理 GNSS差分码偏差(DCB,Differential Code Bias)是由不同类型的GNSS信号在卫星和接收机不同通道产生的时间延迟(硬件延迟/码偏差)差异,按照频率相同或者不同又可以细分为频内偏差(例如…

PADS9.5 : 元件库绘制

元件库绘制 1、打开PADS LOGIC 软件 2、先开始元件的电参数 这理面我们只需要先关注: 门 ,就是当前画的元件有几个部分 示例:两个门:A、B 3、再开始编辑图形 选择创建2D线,绘制PARTA 外框 添加端点,就是接…

生态茶园建设方案——福建蜂窝物联

一、项目背景 为了进一步提高茶产业集约化、产业化发展水平,充分运用物联网、互联网等高新技术为产业赋能,加速推动安溪茶产业转型升级,县政府决定在安溪县推进“安溪智慧生态茶园项目”,并以茶叶重镇感德镇实施“安溪智慧生态茶园…

CRM-如何做好客户管理

客户是企业最重要的资源,也是客户360视图管理的主数据,企业的运转都是围绕客户来开展的,如何做好客户数据的管理是一门学问,也需要企业动态的调整战略。 客户分为企业客户(Account)与个人客户(…

图解智慧:数据可视化如何助你高效洞悉信息?

在信息爆炸的时代,数据扮演着越来越重要的角色,而数据可视化则成为解读和理解海量数据的得力工具。那么,数据可视化是如何帮助我们高效了解数据的呢?下面我就以可视化从业者的角度来简单聊聊这个话题。 无需深奥的专业知识&#x…

环信服务端下载消息文件---菜鸟教程

前言 在服务端,下载消息文件是一个重要的功能。它允许您从服务器端获取并保存聊天消息、文件等数据,以便在本地进行进一步的处理和分析。本指南将指导您完成环信服务端下载消息文件的步骤。 环信服务端下载消息文件是指在环信服务端上,通过调…

实用编程调试技巧

目录 一、调试的基本步骤 二、Debug和Release的介绍 三、Windows环境调试介绍 1.调试环境的准备 2.学会快捷键 最常用的几个快捷键: 断点应用举例: 3.调试的时候查看程序当前信息 (1&#xff09…

橘子学Spring01之spring的那些工厂和门面使用

一、Spring的工厂体系 我们先来说一下spring的工厂体系(也称之为容器),得益于大佬们对于单一职责模式的坚决贯彻,在十几年以来spring的发展路上,扩展出来大量的工厂类,每一个工厂类都承担着自己的功能(其实就是有对应的方法实现)…

redis高级篇之单线程和多线程

目录 1、redis的发展史 2、redis为什么选择单线程? 3、主线程和Io线程是怎么协作完成请求处理的? 4、IO多路复用 5、开启redis多线程 1、redis的发展史 Redis4.0之前是用的单线程,4.0以后逐渐支持多线程 Redis4.0之前一直采用单线程的主…

智慧农业大棚建设方案——福建蜂窝物联

一、项目背景 温室大棚在不适宜植物生长的季节,能提供生育期和增加产量,多用于低温季节喜温蔬菜、花卉、林木等植物栽培或育苗等。因此对种植作物生长环境的要求要精确的多。 大多数农户加温、浇水、通风等,全凭感觉。人感觉冷了就加温&#…