Visual Transformer (ViT)模型详解

news2025/1/21 1:04:05

1 Vit简介

1.1 Vit的由来

ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

Visual Transformer (ViT) 出自于论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在视觉领域的开篇之作。ViT模型是基于Transformer Encoder模型的。

1.2 Vit如何工作

我们知道Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。但是视觉领域处理的是图像数据,因此将Transformer模型应用到图像数据上面临着诸多挑战,理由如下:

  • 与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。

  • 如果按照处理文本的方式来处理图像,即逐像素处理的话,即使是目前的硬件条件也很难。

  • Transformer缺少CNNs的归纳偏差,比如平移不变性和局部受限感受野。

  • CNNs是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNNs更大。

  • Transformer无法直接用于处理基于网格的数据,比如图像数据。

为了解决上述问题,Google的研究团队提出了ViT模型,它的本质其实也很简单,既然Transformer只能处理序列数据,那么我们就把图像数据转换成序列数据就可以了呗。下面来看下ViT是如何做的。

1.3 ViT模型架构

我们先结合下面的动图来粗略地分析一下ViT的工作流程,如下:

  • 将一张图片分成patches

  • 将patches铺平

  • 将铺平后的patches的线性映射到更低维的空间

  • 添加位置embedding编码信息

  • 将图像序列数据送入标准Transformer encoder中去

  • 在较大的数据集上预训练

  • 在下游数据集上微调用于图像分类

 模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder
  • MLP Head(最终用于分类的层结构)

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。

对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

​​
MLP Head

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。


2 ViT工作原理

我们将上图展示的过程近一步分解为6步,接下来一步一步地来解析它的原理。如下图:

2.1 步骤1、将图片转换成patches序列

这一步很关键,为了让Transformer能够处理图像数据,第一步必须先将图像数据转换成序列数据,但是怎么做呢?假如我们有一张图片x\epsilon R^{H*W*C},patch大小为p,那么我们可以创建N个图像patches,可以表示为x_{p}\epsilon R^{N*(p^{2}c)},其中N=\frac{HW}{P^{2}}N就是序列的长度,类似一个句子中单词的个数。在上面的图中,可以看到图片被分为了9个patches。

2.2 步骤2、将Patches铺平

在原论文中,作者选用的patch大小为16,那么一个patch的shape为(3,16,16),维度为3,将它铺平之后大小为3x16x16=768。即一个patch变为长度为768的向量。不过这看起来还是有点大,此时可以使用加一个Linear transformation,即添加一个线性映射层,将patch的维度映射到我们指定的embedding的维度,这样就和NLP中的词向量类似了。

2.3 步骤3、添加Position embedding

与CNNs不同,此时模型并不知道序列数据中的patches的位置信息。所以这些patches必须先追加一个位置信息,也就是图中的带数字的向量。实验表明,不同的位置编码embedding对最终的结果影响不大,在Transformer原论文中使用的是固定位置编码,在ViT中使用的可学习的位置embedding 向量,将它们加到对应的输出patch embeddings上。

2.4 步骤4、添加class token

在输入到Transformer Encoder之前,还需要添加一个特殊的class token,这一点主要是借鉴了BERT模型。添加这个class token的目的是因为,ViT模型将这个class token在Transformer Encoder的输出当做是模型对输入图片的编码特征,用于后续输入MLP模块中与图片label进行loss计算。

2.5 步骤5、输入Transformer Encoder

将patch embedding和class token拼接起来输入标准的Transformer Encoder中。 Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

2.6 步骤6、分类

注意Transformer Encoder的输出其实也是一个序列,但是在ViT模型中只使用了class token的输出,将其送入MLP模块中,去输出最终的分类结果。

3 模型搭建参数

在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16的外还有32x32的。

其中:

Layers就是Transformer Encoder中重复堆叠Encoder Block的次数 L。
Hidden Size就是对应通过Embedding层(Patch Embedding + Class Embedding + Position Embedding)后每个token的dim(序列向量的长度)
MLP Size是Transformer Encoder中MLP Block第一个全连接的节点个数(是token长度的4倍)
Heads代表Transformer中Multi-Head Attention的heads数。
 

4 结果分析

上表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比可得出结论:

  • 在训练epoch较少时Hybrid优于ViT -> Epoch小选Hybrid
  • 当epoch增大后ViT优于Hybrid -> Epoch大选ViT
     

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1347869.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LLM】人工智能应用构建的十大预训练NLP语言模型

在人工智能领域,自然语言处理(NLP)被广泛认为是阅读、破译、理解和理解人类语言的最重要工具。有了NLP,机器可以令人印象深刻地模仿人类的智力和能力,从文本预测到情感分析再到语音识别。 什么是自然语言处理&#xf…

C语言实验5:结构体

目录 一、实验要求 二、实验原理 1. 普通结构体 1.1 显示声明结构体变量 1.2 直接声明结构体变量 ​编辑 1.3 typedef在结构体中的作用 2. 结构体的嵌套 3. 结构体数组 4. 指向结构体的指针 4.1 静态分配 4.2 动态分配 三、实验内容 1. 学生数据库 代码 截图 …

《数据库开发实践》之触发器【知识点罗列+例题演练】

一、什么是触发器? 1.概念: 简单来说触发器就是一种特殊的存储过程,在数据库服务器触发事件的时候会自动执行其SQL语句集。 2.构成四要素: (1)名称:要符合标识符命名规则 (2&am…

跳跃表原理及实现

一、跳表数据结构 跳表是有序表的一种,其底层是通过链表实现的。链表的特点是插入删除效率高,但是查找节点效率很低,最坏的时间复杂度是O(N),那么跳表就是解决这一痛点而生的。 为了提高查询效率,我们可以给链表加上索…

ORACLE Primavera Unifier v23.12 最新虚拟机(VM)分享下载

引言 根据上周的计划,我近日简单制作了一个基于ORACLE Primavera Unifier 最新版23.12的虚拟机演示环境,里面包括了unifier的全套系统服务 此虚拟系统环境仅用于演示、培训和测试目的。如要在生产环境中使用此虚拟机,请您与Oracle 销售代表联…

pngPackerGUI_V2.0是什么工具?

pngPackerGUI_V2.0是什么工具? png图片打包plist工具,手把手教你使用pngPackerGUI_V2.0此软件是在pngpacker_V1.1软件基础之后,开发的界面化操作软件,方便不太懂命令行的小白快捷上手使用。1.下载并解压缩软件,得到如…

关于“Python”的核心知识点整理大全53

目录 18.2.7 Django shell 注意 18.3 创建网页:学习笔记主页 18.3.1 映射 URL urls.py urls.py 注意 18.3.2 编写视图 views.py 18.3.3 编写模板 index.html 往期快速传送门👆(在文章最后): 感谢大家的支…

VitulBox中Ubuntu虚拟机安装JAVA环境——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项

前言 在进行之后操作是请下载好JDK,之后的内容是以Ubuntu虚拟机中安装java环境续写。 提示:以下操作是在虚拟机hadoop用户下操作的,并为安装java环境作准备 一、更新APT 为了确保Hadoop安装过程顺利进行,建议用hadoop用户登录…

MCS接口技术----定时/计数,中断

目录 一.中断系统相关寄存器 1.51单片机中断系统的总体结构: 2.中断源的中断级别(由高到低): 3.与中断有关的四个寄存器: (1)TCON---定时控制寄存器 (2)IE---中断允…

一二三应用开发平台文件处理设计与实现系列之3——后端统一封装设计与实现

背景 前面介绍了前端通过集成vue-simple-uploader实现了文件的上传,今天重点说一下后端的设计与实现。 功能需求梳理 从功能角度而言,实际主要就两项,一是上传,二是下载。其中上传在文件体积较大的情况下,为了加快上…

2013年第二届数学建模国际赛小美赛B题寄居蟹进化出人类的就业模式解题全过程文档及程序

2013年第二届数学建模国际赛小美赛 B题 寄居蟹进化出人类的就业模式 原题再现: 寄居蟹是美国最受欢迎的宠物品种,依靠其他动物的壳来保护。剥去寄居蟹的壳,你会看到它柔软、粉红色的腹部卷曲在头状的蕨类叶子后面。大多数寄居蟹喜欢蜗牛壳&…

Java集合/泛型篇----第五篇

系列文章目录 文章目录 系列文章目录前言一、说说LinkHashSet( HashSet+LinkedHashMap)二、HashMap(数组+链表+红黑树)三、说说ConcurrentHashMap前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通…

加强->servlet->tomcat

0什么是servlet jsp也是servlet 细细体会 Servlet 是 JavaEE 的规范之一,通俗的来说就是 Java 接口,将来我们可以定义 Java 类来实现这个接口,并由 Web 服务器运行 Servlet ,所以 TomCat 又被称作 Servlet 容器。 Servlet 提供了…

数据结构: 位图

位图 概念 用一个bit为来标识数据在不在 功能 节省空间快速查找一个数在不在一个集合中排序 去重求两个集合的交集,并集操作系统中的磁盘标记 简单实现 1.设计思想:一个bit位标识一个数据, 使用char(8bit位)集合来模拟 2.预备工作:a.计算这个数在第几个char b.是这个ch…

「实验记录」CS144 Lab1 StreamReassembler

目录 一、Motivation二、SolutionsS1 - StreamReassembler的对外接口S2 - push_substring序列写入ByteStream 三、Result四、My Code五、Reference 一、Motivation 我们都知道 TCP 是基于字节流的传输方式,即 Receiver 收到的数据应该和 Sender 发送的数据是一样的…

C#-CSC编译环境搭建

一.Microsoft .NET Framework 确保系统中安装Microsoft .NET Framework相关版本下载 .NET Framework 4.7 | 免费官方下载 (microsoft.com)https://dotnet.microsoft.com/zh-cn/download/dotnet-framework/net47 二.编译环境搭建 已经集成编译工具csc.exe,归档至gitcode,实现us…

L1-076:降价提醒机器人

题目描述 小 T 想买一个玩具很久了,但价格有些高,他打算等便宜些再买。但天天盯着购物网站很麻烦,请你帮小 T 写一个降价提醒机器人,当玩具的当前价格比他设定的价格便宜时发出提醒。 输入格式: 输入第一行是两个正整数…

数据隐私:技术和法律的双重挑战

当前,数据已成为企业和个人最宝贵的资产之一。然而,随着数据的广泛收集和共享,数据隐私问题也日益突出。保护个人信息的隐私不仅是法律规定的义务,也是维护社会公正、保护个人权益的必要措施。本文将从数据隐私的概念、重要性、面…

Linux学习第48天:Linux USB驱动试验:保持热情,保持节奏,持续学习是作为一个技术人员应有的基本素质和要求

Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 最近更新的速度和频率大不如以前,主要原因还是自己有些懈怠了。学习是一个持续努力的过程,一旦中断,再想保持以往的状态可能要…

《MySQL系列-InnoDB引擎01》MySQL体系结构和存储引擎

文章目录 第一章 MySQL体系结构和存储引擎1 数据库和实例2 MySQL配置文件3 MySQL数据库路径4 MySQL体系结构5 MySQL存储引擎5.1 InnoDB存储引擎5.2 MyISAM存储引擎5.3 NDB存储引擎5.4 Memory存储引擎5.5 Archive存储引擎5.6 Federated存储引擎 6 连接MySQL6.1 TCP/IP6.2 命名管…