昇思MindSpore学习笔记6-06计算机视觉--Vision Transormer图像分类

news2024/9/9 5:28:00

摘要:

        记录MindSpore AI框架使用ViT模型在ImageNet图像数据分类上进行训练、验证、推理的过程和方法。包括环境准备、下载数据集、数据集加载、模型解析与构建、模型训练与推理等。

一、

1. ViT模型

Vision Transformer

自注意结构模型

Self-Attention

        Transformer模型

                能够训练具有超过100B规模的参数模型

领域

        自然语言处理

        计算机视觉

不依赖卷积操作

2.模型结构

ViT模型主体结构

从下往上

最下面主输入数据集

        原图像划分为多个patch(图像块)

                二维patch(不考虑channel)转换为一维向量

中间backbone基于Transformer模型Encoder部分

        Multi-head Attention结构

        部分结构顺序有调整

                Normalization位置不同

上面Blocks堆叠后接全连接层Head

附加输入类别向量

输出识别分类结果

二、环境准备

确保安装了Python环境和MindSpore

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据准备

1.下载、解压数据集

下载源

http://image-net.org

ImageNet数据集

本案例应用数据集是从ImageNet筛选的子集。

from download import download
​
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
​
path = download(dataset_url, path, kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)

file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

2.数据集路径结构

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

3.加载数据集

import os
​
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
​
​
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
​
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
​
trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

四、模型解析

1.Transformer基本原理

Transformer模型

基于Attention机制的编码器-解码器型结构

模型结构图:

多个Encoder和Decoder模块所组成

Encoder和Decoder详细结构图:

Encoder与Decoder结构组成

多头注意力Multi-Head Attention层

    基于自注意力Self-Attention机制

    多个Self-Attention并行组成

Feed Forward层

Normaliztion层

残差连接(Residual Connection),图中的“Add”

2.Attention模块

Self-Attention核心内容

为输入向量的每个单词学习一个权重

        给定查询向量Query

        计算Query和各个Key的相似性或者相关性

                得到注意力分布

                得到每个Key对应Value的权重系数

        对Value进行加权求和得到最终的Attention数值。

Self-Attention机制:

(1) 最初的输入向量

经过Embedding层

        映射成dim x 3

        分割成三个向量

                Q(Query)

                K(Key)

                V(Value)

输入向量为一个一维向量序列(x1,x2,x3)

每个一维向量经过Embedding层映射出Q、K、V三个向量

        只是Embedding矩阵不同

        矩阵参数通过学习得到

向量之间关联

通过Q、K、V三个矩阵可计算

其中两个向量点乘获得权重

另一个向量承载权重向加的结果

(2) 自注意力机制的自注意主要体现

Q、K、V来源于其自身

自注意过程

        提取输入的不同顺序的向量的联系与特征

        通过不同顺序向量之间的联系紧密性表现

                Q与K乘积经过Softmax的结果

获取Q,K,V向量间权重

        Q、K点乘

        除以维度的平方根

        Softmax处理所有向量的结果

(3) 全局自注意

向量V与Q、K经过Softmax结果

        weight sum

每一组Q、K、V最后都有一个V输出

当前向量结合其他向量关联权重得到结果

Self-Attention全部过程:

多头注意力机制

分割self-Attention处理的向量为多个Head部分处理

        并行加速

        保持参数总量不变

同样的query, key和value映射为高维空间(Q,K,V)

        不同子空间(Q_0,K_0,V_0)

        分开计算自注意力

        最后再合并不同子空间中的注意力信息。

同一个输入向量

多个注意力机制可以并行加速处理

处理时更充分的分析和利用了向量特征

下图中ai和aj是同一个向量分割而得

以下是Multi-Head Attention代码:

from mindspore import nn, ops
​
class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
​
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
​
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)
​
    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
​
        return out

Transformer Encoder

多结构拼接形成Transformer基础结构

Self-Attention

Feed Forward

Residual Connection

Feed Forward,Residual Connection结构代码:

from typing import Optional, Dict
​
class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
​
    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
​
        return x
​
class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell
​
    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

Self-Attention构建ViT模型中的TransformerEncoder部分:

ViT模型Transformer不同

Normalization放在Self-Attention和Feed Forward之前

其他结构不变

Transformer结构图

多个子encoder堆叠构建模型编码器

ViT模型配置超参数num_layers

        确定堆叠层数

Residual Connection,Normalization的结构

保证信息经过深层处理不退化

增强模型泛化能力

TransformerEncoder结构和多层感知器(MLP)结合

构成了ViT模型的backbone部分

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
​
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)
​
            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)
​
            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)
​
    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

ViT模型的输入

传统的Transformer结构

处理自然语言领域的词向量

(Word Embedding or Word Vector),

词向量是一维向量堆叠

图片是二维矩阵堆叠,

多头注意力机制处理一维词向量堆叠时会提取词向量之间的联系也就是上下文语义

ViT模型中:

输入图像每个channel卷积操作划分1616个patch

        一幅输入224 x 224的图像卷积处理

                得到16 x 16个patch

                每一个patch的大小就是14 x 14

每个patch矩阵拉伸成为一维向量

获得近似词向量堆叠的效果

        14 x 14patch转换为长度196的向量

图像输入网络经过的第一步处理。

Patch Embedding代码:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
​
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
​
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
​
    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
​
        return x

输入图像划分patch后

        经过pos_embedding

                class_embedding两个过程。

class_embedding借鉴BERT模型用于文本分类

每一个word vector之前增加一个类别值

196维向量加上class_embedding变为197维

class_embedding是一个可以学习的参数

经过网络的不断训练,输出向量的第一个维度的输出来决定最后的输出类别;

输入16 x 16patch

输出16x16个class_embedding进行分类。

pos_embedding也是一组可以学习的参数

        加入patch矩阵

pos_embedding有4种方案

        采用一维pos_embedding

        由于class_embedding是加在pos_embedding之前

        所以pos_embedding维度会比patch拉伸后的维度加1。

五、整体构建ViT

构建ViT模型代码

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
​
​
def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)
​
​
class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()
​
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
​
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)
​
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)
​
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)
​
    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
​
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
​
        return x

整体流程图如下所示:

六、模型训练与推理

1.模型训练

模型开始训练

设定损失函数

        优化器

        回调函数

调整epoch_size

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
​
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
​
# construct model
network = ViT()
​
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
​
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)
​
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
​
​
# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""
​
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
​
    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss
​
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
​
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
​
# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

输出:

Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)

file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms

2.模型验证

模型验证

ImageFolderDataset接口用于读取数据集

CrossEntropySmooth接口用于损失函数实例化

Model等接口用于编译模型

步骤:

数据增强

定义ViT网络结构

加载预训练模型参数

设置损失函数

设置评价指标

        Top_1_Accuracy输出最大值为预测结果

        Top_5_Accuracy输出前5的值为预测结果

        两个指标的值越大,代表模型准确率越高

编译模型

验证

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
​
trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
​
# construct model
network = ViT()
​
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
​
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
​
# evaluate model
result = model.eval(dataset_val)
print(result)

输出:

{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

3.模型推理

推理图片数据预处理

resize

normalize

匹配训练输入数据

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
​
trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

模型推理

调用模型predict方法

index2label获取对应标签

自定义show_result接口在对应图片上写结果

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
​
​
class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)
​
​
def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")
​
​
def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')
​
​
def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)
​
    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")
​
    return image
​
​
def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)
​
    image = Image.fromarray(image)
    image.save(image_path)
​
​
def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)
​
            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)
​
​
def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)
​
    if show:
        imshow(img, win_name, wait_time)
​
​
def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
​
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
​
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
​
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping
​
​
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

输出:

{236: 'Doberman'}

推理过程完成后

推理文件夹下找图片推理结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1918457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS3实现彩色变形爱心动画【附源码】

随着前端技术的发展&#xff0c;CSS3 为我们提供了丰富的动画效果&#xff0c;使得网页设计更加生动和有趣。今天&#xff0c;我们将探讨如何使用 CSS3 实现一个彩色变形爱心加载动画特效。这种动画不仅美观&#xff0c;而且可以应用于各种网页元素&#xff0c;比如加载指示器或…

【数据结构】线性表----队列详解

1. 队列的基本概念 话不多说&#xff0c;直接开始&#xff01; 队列是一种线性数据结构&#xff0c;同栈类似但又不同&#xff0c;遵循先进先出&#xff08;FIFO, First In First Out&#xff09;的原则。换句话说&#xff0c;最先进入队列的元素会最先被移除。这样的特点使得…

MyBatis拦截器在实际项目中的应用

MyBatis 是一个流行的 Java 持久层框架&#xff0c;它简化了数据库访问的复杂性&#xff0c;为开发者提供了强大的功能。其中&#xff0c;MyBatis 拦截器是一个非常有用的特性&#xff0c;可以帮助开发者灵活地解决各种问题。 一、MyBatis 拦截器 1.1 从执行 SQL 语句的核心流…

力扣爆刷第163天之TOP100五连刷81-85(回文链表、路径和、最长重复子数组)

力扣爆刷第163天之TOP100五连刷81-85&#xff08;回文链表、路径和、最长重复子数组&#xff09; 文章目录 力扣爆刷第163天之TOP100五连刷81-85&#xff08;回文链表、路径和、最长重复子数组&#xff09;一、234. 回文链表二、112. 路径总和三、169. 多数元素四、662. 二叉树…

sort命令

简介 sort是在Linux里非常常用的一个排序命令。将文件的每一行作为一个单位&#xff0c;从首字符向后&#xff0c;依次按ASCII码值进行比较&#xff0c;默认将他们按升序输出。 常用参数 -u &#xff1a;去除重复行 -r &#xff1a;降序排列&#xff0c;默认是升序 …

华为HCIP Datacom H12-821 卷36

1.单选题 在PIM- SM中&#xff0c;以下关于RP 的描述&#xff0c;错误的是哪一选项? A、在PIM-SM中&#xff0c;组播数据流量不一定必须经过RP的转发。 B、对于一个组播组来说&#xff0c;可以同时有多个RP地址&#xff0c;提升网络可靠性。 C、组播网络中&#xff0c;可以…

一篇文章带你解密最近爆火的消费增值模型!

今天&#xff0c;我非常激动地向您介绍一个令人振奋的成功故事。我们的合作伙伴在短短一个月内实现了业绩的飞跃&#xff0c;达到了百万级别的销售额&#xff0c;同时他们的用户活跃度也保持在极高的水平&#xff0c;平均每天有8万至10万的在线用户。这一成就的取得&#xff0c…

ARMV8安全特性:Pointer Authentication

文章目录 前言一、Introduction二、Problem Definition三、Pointer Authentication3.1 Instructions3.2 Cryptography3.3 Key Management 四、Sample Use Cases4.1 Software Stack Protection4.2 Control Flow Integrity (CFI)4.3 Binding Pointers to Addresses 五、Security …

十大优秀AI人工智能作词软件有哪些?

1、妙笔生词&#xff1a;国内专业智能作词工具&#xff0c;是一款非常优秀的国内作词软件&#xff0c;它可以选择语言&#xff0c;风格&#xff0c;韵脚一键生成歌词&#xff0c;也可以仿写歌词&#xff0c;可以续写歌词&#xff0c;可以智能取歌名&#xff0c;找优秀词句&…

华宇携TAS应用中间件亮相2024年山东江信智能信创产品推介会

信创产业是数据、网络安全的基础&#xff0c;也是“新基建”的重要内容&#xff0c;将成为拉动经济发展的重要抓手之一。 7月5日&#xff0c;以“信守时代机遇&#xff0c;创造辉煌未来”为主题的山东江信智能信创产品推介会在济南举办。本次产品推介会汇聚了国内众多信息技术…

windows sshkeygen 多平台添加配置

文章目录 .ssh目录生成新的ssh配置添加公钥到仓库验证 .ssh目录 windows下一般为&#xff1a;C:\Users\15237\.ssh &#xff0c;其中“15237”为当前登录用户 生成新的ssh .ssh目录下打开“Git Bash Here”&#xff08;如果没有&#xff0c;先安装 Git 软件&#xff09; 执…

阿一课代表今日分享之使用dnscat2 进行dns隧道反弹shell(直连模式linux对linux)

DNS介绍 DNS是域名系统(Domain Name System)的缩写&#xff0c;是因特网的一项核心服务&#xff0c;它作为可以将域名和IP地址相互映射的一个分布式数据库&#xff0c;能够使人更方便的访问互联网&#xff0c;而不用去记住能够被机器直接读取的IP数串。 DNS的记录类型有很多&a…

The First项目报告:创新型金融生态Lista DAO

一、Lista DAO是什么&#xff1f; LISTA是Lista DAO的原生加密协议代币&#xff0c;设计为一种可互操作的实用代币&#xff0c;旨在促进去中心化金融&#xff08;DeFi&#xff09;领域内的支付、治理与激励。LISTA的诞生源于Lista DAO项目&#xff0c;该项目是一个基于BNB链的…

IntelliJ IDEA 2024.1.4最新教程!!直接2099!!爽到飞起!!

IntelliJ IDEA 2024.1.4最新破解教程&#xff01;&#xff01;直接2099&#xff01;&#xff01;爽到飞起&#xff01;&#xff01;【资源在末尾】安装馆长为各位看官准备了多个版本&#xff0c;看官可根据自己的需求进行下载和选择安装。https://mp.weixin.qq.com/s/Tic1iR_Xc…

【一文带你了解RAG(检索增强生成) | 概念理论介绍+ 代码实操(含源码)】

文末有福利&#xff01; 引言 针对大型语言模型效果不好的问题&#xff0c;之前人们主要关注大模型再训练、大模型微调、大模型的Prompt增强&#xff0c;但对于专有、快速更新的数据却并没有较好的解决方法&#xff0c;为此检索增强生成&#xff08;RAG&#xff09;的出现&am…

【简历】南京某一本大学:JAVA简历指导,基本拿不到offer

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这份简历是一个一本计算机专业的同学。一本同学在校招的时候&#xff0c;要做好自己的求职层次定位&#xff0c;因为像工业类、邮电类、…

【JavaScript 报错】未捕获的URI错误:Uncaught URIError

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、错误原因分析1. 不合法的URI字符2. 不匹配的编码 二、解决方案1. 检查URI字符2. 使用try-catch块 三、实例讲解四、总结 Uncaught URIError 是JavaScript中常见的一种错误&#xff0c;通常发生在全局URI处理函数&#x…

前端如何取消接口调用

&#x1f9d1;‍&#x1f4bb; 写在开头 点赞 收藏 学会&#x1f923;&#x1f923;&#x1f923; 1. xmlHttpRequest是如何取消请求的&#xff1f; 实例化的XMLHttpRequest对象上也有abort方法 const xhr new XMLHttpRequest(); xhr.addEventListener(load, function(e)…

程控水冷阻性负载是否有替代品出现?

程控水冷阻性负载是广泛应用于工业生产过程中的设备&#xff0c;主要用于冷却和控制电阻性负载。然而&#xff0c;随着科技的不断发展&#xff0c;新型的冷却和控制设备不断涌现&#xff0c;使得程控水冷阻性负载面临着替代品的挑战。 空气冷却系统是一种可能的替代品&#xff…

形态学图像处理

1 工具 1.1 灰度腐蚀和膨胀 当平坦结构元b的原点是(x,y)时&#xff0c;它在(x,y)处对图像f的灰度腐蚀定义为&#xff0c;图像f与b重合区域中的最小值。结构元b在位置(x,y)处对图像f的腐蚀写为&#xff1a; 类似地&#xff0c;当b的反射的原点是(x,y)时&#xff0c;平坦结构元…