Vision Transformer模型架构详解

news2024/12/23 15:09:44

🎀个人主页: https://zhangxiaoshu.blog.csdn.net
📢欢迎大家:关注🔍+点赞👍+评论📝+收藏⭐️,如有错误敬请指正!
💕未来很长,值得我们全力奔赴更美好的生活!

前言

2019年开始,自然语言处理(NLP)领域抛弃了循环神经网络(RNN)序列依赖的问题,开始采用Attention is All you need的Transformer结构[1],其中的Attention是一种可以让模型专注于重要的信息并能够充分学习和吸收的技术。在NLP领域中,伴随着各种语言Transformer模型的提出使得多项语言处理任务的精度和模型深度开始飞速提升。由于基于Transformer的预训练语言模型非常成功,研究者开始探索其在视觉领域的应用。2020年10月,Google创新性的设计了用于分类的Vision Transformer模型[2]—ViT。此后视觉Transformer模型的研究进入了快车道,本文主要对Vision Transformer模型架构进行详细介绍以及在pytorch中的使用方法进行介绍。


文章目录

  • 前言
  • 一、Vision Transformer模型架构
    • 1. Embedding层结构
    • 2. Transformer Encoder结构
      • (1)层归一化(Layer Norm)
      • (2)多头注意力机制(Multi-Head Attention)
      • (3)Dropout/DropPath
      • (4)MLP Block
    • 3. MLP Head结构
  • 二、PyTorch实现
    • 1. 首先安装vit-pytorch库:
    • 2.导入进行调用:
  • 总结


一、Vision Transformer模型架构

下图是原论文中作者给出的关于Vision Transformer的模型总体框架图:

在这里插入图片描述
从图中可以看出,Vision Transformer模型主要由三部分组成:第一部分为Linear Projection of Flattened Patches,也被称为Embedding层,主要用于将输入的图片数据转化为适合Transformer结构处理的形式。第二部分为Transformer Encoder部分,它是整个ViT模型的核心板块,在图右侧给出了更加详细的结构,它主要由层归一化(Layer Norm)、多头注意力机制(Multi-Head Attention)、Dropout/DropPath、MLP Block四部分组成用于学习输入图像数据的特征。第三部分为MLP Head,它是最终用于分类的层结构。下面本设计将对每一个组成部分进行一个详细介绍。

1. Embedding层结构

在视觉Transform模型中,其Transformer Encoder模块的输入形式是一个向量(token)序列,即一个二维矩阵[num_token, token_dim]的形式,如上图所示,输入的粉色小块token0-9对应的都是向量序列。

但是,图像处理和语言处理不一样,它的数据格式和Transformer Encoder输入格式是不一样,而是一个三维矩阵[H, W, C]的形式。所以在视觉Transform模型中首先加入了一个Embedding层结构用于将数据变化为向量序列。其主要过程为:首先将输入的图片形式数据按照模型定义的切割大小切割成多个小块(Patches),然后将切割的小块通过维度变化映射成向量形式。以常见的ViT-B/16为例,它首先将输入图片( 224 × 224 224\times224 224×224)按照 16 × 16 16\times16 16×16的大小进行切分得到196个Patches,接着通过线性映射将每一个Patches(16, 16, 3)映射成一个长度为768的向量。

在具体实现代码时,可以通过一个卷积层和Flatten层来直接实现。以ViT-B/16为例,如图所示其卷积层的参数为:卷积核大小是16x16、步距是16、卷积核的个数是768。数据通过卷积层后维度从(224, 224, 3)变化为(14, 14, 768),接着,将H和W两个维度展平即Flatten操作即可变化为(196, 768)这样的二维矩阵形式,这正是Transformer Encoder的输入格式。
在这里插入图片描述
除了将输入数据的形式变化为Transformer Encoder的输入格式,模型还在输入Transformer Encoder之前加入了[class]token以及Position Embedding,如下图所示。[class]token是参考了BERT所设计的,它是一个可以学习的参数,用于拼接到tokens中专门用于图像数据的分类。以ViT-B/16为例,就是让一个768长度的向量,与从Flatten层输出的数据拼接在一起,即,Cat((1, 768),(196, 768))—>(197, 768)。Position Embedding也是一个可以学习的参数。它是直接叠加在tokens上的(Add),因为对于图像数据而言,每一块和每一块在都有一定的位置依赖关系,所以Position Embedding主要用于表达Patches之间的位置关系。以ViT-B/16为例,就是让一个(197, 768)的向量与之前得到的(197, 768)向量相加。
在这里插入图片描述

2. Transformer Encoder结构

Transformer Encoder其实就是将Encoder Block 重复堆叠L次, Encoder Block结构图如下图2.4所示,主要由层归一化(Layer Norm)多头注意力机制(Multi-Head Attention)Dropout/DropPathMLP Block四部分组成。

(1)层归一化(Layer Norm)

层归一化(Layer Norm):这是一种主要针对NLP领域提出的归一化方法,这里是对每个token进行归一化处理。目前的归一化层主要有BN、LN、IN、GN和SN五种方法,它解决了深度神经网络内部协方差偏移问题,是一种将深度神经网络之间的数据进行归一化的算法,使得深度学习的训练过程中梯度变化趋于稳定,从而使网络在训练时达到快速收敛的目的。将输入的图像shape记为[N, C, H, W],这些方法的主要不同之处是,BatchNorm是在Batch上进行的,对NHW做归一化,对于较小的Batch Size没有太大的作用;LayerNorm是在通道方向上进行的,对CHW归一化,对RNN有很大的作用;InstanceNorm是在图像的像素上进行的,对HW做归一化,主要用在风格化迁移等方面;GroupNorm首先将Channel进行分组,然后再做归一化;SwitchableNorm是将BN、LN、IN结合并给予权重,让网络自己去学习归一化层应当使用的方法。

*有关BN、LN、IN、GN归一化方法的详细介绍可以看我这篇文章:神经网络常用归一化和正则化方法解析(一);

在这里插入图片描述
Layer Norm即层归一化针对神经网络的某一层的所有输入按照以下公式进行归一化操作:

H H H是某一层中隐藏结点的数量, l l l表示层数,可以计算得到Layer Norm的归一化统计量 μ l \mu^l μl σ l \sigma^l σl,如下式:

μ l = 1 H ∑ i = 1 H a i l \mu^l=\frac{1}{H}\sum_{i=1}^{H}a_i^l μl=H1i=1Hail

σ l = 1 H ∑ i = 1 H ( a l − μ l ) 2 \sigma^l=\sqrt{\frac{1}{H}\sum_{i=1}^{H}\left(a^l-\mu^l\right)^2} σl=H1i=1H(alμl)2

其中 a l a^l al表示一个中间输出结果的总和。上面的统计量和样本数没有关系,而是和隐藏层的结点数有关,我们甚至可以使 Batch Size = 1。于是,我们可以根据约定的统计量进行归一化处理,

a ^ l = a l − μ l ( σ l ) 2 + ε {\hat{a}}^l=\frac{a^l-\mu^l}{\sqrt{\left(\sigma^l\right)^2+\varepsilon}} a^l=(σl)2+ε alμl

同样,在Layer Norm中常使用参数增益(gain)和偏置(bias)这两个参数来保障归一化操作不会破坏之前的信息,同BatchNorm中的 γ \gamma γ β \beta β

y i = γ a ^ l + β y_i=\gamma{\hat{a}}^l+\beta yi=γa^l+β

从以上公式可以看到, LN中同层神经元输入拥有相同的均值和方差,不同的输入样本有不同的均值和方差。所以,LN与Batch的大小无关,也不取决于输入Sequence的深度,所以可以在batchsize为1和RNN中对边长的输入Sequence进行Normalize操作。

(2)多头注意力机制(Multi-Head Attention)

多头注意力机制(Multi-Head Attention):通过多个注意力机制的并行组合,将独立的注意力输出串联起来,预期维度得到线性地转化。直观看来,多个注意头允许对序列的不同部分进行注意力运算

对于Self-Attention来说,假设输入的token长度为 L L L,则输入为 [ x 1 , x 2 . . . x L , ] [x_1,x_2...x_L,] [x1,x2...xL,],然后分别将 x 1 x 2 . . . x L x_1x_2...x_L x1x2...xL分别通过三个变化矩阵 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv(这三个参数是可训练的、共享的)得到对应的 q i , k i , v i , q^i,k^i,v^i, qi,ki,vi, 并将 q , k , v q,k,v q,k,v向量序列记为 Q , K , V Q,K,V Q,K,V。计算过程如下式所示,具体实现时使用全连接层实现。

( Q , K , V ) = ( q i , k i , v i ) = x i ∙ ( W q , W k , W v ) (Q,K,V)=\left(q^i,k^i,v^i\right)=x_i\bullet\left(W_q,W_k,W_v\right) (Q,K,V)=(qi,ki,vi)=xi(Wq,Wk,Wv)

其中 i = 1 , 2... L i=1,2...L i=1,2...L q q q表示query,后续会去和每一个k进行匹配, k k k代表key,后续会被每个 q q q匹配, v v v代表从 x x x中提取得到的信息value,后续 q q q k k k匹配的过程可以理解成计算两者的相关性,相关性越大对应 v v v的权重也就越大。

接着将 Q Q Q中的每一个 q i q^i qi去和 K K K中的每一个 k j k^j kj进行匹配,即点积操作。然后再除以 L \sqrt L L 得到对应的 α i , j \alpha_{i,j} αi,j,这样做的目的是进行点乘后的数值很大,导致通过Softmax后梯度变的很小,所以通过除以 L \sqrt L L 来进行缩放。具体计算过程如下式所示。

α i , j = q i ( k j ) T L \alpha_{i,j}=\frac{q^i\left(k^j\right)^T}{\sqrt L} αi,j=L qi(kj)T

α i , j \alpha_{i,j} αi,j表示 x i x_i xi x j x_j xj注意程度,然后对每一行分别进行Softmax处理得到 a ^ \hat{a} a^,相当于 x j x_j xj x i x_i xi权重,即对于 v v v的权重。具体计算过程如下式所示。

a ^ i , j = S o f t m a x ( α i , j ) {\hat{a}}_{i,j}=Softmax(α_{i,j}) a^i,j=Softmax(αij)

上面已经计算得到 a ^ i , j {\hat{a}}_{i,j} a^i,j,即针对每个 v v v的权重,接着进行加权得到最终结果,如下式所示。

b i = ∑ j = 1 L a ^ i , j × v j b^i=\sum_{j=1}^{L}{{\hat{a}}_{i,j}\times v^j} bi=j=1La^i,j×vj

其中 b i b^i bi表示 x i x_i xi经过Self-Attention后的结果。以上四式的过程习惯上用以下式来统一表示。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q ( K ) T L ) V Attention(Q,K,V)=Softmax\left(\frac{Q\left(K\right)^T}{\sqrt L}\right)V Attention(Q,K,V)=Softmax(L Q(K)T)V

对于Multi-Head Attention来说, 使用多头注意力机制能够联合来自不同head部分学习到的信息。首先根据使用的head的数目 h h h W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv权值矩阵均分成 h h h份,即 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV, 其中 i = 1 , 2... h i=1,2...h i=1,2...h,然后还是和Self-Attention模块一样将 x i x_i xi分别通过变化矩阵 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV得到对应head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi, 接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。如下式所示。

h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) {head}_i=Attention\left(QW_i^Q,KW_i^K,VW_i^V\right) headi=Attention(QWiQ,KWiK,VWiV)

其中 Q W i Q QW_i^Q QWiQ同前式相比多了一个 W i Q W_i^Q WiQ,表示这里是根据划分的变化矩阵去计算每一个head的结果。即通过 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV映射得到每个head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi,然后计算结果。
最后将每个head得到的结果进行concat拼接,接着将拼接后的结果通 过 W o 过W^o Wo(可学习的参数)进行融合,融合后得到最终的结果 b i b^i bi。如式(2-11)所示。

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , h e a d 1 … h e a d h ) W o MultiHead\left(Q,K,V\right)=Concat\left({head}_1,{head}_1\ldots{head}_h\right)W^o MultiHead(Q,K,V)=Concat(head1,head1headh)Wo

(3)Dropout/DropPath

Dropout/DropPath:在学习深度学习模型时,由于模型的参数过多、样本数量过少,导致了模型的过度拟合。在神经网络的训练中,常常会碰到一些问题。该方法具有较低的训练数据损失,具有较高的训练准确率。但是,测试数据的损失函数比较大,导致预测的准确性不高。

Dropout能在一定程度上减轻过度拟合,并能在某种程度上实现正规化。其基本原理是:在前向传播前进的过程中,使一个神经元的激活值以 p的概率不能工作,这在下面的图中可以看到。停止工作的神经元用虚线表示,与该神经元相连的相应传播过程将不在存在。这使得模型更加一般化,因为它不会依赖于一些局部特征。

DropPath类似于Dropout,不同的是Dropout 是对神经元随机“失效”,而DropPath是随机“失效”模型中的多分支结构。例如如下图右图所示,若 x x x为输入的张量,其通道为[B,C,H,W],那么DropPath的含义为一个Batch_size中,在经过多分支结构时,随机有drop_prob的样本,不经过主干,而直接经过分支(图中虚线)进行恒等映射。这在一定程度上使模型泛化性更强。
在这里插入图片描述

(4)MLP Block

MLP Block:如前文中Transformer Encoder结构图右侧所示,MLP Block由全连接层、GELU激活函数、Dropout组成,以ViT-B/16为例,第一个全连接层会把输入节点个数翻4倍(197, 768)—> (197, 3072),第二个全连接层会还原回原节点个数(197, 3072)—> (197, 768)。

3. MLP Head结构

通过Transformer Encoder后输出的维度和输入的维度是保持不变的,以ViT-B/16为例,输入的是(197, 768)输出的还是(197, 768)。这里只需要从[class]token抽取生成的对应结果,即从(197, 768)中抽取出[class]token对应的(1, 768),即为需要的分类信息。然后就可以用 MLP Head进行最后的分类得到结果。原论文中提到,在训练ImageNet21K时MLP Head是由全连接层+tanh激活函数+全连接层组成。但是如果是在ImageNet1K或者自己的数据集上时,只需要使用一个全连接层(Linear)即可,其结构如下图所示。
在这里插入图片描述

二、PyTorch实现

ViT模型共有三个不同的规模,如下所示:
。

1. 首先安装vit-pytorch库:

$ pip install vit-pytorch

2.导入进行调用:

import torch
from vit_pytorch import ViT

model = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
 
imgs = torch.randn(1, 3, 224, 224)
 
preds = model(imgs) # (1, 1000)

总结

以上就是对Vision Transformer模型架构的详细介绍及其适用,Vision Transformer模型作为第一个将Transformer结构应用到计算机视觉上的模型,对近年来计算机视觉的研究具有很大的意义,其常常与swin Transformer(可以理解为FPN结构的ViT)用作其他任务如检测、分割的backbone以及视觉特征提取器。

参考:
Attention is all you need
An image is worth 16x16 words: Transformers for image recognition at scale

文中图片大多来自论文和网络,如有侵权,联系删除,文中有不对的地方欢迎指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1304688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动化测试 (一) 12306火车票网站自动登录工具

还记得2011年春运,12306火车票预订网站经常崩溃无法登录吗。 今天我们就开发一个12306网站自动登录软件。 帮助您轻松订票 Web的原理就是,浏览器发送一个Request给Web服务器,Web服务器处理完这个请求之后发送一个HTTP Response给浏览器。 …

【JAVA】黑马MybatisPlus 学习笔记【一】

1.快速入门 为了方便测试,我们先创建一个新的项目,并准备一些基础数据。 1.1 环境配置 导入项目 注意配置一下项目的JDK版本为JDK11。首先点击项目结构设置: 导入两张表,在课前资料中已经提供了SQL文件: 最后&am…

python用来干什么的,python用来做什么的

大家好,小编来为大家解答以下问题,python用来干什么的,python用来做什么的,今天让我们一起来看看吧! 随着互联网行业的发展,编程越来越受到人们的重视,但是始终很多人并不了解编程是什么&#x…

Linux——MySQL数据库系统

一、 MySQL的编译安装 1、准备工作 (1)为了避免发生端口冲突,程序冲突等现象,建议先查询MySQL软件的安装情况,确认没有使用以Rpm方式安装的mysql-server、mysql软件包,否则建议将其卸载 [rootlocalhost ~]…

13、RockerMQ消息类型之广播与集群消息

RocketMq中提供两种消费模式:集群模式和广播模式。 集群模式 集群模式表示同一个消息会被同一个消费组中的消费者消费一次,消息被负载均衡分配到同一个消费者上的多个实例上。 还有另外一种平均的算法是AllocateMessageQueueAveragelyByCircle&#xff…

element table表格内进行表单验证(简单例子,一看就会,亲测有用~)开箱即用!!

效果图&#xff1a; 代码&#xff1a; <div> <el-form ref"form" :model"form" ><el-table :data"form.tableData" align"center" border><el-table-column label"名称"><template slot-scope&…

国标GB28181安防视频云平台EasyCVR出现持续重启现象,是什么问题?该如何解决?

视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等&#xff0c;以及支持厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安防视频监控的能…

边缘计算系统设计与实践

随着科技的飞速发展&#xff0c;物联网和人工智能两大领域的不断突破&#xff0c;我们看到了一种新型的计算模型——边缘计算的崛起。这种计算模型在处理大规模数据、实现实时响应和降低延迟需求方面&#xff0c;展现出了巨大的潜力。本文将深入探讨边缘计算系统的设计原理和实…

MySQL5 和 MySQL8 的配置区别 一些注意事项

1、使用命令行查看MySQL的版本 先保证你的mysql正在运行&#xff0c;假如用户名是root&#xff0c;密码是123456&#xff0c;运行下边的代码可以查看mysql的版本号。 mysql -uroot -p123456这里我的版本是5.7.19。也就是5版本的。 2、不同版本对应的数据库驱动jar包&#x…

【docker 】 安装docker(centOS7)

官网 docker官网 github源码 官网 在CentOS上安装Docker引擎 官网 在Debian上安装Docker引擎 官网 在 Fedora上安装Docker引擎 官网 在ubuntu上安装Docker引擎 官网 在RHEL (s390x)上安装Docker引擎 官网 在SLES上安装Docker引擎 最完善的资料都在官网。 卸载旧版本 …

AOP跨模块捕获异常遭CGLIB拦截而继续向上抛出异常

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、BUG详情 1.1 报错信息 1.2 接口响应信息 1.3 全局异常处理器的定义 二、排查过程 三、解决方案 四、总结 前言 最近&…

【SpringBoot】入门精简

目录 一、初识 SpringBoot 1.1 介绍 1.2 项目创建 1.3 目录结构 1.4 修改配置 二、SpringBoot 集成 2.1 集成 Mybatis框架 2.2 集成 Pagehepler分页插件 2.3 集成 Druid数据库连接池 2.4 集成 Log日志管理 一、初识 SpringBoot 1.1 介绍 Spring Boot是一个用于简化Sp…

无人零售柜:快捷舒适购物体验

无人零售柜&#xff1a;快捷舒适购物体验 通过无人零售柜和人工智能技术&#xff0c;消费者在购物过程中可以自由选择商品&#xff0c;根据个人需求和喜好查询商品清单。这种自主选择的购物环境能够为消费者提供更加舒适和满意的体验。此外&#xff0c;无人零售柜还具有节约时间…

Python手撕kmeans源码

参考了两篇文章 K-Means及K-Means算法Python源码实现-CSDN博客 使用K-means算法进行聚类分析_kmeans聚类分析结果怎么看-CSDN博客 # 定义kmeans类 from copy import deepcopy from sklearn.datasets import make_blobs import numpy as np import matplotlib.pyplot as pltc…

如何充分准备面试,迅速融入团队并在工作中取得卓越成就

首先&#xff0c;关于如何筹备面试&#xff0c;首先需要对所申请公司与职位进行深入的调查了解&#xff0c;并依据可能提出的面试问题预先准备相应的答案&#xff0c;并提前调试面试所需的仪器设备。同时&#xff0c;也要注重自身形象的塑造。更为关键的是 1. 在计算机领域的面…

搭建你的知识付费小程序平台:源码解析与技术实现

知识付费小程序平台在当今数字化时代扮演着越来越重要的角色&#xff0c;为教育者和学习者提供了一个灵活、便捷的学习环境。本文将以关键词“知识付费小程序源码”为基础&#xff0c;探讨如何搭建一个功能强大的知识付费小程序平台&#xff0c;并提供一些基础的技术代码示例。…

串口通信(1)-硬件知识

本文讲解串口通信的硬件知识。让读者快速了解硬件知识&#xff0c;为下一步编写代码做基础。 目录 一、概述 二、串口通信分类 2.1信息的传送方向进行分类 2.2同步通信和异步通信 三、串口协议 3.1 RS232 3.1.1 电气特性 3.1.2 连接器的机械特性 3.1.3 连接类型 3.1…

08.仿简道云公式函数实战-逻辑函数-IF

1. IF函数 IF 函数可用于判断一个条件能否满足&#xff1b;如果满足返回一个值&#xff0c;如果不满足则返回另外一个值。 2. 函数用法 IF(logical_test&#xff0c;value_if_true, value_if_false) 其中各参数的含义如下&#xff1a; logical_test&#xff1a;必需&#…

JVM虚拟机系统性学习-对象存活判断算法、对象引用类型和垃圾清除算法

垃圾回收 在 JVM 中需要对没有被引用的对象&#xff0c;也就是垃圾对象进行垃圾回收 对象存活判断算法 判断对象存活有两种方式&#xff1a;引用计数法、可达性分析算法 引用计数法 引用计数法通过记录每个对象被引用的次数&#xff0c;例如对象 A 被引用 1 次&#xff0c…

被迫搬家,宽带迁移怎么办?

广州一栋违建烂尾楼&#xff0c;13年里从未停止出租&#xff0c;年年住满人。这栋楼没有贴外墙&#xff0c;裸露的水泥表面都被雨水腐蚀&#xff0c;很多阳台没有建好&#xff0c;只是简单加装了护栏&#xff0c;存在巨大安全隐患。 为什么烂尾楼年年满人呢&#xff1f; 因为它…