【MindSpore学习打卡】应用实践-计算机视觉-深入解析 Vision Transformer(ViT):从原理到实践

news2024/12/26 21:19:45

在近年来的深度学习领域,Transformer模型凭借其在自然语言处理(NLP)中的卓越表现,迅速成为研究热点。尤其是基于自注意力(Self-Attention)机制的模型,更是推动了NLP的飞速发展。然而,随着研究的深入,Transformer模型不仅在NLP领域大放异彩,还被引入到计算机视觉领域,形成了Vision Transformer(ViT)。ViT模型在不依赖传统卷积神经网络(CNN)的情况下,依然能够在图像分类任务中取得优异的效果。本文将深入解析ViT模型的结构、特点,并通过代码示例展示如何使用MindSpore框架实现ViT模型的训练、验证和推理。

ViT模型结构

ViT模型的主体结构基于Transformer模型的编码器(Encoder)部分,其整体结构如下图所示:

vit-architecture

模型特点

为什么要使用Patch Embedding?

在传统的Transformer模型中,输入通常是一维的词向量序列,而图像数据是二维的像素矩阵。为了将图像数据转换为Transformer可以处理的形式,我们需要将图像划分为多个小块(patch),并将每个patch转换为一维向量。这一过程称为Patch Embedding。通过这种方式,我们可以将图像数据转换为类似于词向量的形式,从而利用Transformer模型处理图像数据。
为什么要使用位置编码(Position Embedding)?

由于Transformer模型在处理输入序列时不考虑顺序信息,因此在图像数据中,patch之间的空间关系可能会丢失。为了解决这个问题,我们引入了位置编码(Position Embedding),它为每个patch增加了位置信息,使得模型能够识别不同patch之间的空间关系。这对于保留图像的空间结构信息非常重要。

  1. Patch Embedding:输入图像被划分为多个patch(图像块),然后将每个二维patch转换为一维向量,并加上类别向量和位置向量作为模型输入。
  2. Transformer Encoder:模型主体的Block结构基于Transformer的Encoder部分,主要结构是多头注意力(Multi-Head Attention)和前馈神经网络(Feed Forward)。
  3. 分类头(Head):在Transformer Encoder堆叠后接一个全连接层,用于分类。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境和MindSpore。

首先下载本案例的数据集,该数据集是从ImageNet中筛选出来的子集。数据集路径结构如下:

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/
from download import download
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

# 下载数据集
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)

data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

Transformer基本原理

Transformer模型源于2017年的一篇文章,其主要结构为多个编码器和解码器模块。编码器和解码器由多头注意力(Multi-Head Attention)、前馈神经网络(Feed Forward)、归一化层(Normalization)和残差连接(Residual Connection)组成。

Self-Attention机制

Self-Attention机制是Transformer的核心,其主要步骤如下:

  1. 输入向量映射:将输入向量映射成Query(Q)、Key(K)、Value(V)三个向量。
  2. 计算注意力权重:通过点乘计算Query和Key的相似性,并通过Softmax函数归一化。
  3. 加权求和:使用注意力权重对Value进行加权求和,得到最终的Attention输出。

以下是Self-Attention的代码实现:

from mindspore import nn, ops

class Attention(nn.Cell):
    def __init__(self, dim: int, num_heads: int = 8, keep_prob: float = 1.0, attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
        return out

Transformer Encoder

为什么要使用残差连接(Residual Connection)和归一化层(Normalization Layer)?

在深层神经网络中,随着层数的增加,梯度消失和梯度爆炸的问题变得越来越严重。残差连接通过在每一层加上输入的跳跃连接,可以有效缓解这些问题,确保信息能够顺利传递。此外,归一化层(如LayerNorm)可以加速模型的训练,并提高模型的稳定性和泛化能力。这些技术的结合,使得Transformer模型能够在更深的层次上进行有效的训练。

Transformer Encoder由多层Self-Attention和前馈神经网络(Feed Forward)组成,通过残差连接和归一化层增强模型的训练效果和泛化能力。

class FeedForward(nn.Cell):
    def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, activation: nn.Cell = nn.GELU, keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        return self.cell(x) + x

class TransformerEncoder(nn.Cell):
    def __init__(self, dim: int, num_layers: int, num_heads: int, mlp_dim: int, keep_prob: float = 1., attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim, num_heads=num_heads, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob)
            feedforward = FeedForward(in_features=dim, hidden_features=mlp_dim, activation=activation, keep_prob=keep_prob)
            layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])), ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        return self.layers(x)

ViT模型的输入

ViT模型通过将输入图像划分为多个patch,将每个patch转换为一维向量,并加上类别向量和位置向量作为模型输入。以下是Patch Embedding的代码实现:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x

整体构建ViT

以下代码构建了一个完整的ViT模型:

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter

def init(init_type, shape, dtype, name, requires_grad):
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)

class ViT(nn.Cell):
    def __init__(self, image_size: int = 224, input_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, keep_prob: float = 1.0, attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: Optional[nn.Cell] = nn.LayerNorm, pool: str = 'cls') -> None:
        super(ViT, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
        self.cls_token = init(init_type=Normal(sigma=1.0), shape=(1, 1, embed_dim), dtype=ms.float32, name='cls', requires_grad=True)
        self.pos_embedding = init(init_type=Normal(sigma=1.0), shape=(1, num_patches + 1, embed_dim), dtype=ms.float32, name='pos_embedding', requires_grad=True)
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob, drop_path_keep_prob=drop_path_keep_prob, activation=activation, norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
        return x

模型训练与推理

模型训练

模型训练前,需要设定损失函数、优化器和回调函数。以下是训练ViT模型的代码:

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# 定义超参数
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# 构建模型
network = ViT()

# 加载预训练模型参数
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# 定义学习率
lr = nn.cosine_decay_lr(min_lr=float(0), max_lr=0.00005, total_step=epoch_size * step_size, step_per_epoch=step_size, decay_epoch=10)

# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

# 定义损失函数
class CrossEntropySmooth(LossBase):
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss

network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)

# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# 训练模型
model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)], dataset_sink_mode=False)

在这里插入图片描述

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。以下是验证ViT模型的代码:

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# 构建模型
network = ViT()

# 加载预训练模型参数
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)

# 定义评价指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(), 'Top_5_Accuracy': train.Top5CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")

# 验证模型
result = model.eval(dataset_val)
print(result)

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。以下是推理ViT模型的代码:

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"], num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

# 读取推理数据
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG", result=output, out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1898540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式UI开发-lvgl+wsl2+vscode系列:6、布局(Layouts)

一、前言 这节总结一下整体页面的布局方式,lvgl的布局方式比较少,目前只有flex和grid两大类布局,即弹性布局和网格布局,弹性布局一般就是指定相对位置,网格布局就是将整个页面划分为网格状,我们做其它的UI…

【Python机器学习】处理文本数据——用tf-idf缩放数据

为了按照我们预计的特征信息量大小来缩放特征,而不是舍弃那些认为不重要的特征,最常见的一种做法就是使用词频-逆向文档频率(tf-idf)。这一方法对某个特定文档中经常出现的术语给与很高的权重,但是堆在语料库的许多文档…

pandas,dataframe使用笔记

目录 新建一个dataframe不带列名带列名 dataframe添加一行内容查看dataframe某列的数据类型新建dataframe时设置了列名,则数据类型为object dataframe的保存保存为csv文件保存为excel文件 dataframe属于pandas 新建一个dataframe 不带列名 df pd.DataFrame() 带…

【Linux开发】基于ALSA库实现音量调节

基于ALSA库实现音量调节 ALSA库实现音量调节1、使用alsamixer工具查看音频接口2、完整代码2.1、snd_mixer_open2.2、snd_mixer_attach、2.3、snd_mixer_selem_register2.4、snd_mixer_load2.5、snd_mixer_first_elem/snd_mixer_elem_next2.6、snd_mixer_selem_get_playback_vol…

江汉大学刘春萌同学整理的wifi模块 上传mqtt实验步骤

一.固件烧录 1.打开安信可官网 2.点击wifi模组系列的ESP8266 3.点击各类固件后选择固件号1471下载 4.打开烧录工具将下载的二进制文件导入并将后面的起始地址写为0x00000,下面勾选40mhz QIO 8Mbit点击start下载即可 二.本地部署mqtt服务器(windows) 1.下载mosquitto后有一个m…

数据驱动下的SaaS渠道精细化运营:提升ROI的实战指南

在当今数字化转型的大潮中,SaaS(Software as a Service)企业面临着日益激烈的市场竞争。为了在市场中脱颖而出,实现可持续增长,SaaS企业必须转向更为精细化的运营模式,而数据驱动则是实现这一目标的关键。本…

NoSQL 非关系型数据库 Redis 的使用:

redis是基于内存型的NoSQL 非关系型数据库,本内容只针对有基础的小伙伴, 因为楼主不会做更多的解释,而是记录更多的技术接口使用,毕竟楼主不是做教学的,没有教学经验。 关于redis的介绍请自行搜索查阅。 使用redis数据…

Java后端每日面试题(day3)

目录 Spring中Bean的作用域有哪些?Spring中Bean的生命周期Bean 是线程安全的吗?了解Spring Boot中的日志组件吗? Spring中Bean的作用域有哪些? Bean的作用域: singleton:单例,Spring中的bean默…

一种频偏估计与补偿方法

一种简易的频偏估计补偿方法,使用QAM等信号。估计精度受FFT长度限制,可以作为粗频偏估计。 Nfft 1024; % FFT长度 N 10*Nfft; % 仿真符号数 M 16; % 调制QAM16 freq 1e…

PDF合并怎么做?分享几种简单好用的PDF合并方法

PDF文件以其良好的兼容性和稳定的格式,成为了我们日常办公、学习不可或缺的一部分。然而,随着PDF文件的不断增多,如何高效管理这些文件,特别是如何将多个PDF文件合并成一个,成为了许多人头疼的问题。下面给大家分享几款…

超参数优化方法之贝叶斯优化实现流程及代码

超参数优化方法之贝叶斯优化实现流程及代码 在机器学习模型的训练过程中,超参数的选择往往对模型性能有着决定性的影响。贝叶斯优化作为一种高效的超参数调优方法,以其在高维空间中的搜索效率和对最优化问题的独特见解而受到关注。本文将深入探讨贝叶斯…

CTF常用sql注入(三)无列名注入

0x06 无列名 适用于无法正确的查出结果,比如把information_schema给过滤了 join 联合 select * from users;select 1,2,3 union select * from users;列名被替换成了1,2,3, 我们再利用子查询和别名查 select 2 from (select 1,2,3 union select * f…

QT 布局演示例子

效果 源码 #include <QApplication> #include <QWidget> #include <QSplitter> #include <QVBoxLayout> #include <QLabel>int main(int argc, char *argv[]) {QApplication app(argc, argv);QWidget mainWidget;mainWidget.setWindowTitle(&qu…

适合金融行业的国产传输软件应该是怎样的?

对于金融行业来说&#xff0c;正常业务开展离不开文件传输场景&#xff0c;一般来说&#xff0c;金融行业常用的文件传输工具有IM通讯、邮件、自建文件传输系统、FTP应用、U盘等&#xff0c;这些传输工具可以基础实现金融机构的文件传输需求&#xff0c;但也存在如下问题&#…

价值499的从Emlog主题模板PandaPRO移植到wordpress的主题

Panda PRO 主题&#xff0c;一款精致wordpress博客主题&#xff0c;令人惊叹的昼夜双版设计&#xff0c;精心打磨的一处处细节&#xff0c;一切从心出发&#xff0c;从零开始&#xff0c;只为让您的站点拥有速度与优雅兼具的极致体验。 从Emlog主题模板PandaPRO移植到wordpres…

VCL界面组件DevExpress VCL v24.1 - 发布全新的矢量主题

DevExpress VCL是DevExpress公司旗下最老牌的用户界面套包&#xff0c;所包含的控件有&#xff1a;数据录入、图表、数据分析、导航、布局等。该控件能帮助您创建优异的用户体验&#xff0c;提供高影响力的业务解决方案&#xff0c;并利用您现有的VCL技能为未来构建下一代应用程…

CNN文献综述

卷积神经网络&#xff08;Convolutional Neural Networks&#xff0c;简称CNN&#xff09;是深度学习领域中的一种重要模型&#xff0c;主要用于图像识别和计算机视觉任务。其设计灵感来自于生物学中视觉皮层的工作原理&#xff0c;能够高效地处理图像和语音等数据。 基本原理…

Vue 邮箱登录界面

功能 模拟了纯前端的邮箱登录逻辑 还没有连接后端的发送邮件的服务 后续计划&#xff0c;再做一个邮箱、密码登录的界面 然后把这两个一块连接上后端 技术介绍 主要介绍绘制图形人机验证乃个 使用的是canvas&#xff0c;在源码里就有 界面控制主要就是用 表格、表单&#x…

哏号分治,CF103D - Time to Raid Cowavans

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 103D - Time to Raid Cowavans 二、解题报告 1、思路分析 想了半天数据结构最终选择根号分治 我们考虑 大于 550 的公差直接暴力 小于550 的公差的所有询问&#xff0c;我们直接计算该公差后缀和&#xf…

Ubuntu 22.04.4 LTS 安装 php apache LAMP 环境nginx

1 安装php-fpm apt update apt-get install php-fpm #配置php-fpm服务启动 systemctl enable php8.1-fpm systemctl start php8.1-fpm #查看服务 systemctl status php8.1-fpm #查看版本 rootiZbp1g7fmjea77vsqc5hmmZ:~# php -v PHP 8.1.2-1ubuntu2.18 (cli) (built: J…