昇思MindSpore 应用学习-Vision Transformer图像分类

news2024/12/27 9:53:22

昇思MindSpore 应用学习-Vision Transformer图像分类(AI 代码解析)

Vision Transformer图像分类

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformer的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。
ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。
注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。
首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。
运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/
from download import download  # 导入download模块中的download函数

# 定义数据集的URL
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"

# 定义下载的保存路径
path = "./"  # 当前目录

# 调用download函数下载数据集
path = download(dataset_url, path, kind="zip", replace=True)  # 下载数据集并指定为zip文件,若已存在则替换
  1. from download import download:
    • 这一行代码从download模块中导入了download函数。这个函数通常用于从指定的URL下载文件。
  2. dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip":
    • 这里定义了一个字符串dataset_url,它包含了要下载的数据集的链接。这个链接指向一个zip文件,包含了ImageNet数据集的一部分。
  3. path = "./":
    • 这行代码定义了一个变量path,其值为当前目录"./"。这意味着下载的文件将保存在当前执行代码的目录中。
  4. path = download(dataset_url, path, kind="zip", replace=True):
    • 这一行调用了download函数:
      • dataset_url: 指定要下载的文件的URL。
      • path: 指定下载后文件的保存路径。
      • kind="zip": 指定下载文件的类型为zip格式。
      • replace=True: 如果文件已存在,则会被替换。
  • download(url, path, kind, replace):
    • url: 要下载文件的链接。
    • path: 文件下载后保存的路径。
    • kind: 指定下载文件的类型,通常为"zip"、"tar"等,适用于不同格式的压缩包。
    • replace: 布尔值,决定是否覆盖已存在的同名文件。设置为True时,若目标路径已有同名文件,则将其替换。
import os  # 导入os模块,用于处理文件和目录路径

import mindspore as ms  # 导入MindSpore库
from mindspore.dataset import ImageFolderDataset  # 从MindSpore中导入ImageFolderDataset类,用于加载图像文件夹数据集
import mindspore.dataset.vision as transforms  # 导入MindSpore的图像处理变换模块

# 定义数据集路径
data_path = './dataset/'  # 存放数据集的路径

# 定义图像的均值和标准差,用于归一化
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]  # RGB通道的均值
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]    # RGB通道的标准差

# 创建训练数据集
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)  # 从指定路径加载训练集,并随机打乱数据

# 定义数据增强与预处理的操作
trans_train = [
    transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),  # 随机裁剪并调整大小
    transforms.RandomHorizontalFlip(prob=0.5),  # 以50%概率进行随机水平翻转
    transforms.Normalize(mean=mean, std=std),  # 对图像进行归一化处理
    transforms.HWC2CHW()  # 将图像从HWC格式转换为CHW格式
]

# 将变换应用于数据集
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])  # 对图像列应用预处理操作

# 将数据集分批处理
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)  # 将数据集按批次进行处理,每批次16张图像,若不足16张则丢弃
  1. import os:
    • 导入os模块,以便于处理文件路径和目录操作。
  2. import mindspore as ms:
    • 导入MindSpore库,主要用于深度学习相关任务的计算和模型构建。
  3. from mindspore.dataset import ImageFolderDataset:
    • 从MindSpore的dataset模块中导入ImageFolderDataset类,这个类用于加载存储在文件夹中的图像数据集。
  4. import mindspore.dataset.vision as transforms:
    • 导入MindSpore的视觉处理模块,提供了多种图像预处理和增强的操作。
  5. data_path = './dataset/':
    • 定义数据集的基本路径,假设数据集位于当前目录的dataset文件夹中。
  6. mean** 和 **std:
    • 定义图像在训练过程中使用的均值和标准差,用于归一化处理,有助于改善模型的训练效果。
  7. dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True):
    • 创建一个训练数据集对象,从data_path下的train目录加载图像数据,并在每个epoch结束时随机打乱数据顺序。
  8. trans_train:
    • 定义一个包含多个图像预处理操作的列表:
      • RandomCropDecodeResize: 随机裁剪并调整图像大小到224x224。
      • RandomHorizontalFlip: 随机水平翻转图像,概率为0.5。
      • Normalize: 使用指定的均值和标准差对图像进行归一化处理。
      • HWC2CHW: 将图像格式从高度-宽度-通道(HWC)转换为通道-高度-宽度(CHW)。
  9. dataset_train.map(operations=trans_train, input_columns=["image"]):
    • 将定义的转换操作应用到数据集的“image”列上,对每张图像进行预处理。
  10. dataset_train.batch(batch_size=16, drop_remainder=True):
    • 将数据集分成批次进行训练,每个批次包含16张图像,drop_remainder=True表示如果最后一批数据不足16张,则丢弃该批次。
  • ImageFolderDataset:
    • 用于从指定目录加载图像数据集,接收目录路径和是否打乱数据的参数。
  • transforms.RandomCropDecodeResize:
    • 随机裁剪和调整图像尺寸的变换,允许指定目标尺寸、缩放范围和宽高比范围。
  • transforms.RandomHorizontalFlip:
    • 随机水平翻转图像,接受翻转的概率参数。
  • transforms.Normalize:
    • 对图像进行归一化处理,接受均值和标准差。
  • transforms.HWC2CHW:
    • 转换图像的存储格式,从高度-宽度-通道(HWC)转换为通道-高度-宽度(CHW)。
  • dataset.map():
    • 将指定的操作应用于数据集的指定列。
  • dataset.batch():
    • 将数据集分批处理,指定每个批次的大小以及处理不足一个批次的方式。

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。
所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。
在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

自注意力机制的整体过程可以通过以下几个步骤和公式来理解:

  1. 输入向量生成
    image.png
    这里,Q、K、V分别是从输入向量 ( x_i ) 通过不同的权重矩阵 ( W_q )、( W_k )、( W_v ) 得到的。
  2. 计算注意力权重
    image.png
    通过对Q和K的点乘并除以维度的平方根,计算出注意力权重。
  3. 应用Softmax函数
    image.png
    经过Softmax处理后,得到归一化的权重。
  4. 最终输出
    image.png
    最终输出是通过将每个V向量与对应的权重相乘并求和得到的结果。

通过以上步骤,自注意力机制能够在全局范围内建模输入向量之间的关系,提取其特征和联系。

多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。
总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。
所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的a1𝑎1和a2𝑎2是同一个向量进行分割获得的。

以下是Multi-Head Attention代码,结合上文的解释,代码清晰的展现了这一过程。

from mindspore import nn, ops  # 从MindSpore导入神经网络模块nn和操作模块ops

class Attention(nn.Cell):  # 定义Attention类,继承自nn.Cell
    def __init__(self,
                 dim: int,  # 输入特征维度
                 num_heads: int = 8,  # 注意力头的数量
                 keep_prob: float = 1.0,  # 输出的保留概率
                 attention_keep_prob: float = 1.0):  # 注意力矩阵的保留概率
        super(Attention, self).__init__()  # 调用父类构造函数

        self.num_heads = num_heads  # 设置注意力头的数量
        head_dim = dim // num_heads  # 每个头的维度
        self.scale = ms.Tensor(head_dim ** -0.5)  # 缩放因子

        self.qkv = nn.Dense(dim, dim * 3)  # 创建全连接层,用于生成Q、K、V矩阵
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)  # 定义注意力 dropout 层
        self.out = nn.Dense(dim, dim)  # 创建输出的全连接层
        self.out_drop = nn.Dropout(p=1.0-keep_prob)  # 定义输出 dropout 层
        self.attn_matmul_v = ops.BatchMatMul()  # 定义批量矩阵乘法操作
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)  # 定义批量矩阵乘法操作,K矩阵转置
        self.softmax = nn.Softmax(axis=-1)  # 定义softmax操作

    def construct(self, x):  # 定义前向传播
        """Attention construct."""
        b, n, c = x.shape  # 获取输入的形状:batch_size, sequence_length, channels
        qkv = self.qkv(x)  # 通过全连接层生成QKV
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))  # 重新调整形状
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))  # 转置,使得Q、K和V分开
        q, k, v = ops.unstack(qkv, axis=0)  # 分离Q、K、V

        attn = self.q_matmul_k(q, k)  # 计算Q和K的点积
        attn = ops.mul(attn, self.scale)  # 缩放点积结果
        attn = self.softmax(attn)  # 计算softmax以获得注意力权重
        attn = self.attn_drop(attn)  # 应用注意力dropout

        out = self.attn_matmul_v(attn, v)  # 计算注意力加权的值
        out = ops.transpose(out, (0, 2, 1, 3))  # 转置输出
        out = ops.reshape(out, (b, n, c))  # 重新调整输出形状
        out = self.out(out)  # 通过输出全连接层
        out = self.out_drop(out)  # 应用输出dropout

        return out  # 返回最终结果
  1. from mindspore import nn, ops:
    • 导入MindSpore的神经网络模块和操作模块,分别用于构建神经网络和执行各种操作。
  2. class Attention(nn.Cell)::
    • 定义一个名为Attention的类,继承自nn.Cell,表示这是一个神经网络单元。
  3. __init__** 方法**:
    • dim: 输入特征的维度。
    • num_heads: 注意力头的数量,默认为8。
    • keep_probattention_keep_prob: 用于控制dropout的保留概率。
    • self.qkv: 创建一个全连接层,将输入特征转换为Q、K、V三部分。
    • self.attn_dropself.out_drop: 定义dropout层,用于在训练中减少过拟合。
    • self.attn_matmul_vself.q_matmul_k: 定义批量矩阵乘法操作,其中K矩阵在计算时会转置。
    • self.softmax: 定义softmax操作,用于计算注意力权重。
  4. construct** 方法**:
    • x: 输入张量,包含形状为(batch_size, sequence_length, channels)的特征。
    • 使用self.qkv全连接层生成Q、K、V矩阵。
    • 通过ops.reshapeops.transpose调整Q、K、V的形状和顺序。
    • 使用点积计算Q和K的相似度,并通过self.scale进行缩放。
    • 通过softmax计算注意力权重,并应用dropout。
    • 使用注意力权重对V进行加权,并对输出进行形状调整,最后通过全连接层和dropout。
  • nn.Cell:
    • MindSpore中所有模块的基类,定义了前向传播的方法。
  • nn.Dense:
    • 全连接层,用于线性变换,通常用于构建神经网络中的层。
  • nn.Dropout:
    • 随机丢弃一部分神经元的输出,防止过拟合。
  • ops.BatchMatMul:
    • 批量矩阵乘法操作,处理多个矩阵相乘。
  • ops.reshape:
    • 重新调整张量的形状。
  • ops.transpose:
    • 转置张量的维度。
  • ops.unstack:
    • 从一个维度上分离张量的多个部分。
  • nn.Softmax:
    • 计算softmax函数,将输入变换为概率分布。

Transformer Encoder

在了解了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,下面代码实现了Feed Forward,Residual Connection结构。

from typing import Optional, Dict  # 导入Optional和Dict类型注解

class FeedForward(nn.Cell):  # 定义FeedForward类,继承自nn.Cell
    def __init__(self,
                 in_features: int,  # 输入特征的维度
                 hidden_features: Optional[int] = None,  # 隐藏层特征的维度
                 out_features: Optional[int] = None,  # 输出特征的维度
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认为GELU
                 keep_prob: float = 1.0):  # dropout的保留概率
        super(FeedForward, self).__init__()  # 调用父类构造函数
        
        out_features = out_features or in_features  # 输出特征维度未指定时,默认为输入特征维度
        hidden_features = hidden_features or in_features  # 隐藏特征维度未指定时,默认为输入特征维度
        
        self.dense1 = nn.Dense(in_features, hidden_features)  # 第一层全连接层
        self.activation = activation()  # 激活函数实例化
        self.dense2 = nn.Dense(hidden_features, out_features)  # 第二层全连接层
        self.dropout = nn.Dropout(p=1.0-keep_prob)  # dropout层

    def construct(self, x):  # 定义前向传播
        """Feed Forward construct."""
        x = self.dense1(x)  # 通过第一层全连接层
        x = self.activation(x)  # 应用激活函数
        x = self.dropout(x)  # 应用dropout
        x = self.dense2(x)  # 通过第二层全连接层
        x = self.dropout(x)  # 再次应用dropout

        return x  # 返回最终输出


class ResidualCell(nn.Cell):  # 定义ResidualCell类,继承自nn.Cell
    def __init__(self, cell):  # cell是一个神经网络单元
        super(ResidualCell, self).__init__()  # 调用父类构造函数
        self.cell = cell  # 存储传入的神经网络单元

    def construct(self, x):  # 定义前向传播
        """ResidualCell construct."""
        return self.cell(x) + x  # 输出为神经网络单元的输出与输入相加,形成残差连接
  1. 导入部分:
    • from typing import Optional, Dict: 导入类型注解OptionalDict,用于类型提示。
  2. FeedForward** 类**:
    • **构造函数 **__init__:
      • in_features: 输入特征的维度。
      • hidden_features: 隐藏层的特征维度,默认与输入特征相同。
      • out_features: 输出特征的维度,默认与输入特征相同。
      • activation: 激活函数,默认为GELU(高斯误差线性单元)。
      • keep_prob: 控制dropout的保留概率。
      • self.dense1self.dense2: 创建两个全连接层,用于特征转换。
      • self.activation: 激活函数实例化。
      • self.dropout: 定义dropout层,用于减少过拟合。
  3. construct** 方法**:
    • 定义前向传播过程:
      • 先通过第一层全连接层dense1
      • 然后应用激活函数。
      • 接着应用dropout。
      • 再通过第二层全连接层dense2
      • 最后再次应用dropout。
    • 输出最终的结果。
  4. ResidualCell** 类**:
    • **构造函数 **__init__:
      • cell: 传入一个神经网络单元,可以是任何实现了前向传播的网络结构。
    • construct** 方法**:
      • 定义前向传播过程,将输入xcell(x)的输出相加,形成残差连接。这种结构有助于深层网络的训练,提高模型性能。
  • nn.Cell:
    • MindSpore中所有模型的基类,提供了构建神经网络所需的基础功能。
  • nn.Dense:
    • 全连接层,用于执行线性变换,通常用于神经网络中的层。
  • nn.Dropout:
    • 随机丢弃一定比例的神经元,以减少过拟合。
  • activation():
    • 动态实例化传入的激活函数,例如GELU。
  • x + self.cell(x):
    • 实现残差连接,帮助网络学习到更深层的特征,同时缓解梯度消失问题。

接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分,类似于构建了一个Transformer的编码器部分,如下图[1]所示:

  1. ViT模型中的基础结构与标准Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他结构如Residual Connection,Feed Forward,Normalization都如Transformer中所设计。
  2. 从Transformer结构的图片可以发现,多个子encoder的堆叠就完成了模型编码器的构建,在ViT模型中,依然沿用这个思路,通过配置超参数num_layers,就可以确定堆叠层数。
  3. Residual Connection,Normalization的结构可以保证模型有很强的扩展性(保证信息经过深层处理不会出现退化的现象,这是Residual Connection的作用),Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器(MLP)结合,就构成了ViT模型的backbone部分。

class TransformerEncoder(nn.Cell):  # 定义TransformerEncoder类,继承自nn.Cell
    def __init__(self,
                 dim: int,  # 输入特征的维度
                 num_layers: int,  # 编码器层的数量
                 num_heads: int,  # 注意力头的数量
                 mlp_dim: int,  # MLP隐藏层的特征维度
                 keep_prob: float = 1.,  # dropout的保留概率
                 attention_keep_prob: float = 1.0,  # 注意力dropout的保留概率
                 drop_path_keep_prob: float = 1.0,  # drop path的保留概率
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认为GELU
                 norm: nn.Cell = nn.LayerNorm):  # 归一化层,默认为LayerNorm
        super(TransformerEncoder, self).__init__()  # 调用父类构造函数
        
        layers = []  # 初始化层列表

        for _ in range(num_layers):  # 对于每一层
            normalization1 = norm((dim,))  # 创建第一个归一化层
            normalization2 = norm((dim,))  # 创建第二个归一化层
            
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)  # 创建注意力层

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)  # 创建前馈层

            layers.append(  # 将残差连接的序列单元添加到层列表中
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),  # 归一化 + 注意力层
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))  # 归一化 + 前馈层
                ])
            )
        self.layers = nn.SequentialCell(layers)  # 将所有层组合成一个序列单元

    def construct(self, x):  # 定义前向传播
        """Transformer construct."""
        return self.layers(x)  # 将输入通过所有层进行处理并返回
  1. TransformerEncoder** 类**:
    • **构造函数 **__init__:
      • dim: 输入特征的维度。
      • num_layers: 编码器中堆叠的层数量。
      • num_heads: 注意力机制中注意力头的数量。
      • mlp_dim: MLP(多层感知机)层中的隐藏特征维度。
      • keep_probattention_keep_probdrop_path_keep_prob: dropout和路径保持的概率,用于控制模型的正则化。
      • activation: 激活函数,默认为GELU。
      • norm: 归一化层,默认为LayerNorm。
  2. 层的构建:
    • layers: 初始化一个空列表用于存储各层。
    • 使用循环创建指定数量的编码器层,每个编码器层包含两个部分:
      • normalization1: 第一个归一化层。
      • normalization2: 第二个归一化层。
      • attention: 创建注意力层实例。
      • feedforward: 创建前馈层实例。
    • 每个编码器层通过ResidualCell进行封装,形成残差连接,并将组合的层添加到layers列表中。
  3. self.layers:
    • 使用nn.SequentialCell(layers)将所有编码器层组合在一起,形成一个顺序层。
  4. construct** 方法**:
    • 定义前向传播过程,将输入x通过所有编码器层进行处理,返回最终结果。
  • nn.Cell:
    • MindSpore中所有模块的基类,提供了构建神经网络所需的基础功能。
  • nn.SequentialCell:
    • 将多个层组合成一个顺序执行的神经网络模块,按顺序调用每个模块。
  • ResidualCell:
    • 实现残差连接,帮助网络学习深层特征。
  • Attention:
    • 定义注意力机制,通常用于加强模型对输入的关注能力。
  • FeedForward:
    • 定义前馈神经网络层,通常用于进一步处理注意力层的输出。
  • nn.LayerNorm:
    • 层归一化,通常用于加速训练并提高模型稳定性。

ViT模型的输入

传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。
在ViT模型中:

  1. 通过将输入图像在每个channel上划分为16*16个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。
  2. 再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示:

class PatchEmbedding(nn.Cell):  # 定义PatchEmbedding类,继承自nn.Cell
    MIN_NUM_PATCHES = 4  # 定义最小补丁数量常量

    def __init__(self,
                 image_size: int = 224,  # 输入图像的大小
                 patch_size: int = 16,  # 每个补丁的大小
                 embed_dim: int = 768,  # 嵌入特征的维度
                 input_channels: int = 3):  # 输入图像的通道数(例如RGB图像)
        super(PatchEmbedding, self).__init__()  # 调用父类构造函数

        self.image_size = image_size  # 存储图像大小
        self.patch_size = patch_size  # 存储补丁大小
        self.num_patches = (image_size // patch_size) ** 2  # 计算补丁的数量
        # 创建卷积层,从输入通道到嵌入维度,使用补丁大小作为卷积核大小,步幅也为补丁大小
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):  # 定义前向传播
        """Path Embedding construct."""
        x = self.conv(x)  # 通过卷积层处理输入
        b, c, h, w = x.shape  # 获取输出的形状:batch_size, channels, height, width
        x = ops.reshape(x, (b, c, h * w))  # 将输出重新调整形状为(batch_size, channels, num_patches)
        x = ops.transpose(x, (0, 2, 1))  # 转置输出,使得形状为(batch_size, num_patches, channels)

        return x  # 返回最终的嵌入表示
  1. PatchEmbedding** 类**:
    • **常量 **MIN_NUM_PATCHES: 定义最小补丁数量的常量,用于后续的逻辑判断(如果需要的话)。
    • **构造函数 **__init__:
      • image_size: 输入图像的大小,默认是224。
      • patch_size: 每个补丁的大小,默认是16。
      • embed_dim: 嵌入特征的维度,默认是768。
      • input_channels: 输入图像的通道数,默认是3(适用于RGB图像)。
      • self.num_patches: 计算补丁的数量,通过将图像大小除以补丁大小并平方来得到。
      • self.conv: 使用卷积层将输入特征映射到嵌入维度,卷积核大小和步幅都设置为补丁大小。
  2. construct** 方法**:
    • **输入 **x: 表示输入的图像张量。
    • x = self.conv(x): 将输入通过卷积层处理,得到具有嵌入维度的输出特征。
    • b, c, h, w = x.shape: 获取输出张量的形状,分别为批量大小、通道数、高度和宽度。
    • x = ops.reshape(x, (b, c, h * w)): 重新调整输出张量的形状,使其变为(batch_size, channels, num_patches),这里的num_patchesh * w
    • x = ops.transpose(x, (0, 2, 1)): 转置张量的维度,将其形状调整为(batch_size, num_patches, channels),以便后续处理。
  • nn.Cell:
    • MindSpore中所有模型的基类,提供构建神经网络所需的基础功能。
  • nn.Conv2d:
    • 2D卷积层,通常用于处理图像数据。参数包括输入通道数、输出通道数(嵌入维度)、卷积核大小和步幅。
  • ops.reshape:
    • 用于重新调整张量的形状。
  • ops.transpose:
    • 转置张量的维度,用于调整张量的排列顺序,以便于下游的网络结构处理。

通过这样的设计,PatchEmbedding类将输入图像切分为多个补丁,并对这些补丁进行特征提取,形成适合后续Transformer模型处理的嵌入表示。

输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。

  1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
  2. 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
  3. pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。
  4. 由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。
总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。

整体构建ViT

以下代码构建了一个完整的ViT模型。

from mindspore.common.initializer import Normal  # 导入正态分布初始化器
from mindspore.common.initializer import initializer  # 导入初始化器方法
from mindspore import Parameter  # 导入Parameter类

def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()  # 使用指定的初始化类型生成初始值
    return Parameter(initial, name=name, requires_grad=requires_grad)  # 返回一个Parameter对象

class ViT(nn.Cell):  # 定义ViT类,继承自nn.Cell
    def __init__(self,
                 image_size: int = 224,  # 输入图像大小
                 input_channels: int = 3,  # 输入图像的通道数
                 patch_size: int = 16,  # 每个补丁的大小
                 embed_dim: int = 768,  # 嵌入特征的维度
                 num_layers: int = 12,  # Transformer层的数量
                 num_heads: int = 12,  # 注意力头的数量
                 mlp_dim: int = 3072,  # MLP隐藏层的维度
                 keep_prob: float = 1.0,  # dropout的保留概率
                 attention_keep_prob: float = 1.0,  # 注意力dropout的保留概率
                 drop_path_keep_prob: float = 1.0,  # drop path的保留概率
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认为GELU
                 norm: Optional[nn.Cell] = nn.LayerNorm,  # 归一化层,默认为LayerNorm
                 pool: str = 'cls') -> None:  # 池化方式,默认为'cls'
        super(ViT, self).__init__()  # 调用父类构造函数

        # 初始化各种层
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)  # 创建PatchEmbedding层
        num_patches = self.patch_embedding.num_patches  # 获取补丁数量

        # 初始化分类token
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        # 初始化位置嵌入
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool  # 设置池化方式
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)  # 定义位置嵌入的dropout层
        self.norm = norm((embed_dim,))  # 实例化归一化层
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)  # 创建Transformer编码器
        self.dropout = nn.Dropout(p=1.0-keep_prob)  # 定义最后的dropout层
        self.dense = nn.Dense(embed_dim, num_classes)  # 定义全连接层,映射到类的数量

    def construct(self, x):  # 定义前向传播
        """ViT construct."""
        x = self.patch_embedding(x)  # 通过PatchEmbedding处理输入
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))  # 复制分类token以匹配批次大小
        x = ops.concat((cls_tokens, x), axis=1)  # 将分类token与补丁嵌入拼接在一起
        x += self.pos_embedding  # 添加位置嵌入

        x = self.pos_dropout(x)  # 应用位置嵌入的dropout
        x = self.transformer(x)  # 通过Transformer编码器
        x = self.norm(x)  # 应用归一化层
        x = x[:, 0]  # 选择分类token的输出(第一个token)

        if self.training:  # 如果处于训练模式
            x = self.dropout(x)  # 应用最后的dropout

        x = self.dense(x)  # 通过全连接层映射到输出类

        return x  # 返回最终输出
  1. **初始化函数 **init:
    • initializer(init_type, shape, dtype): 根据指定类型生成权重的初始值。
    • Parameter: 将生成的初始值包装为可训练的参数,并设置参数名和梯度要求。
  2. ViT** 类**:
    • **构造函数 **__init__:
      • image_size: 输入图像的大小。
      • input_channels: 输入图像的通道数,适用于RGB图像。
      • patch_size: 每个补丁的大小。
      • embed_dim: 嵌入特征的维度。
      • num_layers: Transformer编码器的层数。
      • num_heads: 注意力头的数量。
      • mlp_dim: MLP中隐藏层的维度。
      • keep_probattention_keep_probdrop_path_keep_prob: 控制dropout的概率。
      • activation: 激活函数,默认为GELU。
      • norm: 归一化层,默认为LayerNorm。
      • pool: 池化策略,默认为使用分类token。
  3. 层的初始化:
    • self.patch_embedding: 实例化PatchEmbedding类,用于图像分块嵌入。
    • self.cls_tokenself.pos_embedding: 使用正态分布初始化分类token和位置嵌入。
    • self.transformer: 实例化TransformerEncoder,用于堆叠多个Transformer层。
    • self.dropoutself.dense: 定义用于最终分类的dropout和全连接层。
  4. construct** 方法**:
    • x = self.patch_embedding(x): 将输入图像通过补丁嵌入层处理。
    • cls_tokens: 复制分类token以适应输入的批量大小。
    • x = ops.concat((cls_tokens, x), axis=1): 将分类token添加到补丁嵌入中。
    • x += self.pos_embedding: 将位置嵌入添加到输入中。
    • x = self.transformer(x): 通过Transformer编码器处理输入。
    • x = x[:, 0]: 选择分类token的输出作为最终的特征表示。
    • x = self.dense(x): 通过全连接层得到最终分类结果。
  • ms.float32:
    • MindSpore中的数据类型,用于指定浮点数类型。
  • nn.Dense:
    • 全连接层,用于将输入特征映射到输出类别数。
  • ops.tile:
    • 用于在指定维度上复制数组,适合批处理场景。
  • ops.concat:
    • 用于在指定维度上连接多个张量。
  • nn.Dropout:
    • 随机丢弃一定比例的神经元,以减少过拟合。

通过这样的设计,ViT类构建了一个Vision Transformer模型,将输入图像通过多个层次的处理后,输出最终的分类结果。

整体流程图如下所示:

模型训练与推理

模型训练

模型开始训练前,需要设定损失函数,优化器,回调函数等。
完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。

from mindspore.nn import LossBase  # 导入LossBase类,用于定义损失函数
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint  # 导入训练监控和检查点相关的类
from mindspore import train  # 导入训练模块

# 定义超参数
epoch_size = 10  # 训练的轮数
momentum = 0.9  # 动量系数
num_classes = 1000  # 类别数量
resize = 224  # 输入图像的大小
step_size = dataset_train.get_dataset_size()  # 获取训练集每个epoch的步数

# 构造模型
network = ViT()  # 实例化ViT模型

# 加载预训练模型权重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"  # 预训练模型的URL
path = "./ckpt/vit_b_16_224.ckpt"  # 本地保存路径

vit_path = download(vit_url, path, replace=True)  # 下载预训练模型
param_dict = ms.load_checkpoint(vit_path)  # 加载检查点
ms.load_param_into_net(network, param_dict)  # 将检查点参数加载到模型中

# 定义学习率
lr = nn.cosine_decay_lr(min_lr=float(0),  # 最小学习率
                        max_lr=0.00005,  # 最大学习率
                        total_step=epoch_size * step_size,  # 总步数
                        step_per_epoch=step_size,  # 每个epoch的步数
                        decay_epoch=10)  # 衰减周期

# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)  # Adam优化器

# 定义损失函数
class CrossEntropySmooth(LossBase):
    """CrossEntropy with Label Smoothing."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()  # 调用父类构造函数
        self.onehot = ops.OneHot()  # 实例化OneHot操作
        self.sparse = sparse  # 是否使用稀疏标签
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)  # 标签平滑后的正值
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)  # 标签平滑后的负值
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)  # 交叉熵损失

    def construct(self, logit, label):
        # 计算损失
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)  # 将标签转为one-hot格式
        loss = self.ce(logit, label)  # 计算交叉熵损失
        return loss  # 返回损失值


network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)  # 实例化平滑交叉熵损失

# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)  # 检查点配置
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)  # 检查点回调

# 初始化模型
# "Ascend + mixed precision" 可以提高性能
ascend_target = (ms.get_context("device_target") == "Ascend")  # 判断是否在Ascend设备上
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")  # 使用混合精度
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")  # 不使用混合精度

# 训练模型
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],  # 添加监控回调
            dataset_sink_mode=False)  # 不使用数据集沉浸模式
  1. 导入模块:
    • 从MindSpore导入所需的库和模块,包含损失函数基类、训练监控、检查点配置等功能。
  2. 超参数设置:
    • 定义训练的基本参数,包括训练轮数、动量、类别数、图像大小和每个epoch的步数。
  3. 模型构建:
    • 实例化ViT模型。
  4. 加载预训练模型权重:
    • 指定预训练模型的URL和本地保存路径,下载并加载检查点参数。
  5. 学习率设置:
    • 使用余弦衰减策略定义学习率,设置最小和最大学习率、总步数和衰减周期。
  6. 优化器定义:
    • 使用Adam优化器,传入模型的可训练参数和学习率。
  7. 损失函数CrossEntropySmooth:
    • 定义带标签平滑的交叉熵损失。根据参数选择稀疏标签或one-hot标签,计算损失值。
  8. 检查点配置:
    • 设置检查点的保存步骤和最大保存数量,实例化检查点回调。
  9. 模型初始化:
    • 根据设备类型选择是否使用混合精度。创建train.Model对象。
  10. 训练模型:
    • 调用模型的train方法,进行训练,设置监控回调和数据集沉浸模式。
  • LossBase:
    • MindSpore中所有损失函数的基类,提供了构建自定义损失函数的基础。
  • nn.Adam:
    • Adam优化器,用于模型参数的优化。
  • ops.OneHot:
    • 操作用于将标签转换为one-hot编码格式。
  • nn.SoftmaxCrossEntropyWithLogits:
    • 计算softmax交叉熵损失的函数。
  • CheckpointConfig:
    • 检查点配置类,用于定义保存检查点的策略。
  • ModelCheckpoint:
    • 用于创建模型检查点保存的回调。
  • train.Model:
    • 用于将模型、损失函数和优化器封装在一起,方便进行训练、评估等操作。

通过这样的设计,代码实现了一个完整的Vision Transformer(ViT)模型的训练流程,包括模型构建、权重加载、优化器设置、损失计算与监控。

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。
ImageFolderDataset主要用于读取数据集。
CrossEntropySmooth是损失函数实例化接口。
Model主要用于编译模型。
与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。
在本案例中,这两个指标代表了在输出的1000维向量中,以最大值或前5的输出值所代表的类别为预测结果时,模型预测的准确率。这两个指标的值越大,代表模型准确率越高。

# 导入必要的库
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)  # 加载验证集,随机打乱数据

# 定义验证集的预处理转换
trans_val = [
    transforms.Decode(),  # 解码图像
    transforms.Resize(224 + 32),  # 将图像调整为尺寸为256(224 + 32)
    transforms.CenterCrop(224),  # 进行中心裁剪,得到224x224的图像
    transforms.Normalize(mean=mean, std=std),  # 进行归一化处理,使用指定的均值和标准差
    transforms.HWC2CHW()  # 将图像的形状从HWC转换为CHW(通道优先)
]

# 应用转换到数据集
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])  # 对图像列进行转换
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)  # 将数据集按批处理,批大小为16

# 构造模型
network = ViT()  # 实例化ViT模型

# 加载检查点
param_dict = ms.load_checkpoint(vit_path)  # 加载预训练模型的参数
ms.load_param_into_net(network, param_dict)  # 将参数加载到网络中

# 定义损失函数
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)  # 实例化平滑交叉熵损失

# 定义评估指标
eval_metrics = {
    'Top_1_Accuracy': train.Top1CategoricalAccuracy(),  # 计算Top-1准确率
    'Top_5_Accuracy': train.Top5CategoricalAccuracy()   # 计算Top-5准确率
}

# 初始化模型
if ascend_target:  # 判断是否在Ascend设备上
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")  # 使用混合精度
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")  # 不使用混合精度

# 评估模型
result = model.eval(dataset_val)  # 在验证集上评估模型
print(result)  # 打印评估结果
  1. 加载验证集:
    • 通过ImageFolderDataset加载验证集数据,指定路径并设置为随机打乱,便于后续训练或评估。
  2. 数据预处理:
    • transforms.Decode(): 解码图像数据。
    • transforms.Resize(224 + 32): 将图像大小调整为256,以保证中心裁剪的效果。
    • transforms.CenterCrop(224): 从中心裁剪出224x224的图像,以适应模型的输入尺寸。
    • transforms.Normalize(mean=mean, std=std): 对图像进行标准化,使用指定的均值和标准差进行归一化处理。
    • transforms.HWC2CHW(): 将图像的通道顺序从高度-宽度-通道(HWC)转换为通道-高度-宽度(CHW),以匹配模型输入格式。
  3. 应用转换:
    • 使用map方法将定义好的转换操作应用到数据集的图像列,并使用batch方法将数据集按批次处理,设置批大小为16。
  4. 模型构建:
    • 实例化ViT模型。
  5. 加载预训练权重:
    • 使用load_checkpoint方法加载预训练模型的参数,并将其导入到网络中。
  6. 定义损失函数:
    • 实例化CrossEntropySmooth,设置稀疏标签、损失计算的方式、平滑因子和类别数量。
  7. 评估指标定义:
    • 使用MindSpore内置的Top-1和Top-5分类准确率计算指标,便于后续评估模型性能。
  8. 模型初始化:
    • 根据设备类型选择是否使用混合精度,构造train.Model对象,准备进行评估。
  9. 模型评估:
    • 使用eval方法在验证集上评估模型性能,并将结果打印出来。
  • ImageFolderDataset:
    • 通过指定的文件夹路径加载图像数据集,适用于图像分类任务。
  • transforms:
    • 数据预处理模块,包含多种图像处理操作,如解码、调整大小、裁剪和归一化。
  • dataset.map:
    • 对数据集中的指定列应用转换操作。
  • dataset.batch:
    • 将数据集按指定的批量大小分割成多个小批次。
  • train.Model:
    • MindSpore中用于训练和评估模型的高层接口。
  • model.eval:
    • 在指定的数据集上评估模型的性能,返回评估结果。

通过这样的设计,代码实现了在验证集上评估Vision Transformer(ViT)模型的功能,包括数据加载、预处理、模型构建、权重加载和性能评估。

{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}

从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。
本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。

# 导入必要的库
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)  # 加载推理数据集,随机打乱数据

# 定义推理数据集的预处理转换
trans_infer = [
    transforms.Decode(),  # 解码图像
    transforms.Resize([224, 224]),  # 调整图像大小为224x224
    transforms.Normalize(mean=mean, std=std),  # 进行归一化处理,使用指定的均值和标准差
    transforms.HWC2CHW()  # 将图像的形状从HWC转换为CHW(通道优先)
]

# 应用转换到推理数据集
dataset_infer = dataset_infer.map(operations=trans_infer,  # 对图像列进行转换
                                   input_columns=["image"],
                                   num_parallel_workers=1)  # 设置并行工作者数量为1
dataset_infer = dataset_infer.batch(1)  # 将推理数据集按批处理,批大小为1
  1. 加载推理数据集:
    • 使用ImageFolderDataset加载推理集数据,指定路径并设置为随机打乱,以增加数据的多样性。
  2. 数据预处理:
    • transforms.Decode(): 解码图像,准备进行后续处理。
    • transforms.Resize([224, 224]): 将图像调整为224x224的大小,以符合模型输入要求。
    • transforms.Normalize(mean=mean, std=std): 对图像进行归一化处理,使用给定的均值和标准差,以适应模型的训练标准。
    • transforms.HWC2CHW(): 将图像的维度从高度-宽度-通道(HWC)转换为通道-高度-宽度(CHW),以符合模型输入格式。
  3. 应用转换:
    • 使用map方法将定义好的转换操作应用到推理数据集的图像列,指定使用一个工作进程进行处理(num_parallel_workers=1)。
  4. 批处理:
    • 使用batch(1)方法将推理数据集按批次处理,设置每个批次大小为1,以方便逐个图像进行推理。
  • ImageFolderDataset:
    • 通过指定的文件夹路径加载图像数据集,适用于图像分类、推理等任务。
  • transforms:
    • 数据预处理模块,包含多种图像处理操作,如解码、调整大小、归一化和通道转换。
  • dataset.map:
    • 对数据集中的指定列应用转换操作,可以通过设置num_parallel_workers控制并行处理的工作者数量。
  • dataset.batch:
    • 将数据集按指定的批量大小分割成多个小批次,以便于后续模型推理或训练。

通过这样的设计,代码实现了对推理数据集的准备,包括加载、预处理和批处理,适用于后续的模型推理操作。

接下来,我们将调用模型的predict方法进行模型。
在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io


class Color(Enum):
    """定义颜色枚举."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def check_file_exist(file_name: str):
    """检查文件是否存在."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def color_val(color):
    """获取颜色值."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imread(image, mode=None):
    """读取图像."""
    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass  # 如果是ndarray,直接使用
    elif isinstance(image, str):
        check_file_exist(image)  # 检查文件是否存在
        image = Image.open(image)  # 打开图像文件
        if mode:
            image = np.array(image.convert(mode))  # 转换为指定模式
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """保存图像."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=0o777, exist_ok=True)  # 创建目录

    image = Image.fromarray(image)  # 将numpy数组转换为Image对象
    image.save(image_path)  # 保存图像


def imshow(img, win_name='', wait_time=0):
    """显示图像."""
    cv2.imshow(win_name, imread(img))  # 使用OpenCV显示图像
    if wait_time == 0:  # 防止窗口关闭时程序挂起
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            if closed or ret != -1:  # 如果窗口被关闭或按下任意键
                break
    else:
        ret = cv2.waitKey(wait_time)  # 等待指定时间


def show_result(img: str,
                result: dict,
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """在图像上标记预测结果."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)  # 保存结果图像

    if show:
        imshow(img, win_name, wait_time)  # 显示结果图像


def index2label():
    """返回ImageNet数据集的类别映射字典."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping


# 读取推理数据
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]  # 提取图像数据
    image = ms.Tensor(image)  # 转换为Tensor
    prob = model.predict(image)  # 进行预测
    label = np.argmax(prob.asnumpy(), axis=1)  # 获取预测标签
    mapping = index2label()  # 获取类别映射
    output = {int(label): mapping[int(label)]}  # 构建输出信息
    print(output)  # 打印输出

    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")  # 显示并保存结果图像
  1. 导入模块:
    • 导入所需的库,包括文件操作、图像处理、数组操作、枚举类和MATLAB数据读取。
  2. 定义颜色枚举:
    • 使用Enum定义常用颜色的枚举,方便在代码中引用。
  3. 文件存在性检查:
    • check_file_exist函数检查指定文件是否存在,如果不存在则抛出异常。
  4. 颜色值转换:
    • color_val函数根据输入类型返回对应的颜色值,支持字符串、枚举、元组、整数和NumPy数组。
  5. 图像读取:
    • imread函数支持从文件路径、NumPy数组或Path对象中读取图像,并可选择性地转换图像模式。
  6. 图像保存:
    • imwrite函数将NumPy数组转换为图像并保存,支持自动创建保存目录。
  7. 图像显示:
    • imshow函数使用OpenCV展示图像,处理窗口关闭的情况。
  8. 结果显示:
    • show_result函数在图像上标记预测结果,并可选择保存或显示结果图像。
  9. 类别映射:
    • index2label函数读取ImageNet的元数据,构建类别索引与名称的映射字典。
  10. 推理过程:
    • 遍历推理数据集,提取图像数据,进行预测,获取预测标签,并使用映射字典生成输出结果。最后,调用show_result显示并保存结果图像。
  • os:
    • 提供与操作系统交互的功能,如文件和目录的操作。
  • pathlib:
    • 提供面向对象的文件路径处理方式。
  • cv2:
    • OpenCV库,用于图像处理和计算机视觉任务。
  • numpy:
    • 支持大规模多维数组与矩阵运算的库,提供许多高级数学函数。
  • PIL.Image:
    • Python Imaging Library,用于图像创建、打开、处理和保存。
  • Enum:
    • 枚举类,用于定义一组命名的常量。
  • scipy.io:
    • SciPy库中的输入输出模块,用于加载MATLAB文件。

通过这样的设计,代码实现了图像处理、推理和结果标记的完整流程,适用于图像分类和可视化任务。

{236: 'Doberman'}

推理过程完成后,在推理文件夹下可以找到图片的推理结果,可以看出预测结果是Doberman,与期望结果相同,验证了模型的准确性。

总结

本案例完成了一个ViT模型在ImageNet数据上进行训练,验证和推理的过程,其中,对关键的ViT模型结构和原理作了讲解。通过学习本案例,理解源码可以帮助用户掌握Multi-Head Attention,TransformerEncoder,pos_embedding等关键概念,如果要详细理解ViT的模型原理,建议基于源码更深层次的详细阅读。

整体代码

#!/usr/bin/env python
# coding: utf-8

# # Vision Transformer图像分类

# ## Vision Transformer(ViT)简介
# 
# ViT是自然语言处理和计算机视觉两个领域的融合结晶,在不依赖卷积操作的情况下,能在图像分类任务上取得优异效果。

# In[1]:
from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)

# In[ ]:
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# 加载训练集
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

# 定义训练集的数据增强转换
trans_train = [
    transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

# 应用转换并批处理
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

# ## 模型解析
# Transformer模型的基本原理是多头注意力(Multi-Head Attention)机制。

# In[ ]:
from mindspore import nn, ops

class Attention(nn.Cell):
    def __init__(self, dim: int, num_heads: int = 8, keep_prob: float = 1.0, attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0 - attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0 - keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        """Attention构造函数."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

# ### Transformer Encoder
# 将Self-Attention与Feed Forward,Residual Connection结合形成Transformer的基础结构。

# In[ ]:
from typing import Optional

class FeedForward(nn.Cell):
    def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, activation: nn.Cell = nn.GELU, keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0 - keep_prob)

    def construct(self, x):
        """Feed Forward构造函数."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)

        return x

class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        """ResidualCell构造函数."""
        return self.cell(x) + x

class TransformerEncoder(nn.Cell):
    def __init__(self, dim: int, num_layers: int, num_heads: int, mlp_dim: int, keep_prob: float = 1., attention_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim, num_heads=num_heads, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim, hidden_features=mlp_dim, activation=activation, keep_prob=keep_prob)

            layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])), ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))

        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Transformer构造函数."""
        return self.layers(x)

# ### ViT模型的输入
# 将输入图像划分为patch并进行嵌入。

# In[ ]:
class PatchEmbedding(nn.Cell):
    def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        """Patch Embedding构造函数."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))

        return x

# ### 整体构建ViT
# 以下代码构建了一个完整的ViT模型。

# In[ ]:
from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter

def init(init_type, shape, dtype, name, requires_grad):
    """初始化函数."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)

class ViT(nn.Cell):
    def __init__(self, image_size: int = 224, input_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, keep_prob: float = 1.0, attention_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: Optional[nn.Cell] = nn.LayerNorm, pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0), shape=(1, 1, embed_dim), dtype=ms.float32, name='cls', requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0), shape=(1, num_patches + 1, embed_dim), dtype=ms.float32, name='pos_embedding', requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0 - keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob, activation=activation, norm=norm)
        self.dropout = nn.Dropout(p=1.0 - keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        """ViT构造函数."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

# ## 模型训练与推理

# ### 模型训练
# 首先设定损失函数,优化器等。

# In[ ]:
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# 定义超参数
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# 构造模型
network = ViT()

# 加载检查点
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"

vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# 定义学习率
lr = nn.cosine_decay_lr(min_lr=float(0), max_lr=0.00005, total_step=epoch_size * step_size, step_per_epoch=step_size, decay_epoch=10)

# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

# 定义损失函数
class CrossEntropySmooth(LossBase):
    """平滑交叉熵损失."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss

network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)

# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# 训练模型
model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)], dataset_sink_mode=False)

# ### 模型验证
# 验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。

# In[ ]:
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# 构造模型
network = ViT()

# 加载检查点
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes)

# 定义评价指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")

# 评估模型
result = model.eval(dataset_val)
print(result)

# ### 模型推理
# 定义推理图片的数据预处理方法。

# In[ ]:
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"], num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

# 自定义推理结果显示函数

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io


class Color(Enum):
    """定义颜色枚举."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def check_file_exist(file_name: str):
    """检查文件是否存在."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def color_val(color):
    """获取颜色值."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imread(image, mode=None):
    """读取图像."""
    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """保存图像."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=0o777, exist_ok=True)

    image = Image.fromarray(image)
    image.save(image_path)


def imshow(img, win_name='', wait_time=0):
    """显示图像."""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # 防止窗口关闭时程序挂起
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            if closed or ret != -1:  # 如果窗口被关闭或按下任意键
                break
    else:
        ret = cv2.waitKey(wait_time)  # 等待指定时间


def show_result(img: str,
                result: dict,
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """在图像上标记预测结果."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite
(img, out_file)

    if show:
        imshow(img, win_name, wait_time)

def index2label():
    """返回ImageNet数据集的类别映射字典."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping

# 读取推理数据
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)  # 转换为Tensor
    prob = model.predict(image)  # 进行模型预测
    label = np.argmax(prob.asnumpy(), axis=1)  # 获取预测标签
    mapping = index2label()  # 获取类别映射
    output = {int(label): mapping[int(label)]}  # 构建输出信息
    print(output)  # 打印输出结果

    # 显示推理结果
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
  1. 显示推理结果的函数:
    • show_result函数在指定的图片上标记出预测结果,并可以选择保存结果图像。
  2. 类别映射的获取:
    • index2label函数从存储的meta.mat文件中提取类别名称与索引的映射关系,方便在推理时将模型输出的标签转换为具体的类别名称。
  3. 推理过程:
    • 通过遍历推理数据集,提取每张图像,转换为适合模型输入的Tensor格式。
    • 使用模型的predict方法进行预测,得到的概率分布通过argmax获取预测的类别标签。
    • 通过index2label函数将标签转换为类别名称,并打印输出结果。
    • 调用show_result函数在图像上显示预测结果,并选择是否保存结果图像。
  • ImageFolderDataset:
    • 用于加载图像数据集,支持从文件夹中读取图像及其标签。
  • opsnn:
    • MindSpore中的基本运算和神经网络层构建模块。
  • Tensor:
    • 用于表示和操作多维数组的数据结构,支持GPU加速。
  • argmax:
    • NumPy函数,用于返回最大值的索引,常用于分类模型的结果处理。
  • imshowimreadimwrite:
    • 用于图像的读取、显示和保存,常用的图像处理函数。

通过这个完整的流程,展示了ViT模型在图像分类任务上的应用,包括数据准备、模型训练和推理等过程,帮助用户深入理解ViT的实现和使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1939284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

渗透测试靶机---Kioptrix5

渗透测试靶机—Kioptrix5 启动靶机&#xff0c;扫描ip&#xff0c;平平无奇 扫描 惯例&#xff0c;访问80&#xff0c;先看看 好像是没什么内容&#xff0c;查看页面源代码 搜素这个页面的框架&#xff1a; 直接拉下来查看就行 这里存在一个路径穿越 这里就暴露出来了更…

单片机原理及技术(五)—— 单片机与开关、键盘以及显示器件的接口设计(C51编程)

目录 一、单片机控制发光二极管显示 1.1 单片机与发光二极管的连接 1.2 拉电流和灌电流 1.3 I/O端口的编程控制 二、开关状态检测 2.1 开关控制单个LED灯亮灭 三、单片机控制LED数码管的显示 3.1 LED数码管的显示原理 3.2 LED数码管的静态显示与动态显示 3.2.1 静态显…

用不同的url头利用Python访问一个网站,把返回的东西保存为txt文件

这个需要调用requests模块&#xff08;相当于c的头文件&#xff09; import requests 还需要一个User-Agent头&#xff08;这个意思就是告诉python用的什么系统和浏览器&#xff09; Google Chrome&#xff08;Windows&#xff09;: Mozilla/5.0 (Windows NT 10.0; Win64; x64…

软件质量模型、生命周期模型、测试过程模型

目录 测试用例 定义 常见测试用例的核心8要素 软件质量模型 质量模型标准 软件开发过程模型&#xff08;软件生命周期模型&#xff09; 瀑布模型 软件测试过程模型 V模型 W模型 测试用例 定义 测试用例&#xff0c;也叫Test Case&#xff0c;为了特定的目的而设计…

【Apache Doris】数据副本问题排查指南

【Apache Doris】数据副本问题排查指南 一、问题现象二、问题定位三、问题处理 本文主要分享Doris中数据副本异常的问题现象、问题定位以及如何处理此类问题。 一、问题现象 问题日志 查询报错 Failed to initialize storage reader, tablet{tablet_id}.xxx.xxx问题说明 查…

使用Python的Turtle库绘制草莓熊

引言 Turtle库是Python标准库中一个非常有趣且实用的模块&#xff0c;它主要用于绘制图形和动画。Turtle图形学源于Logo语言&#xff0c;是一种基于命令的绘图方式。通过控制一个名为“海龟”的虚拟角色&#xff0c;在屏幕上移动和绘制&#xff0c;Turtle库可以轻松地教授基础…

IDEA工具中Java语言写小工具遇到的问题

一&#xff1a;读取excel时遇到 org/apache/poi/ss/usermodel/WorkbookProvider 解决办法&#xff1a; 在pom.xml中把poi的引文包放在最前面即可&#xff08;目前就算放在最后面也不报错了&#xff0c;不知道为啥&#xff09; 二&#xff1a;本地maven打包时&#xff0c;没有…

React基础学习-Day08

React基础学习-Day08 React生命周期&#xff08;旧&#xff09;&#xff08;新&#xff09;&#xff08;函数组件&#xff09; &#xff08;旧&#xff09; 在 React 16 版本之前&#xff0c;React 使用了一套不同的生命周期方法。这些生命周期方法在 React 16 中仍然可以使用…

【人工智能】Python实现文本转换为语音:使用gTTS库实现

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、引言二、准备工作三、使用gTTS实现文本转换为语音详细步骤 四、人工智能与TTS技术五、总结 一、引言 文本转换为语音&#xff08;Text-to-Speech&#xff0c;简称TTS&#xff09;技术是人工智能的重要组成部分&#xf…

2024年7月萤火虫航天为NASA发射8颗立方体卫星

作为美国宇航局立方体卫星发射计划的一部分&#xff0c;萤火虫航空航天公司于7月3日在该公司的阿尔法火箭上发射了八颗小型卫星。这枚名为“夏日噪音”的火箭于太平洋夏令时&#xff08;PDT&#xff09;晚上9点04分从加利福尼亚州范登堡空军基地的2号航天发射场成功升空。 立方…

SpringBoot整合SSE,实现后端主动推送DEMO

前言 说起服务端主动推送&#xff0c;大家第一个想到的一定是WEBSOCKET 。 作为软件工程师&#xff0c;不能无脑使用一种技术&#xff0c;要结合实际情况&#xff0c;择优选取。 SSE&#xff08;Server-Sent Events&#xff09;相比于WEBSOCKET 1、轻量化、兼容性 基于传统…

Mac装虚拟机占内存吗 Mac用虚拟机装Windows流畅吗

如今&#xff0c;越来越多的Mac用户选择在他们的设备上安装虚拟机来运行不同的操作系统。其中&#xff0c;最常见的是使用虚拟机在Mac上运行Windows。然而&#xff0c;许多人担心在Mac上装虚拟机会占用大量内存&#xff0c;影响电脑系统性能。此外&#xff0c;有些用户还关心在…

抖音火爆 百度地图导航高阶定制茉莉13个语音包附带安装教程,开车再也不会犯困了

慎用&#xff0c;慎用&#xff01; 1、工具下载&#xff1a; 百度导航高阶定制茉莉13个语音包https://pan.quark.cn/s/8669c1dad02a下载 | MT管理器&#xff1a;https://pan.quark.cn/s/b7b8e8f16326 2、语音包路径&#xff1a; 百度导航路径&#xff1a; /storage/emulate…

【LLM】-04-提示工程 - 文本转换

目录 1、文本翻译 1.1、翻译为德语 1.2、识别语种 1.3、多语言翻译 1.4、同时进行语气转换 1.5、通用翻译器 2、语气与写作风格调整 3、数据格式转换 4、拼写及语法纠正 5、综合样例 大语言模型具有强大的文本转换能力&#xff0c;可以实现多语言翻译、拼写纠正、语法…

数据结构 - 栈(精简介绍)

文章目录 普通栈Stack用法Q 最长有效括号 单调栈Q 接雨水 普通栈 栈就是一个先进后出的结构 想象一个容器&#xff0c;往里面一层一层放东西&#xff0c;最早放进去的东西被压在下面&#xff08;所以放元素也叫压栈&#xff09;&#xff0c;要拿到这个最低层的东西需要先把上面…

异步电机矢量控制matlab simulink

1、内容简介 略 86-可以交流、咨询、答疑 异步电机、矢量控制 2、内容说明 略 3、仿真分析 略 4、参考论文 略

[Python库](3) Arrow库

目录 1.简介 2.安装 3.函数 3.1.获取当前UTC时间( 世界协调时时间 ) 3.2.格式化日期 3.3.创建Arrow对象 3.4.时间改变 3.5.获取时间戳 3.6.时区改变 4.小结 1.简介 Arrow库是一个Python库&#xff0c;提供了一套用于处理日期和时间的API。Arrow库特别适合在需要进行大…

C++搜索算法(dfs)

目录 一.dfs简介 二.dfs的运用 1.迷宫问题 经典题型&#xff1a;最快走出迷宫 题目描述&#xff1a; 数据范围&#xff1a; 题目分析&#xff1a; 正确代码 2.棋盘问题&#xff1a; 经典题型&#xff1a;八皇后问题 题目描述&#xff1a; 题目分析&#xff1a; 正…

微服务实战系列之玩转Docker(五)

前言 在我们日常的工作生活中&#xff0c;经常听到的一句话&#xff1a;“是骡子是马拉出来遛遛”。目的是看一个人/物是不是名副其实。我们在使用docker时&#xff0c;也要看看它究竟是如何RUN起来的。当面试官问你的时候&#xff0c;可以如是回答&#xff0c;保你“一文通关…

SQUID - 形状条件下的基于分子片段的3D分子生成等变模型 评测

SQUID 是一个形状条件下基于片段的3D分子生成模型&#xff0c;给一个3D参考分子&#xff0c;SQUID 可以根据参考分子的形状&#xff0c;基于片段库&#xff0c;生成与参考分子形状非常相似的分子。 SQUID 模型来自于 ICLR 2023 文章&#xff08;2022年10月6日提交&#xff09;&…