深度学习:基于MindSpore实现ResNet50中药分拣

news2024/12/24 9:02:41

ResNet基本介绍

ResNet(Residual Network)是一种深度神经网络架构,由微软研究院的Kaiming He等人在2015年提出,并且在ILSVRC 2015竞赛中取得了很好的成绩。ResNet主要解决了随着网络深度增加而出现的退化问题,即当网络变得非常深时,训练误差和验证误差可能会开始上升,这并不是因为过拟合,而是由于深层网络难以优化。

ResNet的核心思想是引入了残差学习框架来简化许多层网络的训练。通过构建“跳跃连接”或称“捷径连接”,允许一层直接与更深层相连接。这种设计让网络可以学习到一个残差函数,相对于传统的学习未参考输入x的目标映射H(x),残差块学习的是F(x) = H(x) - x。这样的结构使得信息能够跨过多层流动,从而缓解了梯度消失/爆炸的问题,也使得训练更深的网络成为可能。

残差块

残差块(Residual Block)是构成ResNet(Residual Network)的核心组件,它通过引入所谓的“跳跃连接”或“捷径连接”来解决深层网络训练中的梯度消失/爆炸问题。这种设计允许信息在多层之间直接传递,从而帮助优化非常深的神经网络。

残差块的核心思想是让网络去学习输入和输出之间的差异,也就是所谓的“残差”,而不是直接学习原始的输入到输出的映射。用数学语言来说,就是:

  • 传统方法:网络尝试学习H(x),即从输入x直接得到输出H(x)。
  • 残差方法:网络学习的是F(x) = H(x) - x,即输入x与期望输出H(x)之间的差值。然后,实际的输出是由原输入加上这个差值得到,即H(x) = F(x) + x。

为什么这样做有效?

  1. 简化学习任务:相比于学习复杂的非线性映射H(x),学习残差F(x)可能要简单得多。特别是当H(x)接近于x时(例如,对于一些层来说不需要对输入进行太多改变),此时F(x)可以趋近于0,表示这些层只需学会恒等映射即可,这比学习复杂的变换要容易很多。

  2. 缓解梯度问题:由于信息可以通过跳跃连接直接传递给后面的层,因此即使在深层网络中,早期层的信息也不会因为经过多层而完全丢失,从而减轻了梯度消失/爆炸的影响。

  3. 提高模型性能:通过允许网络更深,同时保持良好的可训练性,ResNet能够实现更高的准确率,并且在多个视觉任务上取得了当时最佳的结果。

一个基本的残差块结构如图所示:

(图源:7.6. 残差网络(ResNet) — 动手学深度学习 2.0.0 documentation)

一个典型的残差块包含以下几个部分:

  1. 卷积层:通常是一个或多个3x3的卷积层。这些层执行常规的特征提取任务。
  2. 批量归一化层(Batch Normalization, BN):每个卷积层后紧跟着一个批量归一化层,以加速收敛过程并提高模型的泛化能力。
  3. ReLU激活函数:除了最后一个卷积层之后,其余卷积层之后都会应用ReLU激活函数来引入非线性。
  4. 跳跃连接:这是残差块最独特的部分。它将输入直接加到经过上述处理后的输出上。如果输入和输出具有相同的尺寸(即相同数量的通道数和空间维度),则可以直接相加;如果它们不同,则需要使用额外的1x1卷积层对输入进行调整,以便能够正确地与输出相加。

当残差块的输入和输出特征图的通道数不同时,直接相加是不可行的。为了在这种情况下实现跳跃连接,需要对输入进行适当的变换,使得它与输出具有相同的维度。通常的做法是使用1x1卷积层来调整输入的通道数,将channels_in转换成channels_out。这个1x1卷积层不会改变输入的高度和宽度,只改变通道数。如右图所示。 

(图源:7.6. 残差网络(ResNet) — 动手学深度学习 2.0.0 documentation)

BuildingBlock和Bottleneck

Building Block和Bottleneck是两种不同类型的残差块结构,它们的主要区别在于内部的卷积层配置以及计算效率。

BuildingBlock

Building Block是最简单的残差块形式,通常用于较浅的ResNet模型中,比如ResNet-18和ResNet-34。它的基本结构包括:

  • 两个3x3卷积层。
  • 每个卷积层之后接一个批量归一化(Batch Normalization, BN)层。
  • 第一个卷积层之后有一个ReLU激活函数。
  • 跳跃连接(skip connection),将输入直接加到第二个卷积层的输出上。
  • 最后一个ReLU激活函数位于跳跃连接之后。

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili ) 

这样的结构简单且有效,但随着网络深度的增加,计算成本也会显著上升,因为每个3x3卷积层都会对大量的特征图进行操作。 

Bottleneck

Bottleneck结构被设计用于更深的ResNet模型,例如ResNet-50、ResNet-101和ResNet-152。它通过引入瓶颈(bottleneck)机制来减少计算量,同时保持或增强模型性能。Bottleneck的基本结构如下:

  • 一个1x1卷积层,用于减少通道数(降维),这被称为瓶颈。
  • 一个3x3卷积层,这个卷积层在较少的通道上执行卷积运算。
  • 另一个1x1卷积层,用于恢复原来的通道数(升维)。
  • 每个卷积层后面都跟着BN层。
  • 在第一个1x1卷积层和3x3卷积层之间有一个ReLU激活函数。
  • 跳跃连接,将原始输入直接加到最后一个1x1卷积层的输出上。
  • 最后一个ReLU激活函数位于跳跃连接之后。

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili ) 

Bottleneck使得可以在不大幅增加计算负担的情况下堆叠更多的层,从而构建更深的网络。 

各类ResNet模型架构如下所示:

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili )

数据集

本案例使用“中药炮制饮片”数据集,该数据集由程度中医药大学提供,共包含中药炮制饮片的3个品种,分为:蒲黄、山楂、王不留行,每个品种有4个炮制状态:生品、不及、适中、太过。其中每类包含500张图片共12类,图片尺寸为4K,图片格式为jpg。

数据集可视化如下:

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili )

基于MindSpore实现ResNet50中药分拣

数据下载及预处理

from download import download
import os

url = "https://mindspore-courses.obs.cn-north-4.myhuaweicloud.com/deep%20learning/AI%2BX/data/zhongyiyao.zip"
# 创建的是调试任务,url修改为数据集上传生成的url链接
if not os.path.exists("dataset"):
    download(url, "dataset", kind="zip")

因图片原尺寸为4K过大,因此需要将其Resize至指定尺寸

from PIL import Image
import numpy as np
data_dir = "dataset/zhongyiyao/zhongyiyao"
new_data_path = "dataset1/zhongyiyao"
if not os.path.exists(new_data_path):
    for path in ['train','test']:
        data_path = data_dir + "/" + path
        classes = os.listdir(data_path)
        for (i,class_name) in enumerate(classes):
            floder_path =  data_path+"/"+class_name
            print(f"正在处理{floder_path}...")
            for image_name in os.listdir(floder_path):
                try:
                    image = Image.open(floder_path + "/" + image_name)
                    image = image.resize((1000,1000))
                    target_dir = new_data_path+"/"+path+"/"+class_name
                    if not os.path.exists(target_dir):
                        os.makedirs(target_dir)
                    if not os.path.exists(target_dir+"/"+image_name):
                        image.save(target_dir+"/"+image_name)
                except:
                    pass     

将数据集划分为训练集、测试集、验证集

from sklearn.model_selection import train_test_split
import shutil

def split_data(X, y, test_size=0.2, val_size=0.2, random_state=42):
    """
    This function splits the data into training, validation, and test sets.
    """
    # Split the data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # Split the training data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_size/(1-test_size), random_state=random_state)

    return X_train, X_val, X_test, y_train, y_val, y_test


data_dir = "dataset1/zhongyiyao"
floders = os.listdir(data_dir)
target = ['train','test','valid']
if set(floders) == set(target):
    # 如果已经划分则跳过
    pass
elif 'train' in floders:
    # 如果已经划分了train,test,那么只需要从train里边划分出valid
    floders = os.listdir(data_dir)
    new_data_dir = os.path.join(data_dir,'train')
    classes = os.listdir(new_data_dir)
    if '.ipynb_checkpoints' in classes:
        classes.remove('.ipynb_checkpoints')
    imgs = []
    labels = []
    for (i,class_name) in enumerate(classes):
        new_path =  new_data_dir+"/"+class_name
        for image_name in os.listdir(new_path):
            imgs.append(image_name)
            labels.append(class_name)
    imgs_train,imgs_val,labels_train,labels_val = X_train, X_test, y_train, y_test = train_test_split(imgs, labels, test_size=0.2, random_state=42)
    print("划分训练集图片数:",len(imgs_train))
    print("划分验证集图片数:",len(imgs_val))
    target_data_dir = os.path.join(data_dir,'valid')
    if not os.path.exists(target_data_dir):
        os.mkdir(target_data_dir)
    for (img,label) in zip(imgs_val,labels_val):
        source_path = os.path.join(data_dir,'train',label)
        target_path = os.path.join(data_dir,'valid',label)
        if not os.path.exists(target_path):
            os.mkdir(target_path)
        source_img = os.path.join(source_path,img)
        target_img = os.path.join(target_path,img)
        shutil.move(source_img,target_img)
else:
    phones = os.listdir(data_dir)
    imgs = []
    labels = []
    for phone in phones:
        phone_data_dir = os.path.join(data_dir,phone)
        yaowu_list = os.listdir(phone_data_dir)
        for yaowu in yaowu_list:
            yaowu_data_dir = os.path.join(phone_data_dir,yaowu)
            chengdu_list = os.listdir(yaowu_data_dir)
            for chengdu in chengdu_list:
                chengdu_data_dir = os.path.join(yaowu_data_dir,chengdu)
                for img in os.listdir(chengdu_data_dir):
                    imgs.append(img)
                    label = ' '.join([phone,yaowu,chengdu])
                    labels.append(label)
    imgs_train, imgs_val, imgs_test, labels_train, labels_val, labels_test = split_data(imgs, labels, test_size=0.2, val_size=0.2, random_state=42)
    img_label_tuple_list = [(imgs_train,labels_train),(imgs_val,labels_val),(imgs_test,labels_test)]
    for (i,split) in enumerate(spilits):
        target_data_dir = os.path.join(data_dir,split)
        if not os.path.exists(target_data_dir):
            os.mkdir(target_data_dir)
        imgs_list,labels_list = img_label_tuple_list[i]
        for (img,label) in zip(imgs_list,labels_list):
            label_split = label.split(' ')
            source_img = os.path.join(data_dir,label_split[0],label_split[1],label_split[2],img)
            target_img_dir = os.path.join(target_data_dir,label_split[1]+"_"+label_split[2])
            if not os.path.exists(target_img_dir):
                os.mkdir(target_img_dir)
            target_img = os.path.join(target_img_dir,img)
            shutil.move(source_img,target_img)
    

定义数据加载方式

from mindspore.dataset import GeneratorDataset
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype
# 注意没有使用Mindspore提供的ImageFloder进行加载,原因是调试任务中'.ipynb_checkpoints'缓存文件夹会被当作类文件夹进行识别,导致数据集加载错误
class Iterable:
    def __init__(self,data_path):
        self._data = []
        self._label = []
        self._error_list = []
        if data_path.endswith(('JPG','jpg','png','PNG')):
            # 用作推理,所以没有label
            image = Image.open(data_path)
            self._data.append(image)
            self._label.append(0)
        else:
            classes = os.listdir(data_path)
            if '.ipynb_checkpoints' in classes:
                classes.remove('.ipynb_checkpoints')
            for (i,class_name) in enumerate(classes):
                new_path =  data_path+"/"+class_name
                for image_name in os.listdir(new_path):
                    try:
                        image = Image.open(new_path + "/" + image_name)
                        self._data.append(image)
                        self._label.append(i)
                    except:
                        pass
                

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)
    
    def get_error_list(self,):
        return self._error_list
    
def create_dataset_zhongyao(dataset_dir,usage,resize,batch_size,workers):
    data = Iterable(dataset_dir)
    data_set = GeneratorDataset(data,column_names=['image','label'])
    trans = []
    if usage == "train":
        trans += [
            vision.RandomCrop(700, (4, 4, 4, 4)),
            # 这里随机裁剪尺度可以设置
            vision.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        vision.Resize((resize,resize)),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    target_trans = transforms.TypeCast(mstype.int32)
    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    data_set = data_set.map(
        operations=target_trans,
        input_columns='label',
        num_parallel_workers=workers)

    # 批量操作
    data_set = data_set.batch(batch_size,drop_remainder=True)

    return data_set

加载数据

import mindspore as ms
import random
data_dir = "dataset1/zhongyiyao"
train_dir = data_dir+"/"+"train"
valid_dir = data_dir+"/"+"valid"
test_dir = data_dir+"/"+"test"
batch_size = 32 # 批量大小
image_size = 224 # 训练图像空间大小
workers = 4 # 并行线程个数
num_classes = 12 # 分类数量


# 设置随机种子,使得模型结果复现
seed = 42
ms.set_seed(seed)
np.random.seed(seed)
random.seed(seed)

dataset_train = create_dataset_zhongyao(dataset_dir=train_dir,
                                       usage="train",
                                       resize=image_size,
                                       batch_size=batch_size,
                                       workers=workers)
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_zhongyao(dataset_dir=valid_dir,
                                     usage="valid",
                                     resize=image_size,
                                     batch_size=batch_size,
                                     workers=workers)
dataset_test = create_dataset_zhongyao(dataset_dir=test_dir,
                                     usage="test",
                                     resize=image_size,
                                     batch_size=batch_size,
                                     workers=workers)
step_size_val = dataset_val.get_dataset_size()

print(f'训练集数据:{dataset_train.get_dataset_size()*batch_size}\n')
print(f'验证集数据:{dataset_val.get_dataset_size()*batch_size}\n')
print(f'测试集数据:{dataset_test.get_dataset_size()*batch_size}\n')

实现标签映射

#index_label的映射
index_label_dict = {}
classes = os.listdir(train_dir)
if '.ipynb_checkpoints' in classes:
    classes.remove('.ipynb_checkpoints')
for i,label in enumerate(classes):
    index_label_dict[i] = label
label2chin = {'ph_sp':'蒲黄-生品',  'ph_bj':'蒲黄-不及', 'ph_sz':'蒲黄-适中', 'ph_tg':'蒲黄-太过', 'sz_sp':'山楂-生品',
              'sz_bj':'山楂-不及', 'sz_sz':'山楂-适中', 'sz_tg':'山楂-太过', 'wblx_sp':'王不留行-生品', 'wblx_bj':'王不留行-不及',
              'wblx_sz':'王不留行-适中', 'wblx_tg':'王不留行-太过'}
index_label_dict

定义ResNet50

残差块定义

class ResidualBlock(nn.Cell):
    expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x  # shortscuts分支

        out = self.conv1(x)  # 主分支第一层:1*1卷积层
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)  # 主分支第三层:1*1卷积层
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

创建残差块的函数 

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None  # shortcuts分支


    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion
    # 堆叠残差网络
    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)

ResNet定义

from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops


class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = ops.ReduceMean(keep_dims=True)
        # self.avg_pool = nn.AvgPool2d()
        
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):
        
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x,(2,3))
        
        x = self.flatten(x)
        x = self.fc(x)

        return x

使用预训练的ResNet微调进行预测

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrained_ckpt)
        param_dict = load_checkpoint(pretrained_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_url = "https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)
network = resnet50(pretrained=True)
num_class = 12
# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
# 重置全连接层
network.fc = fc

for param in network.get_parameters():
    param.requires_grad = True

模型训练与推理

num_epochs = 50
# early stopping
patience = 5
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,
                        step_per_epoch=step_size_train, decay_epoch=num_epochs)
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = ms.Model(network, loss_fn, opt, metrics={'acc'})

# 最佳模型存储路径
best_acc = 0
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"

def train_loop(model, dataset, loss_fn, optimizer):
    # Define forward function
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss, logits

    # Get gradient function
    grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

    # Define function of one-step training
    def train_step(data, label):
        (loss, _), grads = grad_fn(data, label)
        loss = ops.depend(loss, optimizer(grads))
        return loss
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0 or batch == step_size_train - 1:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

from sklearn.metrics import classification_report

def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    y_true = []
    y_pred = []
    for data, label in dataset.create_tuple_iterator():
        y_true.extend(label.asnumpy().tolist())
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        y_pred.extend(pred.argmax(1).asnumpy().tolist())
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    print(classification_report(y_true,y_pred,target_names= list(index_label_dict.values()),digits=3))
    return correct,test_loss

开始训练 

no_improvement_count = 0
acc_list = []
loss_list = []
stop_epoch = num_epochs
for t in range(num_epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(network, dataset_train, loss_fn, opt)
    acc,loss = test_loop(network, dataset_val, loss_fn)
    acc_list.append(acc)
    loss_list.append(loss)
    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(network, best_ckpt_path)
        no_improvement_count = 0
    else:
        no_improvement_count += 1
        if no_improvement_count > patience:
            print('Early stopping triggered. Restoring best weights...')
            stop_epoch = t
            break 
print("Done!")

模型推理

import matplotlib.pyplot as plt

num_class = 12  # 
net = resnet50(num_class)
best_ckpt_path = 'BestCheckpoint/resnet50-best.ckpt'
# 加载模型参数
param_dict = ms.load_checkpoint(best_ckpt_path)
ms.load_param_into_net(net, param_dict)
model = ms.Model(net)
image_size = 224
workers = 1

def visualize_model(dataset_test):
    # 加载验证集的数据进行验证
    data = next(dataset_test.create_tuple_iterator())
    # print(data)
    images = data[0].asnumpy()
    labels = data[1].asnumpy()
    # 预测图像类别
    output = model.predict(ms.Tensor(data[0]))
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(10, 6))
    for i in range(6):
        plt.subplot(2, 3, i+1)
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}  actual:{}'.format(index_label_dict[pred[i]],index_label_dict[labels[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.4914, 0.4822, 0.4465])
        std = np.array([0.2023, 0.1994, 0.2010])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()

visualize_model(dataset_val)

更多内容可参考MindSpore官方教程:

基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2194494.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

域名劫持怎么处理?如何判断dns是否被劫持

随着网络环境的日益复杂,网站安全问题也日益凸显。域名劫持怎么处理?域名劫持是网站运营中不容忽视的安全威胁,在遇到域名劫持的时候应该学会应急响应、加强安全防护措施以及持续的安全维护,我们可以有效降低其带来的风险。 域名劫…

AOP 能够取代依赖注入吗?

AOP(面向方面编程)和依赖注入(DI)都是面向对象编程中非常重要的设计概念,它们在软件开发中扮演着不同的角色,但常常被用于解决相似的问题,如解耦、提高代码的可维护性和灵活性等。那么&#xff…

双碳平台-企业EMS -能源管理系统-能源在线监测平台

一、介绍 基于SpringCloud的能管管理系统-能源管理平台源码-能源在线监测平台-双碳平台源码-SpringCloud全家桶-能管管理系统源码 二、软件架构 二、功能介绍 三、数字大屏展示 四、数据采集原理 五、软件截图

面试问我LLM中的RAG,秒过!!!

本篇文章涉及了 RAG 流程中的数据拆分、向量化、查询重写、查询路由等等,在做 RAG 的小伙伴一定知道这些技巧的重要性。推荐仔细阅读,建议收藏,多读几遍,好好实践。 本文是对检索增强生成(Retrieval Augmented Genera…

matlab碳交易机制下考虑需求响应的综合能源系统优化运行

目录 1 主要内容 架构模型: 需求响应模型: 目标函数: 对比算例设计: 2 部分程序 3 程序结果 4 下载链接 1 主要内容 该程序复现文献《碳交易机制下考虑需求响应的综合能源系统优化运行》,解决碳交易机制下考虑…

大数据新视界 --大数据大厂之 Alluxio 数据缓存系统在大数据中的应用与配置

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

时间序列顶会一网打尽!时间序列基础模型的最新进展!

前言 最近时间序列基础模型领域,迎来了里程碑式的突破。 TimeGPT作为首个原生基础模型,于去年八月问世,一发布就震撼了预测领域。 众多其他基础模型也相继发布,包括但不限于: TimesFM MOIRAI Tiny Time Mixers&am…

Vue83 引入elementUI

笔记 安装插件 安装按需引入插件 代码 ### App.vue <template><div><button>原生的按钮</button><input type"text"><atguigu-row><atguigu-button>默认按钮</atguigu-button><atguigu-button type"pr…

Pikachu-Sql-Inject -基于boolian的盲注

基于boolean的盲注: 1、没有报错信息显示&#xff1b; 2、不管是正确的输入&#xff0c;还是错误的输入&#xff0c;都只显示两种情况&#xff0c;true or false&#xff1b; 3、在正确的输入下&#xff0c;输入and 1 1/and 1 2发现可以判断&#xff1b; 布尔盲注常用函数&…

MySQL连接查询:外连接

先看我的表结构 dept表 emp表 外连接分为 1.左外连接 2.右外连接 1.左外连接 基本语法 select 字段列表 FORM 表1 LEFT [OUTER] JOIN 表2 ON 条件;例子&#xff1a;查询emp表的所有数据&#xff0c;和对应部门的员工信息&#xff08;左外连接&#xff09; select e.*, d.n…

全网最详细大语言模型(LLM)入门学习路线图

Github项目上有一个大语言模型学习路线笔记&#xff0c;它全面涵盖了大语言模型的所需的基础知识学习&#xff0c;LLM前沿算法和架构&#xff0c;以及如何将大语言模型进行工程化实践。这份资料是初学者或有一定基础的开发/算法人员入门活深入大型语言模型学习的优秀参考。这份…

假期顺便测试了一下高德POI的准确度及对景区地图的一些感想

所使用的测试工具: GIS 移动端工具 1.山西大同乌龙峡 2.山西大同昊天寺 3.山西大同火山地质公园 4.山西大同忘忧农场 总的来说高德精度还是不错的&#xff0c;测试的几个位置都比较准确&#xff01;但景区内部的目标不是很全&#xff0c;内部小的位置完全没有标记&#xff01…

C语言 | Leetcode C语言题解之第461题汉明距离

题目&#xff1a; 题解&#xff1a; int hammingDistance(int x, int y) {int s x ^ y, ret 0;while (s) {s & s - 1;ret;}return ret; }

HDLBits中文版,标准参考答案 |2.5 More Verilog Features | 更多Verilog 要点

关注 望森FPGA 查看更多FPGA资讯 这是望森的第 7 期分享 作者 | 望森 来源 | 望森FPGA 目录 1 Conditional ternary operator | 条件三目运算符 2 Reduction operators | 归约运算器 3 Reduction: Even wider gates | 归约&#xff1a;更宽的门电路 4 Combinational fo…

时域交织ADC建模文档

时域交织ADC建模文档 Time-interleaved SAR ADC modeling 32-way 6-bit TI SAR ADC 发货内容 仅有19页PDF&#xff0c;内有MATLAB代码&#xff08;3页&#xff09; MATLAB建模&#xff1b;TI SAR ADC;

微博创作平台:编辑技巧

文章目录 I 编辑技巧II 变形工具微博个人认证升级体系(橙V、金V体系规则)广告共享计划V+粉丝订阅I 编辑技巧 图片和视频一起发的时候,要求图+视频的总数不能大于9.微博app编辑文字时,图N可自动链接图片,例如图1可自动关联第一张图片,点击文字可直接打开第一张图片 II 变形…

GPU Puzzles讲解(二)

GPU-Puzzles项目是一个很棒的学习cuda编程的项目&#xff0c;可以让你学习到GPU编程和cuda核心并行编程的概念&#xff0c;通过一个个小问题让你理解cuda的编程和调用&#xff0c;创建共享显存空间&#xff0c;实现卷积和矩阵乘法等 https://github.com/srush/GPU-Puzzleshttp…

羚羊种类检测系统源码分享

羚羊种类检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …

项目——超级马里奥——Day(3)

一、游戏开发思路&#xff1a; 1.Frame--->BackGround--->Obstacle---->BufferedImage&#xff0c;人物等 2.BackGround的构造函数&#xff1a; 只要记住窗口里边的每一个场景&#xff0c;只要游戏一开始就已经出现在屏幕里边的&#xff0c;都是在构造函数里边 3.绘…

就业市场需求分析:基于前程无忧岗位数据分析

背景介绍&#xff1a;在前程无忧网站&#xff0c;以"数据分析师""武汉"作为搜索关键词&#xff0c;爬取50页岗位数据合计980条。以该数据为基础&#xff0c;从岗位搜索匹配度、HR活跃度、不同区域/行业/企业的岗位数量和薪资分布等角度进行分析。 1、原始数…