昇思25天学习打卡营第9天|MindSpore-Vision Transformer图像分类

news2024/12/25 0:35:25

Vision Transformer图像分类

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取¶

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

实践环境安装

Python环境为 3.9.19

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

# or 
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple mindspore==2.2.14

数据准备

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)


import os

import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms


data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。

3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

Self-Attention的全部过程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1888147.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2.2.5 C#中显示控件BDPictureBox 的实现----ROI交互续2

2.2.5 C#中显示控件BDPictureBox 的实现----ROI交互续2 1 ROI数组作用说明 变量:m_ROIs[5] ROI 使用效果图 ROI数组说明 2 ROI显示逻辑图 ROI 交互主要是在设定状态下, runmode下只要普通显示即可 3 主要ROI显示函数函数 判断当前鼠标是否获取…

20240702在vmware17.5虚拟机中让ubuntu22.04使用主机的代理上网

20240702在vmware17.5虚拟机中让ubuntu22.04使用主机的代理上网 2024/7/2 14:41 百度:vmware 虚拟机 使用主机代理 上网 https://blog.csdn.net/nomoremorphine/article/details/138738065?utm_mediumdistribute.pc_relevant.none-task-blog-2~default~baidujs_ba…

周下载量20万的npm包---store

https://www.npmjs.com/package/store <script setup> import { onMounted } from vue import store from storeonMounted(() > {store.set(user, { name: xutongbao })let user store.get(user)console.log(user) //对象console.log(localStorage.getItem(user)) //…

[SIEMENS/S7-300] 接线图分析

SIEMENS-S7-300 系列 接线图 (整体) (PLC) 操作员面板 (操作员控制面板) (24v在第11页面第九栏,12页面第一栏。PE在17页第九栏) 跟进地址,到24V电源适配器 操作面板来自X4.0端子块,并且带有WC4.0标签的电缆。 PLC 操作面板与PLC之间通信使用Profibus电缆,直接通过CPU的MPI端…

Kotlin中的数据类型

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

OBS 26.1.1 64-bit (win10)

OBS 26.1.1 64-bit &#xff08;win10&#xff09;采集窗口问题 出现采集界面错误 点击【属性】 修改【采集方式】 完美解决

Vue3 sortablejs 表格拖拽后,表格无法更新的问题处理

实用sortablejs在vue项目中实现表格行拖拽排序 你可能会发现&#xff0c;表格排序是可以实现&#xff0c;但是我们基于数据驱动的vue中关联的数据并没有发生变化&#xff0c; 如果你的表格带有列固定(固定列实际上在dom中有两个表格&#xff0c;其中固定的列在一个表格中&…

深入理解C++中的锁

目录 1.基本互斥锁&#xff08;std::mutex&#xff09; 2.递归互斥锁&#xff08;std::recursive_mutex&#xff09; 3.带超时机制的互斥锁&#xff08;std::timed_mutex&#xff09; 4.带超时机制的递归互斥锁&#xff08;std::recursive_timed_mutex&#xff09; 5.共享…

2.2.4 C#中显示控件BDPictureBox 的实现----ROI交互

2.2.4 C#中显示控件BDPictureBox 的实现----ROI交互 1 界面效果 在设定模式下&#xff0c;可以进行ROI 框的拖动&#xff0c;这里以Rect1举例说明 2 增加ROI类定义 /// <summary> /// ROI_single /// 用于描述图片感兴趣区域 /// type: 0:Rect1;1:Rect2;2:Circle ;3:…

压缩pdf文件大小,压缩pdf文件大小软件哪个好

在数字化时代&#xff0c;PDF文件因其卓越的跨平台兼容性和稳定性而成为工作与学习的好帮手。然而&#xff0c;当PDF文件体积过大时&#xff0c;传输和存储便成了一项挑战。别担心&#xff0c;本文将为你揭秘如何快速压缩PDF文件&#xff0c;让你的文档轻装上路&#xff01; 压…

【Python实战因果推断】14_线性回归的不合理效果4

目录 Debiasing Step Denoising Step Standard Error of the Regression Estimator Debiasing Step 回想一下&#xff0c;最初由于混杂偏差&#xff0c;您的数据看起来是这样的、 随着信贷额度的增加&#xff0c;违约率呈下降趋势&#xff1a; 根据 FWL 定理&#xff0c;您可…

力扣习题--哈沙德数

一、前言 本系列主要讲解和分析力扣习题&#xff0c;所以的习题均来自于力扣官网题库 - 力扣 (LeetCode) 全球极客挚爱的技术成长平台 二、哈沙德数 1. 哈沙德数 如果一个整数能够被其各个数位上的数字之和整除&#xff0c;则称之为 哈沙德数&#xff08;Harshad number&…

JAVA学习笔记-JAVA基础语法-DAY21-缓冲流、转换流、序列化流

第一章 缓冲流 昨天学习了基本的一些流&#xff0c;作为IO流的入门&#xff0c;今天我们要见识一些更强大的流。比如能够高效读写的缓冲流&#xff0c;能够转换编码的转换流&#xff0c;能够持久化存储对象的序列化流等等。这些功能更为强大的流&#xff0c;都是在基本的流对象…

【echarts】拖拽滑块dataZoom-slider自定义样式,简单适配移动端

电脑端 移动端 代码片段 dataZoom: [{type: inside,start: 0,end: 100},{type: slider,backgroundColor: #F2F5F9,fillerColor: #BFCCE3,height: 13, // 设置slider的高度为15start: 0,end: 100,right: 60,left: 60,bottom: 15,handleIcon:path://M30.9,53.2C16.8,53.2,5.3,41.…

【Git 学习笔记】1.3 Git 的三个阶段

1.3 Git 的三个阶段 由于远程代码库后续存在新的提交&#xff0c;因此实操过程中的结果与书中并不完全一致。根据书中 HEAD 指向的 SHA-1&#xff1a;34acc370b4d6ae53f051255680feaefaf7f7850d&#xff0c;可通过以下命令切换到对应版本&#xff0c;并新建一个 newdemo 分支来…

基于LSTM、GRU和RNN的交通时间序列预测(Python)

近年来&#xff0c;人工智能技术的发展推动了智慧交通领域的进步&#xff0c;交通流预测日益成为研究热点之一。交通流预测是基于历史的交通数据对未来时段的交通流状态参数进行预测。作为交通流状态的直接反映&#xff0c;交通流参数的预测结果可以直接应用于 ATIS 和ATMS 中&…

QT Designer中的qrc文件如何创建,将图片添加进qrc文件

创建qrc文件可以在qt中给空间添加个性化属性 一、创建qrc文件的方式 1、将以下代码复制到txt文件文件中 <!DOCTYPE RCC> <RCC version"1.0"> <qresource prefix"/"><file>background_img.png</file><file>backgrou…

第二证券:可转债基础知识?想玩可转债一定要搞懂的交易规则!

可转债&#xff0c;全称是“可转化公司债券”&#xff0c;是上市公司为了融资&#xff0c;向社会公众所发行的一种债券&#xff0c;具有股票和债券的双重特点&#xff0c;投资者可以选择按照发行时约定的价格将债券转化成公司一般股票&#xff0c;也可作为债券持有到期后收取本…

计算机网络网络层复习题1

一. 单选题&#xff08;共27题&#xff09; 1. (单选题)以太网 MAC 地址、IPv4 地址、IPv6 地址的地址空间大小分别是&#xff08; &#xff09;。 A. 2^48&#xff0c;2^32&#xff0c;2^128B. 2^32&#xff0c;2^32&#xff0c;2^96C. 2^16&#xff0c;2^56&#xff0c;2^6…

【51单片机入门】矩阵键盘

文章目录 前言矩阵键盘介绍与检测原理原理图代码讲解总结 前言 在嵌入式系统设计中&#xff0c;键盘输入是一种常见的人机交互方式。其中&#xff0c;矩阵键盘因其简单、方便和易于扩展的特性&#xff0c;被广泛应用于各种设备中。本文将介绍如何使用51单片机来实现矩阵键盘的…