Vision Transformer(VIT)模型介绍

news2024/12/26 0:24:01

计算机视觉


文章目录

  • 计算机视觉
  • Vision Transformer(VIT)
  • Patch Embeddings
  • Hybrid Architecture
  • Fine-tuning and higher resolution
  • PyTorch实现Vision Transformer


Vision Transformer(VIT)

Vision Transformer(ViT)是一种新兴的图像分类模型,它使用了类似于自然语言处理中的Transformer的结构来处理图像。这种方法通过将输入图像分解成一组图像块,并将这些块变换为一组向量来处理图像。然后,这些向量被输入到Transformer编码器中,以便对它们进行进一步的处理。ViT在许多计算机视觉任务中取得了与传统卷积神经网络相当的性能,但其在处理大尺寸图像和长序列数据方面具有优势。与自然语言处理(NLP)中的Transformer模型类似,ViT模型也可以通过预训练来学习图像的通用特征表示。在预训练过程中,ViT模型通常使用自监督任务,如图像补全、颜色化、旋转预测等,以无需人工标注的方式对图像进行训练。这些任务可以帮助ViT模型学习到更具有判别性和泛化能力的特征表示,并为下游的计算机视觉任务提供更好的初始化权重。
在这里插入图片描述

Patch Embeddings

Patch embedding是Vision Transformer(ViT)模型中的一个重要组成部分,它将输入图像的块转换为向量,以便输入到Transformer编码器中进行处理。

Patch embedding的过程通常由以下几个步骤组成:

图像切片:输入图像首先被切成大小相同的小块,通常是16x16、32x32或64x64像素大小。这些块可以重叠或不重叠,取决于具体的实现方式。
展平像素:每个小块内的像素被展平成一个向量,以便能够用于后续的矩阵计算。展平的像素向量的长度通常是固定的,与ViT的超参数有关。
投影:接下来,每个像素向量通过一个可学习的线性变换(通常是一个全连接层)进行投影,以便将其嵌入到一个低维的向量空间中。
拼接:最后,所有投影向量被沿着一个维度拼接在一起,形成一个大的二维张量。这个张量可以被看作是输入序列的一个矩阵表示,其中每一行表示一个图像块的嵌入向量。
通过这些步骤,Patch embedding将输入的图像块转换为一组嵌入向量,这些向量可以被输入到Transformer编码器中进行进一步的处理。Patch embedding的设计使得ViT能够将输入图像的局部特征信息编码成全局特征,从而实现了对图像的整体理解和分类。
在这里插入图片描述

Inductive bias
在Vision Transformer(ViT)模型中,也存在着Inductive bias,它指的是ViT模型的设计中所假定的先验知识和偏见,这些知识和偏见可以帮助模型更好地学习和理解输入图像。

ViT的Inductive bias主要包括以下几个方面:

图像切片:ViT将输入图像划分为多个大小相同的块,每个块都是一个向量。这种切片方式的假设是,输入图像中的相邻区域之间存在着相关性,块内像素的信息可以被整合到一个向量中。
线性投影:在Patch embedding阶段,ViT将每个块的像素向量通过线性投影映射到一个较低维度的向量空间中。这种映射方式的假设是,输入图像的特征可以被表示为低维空间中的点,这些点之间的距离可以捕捉到图像的局部和全局结构。
Transformer编码器:ViT的编码器部分采用了Transformer结构,这种结构能够对序列中的不同位置之间的依赖关系进行建模。这种建模方式的假设是,输入图像块之间存在着依赖关系,这些依赖关系可以被利用来提高模型的性能。
通过这些Inductive bias,ViT模型能够对输入图像进行有效的表示和学习。这些假设和先验知识虽然有一定的局限性,但它们可以帮助ViT更好地处理图像数据,并在各种计算机视觉任务中表现出色。

Hybrid Architecture

在ViT中,Hybrid Architecture是指将卷积神经网络(CNN)和Transformer结合起来,用于处理图像数据。Hybrid Architecture使用一个小的CNN作为特征提取器,将图像数据转换为一组特征向量,然后将这些特征向量输入Transformer中进行处理。

CNN通常用于处理图像数据,因为它们可以很好地捕捉图像中的局部和平移不变性特征。但是,CNN对于图像中的全局特征处理却有一定的局限性。而Transformer可以很好地处理序列数据,包括文本数据中的全局依赖关系。因此,将CNN和Transformer结合起来可以克服各自的局限性,同时获得更好的图像特征表示和处理能力。

在Hybrid Architecture中,CNN通常被用来提取局部特征,例如边缘、纹理等,而Transformer则用来处理全局特征,例如物体的位置、大小等。具体来说,Hybrid Architecture中的CNN通常只包括几层卷积层,以提取一组局部特征向量。然后,这些特征向量被传递到Transformer中,以捕捉它们之间的全局依赖关系,并输出最终的分类或回归结果。

相对于仅使用Transformer或CNN来处理图像数据,Hybrid Architecture在一些图像任务中可以取得更好的结果,例如图像分类、物体检测等。

Fine-tuning and higher resolution

在ViT模型中,我们通常使用一个较小的分辨率的输入图像(例如224x224),并在预训练阶段将其分成多个固定大小的图像块进行处理。然而,当我们将ViT模型应用于实际任务时,我们通常需要处理更高分辨率的图像,例如512x512或1024x1024。

为了适应更高分辨率的图像,我们可以使用两种方法之一或两种方法的组合来提高ViT模型的性能:

Fine-tuning: 我们可以使用预训练的ViT模型来初始化网络权重,然后在目标任务的数据集上进行微调。这将使模型能够在目标任务中进行特定的调整和优化,并提高其性能。
Higher resolution: 我们可以增加输入图像的分辨率来提高模型的性能。通过处理更高分辨率的图像,模型可以更好地捕捉细节信息和更全面的视觉上下文信息,从而提高模型的准确性和泛化能力。
通过Fine-tuning和Higher resolution这两种方法的组合,我们可以有效地提高ViT模型在计算机视觉任务中的表现。这种方法已经在许多任务中取得了良好的结果,如图像分类、目标检测和语义分割等。

PyTorch实现Vision Transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets

# 定义ViT模型
class ViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072):
        super(ViT, self).__init__()
        
        # 输入图像分块
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = 3 * patch_size ** 2
        self.proj = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        
        # Transformer Encoder
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth)
        
        # MLP head
        self.layer_norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(dim, num_classes)
        
    def forward(self, x):
        # 输入图像分块
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        
        # Transformer Encoder
        x = self.transformer_encoder(x)
        
        # MLP head
        x = self.layer_norm(x.mean(1))
        x = self.fc(x)
        
        return x

# 加载CIFAR-10数据集
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 实例化ViT模型
model = ViT()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 训练模型
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    # 训练模式
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计训练损失和准确率
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels.data)
        
    train_loss = train_loss / len(train_loader.dataset)
    train_acc = train_acc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1395339.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法笔记 #3

2024年1月18日 Q1:搜索二叉树 A:查找,左子树比根几点小,右子树比根大。 删除:1)搜索删除的目标节点,记录其父节点; 2)左右孩子都为空,直接删掉;…

【2023地理设计组一等奖】城市业态基因图谱与城市风格感知

作品介绍 1 设计思路 1.1 作品背景 随着城市化进程不断推进,城市的餐饮、购物、休闲和文化业态成为城市发展中的重要组成部分,深刻影响着城市经济社会发展和居民生活,塑造了城市的外在形象和内在风格,形成了各具特色的城市发展特征。因此,在区域和全国尺度上研究城市业态…

TDengine 创始人陶建辉在汽车 CIOCDO 论坛发表演讲,助力车企数字化转型

当前,汽车行业的数字化转型如火如荼。借助数字技术的充分利用,越来越多的车企进一步提升了成本优化、应用敏捷性、高度弹性和效率。这一转型使得业务应用的开发和管理模式发生了颠覆性的创新,赋予了汽车软件快速响应变化和动态调度资源的能力…

深度学习(1)--基础概念

一.计算机视觉(CV) (1).计算机视觉中图像表示为三位数组,其中三维数组中像素的值为0~255,像素的值越低表示该点越暗,像素的值越高表示该点越亮。 (2).图像表示 A*B*C,其中A,B分别为图像的长和宽,C则表示图像的颜色通道…

Spring Boot 单体应用升级 Spring Cloud 微服务

作者:刘军 Spring Cloud 是在 Spring Boot 之上构建的一套微服务生态体系,包括服务发现、配置中心、限流降级、分布式事务、异步消息等,因此通过增加依赖、注解等简单的四步即可完成 Spring Boot 应用到 Spring Cloud 升级。 *Spring Cloud …

vim 编辑器如何同时注释多行以及将多行进行空格

一、场景 YAML文件对空格的要求非常严格,因此在修改YAML时,我们可能需要批量添加空格。 二、操作步骤 请注意:您的所有操作都将以第一行为基准。也就是说,第一行有多少个空格,下面的行就会模仿添加相同数量的空格。…

一款开源且不限制大小可以设置过期时间的支持分享的的开源文件共享系统picoshare 部署教程

1.拉取镜像 2.部署 创建目录 mkdir -p /opt/picoshare/data 部署 其中:"somesecretpass"是密码 docker run \--env "PORT4001" \--env "PS_SHARED_SECRETsomesecretpass" \--publish 10005:4001/tcp \--volume "/opt/picoshare/data:…

[Android] Android架构体系(2)

文章目录 Bionic精简对系统调用的支持:不支持 System V IPC:有限的 Pthread 功能:有限支持C:不再支持本地化和/或宽字符:Bionic新增的特性系统属性硬编码写死的UID/GID内置了DNS解析硬编码写死的服务和协议 硬件抽象层Linux内核匿名共享内存(ASHMem)Binder-BinderLoggerION 内存…

找不到满意的机器视觉工程师工作,想找到满意的工作很难,人太多了,只能先就业后择业

找不到满意的机器视觉工程师工作,想找到满意的工作很难,人太多了,只能先就业后择业

ICCV2023 | PTUnifier+:通过Soft Prompts(软提示)统一医学视觉语言预训练

论文标题:Towards Unifying Medical Vision-and-Language Pre-training via Soft Prompts 代码:https://github.com/zhjohnchan/ptunifier Fusion-encoder type和Dual-encoder type。前者在多模态任务中具有优势,因为模态之间有充分的相互…

Docker(二)安装指南

作者主页: 正函数的个人主页 文章收录专栏: Docker 欢迎大家点赞 👍 收藏 ⭐ 加关注哦! 安装 Docker Docker 分为 stable test 和 nightly 三个更新频道。 官方网站上有各种环境下的 安装指南,这里主要介绍 Docker 在…

Vue学习笔记9--vuex(专门在Vue中实现集中式状态(数据)管理的一个Vue插件)

一、vuex是什么? 概念:专门在Vue中实现集中式状态(数据)管理的一个Vue插件,对vue应用中多个组件的共享状态进行集中式的管理(读/写),也是一种组件间通信的方式,且适用于…

LeetCode、2462. 雇佣 K 位工人的总代价【中等,最小堆+双指针】

文章目录 前言LeetCode、2462. 雇佣 K 位工人的总代价【中等,最小堆双指针】题目及类型思路及代码实现 资料获取 前言 博主介绍:✌目前全网粉丝2W,csdn博客专家、Java领域优质创作者,博客之星、阿里云平台优质作者、专注于Java后…

rust获取本地外网ip地址的方法

大家好,我是get_local_info作者带剑书生,这里用一篇文章讲解get_local_info的使用。 get_local_info是什么? get_local_info是一个获取linux系统信息的rust三方库,并提供一些常用功能,目前版本0.2.4。详细介绍地址&a…

【踩坑日志】SpringBoot读取nacos配置信息并提取信息中的IP地址(配置属性解析异常+排错记录)

缘起 :项目需读取nacos中动态的TDengine数据库连接信息并提取IP,一个并不复杂的操作,但作为一个nacos知识浅薄的菜鸡,我愣是捯饬了几个小时……惭愧惭愧…… 异常代码 Data Component public class TaosLink { // Value("…

在PyCharm中创建Flask项目

在 PyCharm 中创建 Flask 项目的步骤如下: 打开 PyCharm,并选择 "Create New Project"(新建项目)。在弹出的窗口中,选择左侧的 "Python" 选项,然后选择右侧的 "Flask" 项目…

【JavaEE】_网络通信原理

目录 1. 网络发展史 2. 网络通信基础 1.1 IP地址 1.2 端口号 1.3 协议 1.3.1 概念 1.3.2 五元组 1.4 协议分层 1.4.1 协议分层的优点 1.4.2 协议分层的分类 1.4.3网络设备所在分层 1.4.4 两台主机通过TCP/IP协议通讯过程 1.5 封装与分用 1.5.1 封装 1.5.2 分用…

【EI会议征稿通知】第四届工业制造与结构材料国际学术会议(IMSM 2024)

第四届工业制造与结构材料国际学术会议(IMSM 2024) 2024 4th International Conference on Industrial Manufacturing and Structural Materials(IMSM 2024) 第四届工业制造与结构材料国际学术会议(IMSM 2024&#x…

TypeScript依赖注入框架Typedi的使用、原理、源码解读

简介 typedi是一个基于TS的装饰器和reflect-metadata的依赖注入轻量级框架,使用简单易懂,方便拓展。 使用typedi的前提是安装reflect-metadata,并在项目的入口文件的第一行中声明import ‘reflect-metadata’,这样就会在原生的R…

【图解数据结构】深度解析时间复杂度与空间复杂度的典型问题

🌈个人主页:聆风吟 🔥系列专栏:图解数据结构、算法模板 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 一. ⛳️上期回顾二. ⛳️常见时间复杂度计算举例1️⃣实例一2️⃣实例二3️⃣实例三4️⃣实例四5…