计算机视觉之Vision Transformer图像分类

news2024/11/15 12:17:49

Vision Transformer(ViT)简介

自注意结构模型的发展,特别是Transformer模型的出现,极大推动了自然语言处理模型的发展。Transformers的计算效率和可扩展性使其能够训练具有超过100B参数的规模空前的模型。ViT是自然语言处理和计算机视觉的结合,能够在图像分类任务上取得良好效果,而不依赖卷积操作。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

vit-architecture

模型特点

ViT模型是一种用于图像分类的模型,将原图像划分为多个图像块,然后将这些图像块转换为一维向量,加上类别向量和位置向量作为模型输入。模型主体采用基于Transformer的Encoder结构,但调整了Normalization的位置,其中最主要的结构是Multi-head Attention。模型在Blocks堆叠后接全连接层,使用类别向量的输出进行分类,通常将全连接层称为Head,Transformer Encoder部分称为backbone。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

transformer-architecture

模型训练

模型训练前需要设定损失函数、优化器、回调函数等,以及建议根据项目需要调整epoch_size。训练ViT模型需要很长时间,可以通过输出的信息查看训练的进度和指标。

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# construct model
network = ViT()

# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"

vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)


# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss


network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

总结

本案例演示了如何在ImageNet数据集上训练、验证和推断ViT模型。通过讲解ViT模型的关键结构和原理,帮助用户理解Multi-Head Attention、TransformerEncoder和pos_embedding等关键概念。建议用户基于源码深入学习,以更详细地理解ViT模型的原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第零章 HCIA复习

目录 HCIA复习 一、OSI七层模型 二、TCP/UDP协议 传输层协议 TCP/IP协议簇 封装/解封装 模型区别 协议号和类型字段 类型字段 协议号 UDP协议头部: TCP协议头部: IP报文参数: 三、DHCP协议 定义 PC端初次获取IP地址 交换机转…

Linux vim的使用(一键安装则好用的插件_forcpp),gcc的常见编译链接操作

vim 在Linux系统上vim是个功能还比较完善的软件。但是没装插件的vim用着还是挺难受的,所以我们直接上一款插件。 我们只需要在Linux上执行这个命令就能安装(bite提供的) curl -sLf https://gitee.com/HGtz2222/VimForCpp/raw/master/install.sh -o ./install.sh …

基于1bitDAC的MU-MIMO的非线性预编码算法matlab性能仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1 基于1-bit DAC的非线性预编码背景 4.2 ZF(Zero-Forcing) 4.3 WF(Water-Filling) 4.3 MRT(Maximum Ratio Transmission&…

昇思25天学习打卡营第09天|保存与加载

在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。 import numpy as np import mindspore from mindspore import nn fr…

【云岚到家】-day05-6-项目迁移-门户-CMS

【云岚到家】-day05-6-项目迁移-门户-CMS 4 项目迁移-门户4.1 迁移目标4.2 能力基础4.2.1 缓存方案设计与应用能力4.2.2 静态化技术应用能力 4.3 需求分析4.3.1 界面原型 4.4 系统设计4.4.1 表设计4.4.2 接口与方案4.4.2.1 首页信息查询接口4.4.3.1 数据缓存方案4.4.3.2 页面静…

C++相关概念和易错语法(20)(赋值兼容转换、多继承、继承与组合)

1.赋值兼容转换 赋值兼容转换有一点易混&#xff0c;先看一下下面的代码&#xff0c;想想a、b、c对象里面存的什么&#xff0c;顺便结合以前的知识&#xff0c;对继承加深理解。 #include <iostream> using namespace std;class A { protected:A(int a):_a(a){}int _a; …

PostgreSQL 中如何解决因频繁的小事务导致的性能下降?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 PostgreSQL 中解决因频繁小事务导致性能下降的方法 PostgreSQL 中解决因频繁小事务导致性能下降的方法…

使用机器学习 最近邻算法(Nearest Neighbors)进行点云分析 (scikit-learn Open3D numpy)

使用 NearestNeighbors 进行点云分析 在数据分析和机器学习领域&#xff0c;最近邻算法&#xff08;Nearest Neighbors&#xff09;是一种常用的非参数方法。它广泛应用于分类、回归和聚类分析等任务。下面将介绍如何使用 scikit-learn 库中的 NearestNeighbors 类来进行点云数…

C++ | Leetcode C++题解之第233题数字1的个数

题目&#xff1a; 题解&#xff1a; class Solution { public:int countDigitOne(int n) {// mulk 表示 10^k// 在下面的代码中&#xff0c;可以发现 k 并没有被直接使用到&#xff08;都是使用 10^k&#xff09;// 但为了让代码看起来更加直观&#xff0c;这里保留了 klong l…

优化器算法

优化器算法 梯度下降算法 首先引用动手学深度学习中对梯度下降算法的直观理解与推导。说明了不断的迭代可能会使得f(x)的值不断下降&#xff0c;从直观上解释了梯度下降的可能性。 将损失函数在x点处一阶泰勒展开。 f ( x ϵ ) f ( x ) ϵ f ′ ( x ) O ( ϵ 2 ) . f(x\eps…

在InternStudio上创建一台GPU服务器

填写配置 创建完成 ssh连接&#xff0c;并测试常用指令 查看开发机信息 查看gpu信息 创建conda环境 跑个test

可重入锁深入学习(有码)

【摘要】 ​今天&#xff0c;梳理下java中的常用锁&#xff0c;但在搞清楚这些锁之前&#xff0c;先理解下 “临界区”。临界区在同步的程序设计中&#xff0c;临界区段活称为关键区块&#xff0c;指的是一个访问共享资源&#xff08;例如&#xff1a;共享设备或是共享存储器&a…

9. Python的魔法函数

Python中的魔法函数 在Python中魔法函数是在为类赋能&#xff0c;使得类能够有更多操作。通过重写类中的魔法函数&#xff0c;可以完成很多具体的任务 1. __str__ 通过str魔法函数&#xff0c;可以设置对类的实例的 print() 内容 2. __len__ 通过len魔法函数&#xff0c;可…

tessy 集成测试:小白入门指导手册

目录 1,创建集成测试模块且分析源文件 2,设置测试环境 3,TIE界面设置相关函数 4,SCE界面增加用例 5,编辑数据 6,用例所对应的测试函数序列 7,添加 work task 函数 8,为测试场景添加函数 9,为函数赋值 10,编辑时间序列的数值 11,执行用例 12,其他注意事项…

计算机毕设:服装购物管理系统(Java+Springboot+MySQL+Tomcat),完整源代码+数据库+毕设文档+部署说明

本文关键字&#xff1a;Java编程&#xff1b;Springboot框架&#xff1b;毕业设计&#xff1b;毕设项目&#xff1b;编程实战&#xff1b;医护人员管理系统&#xff1b;项目源代码&#xff1b;程序数据库&#xff1b;毕设文档&#xff1b;项目部署说明&#xff1b; 一、项目说…

Java中JUC包详解

文章目录 J.U.C.包LockReadWriteLockLockSupportAQSReentrantLock对比synchronized加锁原理释放锁原理 CountDownLatchCyclicBarrierSemaphore J.U.C.包 java.util.concurrent&#xff0c;简称 J.U.C.。是Java并发工具包&#xff0c;提供了在多线程编程中常用的工具类和框架&a…

实战检验:Orange Pi AIpro AI开发板的性能测试与使用体验

文章目录 前言Orange Pi AIpro 简介Orange Pi AIpro 体验将Linux镜像烧录到TF卡YOLO识别视频中物体肺部CT识别 Orange Pi AIpro 总结 前言 Orange Pi AIpro&#xff0c;作为首款基于昇腾技术的AI开发板&#xff0c;它集成了高性能图形处理器&#xff0c;配备8GB/16GB LPDDR4X内…

MySQL复合查询(重点)

前面我们讲解的mysql表的查询都是对一张表进行查询&#xff0c;在实际开发中这远远不够。 基本查询回顾 查询工资高于500或岗位为MANAGER的雇员&#xff0c;同时还要满足他们的姓名首字母为大写的J mysql> select * from emp where (sal>500 or jobMANAGER) and ename l…

强化学习:bellman方程求解state value例题

最近在学习强化学习相关知识&#xff0c;强烈推荐西湖大学赵世钰老师的课程&#xff0c;讲解的非常清晰流畅&#xff0c;一路学习下来令人身心大爽&#xff0c;感受数学抽丝剥茧&#xff0c;化繁为简的神奇魅力。 bellman方程还是比较容易理解的&#xff1a;当前状态下的state …

嵌入式linux系统中GDB调试器详解

前言 GDB全称GNU symbolic debugger,它是诞生于GNU开源组织的(同时诞生的还有 GCC、Emacs 等)UNIX及UNIX-like下的调试工具,是Linux下最常用的程序调试器,GDB 支持调试多种编程语言编写的程序,包括C、C++、Go、Objective-C、OpenCL、Ada 等。但是在实际应用中,GDB 更常…