Vision Transform—用于大规模图像分类的Transformers架构

news2025/1/12 3:01:04

VIT — 用于大规模图像识别的 Transformer

论文题目:AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE。
官方代码:https://github.com/google-research/vision_transformer

引言与概述

Vision Transformer(ViT)是一种基于注意力机制的深度学习模型,它是由 Google 提出的,旨在将Transformer架构应用到计算机视觉任务中。它的提出证明了: Transformer 在CV领域的可行性。

在这里插入图片描述

Transformers在之前可以和CNN结合起来用增强特征图的方式用于分类和检测。

  1. 而在vision Transform中没有使用到深度的2d卷积,完全使用Transform中的多头自注意力机制,与Encode结构进行实现的

  2. 输入的patches的结构经过一个vit块后的结构保持不变,因此可以用来直接对vit块进行N次的堆叠操作。

Vision Transformer(ViT)模型架构是一种基于 Transformer 架构的深度学习模型,用于处理计算机视觉任务(最主要的是会应用到分类的任务中)。下面是 Vision Transformer 的模型框架。

在这里插入图片描述

  • Linear Projection of Flattened Patches模块(Embedding层):将输入的图像数据转换为可以输入到 Transformer 编码器中的序列化表示,其中包括了patch+position+learnable embedding。

  • Transformer encoder (Transformer编码层):向量表示被输入到 Transformer 编码器中。每个 Transformer 编码器包含多头自注意力机制和前馈神经网络,用于捕捉全局信息和学习特征表示。这一部分是 Vision Transformer 中最关键的组件。

  • MLP Head(用于分类的全连接层):在经过一系列 Transformer 编码器之后,模型的输出会被送入一个包含多层感知机(MLP)的输出层中,用于最终的分类或其他计算机视觉任务。

处理流程

  1. 输入阶段: 首先将输入的原始图像按照给定大小切分成固定大小的图像块(patches),每个图像块包含图像中的局部信息。 每个图像块通过一个线性变换(通常是一个卷积层)映射到一个低维的特征空间,得到Patch Embeddings。同时,为每个Patch Embedding加入位置编码(Positional Embedding)和可学习嵌入(learnable Embedding) 以综合考虑图像的空间位置和全局信息。

  2. 输入到Transformer: 将经过嵌入层处理的序列化表示作为输入,传入到多层Transformer Encoder来对序列化表示进行处理。

  3. 输出分类结果: 经过一系列Transformer编码器的处理后,模型的最后一层输出向量经过全连接层或其他分类层进行分类任务。

图像处理与前向传播

在论文中提到过,在Vit Transform提出之前,其实就有人使用过将图像中的每一个像素看作是一个序列数据,本质上就是将图像的像素数据进行一个展平操作=WxHx3个像素,就会导致图像得到的序列数据过长从而影响整个的运算。

作者就针对这一个问题提出了使用Transformer receives as input a 1D
sequence of token embeddings. To handle 2D images, we reshape the image x ∈ RH×W×C into asequence of flattened 2D patches xp ∈ RN×(P2·C)(使用patch方法将2d的数据转换为1d的数据,但是是切分为多个小的patch块进行的。

Patch Embedding

Patch Embedding 是指将输入的图像划分为固定大小的图像块(patches)后,将每个图像块映射成一个向量表示,最终所有的图像块被变换成满足transformer输入的一维表示

将输入为224x224的图像,切分为196个16x16的Patch图像块。
在这里插入图片描述
我们从50176个token转变为了196个token从而降低了序列的长度。

  1. 图像划分为图像块: 输入的原始图像被切分成大小相同的图像块,每个图像块通常是不重叠的固定大小的方形区域。这样做的目的是为了将图像信息分割为局部区域,使得模型能够处理不同尺寸的图像。 例:以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到 ( 224 / 16 ) 2 = 196 ( 224 / 16 )^2 = 196 (224/16)2=196个Patches,每个Patche数据shape为[16, 16, 3]。

  2. 映射为向量表示: 对于每个图像块,通过一个线性变换(一个卷积层)将其映射成一个一维特征向量,也称为 Patch Embedding。这个过程可以理解为将图像块中的像素信息转换为一个固定维度的向量,以表示该图像块的特征。 例:每个Patche数据通过映射得到一个长度为768的向量,即[16, 16, 3] -> [768]

在这里插入图片描述

  1. 串联所有 Patch Embedding: 将所有图像块经过 Patch Embedding 后得到的向量表示串联在一起,形成一个序列化的特征矩阵。这个矩阵作为Transformer的输入,传入Transformer编码器进行处理。 例:将196个Patchs串联起来,最终组成[196,768]的二维token向量,token的个数是196,token维度是768

Class token + Positional Embedding

Class token
在ViT模型的原论文中,作者模仿BERT模型,为ViT 模型中引入一个专门用于分类的 [class] token。这个 [class] 是一个可训练的参数,其数据格式和其他 token 一样都是一个向量,例如在 ViT-B/16 模型中,这个向量的长度是768

其中引入的class token代表的是一个可以学习的张量。

引入了class token之后就需要和经过Linear Projection of Flattened Patches展平之后得到的向量进行一个cat的操作。

在这里插入图片描述

下面给出运算步骤的一个示意图。‘

在这里插入图片描述
我们最后得到的信息就是,批次数 + (196+1)[经过cat连接操作] +768(论文中D的长度)

向量拼接: 将 [class] token 的向量与之前从图片中生成的 tokens 拼接在一起。例如,原先有 196 个 tokens,每个 token 长度为 768,则token embeddding变为Cat([1, 768], [196, 768]) -> [197, 768],注意, [ c l a s s ] [class] [class] token插在开头。

之后我们还需要加入位置编码就可以对位置信息进行编码从而输入序列数据进入Transform的多头自注意力机制模块。

在这里插入图片描述

ViT的Position Embedding采用的是一个可学习/训练的 1-D 位置编码嵌入,是直接叠加在tokens上的(add), 例: position embedding的shape应该是[197,768],直接position embedding + [ c l a s s ] [class] [class] token embedding=[197,768]

相加的操作不改变整体的一个维度信息。

在这里插入图片描述

使用什么样的位置编码对结果的影响不是特别的大。

Transformer encode

Transformer encoder在代码中有时是打包在Block里面的。

Transformer Encoder 是用来处理输入序列的部分,它由多个 Transformer Block 组成。每个 Transformer Block 都包括两个主要组件:多头自注意力层和前馈神经网络。

在这里插入图片描述

包括了两个残差连接的部分所组成。

  • Multi-head Self-Attention Layer: 在每个 Transformer Block 中,输入特征首先被送入一个多头自注意力层。这个层用来计算输入序列中每个位置对应的注意力权重,以捕捉不同位置之间的关系。

  • Feed-Forward Neural Network: 在经过多头自注意力层后,每个位置的特征会通过一个前馈神经网络进行处理。这个前馈神经网络通常由两个全连接层和激活函数组成,用来对每个位置的特征进行非线性变换和映射

  • Residual Connection & Layer Normalization: 在每个 Transformer Block 的多头自注意力层和前馈神经网络中都会包含残差连接和层归一化操作,这有助于缓解梯度消失问题和加速训练过程

多个Transformer encode模块进行堆叠,ViT 模型能够有效地学习输入图像的复杂特征和结构信息。(代码中好像是连续堆叠三次)多头注意力部分使用之前的类进行单独的定义。

论文中给出的一个前向传播的基本的公式形式。

在这里插入图片描述

等式 1:由 图像块嵌入 x p i E x_p^iE xpiE,类别向量 x c l a s s x_{class} xclass,位置编码 E p o s E_{pos} Epos构成 输入向量 z 0 z_0 z0
等式 2:由 多头注意力机制、层归一化和跳跃连接构成的 MSA Block,可重复L个,最后一个输出为 z l ‘ z_l^{‘} zl
等式 3:由 前馈网络、层归一化 和 跳跃连接构成的 MLP Block,可重复L个,最后一个输出为 z l z_l zl
等式 4:由 层归一化 和 分类头 (MLP or FC) 输出 图像表示 y y y

MLP Head

在这里插入图片描述

先进行一个4倍的扩增,之后在通过一个线性层进行下采样的一个操作,使得输入输出的大小保持不变。(隐藏层拓为4倍)

注意和之前还有做一次残差连接的操作。

MLP Head 是指位于模型顶部的全连接前馈神经网络模块,用于将提取的图像特征表示转换为最终的分类结果或其他预测任务输出。MLP Head 通常跟在 Transformer Encoder 的输出之后,作为整个模型的最后一层。

具体来说,当我们只需要分类信息时,只需要提取出 [ c l a s s ] [class] [class]token生成的对应结果就行,即[197, 768]中抽取出[class] token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

在这里插入图片描述
取出class token对应的那一部分代码信息,(197-1)

维度变换总结

在这里插入图片描述

总结:

  1. 输入图像的input shape=[1,3,224,224],1是batch_size,3是通道数,224是高和宽输入图像经过patch Embedding,其中Patch大小是14,卷积核是768,则经过分块后,获得的块数量是 196,每个块的维度被转换为768,即得到的patch embedding的shape=[1,196,768]

  2. 将可学习的[class] token embedding拼接到patch embedding前,得到shape=[1,197,768]

  3. 将position embedding加入到拼接后的embedding中,组成最终的输入嵌入,最终的输入shape=[1,197,768]

  4. 输入嵌入送入到Transformer encoder中,shape并不发生变化

  5. 最后transformer的输出被送入到MLP或FC中执行分类预测,选取[class] token作为分类器的输入,以表示整个图像的全局信息,假设分类的类目为K,最终的shape=[1,768]*[768,K]=[1,K] K也就是分类的类别数目。

最后总结的VIT-b/16模型所对应的结构示意图如图所示。

在这里插入图片描述

论文中还给出了通过训练之后得到的位置编码直接的余弦相似度的示意图。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2150549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虚拟机vaware中cpu设置跑满大核

首先,大核速度快,并且在资源紧张时大核优先,小核甚至是闲着围观大核跑满。其次,遇到经常切换操作虚拟机和win11的使用场景,切换核心本身也会造成一点卡顿,降低虚拟机里操作流畅度。另外,13代在你…

【linux】4张卡,坏了1张,怎么办?

先禁用这张卡 grub 禁用,防止加载驱动 禁用这张卡的 PCI # 禁用 PCI 设备 0000:b1:00.0 (NVIDIA GPU) ACTION"add", SUBSYSTEM"pci", ATTR{vendor}"0x10de", KERNELS"0000:b1:00.0", RUN"/bin/sh -c echo 0000:b1:00…

vue part 10

vue-resource 在vue1.0时代讲的比较多,是vue.插件库, import vueResource from vue-resourceVue.use(vueResource) 在vc和vm中会多出如下F12代码即,$http:() 他的用法和返回值和axios一模一样,但是不常维护了 插槽 默认插槽 …

11年计算机考研408-数据结构

设执行了k次。 解析: d要第一个出,那么abc先入栈,d入栈然后再出栈,这前面是一个固定的流程,后面就很灵活了,可以ecba,ceba,cbea,cbae。 答案是4个序列。 解析&#xff1a…

解决redis缓存击穿问题之布隆过滤器

布隆过滤器 1. 什么是布隆过滤器 布隆过滤器(Bloom Filter)是一个空间效率很高的数据结构,用于判断一个元素是否在一个集合中。布隆过滤器的核心思想是利用位数组和一系列随机映射函数(哈希函数)来快速判断某个元素是…

基于SpringBoot+Vue+MySQL的网上租赁系统

系统展示 用户前台界面 管理员后台界面 系统背景 在当前共享经济蓬勃发展的背景下,网上租赁系统作为连接租赁双方的重要平台,正逐步改变着人们的消费观念和生活方式。通过构建一个基于SpringBoot、Vue.js与MySQL的网上租赁系统,我们旨在为用户…

LangChain 和 Elasticsearch 加速构建 AI 检索代理

作者:来自 Elastic Joe McElroy, Aditya Tripathi, Serena Chou Elastic 和 LangChain 很高兴地宣布发布新的 LangGraph 检索代理模板,旨在简化需要代理使用 Elasticsearch 进行代理检索的生成式人工智能 (GenAI) 代理应用程序的开发。此模板预先配置为使…

基于机器学习的癌症数据分析与预测系统实现,有三种算法,bootstrap前端+flask

研究背景 癌症作为全球范围内最主要的死亡原因之一,已成为当代医学研究和公共健康的重大挑战。据世界卫生组织(WHO)的统计,癌症每年导致全球数百万人的死亡。随着人口老龄化、环境污染和生活方式的改变,癌症的发病率逐…

Pytorch学习---基于经典网络架构ResNet训练花卉图像分类模型

基于经典网络架构训练图像分类模型 导包 import copy import json import time import torch from torch import nn import torch.optim as optim import torchvision import os from torchvision import transforms, models, datasets import numpy as np import matplotlib.…

【使用Hey对vllm接口压测】模型并发能力

使用Hey对vllm进行模型并发压测 docker run --rm --networkknowledge_network \registry.cn-shanghai.aliyuncs.com/zhph-server/hey:latest \-n 200 -c 200 -m POST -H "Content-Type: application/json" \-H "Authorization: xxx" \-d {"model"…

【类型黑市】指针

大家好我是#Y清墨,今天我要介绍的是指针。 意义 指针就是存放内存地址的变量。 分类 因为变量本身是分类型的,我们学过的变量类型有 int, long long, char, double, string, 甚至还有结构体变量。 同样,指针也分类型,如果指针指向…

云韧性,现代云服务不可或缺的组成部分

韧性,一个物理学概念,表示材料在变形或者破裂过程中吸收能量的能力。韧性越好,则发生脆性断裂的可能性越小。 如今,韧性也延伸到企业特质、产品特征等之中,用于形容企业、产品乃至服务的优劣。同样,随着云…

3. Internet 协议的安全性

3. Internet 协议的安全性 (1) 常用网络协议的功能、使用的端口及安全性 HTTP协议 功能:用于从服务器传输超文本到本地浏览器。端口:默认是80端口。安全性:不提供数据加密,存在数据泄露和中间人攻击风险。使用HTTPS协议(443端口)可以增强安全性。FTP协议 功能:实现文件的…

电脑录课软件哪个好用,提高教学效率?电脑微课录屏软件推荐

在当今这个数字化时代,教育领域也迎来了翻天覆地的变化。随着远程教学和在线学习的普及,教师们开始寻求更高效、更便捷的教学工具来提升教学质量和学生的学习体验。电脑录课软件,作为现代教育技术的重要组成部分,能够帮助教师轻松…

【CPP】类与继承

14 类与继承 在前面我们提到过继承的一些概念,现在我们来回顾一下 打个比方:在CS2中我们把玩家定义为一个类 class 玩家: 血量:100阵营(未分配)服饰(未分配)位置(未分配)武器(未分配)是否允许携带C4(未分配)是否拥有C4(未分配) 当对局创建时,会新生成两个类,这两个类继承自&qu…

【Linux庖丁解牛】—Linux基本指令(上)!

🌈个人主页:秋风起,再归来~🔥系列专栏: Linux庖丁解牛 🔖克心守己,律己则安 目录 1、 pwd命令 2、ls 指令 3、cd 指令 4、Linux下的根目录 5、touch指令 6、 stat指令 7、mkdi…

LabVIEW提高开发效率技巧----采用并行任务提高性能

在复杂的LabVIEW开发项目中,合理利用并行任务可以显著提高系统的整体性能和响应速度。并行编程是一种强大的技术手段,尤其适用于实时控制、数据采集以及多任务处理等场景。LabVIEW的数据流编程模型天然支持并行任务的执行,结合多核处理器的硬…

OrCAD使用,快捷键,全选更改封装,导出PCB网表

1 模块名称 2 快捷键使用 H: 镜像水平 V:镜像垂直 R: 旋转 I: 放大 O: 放小 P:放置元器件 W: 步线 B: 总线(无电气属性) E: 总线连接符(和BUS一起用&#xff09…

【网络通信基础与实践第四讲】用户数据报协议UDP和传输控制协议TCP

一、UDP的主要特点 1、UDP是无连接的,减少了开销和发送数据之前的时延 2、UDP使用尽最大努力交付,但是不保证可靠交付 3、UDP是面向报文的。从应用层到运输层再到IP层都只是添加一个相应的首部即可 4、UDP没有拥塞机制,源主机以恒定的速率…

基于JAVA+SpringBoot+Vue的学生干部管理系统

基于JAVASpringBootVue的学生干部管理系统 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末附源码下载链接🍅 哈…