Vision Transformer(ViT)

news2025/1/11 17:57:27

1. 概述

Transformer[1]是Google在2017年提出的一种Seq2Seq结构的语言模型,在Transformer中首次使用Self-Atttention机制完全代替了基于RNN的模型结构,使得模型可以并行化训练,同时解决了在基于RNN模型中出现了长距离依赖问题,因为在Self-Attention中能够对全局的信息建模。

Transformer结构是一个标准的Seq2Seq结构,包含了Encoder和Decoder两个部分。其中基于Encoder的Bert[2]模型和基于Decoder的GPT[3]模型刷新了NLP中多个任务的记录,在NLP多种应用中取得了巨大的成功。以BERT模型为例,在BERT模型中,首先在大规模数据上利用无监督学习训练语言模型,对于具体的下游任务,如文本分类,利用预训练模型在下游数据上Fine-tuning。

基于Transformer框架的模型在NLP领域大获成功,而在CV领域还是基于CNN模型的情况下,能否将Transformer引入到CV中呢?ViT(Vision Transformer)[4]作为一种尝试,希望能够通过尽可能少的模型改动,实现Transformer在CV中的应用。

2. 算法原理

2.1. Transformer的基本原理

Transformer框架是一个典型的Seq2Seq结构,包括了Encoder和Decoder两个部分,其框架结构如下图所示:

在这里插入图片描述

在Transformer框架结构中,Encoder部分如上图的左半部分,Decoder部分如上图的右半部分。由于在ViT中是以Encoder部分为主要部分,同时,BERT模型也是以Transformer中Encoder为原型的模型,因此在这里对Bert模型做简单介绍,对于完整的Transformer框架的介绍可见参考文献[5]。BERT是基于上下文的预训练模型,BERT模型的训练分为两步:第一,pre-training;第二,fine-tuning。其中,在pre-training阶段,首先会通过大量的文本对BERT模型进行预训练,然而,标注样本是非常珍贵的,在BERT中则是选用大量的未标注样本来预训练BERT模型。在fine-tuning阶段,会针对不同的下游任务适当改造模型结构,同时,通过具体任务的样本,重新调整模型中的参数。

2.1.1. BERT模型的网络结构

BERT模型是Transformer结构的Encoder部分,其基本的网络结构如下图所示:

在这里插入图片描述

这个结构与Transformer中的Encoder结构是完全一致的。

2.1.2. BERT模型的输入Embedding

为了使得BERT能够适配更多的应用,模型在pre-training阶段,使用了Masked Language Model(MLM)和Next Sentence Prediction(NSP)两种任务作为模型预训练的任务,其中MLM可以学习到词的Embedding,NSP可以学习到句子的Embedding。在Transformer中,输入中会将词向量与位置向量相加,而在BERT中,为了能适配上述的两个任务,即MLM和NSP,这里的Embedding包含了三种Embedding的和,如下图所示:

在这里插入图片描述

其中,Token Embeddings是词向量,第一个单词是CLS标志,可以用于之后的分类任,Segment Embeddings用来区别两种句子,这是在预训练阶段,针对NSP任务的输入,Position Embeddings是位置向量,但是和Transformer中不一样,与词向量一样,是通过学习出来的。此处包含了两种标记,一个是[CLS],可以理解为整个输入特征的向量表示;另一个是[SEP],用于区分不同的句子。

2.1.3. 重要的Multi-Head Attention

Multi-Head Attention结构是所以基于Transformer框架模型的灵魂,Multi-Head Attention结构是由多个Scaled Dot-Product Attention模块组合而成,如下图所示:

在这里插入图片描述

其过程可以表示为:

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , ⋯   , h e a d h ) W o MultiHead\left ( Q,K,V \right ) =Concat\left ( head_1,\cdots, head_h \right ) W^o MultiHead(Q,K,V)=Concat(head1,,headh)Wo

其中,每一个 h e a d i head_i headi就是一个Scaled Dot-Product Attention。Multi-head Attention相当于多个不同的Scaled Dot-Product Attention的集成,引入Multi-head Attention可以扩大模型的表征能力,同时这里面的 h h h个Scaled Dot-Product Attention模块是可以并行的,没有层与层之间的依赖,相比于RNN,可以提升效率。而Scaled Dot-Product Attention的计算方法为:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention\left ( Q,K,V \right )=softmax\left ( \frac{QK^T}{\sqrt{d_k} }\right )V Attention(Q,K,V)=softmax(dk QKT)V

其中 1 d k \frac{1}{\sqrt{d_k} } dk 1最主要的目的是对点积缩放。计算过程可由下图表示:

在这里插入图片描述

2.1.4. 下游任务的fine-tuning

在预训练阶段,BERT采用了Masked Language Model和Next Sentence Prediction两个训练任务作为其语言模型的训练,其中,Masked Language Model的原理是随机将一些词替换成[MASK],在训练的过程中,通过上下文信息来预测被mask的词;Next Sentence Prediction的目的是让模型理解两个橘子之间的关系,训练的输入是两个句子,BERT模型需要判断后一个句子是不是前一个句子的下一句。这两个任务最大的特点就是可以无监督学习,这样就可以避免模型对大规模标注数据依赖的问题。

在预训练模型完成后,就可以在具体的下游任务中应用BERT模型。这里以文本分类为例,句子对的分类任务,即输入是两个句子,输入如下图所示:

在这里插入图片描述

输出是BERT的第一个[CLS]的隐含层向量 C ∈ R H C\in \mathbb{R}^H CRH,在Fine-Tune阶段,加上一个权重矩阵 W ∈ R K × H W\in \mathbb{R}^{K\times H} WRK×H,其中, K K K为分类的类别数。最终通过Softmax函数得到最终的输出概率。

2.2. ViT的基本原理

ViT模型是希望能够尽可能少对Transformer模型修改,并将Transformer应用于图像分类任务的模型。ViT模型也是基于Transformer的Encoder部分,这一点与BERT较为相似,同时对Encoder部分尽可能少的修改。

2.2.1. ViT的网络结构

ViT的网络结构如下图所示:

在这里插入图片描述

ViT模型的网络结构如上图的右半部分所示,与原始的Transformer中的Encoder不同的是Norm所在的位置不同,类似BERT模型中[class]标记位的设置,ViT在Transformer输入序列前增加了一个额外可学习的[class]标记位,并且该位置的Transformer Encoder输出作为图像特征。

Vision Transformer(ViT)将输入图片拆分成 16 × 16 16\times 16 16×16个patches,每个patch做一次线性变换降维同时嵌入位置信息,然后送入Transformer。类似BERT[CLS]标记位的设计,在ViT中,在输入序列前增加了一个额外可学习的[class]标记位,并将其最终的输出作为图像特征,最后利用MLP做最后的分类,如上图中的左半部分所示,其中,[class]标记位为上图中Transformer Encoder的0*。那么现在的问题就是两个部分,第一,如何将图像转换成一维的序列数据,因为BERT处理的文本数据是一维的序列数据;第二,如何增加位置信息,因为在Transformer中是需要对位置信息编码的,在BERT中是通过学习出来,而在Transformer中是利用sin和cos这两个公式生成出来。

2.2.2. 图像到一维序列数据的转换

对于 x ∈ R H × W × C \mathbf{x}\in \mathbb{R}^{H\times W\times C} xRH×W×C的图像,首先需要将其变成 x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_p\in \mathbb{R}^{N\times \left (P^2\cdot C \right )} xpRN×(P2C)的2D的patch的序列,这里面, ( H , W ) \left ( H,W \right ) (H,W)表示的是原图的分辨率, C C C表示的通道(channel)的数目, ( P , P ) \left ( P,P \right ) (P,P)表示的是每个patch的分辨率, N = H W / p 2 N=HW/p^2 N=HW/p2表示的是patch的个数,对于一个通道,上述的这个过程可以如下图所示:

在这里插入图片描述

假设输入图片大小是 256 × 256 256\times 256 256×256,每个patch的大小为 32 × 32 32\times 32 32×32,则最后的总的patch个数为64。对于每个patch,我们还需要将其转换成embeding的表示,ViT中使用到了线性变换,即:

z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; ⋯   ; x p N E ] + E p o s \mathbf{z}_0=\left [ \mathbf{x}_{class};\mathbf{x}_p^1\mathbf{E};\mathbf{x}_p^2\mathbf{E};\cdots ;\mathbf{x}_p^N\mathbf{E} \right ]+\mathbf{E}_{pos} z0=[xclass;xp1E;xp2E;;xpNE]+Epos

其中, E ∈ R ( P 2 ⋅ C ) × D \mathbf{E}\in \mathbb{R}^{\left ( P^2\cdot C \right )\times D} ER(P2C)×D E p o s ∈ R ( N + 1 ) × D \mathbf{E}_{pos}\in \mathbb{R}^{\left ( N+1 \right )\times D} EposR(N+1)×D。首先对于第 i i i个patch,我们看到 x p i E \mathbf{x}_p^i\mathbf{E} xpiE是将patch转换成 D D D维的向量,具体过程如下:

在这里插入图片描述

这里的卷积操作中卷积核大小为 P × P P\times P P×P,步长为 P P P。参考文献[6]给出了较为容易理解的代码,注释的代码如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size) # 图片原始大小
        patch_size = (patch_size, patch_size) # 每个patch的大小
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # 拆分成每个patch后,每个维度的patch个数
        self.num_patches = self.grid_size[0] * self.grid_size[1] # 总共的patch个数

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) # 对每个patch做线性变换
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() # 归一化

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2) # 这里C即为向量的维度,HW表示的是patch的个数
        x = self.norm(x)
        return x

除此之外还有两个向量,分别为 x c l a s s \mathbf{x}_{class} xclass E p o s \mathbf{E}_{pos} Epos x c l a s s \mathbf{x}_{class} xclass表示的给到一个用于最后图像表示的向量,用于最后的分类任务, E p o s \mathbf{E}_{pos} Epos表示的是位置向量,这两个向量都是通过随机初始化的,并在训练过程中得到的,在参考文献[6]中的代码如下:

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))

通过以上的过程后,便可以直接使用标准的BERT流程开始训练,这部分不再赘述,可参见参考文献[5]和参考文献[6]的具体实现。

2.2.3. 训练目标以及fine-tune

ViT的训练与BERT是不一样的,在BERT中采用的无监督的训练,而在ViT中使用的是监督训练,使用的数据集是有标签的分类数据集,如ILSVRC-2012 ImageNet数据集,该数据集是一个包含了1000个类别的带标签的数据集。[class]标记的向量最初为 z 0 0 = x c l a s s \mathbf{z}_0^0=\mathbf{x}_{class} z00=xclass,在训练过程中,通过Transformer Encoder得到[class]标记的最终向量为 z L 0 \mathbf{z}_L^0 zL0,对其进行归一化并以此作为图像的表示 y \mathbf{y} y

y = L N ( z L 0 ) \mathbf{y}=LN\left ( \mathbf{z}_L^0 \right ) y=LN(zL0)

在训练过程中,后接一个带一个隐含层的MLP,得到整个网络的结构。在Fine-tuning时,去掉最终的这部分,直接用一个线性曾代替这部分重新训练。

在参考文献[4]中,作者设计了不同大小的网络结构,如下图所示:

在这里插入图片描述

从最终的效果上来看,ViT模型的效果还是要优于传统的基于CNN的模型的:

在这里插入图片描述

2.2.4. 一个有意思的点

在上述的ViT的过程中,位置的向量 E p o s \mathbf{E}_{pos} Epos是随机初始化的,那么最终训练出来的这个向量的值能表示其在原始图像中的真实位置吗?在参考文献[4]中设计了这样一个方法,假设有 7 × 7 7\times 7 7×7个patch,每个patch的位置向量与其他patch的位置向量计算相似度,得到了如下的一张图,其中自身的相似度为 1 1 1

在这里插入图片描述

我们发现最终训练出来的位置向量已经具有了空间了信息,即与同行同列之间具有相对较高的相似度。

3. 总结

ViT模型将Transformer引入到图像的分类中,更准确的说是Transformer中的Encoder模块。为了能够尽可能少地对原始模型的修改,在ViT中将图像转换成一维的序列表示,以改成标准的文本形式,通过这种方式实现Transformer在CV中的应用。

参考文献

[1] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30.

[2] Devlin J , Chang M W , Lee K , et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding[J]. 2018.

[3] Radford A, Narasimhan K, Salimans T, et al. Improving language understanding by generative pre-training[J]. 2018.

[4] Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.

[5] Transformer的基本原理

[6] vision_transformer 代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/373797.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TDG code

部分 数据集 参数设置 def setup_args(args None):args.algorithm_name TDG# args.algorithm_name HDGargs.user_num 1000000args.attribute_num 6args.domain_size 64args.epsilon 0.2args.dimension_query_volume 0.5args.query_num 20args.query_dimension 3运行…

leetcode 41~50 学习经历

leetcode 41~50 学习经历41. 缺失的第一个正数42. 接雨水43. 字符串相乘44. 通配符匹配45. 跳跃游戏 II46. 全排列47. 全排列 II48. 旋转图像49. 字母异位词分组50. Pow(x, n)小结41. 缺失的第一个正数 给你一个未排序的整数数组 nums ,请你找出其中没有出现的最小的…

C语言数据结构(二)—— 受限线性表 【栈(Stack)、队列(Queue)】

在数据结构逻辑层次上细分,线性表可分为一般线性表和受限线性表。一般线性表也就是我们通常所说的“线性表”,可以自由的删除或添加结点。受限线性表主要包括栈和队列,受限表示对结点的操作受限制。一般线性表详解,请参考文章&…

数据结构基础之栈和队列

目录​​​​​​​ 前言 1、栈 2、队列 2.1、实现队列 2.2、循环队列 前言 上一篇中我们介绍了数据结构基础中的《动态数组》,本篇我们继续来学习两种基本的数据结构——栈和队列。 1、栈 特点:栈也是一种线性结构,相比数组&#xff…

(汇总记录)电机控制算法

1.S曲线应用电机加减速 电机控制 | S曲线加减速 - Tuple - 博客园 (cnblogs.com) 如要将S型曲线应用到电机的加减速控制上,需要将方程在X、Y坐标系进行平移,同时对曲线进行拉升变化:即 Y A B / ( 1 exp( -ax b ) ) ,则根据该…

Pandas怎么添加数据列删除列

Pandas怎么添加数据列 1、直接赋值 # 1、直接赋值df.loc[:, "最高气温"] df["最高气温"].str.replace("℃", "").astype("int32")df.loc[:, "最低气温"] df["最低气温"].str.replace("℃"…

Java异常架构与异常关键字

Java异常简介 Java异常是Java提供的一种识别及响应错误的一致性机制。 Java异常机制可以使程序中异常处理代码和正常业务代码分离,保证程序代码更加优雅,并提高程序健壮性。在有效使用异常的情况下,异常能清晰的回答what, where, why这3个问…

【编程入门】N种编程语言做个应用市场(appstore)

背景 前面已输出多个系列: 《十余种编程语言做个计算器》 《十余种编程语言写2048小游戏》 《17种编程语言10种排序算法》 《十余种编程语言写博客系统》 《十余种编程语言写云笔记》 《N种编程语言做个记事本》 本系列做了个应用市场,支持下载安装安卓…

Bootstrap系列之导航

Bootstrap导航 可以在 ul 元素上添加 .nav类&#xff0c;在每个 li 选项上添加 .nav-item 类&#xff0c;在每个链接上添加 .nav-link 类: 基本的导航 <div class"container mt-3"><h2>导航</h2><p>简单的水平导航:</p><ul class&…

基于yolov5与改进VGGNet的车辆多标签实时识别算法

摘 要 为了能快速、有效地识别视频中的车辆信息&#xff0c;文中结合YOLOv3算法和CNN算法的优点&#xff0c;设计了一种能实时识别车辆多标签信息的算法。首先&#xff0c;利用具有较高识别速度和准确率的YOLOv3实现对视频流中车辆的实时监测和定位。在获得车辆的位置信息后…

《亚马逊逆向工作法》读书笔记

文章目录书籍信息构件&#xff1a;领导力准则与机制亚马逊领导力准则机制&#xff1a;强化领导力准则年度计划&#xff1a;OP1与OP2S-Team目标亚马逊的薪酬制度&#xff1a;强化长期思维招聘&#xff1a;亚马逊独特的抬杆者流程抬杆者招聘流程组织&#xff1a;独立单线程领导模…

Redis-Java代码使用示例

在我之前的项目中&#xff0c;使用Redis是我们团队自己封装了一个Redis操作类&#xff0c;但是这只是在Spring提供的RedisTemplate上做了一层封装而已&#xff0c;当时使用不是很熟练&#xff0c;都是一边在网上查资料&#xff0c;一边使用&#xff1b;这篇文章会介绍两种使用方…

分布式一致性算法——Paxos 和 Raft 算法

写在前面 本文隶属于专栏《100个问题搞定大数据理论体系》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见100个问题搞定大数据理论体系 I. 简介 介绍Paxos…

局域网实现PC、Pad、Android互联

文章目录局域网实现PC、Pad、Android互联一、网络邻居1、 Windows 配置1.1 开启共享功能1.2 设置用户1.3 共享文件夹2、 Pad 连接二、 FTP & HTTP1、 电脑配置1.1 HTTP 服务1.2 FTP 服务2、 连接3、 电脑连接 FTP三、 其他方式局域网实现PC、Pad、Android互联 在我们使用多…

【micropython】SPI触摸屏开发

背景&#xff1a;最近买了几块ESP32模块&#xff0c;看了下mircopython支持还不错&#xff0c;所以买了个SPI触摸屏试试水&#xff0c;记录一下使用过程。硬件相关&#xff1a;SPI触摸屏使用2.4寸屏幕&#xff0c;常见淘宝均可买到&#xff0c;驱动为ILI9341&#xff0c;具体参…

windows服务器实用(2)——搭建本地文档管理(gitbit的部署)

windows服务器实用——部署gitbit 在日常的项目管理中&#xff0c;无论是文档还是代码&#xff0c;一般都是存在本地。但是本地的文件存在一定的不确定性&#xff0c;尤其是当文档经常改动的时候&#xff0c;如果要找回之前改动的文件是很困难的。如果每次的改动都存在本地&am…

数据结构与算法之链表

目录单链表概念单链表操作循环链表概念循环链表操作双向循环链表概念双向循环链表操作单链表 概念 单链表也叫单向链表&#xff0c;是链表中最简单的一种形式&#xff0c;它的每个节点包含两个域&#xff0c;一个信息域&#xff08;元素域&#xff09;和一个链接域。这个链接…

微信投票-课后程序(JAVA基础案例教程-黑马程序员编著-第七章-课后作业)

【实验7-5】 微信投票 【任务介绍】 1.任务描述 如今微信聊天已经普及到几乎每一个人&#xff0c;在聊天中&#xff0c;经常会有人需要帮忙在某个APP中投票。本案例要求编写一个模拟微信投票的程序&#xff0c;通过在控制台输入指令&#xff0c;实现添加候选人、查看当前投票…

【C语言刷题】找单身狗、模拟实现atoi

目录 一、找单身狗 1.暴力循环法 2.分组异或法 二、模拟实现atoi 1.atoi函数的功能 2.模拟实现atoi 一、找单身狗 题目描述&#xff1a;给定一个数组中只有两个数字是出现一次&#xff0c;其他所有数字都出现了两次。 编写一个函数找出这两个只出现一次的数字。 比如&…

【Maven】(三)Maven仓库概念及私服安装与使用 附:Nexus安装包下载地址

文章目录1.前言2.Maven的仓库2.1.仓库类型3.私服Nexus3.1.Nexus的安装与配置3.1.1.使用安装包安装3.1.2.使用Docker安装3.2.Nexus配置3.2.1.仓库配置在这里插入图片描述4.私服的使用4.1.修改Maven配置4.2.从私服中下载构件4.3.推送构件到私服5.小结1.前言 本系列文章记录了 Ma…