基于Mindspore,通过Resnet50迁移学习实现猫十二分类

news2025/3/10 21:16:26

使用平台介绍

使用平台:启智AI协作平台
使用数据集:百度猫十二分类

数据集介绍

有cat_12_train和cat_12_test和train_list.txt
train_list.txt内有每张图片所对应的标签

Minspore部分操作科普

数据集加载

Mindspore加载图片数据集就直接调整成这种格式就行,然后可以用这个函数加载,自动生成两个列,一列是图片,一列是标签;ImageFolderDataset函数会自动读取和处理数据集,标签就是文件夹的名称
在这里插入图片描述
在这里插入图片描述

数据处理

map函数里可以一键进行处理和映射,定义好数据处理函数,直接把路径和标签Map处理,后面可以带上是训练集还是测试集的标签;,要处理图片就指定 input_columns 参数为image,这个是前面数据读取形成的;
本项目这里用的是ImageFolderDataset,可以自动生成图片对应数据的两个列,要是相对数据处理就设置为数据列名即可
在这里插入图片描述

数据批处理和重复

batch就是批处理,把数据分成指定数量的一个个批次
在这里插入图片描述
最后repeat对数据进行重复
在这里插入图片描述

整体数据处理流程

就是读取数据形成数据和标签对应的列(读取数据函数有很多),然后定义数据预处理函数,在map函数里一键映射,指定要处理的列一键处理,最后对数据进行批次划分,就拿到可以放进训练网络函数的规范数据集了。当然使用时候还要用create_tuple_iterator或者create_dict_iterator函数形成可迭代的数据集。

迁移学习

在迁移学习中,固定特征训练和模型微调都是常用的技术

固定特征训练

在源任务上训练一个模型,并将其应用到目标任务上。在这个过程中,模型的特征提取器是固定的,只对输出层进行调整。这种方法可以利用源任务中已经学习到的特征,从而减少目标任务的训练时间和数据需求。固定特征训练通常适用于目标任务和源任务具有相似的特征空间,并且目标任务的数据量较小的情况。

模型微调

使用预训练模型作为初始模型,并在目标任务的数据集上进行进步训练。在微调阶段,可以根据目标任务的数据和特定要求调整模型的参数使其适应目标任务。模型微调的主要目的是通过在目标任务上的有限训练来调整预训练模型,以取得更好的性能。

举例

以训练一个猫、狗分类器为例,固定特征训练是指在一个大型的猫狗图片数据集上训练一个通用的图像识别模型,然后将该模型应用于特定的猫、狗分类任务。在这个过程中,我们只需要调整模型的输出层,使其能够正确地对猫和狗进行分类。而模型微调则是指使用一个已经在大型数据集上训练好的通用图像识别模型,然后在特定的猫、狗图片数据集上进行进一步训练,以优化模型的性能。在这个过程中我们可以根据猫、狗图片的特点来调整模型的参数,使其能够更好地识别猫和狗。

数据处理过程

由于本项目采用的数据集是百度所提供的猫十二分类,要使用ImageFolderDataset的话,形式不太匹配

现有形式

cat_12_train和cat_12_test里面都是一张一张的图片
train_list.txt内有每张图片所对应的标签
在这里插入图片描述

处理后形式

train和val文件夹内分别有十二个子文件夹,代表12类猫,每个子文件夹内又有一张张的图片
在这里插入图片描述在这里插入图片描述

处理代码

这里有相关代码进行自动划分,但是对于训练集和测试集的划分,我直接采用了手动操作,也可以用代码来实现的;

# 处理异常图片
dir_lit = os.listdir('./work/cat_12_train/')
# dir_lit为一个列表,里面是一张张图片的名称
for list in dir_lit:
# list是图片名称,这里的操作是把这个图片形成一个个的路径
    img_path=os.path.join('./work/cat_12_train/',list)
    print(img_path)
    # 如果不是RGB那就转换为RGB
    img=Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(img_path)

dir_lit = os.listdir('./work/cat_12_test/')
for list in dir_lit:
    img_path=os.path.join('./work/cat_12_test/',list)
    img=Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(img_path)
# 整理数据格式
# 创建12个文件夹分别对应标签
path='./work/MyDataset/'
for i in range(12):
    if not os.path.exists(path+str(i)):
        os.mkdir(path+str(i))
    else:
        continue
#读取每一行
with open(f'./work/train_list.txt','r')as f:
    img_path=f.readlines()
    print(img_path)
    # 里面是一个个的'cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg\t0\n'

# 把对应文件放到对应标签文件夹下
for img in img_path:
    # 拿取每一张图片路径
    # print(img)
    img_src= img.split('\t')[0]
    # img_src为一个个图片路径
    rel_src= img_src.split('cat_12_train/')[1]
    # rel_src为图片名称
    img_label = img.split('\t')[1]
    img_label = img_label.split('\n')[0]
    # img_label为图片标签
    print(img_src)
    print(rel_src)
    print(img_label)
    # os.system(f'cp ./work/{img_src} ./work/MyDataset/{img_label}/{rel_src}')
    shutil.copy(f'./work/{img_src}',f'./work/MyDataset/{img_label}/{rel_src}')
print('图片处理完毕')

整体代码

# 解压上传的数据集压缩包并查看数据集结构
!unzip MyDataset.zip -d data/
import os
print(os.listdir("data"))
print(os.listdir("data/train"))
print(os.listdir('data/train/1'))

输出

['val', 'train']
['9', '10', '5', '8', '11', '2', '7', '3', '0', '1', '6', '4']
['DKkQylbgdrWRjYap63MCJe0UBLhcHXPm.jpg', 'k9HWNaG2Z1wUKAOYdSDu7vRr4xBqmTCV.jpg', 'LfIoOrSNvKHQzsGtm5eMZc0lRuBXhTP6.jpg', '0esFjXNqc5xbMmUaJkRVPwQorWlu3LvA.jpg', 'PMoFIabq0W9U2wETZr7yf4JLYdBxv6hQ.jpg', 'QlUX4zHfPZ3LxRDswqm5FeMbnTWNaj6g.jpg', '7E9oOUcQjkLMvpAtNymHCRSqFfdGVDK4.jpg', 'pznq7EivBH9LwrNysIWxgGeomTlOP8cZ.jpg', '3Ndv9X6uTgzFtnoA01VECIBPj7xqlewG.jpg', 'o1g6adKmS4lBDw2F5buAYnetUWh7xXGz.jpg', '7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg', 'qgimvDE8Zaf4PJ32dkNhwVy5nxATOtrX.jpg', 'RJnThakSOGUzeFBdigXAm2NsL8jyYvu3.jpg', 'Ig61xq3ME78fdCRTDWhaKkcyuOQj24PG.jpg', 'tAdqSefI0DohNuU6wgVyPca7Qz5lYTOH.jpg', 'pTe31kYFqwyOGmV50sbhgoLQ9KcjJaxd.jpg', 'HUuwb4gRqoPWD3Lvrsa9hVcQ7FSfOT8t.jpg', 'cqkJDEpWiwS69UxFtKMPRgb4mXhj1LAs.jpg', '2pjBVbqF30cUTvRIYtsCGfgwPKOJz4ua.jpg', 'e7f1iucltpVQTXroFR96xawm2BDZYnNG.jpg', 'tWcMpXTe78zo2ikhbUOqPud6VJ5RfSEw.jpg', 'gcTxA5NwLztvWr7YPCMnDjdFyfqoa2uK.jpg', 'qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg', 'S4hfUR5kOj7CXr826Gxa9t3bEBPioJq1.jpg', 'A8PMtHzoyFE0WgjpZ2qUYbduL4T9arxN.jpg', '8E0bSi3h1aVy2cNWpgOsKvxZCQtzqkLU.jpg', 'oDc9XxipzfBjUAEl0hOmyd4PNr5v8IsG.jpg', 'AVGJoCPsX3L6I5Y2M94kEO7vNHmt1Ure.jpg', 'SqyF8c0Rak1NedXpYvlI9TsVwzOhtGZJ.jpg', '04Iv3QNKtu2DAfRTgs9XZwBMb1Cm7l6P.jpg', '94NwrzYLo8iMtagTR3SfPGHmWvZXbyUx.jpg', 'oiU4YjnhNpI3JWagx8SuTCktA6qZXGRH.jpg', 'e416wAERYOQ7NutUJDcIVFk2oPWpC3q8.jpg', '9I3enpUrZ5xD2TvRAOFt4S7lBfVMdqsJ.jpg', 'oLNGFnUPmQhxOkdbv37HwSj8uql4z1sp.jpg', 'nHfDoId8SXKzMt1weR7bJlaWPcNx32yv.jpg', 'dzV1Psxncp6H8g4KWhX3mbrTfqwuLaNv.jpg', 'OW6e1GbpNsfmxFvLQKMnIByX4hDcS2io.jpg', 'MOGw0PDqjmnLdViez26b7WY3hU85vatH.jpg', '3R5BWakTdG2hKjJoiNxg0pr61LsDqSuM.jpg', '41OaVziAEuenpKqv7LCYMPsGH6BkQotD.jpg', 'fRbdkW0GDAhBjpTVeonPItycEia45Ns9.jpg', 'RsYG3VJi7NTXptoPQvWKhcFaDqIe4EkC.jpg', 'jWPXtA07yYrcRxNBUkwC9dpuSs2M463e.jpg', 'bqRVATEuI4x3kJSO9DitWjYms8KoG2Lz.jpg', 'OCXPGzodQsZHRnfMFaBkqW9hKYxA4glr.jpg', 'hAPzcCeE04sSadDIFB627ipyOWgjX8mq.jpg', 'PIbktpOd2DqHwVceLzUyE7CmohMjNA91.jpg', 'adAjP14SXL6vVJ95TRrMIYDiHl7BUqbF.jpg', 'hTxYnXrQ35vwKL1NEMSIot02djHy8i6D.jpg', 'kwuLVmg7n9I4iEOzMQC1NxJfvX6Bhoqs.jpg', '78WTn5auMmQshIZi2qDAYdR1oKcwEzfk.jpg', 'hOEPm8o26CBptkD5yT3fsgbM4dRuVZzw.jpg', 'PIpXbRiu68dm2s1DfxHJGAYLegOUMzoB.jpg', '8NxvitwMaCpsuEQT17nDXzFR5gAZ4rfc.jpg', 'JgxcdpvW7f6lKMShPjHFeZ2RDX13UCiN.jpg', 'BhHLRN0QTWOwl3UEG9J17XScni2P5gVe.jpg', '7Gw2o8LJTF4ecZI6nl1WuDrsAOSfQPaq.jpg', 't6xZhQkD2jWCOi1r3fK7T9slGHVbwgNd.jpg', 'gbfyYtlWaAO4iUCPK2cFVkoQX9MmwJTI.jpg', '9PQic6o4VyZ13pLAYu2avFWSbJRz57fC.jpg', 'Cf2Z3j6hYliVOduEvK5NJp4yba0wSGcA.jpg', '2QvYgMIzELXH4Fy8GNDBaPS3W0tVZ5xq.jpg', 'oFXrWl80gRMenqPbG2uZv5wk4KmHjaNd.jpg', 'l1NsjeJKdvFimRgh6IEZuqfxCw5An8o4.jpg', 'tvWgSwN9m3BZ5qOXjK7LexVCIn6F4AHk.jpg', 'ahcTZUbOmJsloVt8vGMjwPqIXd0x9iy1.jpg', 'hp0nNWXar81lB3eSYE9kGcdDL4tJfuwC.jpg', 'HNehmorRIS6M2iDj48gZL9OKva7Xck1n.jpg', '0GX4YKdcwBi15lTpR7ExWO2ZagseoVNI.jpg', 'b57SiGKYPaE1DrfxJtVeQdlOAUojLZhR.jpg', 'f4gLhHjyKdxTumna7F5pGWPqVIRY01et.jpg', 'AHxQ1GFgRLs802diT3VIlwOoqkW9Sar4.jpg', '09i4DcyrWktZb1naHFEpL5elhG3CvYxu.jpg', 'AieqOGKQC7fgDayW9kLuJ5mUHx4XP2I6.jpg', 'MwZIekE7oxPtRpTHQVf4l6qA2iC1zWLh.jpg', '6M8xAZBdQLkuDcF14HEz53J0IboiPfUa.jpg', 'eD4gfaQTFdoWUCnhPj6YmIZBl5AMxNik.jpg', 'PwBJK7rZHDhq46ynYoj9Saxip0IldMV2.jpg', 'neHaTbwPkdVmoOA3JyWxR7Lh92NDpzEf.jpg', 'E38k6xhQFYKALn4tDlwiPBfpdCygeSNs.jpg', 'N1VpzqjoRmPZCQ5KasAv72TwtMDFrG0d.jpg', 'RvXKbfDuF4W6exgVcInE3SktJa9LzBj8.jpg', 'aMoxSymjdiUwbJ6k5NzGR09uILQ4sEc3.jpg', 'l1WfIcvOZk9jAn2xQwtSCEgY5XhRyoFB.jpg', '0IWfLUGk53iHt9SElNzKsBCDwuMjpPbR.jpg', '80vskwDtCRAz9iWYjnrhIGfeXdUZx16b.jpg', '0WglfKCD5Gu2LqI1msTSZa6orBO7XAz9.jpg', 'os5kaDubPM1hY7f6gRrSOZqNQFEAU89v.jpg', 'm2azqs0NGPDdjR8rUTxWF4covLE5ikQZ.jpg', 'siKAzUrV9eykjlCQ0odZEhnW7FIgTuLm.jpg', 'jZWaf0ne2R1pDo9hTBkCbA8YOq3LlQ5x.jpg', 'Kny8zFiIt4vxNSO6g7Lu9kGfdVJoqPC3.jpg', 'LTMkHx9w2nfsRiZec3bEVtmujpv7qS1y.jpg', 'oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg', 'oJ4HWQkZDvta8rUyinRu9fVNs3BX1Kj7.jpg', 'mUp6082yMQghXY13OtvxabTrNSEeiu4B.jpg', 'RZWpn9jGxcKSUb3Y56fVMQHlJhEIeNiA.jpg', 'L86JQlekn9Ko01TbXHYMFImdZ2upxg5h.jpg', 'I0YcgXB97QL6MtHlU3p1znqWdCGOD4mo.jpg', 'IG9NFCfMybKYiQhquOd5H3DjwlSakW6U.jpg', 'SZYosxl4cHRWyT3h5JFqNjGdnIag6907.jpg', '4dMVtGvRJbrjK6X17STZ9Lx3kgeEioqp.jpg', 'q9YrDFK73Jfv15SHpTWelGAIwnBxt0iE.jpg', 'spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg', 'p1ji352o6vhd8l0Q4uNVRZrIgkSLnfBq.jpg', '7WQ0ByMPtJAdZ8h5OkveLi3ScuU6bIY1.jpg', 'hajCi0GDlVP2ONg6FeSWrvubQ34ozwkx.jpg', 'RKLDkUwmFg0Oj5tPeIs31y7J8AQZ9dni.jpg', 'hWOAp1EV6nJzYxaHt03T8GPNe7ujUiF9.jpg', 'B0a2VHnwQv1byMDTlEiOJXxI7Scs9Zjf.jpg', 'gMzOoyTrGniBj1vxN0AeD9VQsFHU7aKR.jpg', 'CiBq0GVawv1rdYyLDjcWoIXP6SKbzH8F.jpg', 'G71cYNEBD6shJLkgVzwb52m8oRluKUHS.jpg', '7IdLnFCb3a25cKNV6tXuYi1fe0hJQMOU.jpg', 'jsThJuVYQxUKSz3btXdA5q8M1O9Cioaw.jpg', '2OpyK1cm85obujwEMqGWNv9V7PnQfJ3U.jpg', 'DLIZr5TjPepd8csioJXMbYHk94RmKx6v.jpg', 'Q80DEFkGlxJj2qR37t4ZKpY6zMdvuIyr.jpg', 'EFxXsVJ0qHkomcBhnLfY96W5U4yOliQG.jpg', 'okw9N05dAnsxgW2IuQy7eGhz1iLOqVrJ.jpg', 'obRL95fxtP1uCNBwQiTjWsdqUvgp2Z43.jpg', '6jTZ5sfCpGwJWIK3DaYQvixLbNt48nHr.jpg', 'gLqBoG3ah0AHXIYWS7dFTt6pxDw8snQv.jpg', '1lzs3kM8NiILvcgDYtn6fdCoSeXauJ5P.jpg', 'FEuyDnKSIJ1a6UtY5LB2rGpRqOm3xP9Z.jpg', 'OKvn2uJmWQi4R9Cs8B7fxbkZtoczq31Y.jpg', 'ByYKkZHb8omRPcvfe5GzXsxOQ3DlLuUq.jpg', 'SKoiaj8C3UGyvJQXh5zWwrxNmYkdEHqn.jpg', '2cKUvXCjm5HNWksY1b4ioIgdSFqyMtEJ.jpg', 'kJK9OA3hXpMWeUY7cifvrz0BItn1VS2T.jpg', 't1DnLxSZXwWTgeJsyE02lrjHfdM35po8.jpg', '5C76eISyb3vmPZuMYcARHU8aFQrBWf1k.jpg', 'puBcg8Fh6tXs27doz1aAIl4L0iVYC3wE.jpg', 'MV5C7YmuzG1LyZplFXvqOQkW4JStjcNP.jpg', 'ruleKNQvzwqmy5sn9MDd7I2RUJjVCWh8.jpg', 'l9Z3gPwjC5HbhINcfVO8dnz1qAxBrJkU.jpg', 'H9BcFOo8UI3jX2CyW0mzxn7agJNAsZQS.jpg', 'jHUJE37YZOGAXInPmyCSp9f0o4uvRe5W.jpg', 'I8jNkAVgZ1yqDw5K9b0Wm4rETfiGBcUF.jpg', 'kKzQrE6GjfpeFhsXx2Ddu9YaTHc3PUbB.jpg', 'jla5O2TkVhefr07XDLpMEonuG6yJWgYd.jpg', 'Km9BZsaSUoxQ4VArcXYyHThIDRbq2t7l.jpg', 'fBp0Yor4EQtWkM7I3TsnNHLXuvCFacjS.jpg']

超参数设置

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 10                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

数据预处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 数据集目录路径
data_path_train = "data/train/"
data_path_val = "data/val/"

# 创建训练数据集

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)
    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32
    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]

    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)
    # 批量操作
    data_set = data_set.batch(batch_size)
    return data_set


dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
print(step_size_train)
print(step_size_val)
data = next(dataset_val.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)

输出

96
24
Tensor of image (18, 3, 224, 224)
Labels: [ 1  2  4  3  0 10  5  4 11  9 11  6  7  1 11  5  1  3]

数据集可视化查看

import matplotlib.pyplot as plt
import numpy as np

# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}

plt.figure(figsize=(5, 5))
for i in range(4):
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    plt.title(class_name[int(labels[i].asnumpy())])
    plt.axis("off")

plt.show()

在这里插入图片描述

网络结构搭建

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal


weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
class ResidualBlockBase(nn.Cell):
    expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlockBase, self).__init__()
        if not norm:
            self.norm = nn.BatchNorm2d(out_channel)
        else:
            self.norm = norm

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)
        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):
        """ResidualBlockBase construct."""
        identity = x  # shortcuts分支

        out = self.conv1(x)  # 主分支第一层:3*3卷积层
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out
class ResidualBlock(nn.Cell):
    expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x  # shortscuts分支

        out = self.conv1(x)  # 主分支第一层:1*1卷积层
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)  # 主分支第三层:1*1卷积层
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None  # shortcuts分支


    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion
    # 堆叠残差网络
    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_net


class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义,
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = nn.AvgPool2d()
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):

        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        # download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

形式一:模型微调

模型训练
from mindspore import nn, train
from mindspore.nn import Loss, Accuracy
!pip install download
import mindspore as ms
from download import download
network = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = network.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 12)
# 重置全连接层
network.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
network.avg_pool = avg_pool

import mindspore as ms
import mindspore

# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=momentum)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 实例化模型
model = train.Model(network, loss_fn, opt, metrics={"Accuracy": Accuracy()})

def forward_fn(inputs, targets):

    logits = network(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = mindspore.ops.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):

    loss, grads = grad_fn(inputs, targets)
    opt(grads)

    return loss
    
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
# 最佳模型保存路径
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"
import os
import time

# 开始循环训练
print("Start Training Loop ...")

best_acc = 0

for epoch in range(num_epochs):
    losses = []
    network.set_train()

    epoch_start = time.time()

    # 为每轮训练读入数据
    for i, (images, labels) in enumerate(data_loader_train):
        labels = labels.astype(ms.int32)
        loss = train_step(images, labels)
        losses.append(loss)

    # 每个epoch结束后,验证准确率

    acc = model.eval(dataset_val)['Accuracy']

    epoch_end = time.time()
    epoch_seconds = (epoch_end - epoch_start) * 1000
    step_seconds = epoch_seconds/step_size_train

    print("-" * 20)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, sum(losses)/len(losses), acc
    ))
    print("epoch time: %5.3f ms, per step time: %5.3f ms" % (
        epoch_seconds, step_seconds
    ))

    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(network, best_ckpt_path)

print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)

输出

Start Training Loop ...
--------------------
Epoch: [  1/ 10], Average Train Loss: [1.774], Accuracy: [0.838]
epoch time: 60892.337 ms, per step time: 634.295 ms
--------------------
Epoch: [  2/ 10], Average Train Loss: [0.762], Accuracy: [0.905]
epoch time: 8745.406 ms, per step time: 91.098 ms
--------------------
Epoch: [  3/ 10], Average Train Loss: [0.568], Accuracy: [0.921]
epoch time: 8449.129 ms, per step time: 88.012 ms
--------------------
Epoch: [  4/ 10], Average Train Loss: [0.508], Accuracy: [0.910]
epoch time: 8199.763 ms, per step time: 85.414 ms
--------------------
Epoch: [  5/ 10], Average Train Loss: [0.459], Accuracy: [0.900]
epoch time: 7856.060 ms, per step time: 81.834 ms
--------------------
Epoch: [  6/ 10], Average Train Loss: [0.405], Accuracy: [0.931]
epoch time: 8138.927 ms, per step time: 84.780 ms
--------------------
Epoch: [  7/ 10], Average Train Loss: [0.368], Accuracy: [0.919]
epoch time: 8333.523 ms, per step time: 86.808 ms
--------------------
Epoch: [  8/ 10], Average Train Loss: [0.354], Accuracy: [0.912]
epoch time: 8271.008 ms, per step time: 86.156 ms
--------------------
Epoch: [  9/ 10], Average Train Loss: [0.338], Accuracy: [0.928]
epoch time: 8457.969 ms, per step time: 88.104 ms
--------------------
Epoch: [ 10/ 10], Average Train Loss: [0.338], Accuracy: [0.907]
epoch time: 8183.743 ms, per step time: 85.247 ms
================================================================================
End of validation the best Accuracy is:  0.931, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt
模型评估
import matplotlib.pyplot as plt
import mindspore as ms

def visualize_model(best_ckpt_path, val_ds):
    net = resnet50()
    # 全连接层输入层的大小
    in_channels = net.fc.in_channels
    # 输出通道数大小为分类数12
    head = nn.Dense(in_channels, 12)
    # 重置全连接层
    net.fc = head
    # 平均池化层kernel size为7
    avg_pool = nn.AvgPool2d(kernel_size=7)
    # 重置平均池化层
    net.avg_pool = avg_pool
    # 加载模型参数
    param_dict = ms.load_checkpoint(best_ckpt_path)
    ms.load_param_into_net(net, param_dict)
    model = train.Model(net)
    #print(net)
    # 加载验证集的数据进行验证
    data = next(val_ds.create_dict_iterator())
    images = data["image"].asnumpy()
    print(type(images))
    print(images.shape)
    #print(images)
    labels = data["label"].asnumpy()
    #print(labels)
    class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}
    # 预测图像类别
    data_pre=ms.Tensor(data["image"])
    print(data_pre.shape)
    print(type(data_pre))
    output = model.predict(data_pre)
    
    #print(output)
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(5, 5))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()
visualize_model('BestCheckpoint/resnet50-best.ckpt', dataset_val)

输出
在这里插入图片描述

形式二:固定特征训练

模型训练
net_work = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为分类数12
head = nn.Dense(in_channels, 12)
# 重置全连接层
net_work.fc = head
# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool
# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False
# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)
    return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters)
def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss
# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": Accuracy()})
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
num_epochs = 10
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
# 开始循环训练
print("Start Training Loop ...")
best_acc = 0
for epoch in range(num_epochs):
    losses = []
    net_work.set_train()
    epoch_start = time.time()
    # 为每轮训练读入数据
    for i, (images, labels) in enumerate(data_loader_train):
        labels = labels.astype(ms.int32)
        loss = train_step(images, labels)
        losses.append(loss)
    # 每个epoch结束后,验证准确率
    acc = model1.eval(dataset_val)['Accuracy']
    epoch_end = time.time()
    epoch_seconds = (epoch_end - epoch_start) * 1000
    step_seconds = epoch_seconds/step_size_train
    print("-" * 20)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, sum(losses)/len(losses), acc
    ))
    print("epoch time: %5.3f ms, per step time: %5.3f ms" % (
        epoch_seconds, step_seconds
    ))
    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(net_work, best_ckpt_path)
print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)
模型评估
visualize_model(best_ckpt_path, dataset_val)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1495618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

6020一拖二快充线:手机充电的革命性创新

在快节奏的现代生活中,手机已不仅仅是一个通讯工具,更是我们工作、学习和娱乐的得力助手。然而,手机的电量问题一直是困扰着我们的难题。为了解决这个问题,市场上出现了一种名为“一拖二快充线”的充电设备,它不仅具备…

事件对象与环境对象

一、事件对象 事件对象里有事件触发时的相关信息&#xff0c;包括属性和方法。注册时间中&#xff0c;回调函数的第一个参数就是事件对象&#xff0c;一般命名为event、ev、e。其中e.target是主要对象&#xff0c;target是事件源对象&#xff0c;即当前触发元素。 <button …

Midjourney入门:AI绘画真的能替代人类的丹青妙笔吗?

名人说&#xff1a;一花独放不是春&#xff0c;百花齐放花满园。——《增广贤文》 作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 目录 一、简要介绍1、Midjourney2、使用方法 二、绘画1、动物类2、风景类3、动漫类4、艺…

Unity:Animation 三 Playable、ImportModel

目录​​​​​​​ 1. Playables API 1.1 Playable vs Animation 1.2 Advantages of using the Playables API 1.3 PlayableGraph Visualizer 2. Creating models outside of Unity 2.1 Preparing your model files for export 2.1.1 Scaling factors 2.1.2 优化模型文…

腾讯云2核2G3M轻量应用服务器优惠价格61元一年

腾讯云2核2G3M轻量应用服务器优惠价格61元一年&#xff0c;200GB月流量、40GB SSD云硬盘。 一张表看懂腾讯云服务器租用优惠价格表&#xff0c;一目了然&#xff0c;腾讯云服务器分为轻量应用服务器和云服务器CVM&#xff0c;CPU内存配置从2核2G、2核4G、4核8G、8核16G、4核16…

网络编程:数据库实现增删改

1.数据库实现增删改 程序代码&#xff1a; 1 #include<myhead.h>2 //定义添加数据函数3 int do_add(sqlite3*ppDb)4 {5 //准备sql语句6 int add_numb;//工号7 char add_name[20];//姓名8 double add_salary;9 printf("请输入要添加的工号:&quo…

网康科技 NS-ASG 应用安全网关 SQL注入漏洞复现(CVE-2024-2022)

0x01 产品简介 网康科技的NS-ASG应用安全网关是一款软硬件一体化的产品,集成了SSL和IPSec,旨在保障业务访问的安全性,适配所有移动终端,提供多种链路均衡和选择技术,支持多种认证方式灵活组合,以及内置短信认证、LDAP令牌、USB KEY等多达13种认证方式。 0x02 漏洞概述 …

ArmSoM规划开发基于RK3576的开发套件

ArmSoM正计划推出一款新的产品&#xff0c;这款产品将采用强大的RK3576芯片。 本文将为您介绍我们的新产品搭载的RK3576性能参数&#xff0c;以及它如何为您提供卓越的性能和功能。 RK3576处理器 RK3576处理器是一款强大的处理器&#xff0c;具备出色的性能和多样化的功能&a…

机器学习的边界与实际应用

目录 前言1 机器学习的广泛适用性1.1. 利用输入输出映射1.2. 大量的可用数据 2 机器学习能做的事情举例2.1 自动驾驶2.2 用户请求处理2.3 有大量数据的医学影像诊断 3 机器学习不能做的事情举例3.1 市场分析报告3.2 感同身受的邮件回复3.3 手势意图判断3.4 少量数据的医学影像诊…

NotePad++下载

notepad官网地址是https://notepad-plus-plus.org/。提供了许多强大的功能和工具&#xff0c;使其成为许多程序员和开发人员的首选工具。 Notepad 是一款广受欢迎的开源文本编辑器&#xff0c;它提供了许多强大的功能和工具&#xff0c;使其成为许多程序员和开发人员的首选工具…

Hadoop配置日志的聚集——jobhistory不显示任务问题

问题&#xff1a; 一开始job history是正常的&#xff0c;配置了日志的聚集以后不管做什么任务都不显示任务&#xff0c;hdfs是正常运行&#xff0c;而且根据配置步骤都重启过了。 下面先po出日志聚集的操作步骤&#xff0c;再讲问题 1.配置yarn-site.xml cd $HADOOP_HOME/e…

URL?后参数有特殊字符问题

前端对于URL的参数不做处理 不处理、用URLDecoder.decode()处理、用URLEncoder.encode()处理、用URLEncoder.encode()处理后再用URLDecoder.decode()处理 结果 前端对于URL的参数用encodeURIComponent(‘XF-OPPZZD-26*316’)处理 结果 前端不处理有&字符时 结果会把后…

阿里云轻量应用服务器2核4G5M优惠价格165元一年,252元15个月、756元3年

阿里云轻量应用服务器2核4G5M优惠价格165元一年&#xff0c;252元15个月、756元3年&#xff0c;500GB月流量&#xff0c;60GB SSD云硬盘。 一张表看懂腾讯云服务器租用优惠价格表&#xff0c;一目了然&#xff0c;腾讯云服务器分为轻量应用服务器和云服务器CVM&#xff0c;CPU…

借助 Terraform 功能协调部署 CI/CD 流水线-Part 1

在当今快节奏的开发环境中&#xff0c;实现无缝、稳健的 CI/CD 流水线对于交付高质量软件至关重要。在本文中&#xff0c;我们将向您介绍使用 Bitbucket Pipeline、ArgoCD GitOps 和 AWS EKS 设置部署的步骤&#xff0c;所有步骤都将利用 Terraform 的强大功能进行编排。在Part…

互联网智慧工地源码,“互联网+建筑大数据”SaaS微服务架构,支持PC端、手机端、数据大屏端

智慧工地源码&#xff0c;支持多端展示&#xff08;PC端、手机端、平板端&#xff09;SaaS微服务架构&#xff0c;项目监管端&#xff0c;工地管理端源码 智能时代的风暴已经融入了我们生活的每个方面&#xff0c;智能手机、iPad等移动终端智能设备已经成为我们生活的必需品。智…

[C语言]——分支和循环(3)

目录 一.while循环 1.if 和 while的对比 2.while语句的执行流程&#xff1a; 3.while循环的实践 4.练习 二.for循环 1.语法形式 2.for循环的执行流程 3.for循环的实践 4.while循环和for循环的对比 5.练习 三.do-while循环 1.语法形式 2.do while循环的执行流程 3…

ES单节点部署

ES 拉取镜像 docker pull elasticsearch:7.10.1启动容器 docker run -d -p 9200:9200 -p 9300:9300 -e "discovery.typesingle-node" -e "ES_JAVA_OPTS-Xms1g -Xmx1g" -v /es_data:/usr/share/elasticsearch/data --name es 558380375f1a注&#xff1a…

覆盖element-ui的el-menu样式记录:背景图片、菜单图标、菜单高亮与鼠标悬浮高亮、调整子菜单等样式

页面中修改el-menu 设置background-color"transparent"&#xff0c;menu菜单下的背景图片则能正常显示了 <el-menuclass"el-menu-demo"mode"horizontal"background-color"transparent"><el-menu-item index"1">…

Java基于微信小程序的校园失物招领小程序

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

app逆向-ratel框架-AES,DES,MD5,SHA1加密算法java hook程序

一、前言 AES&#xff08;高级加密标准&#xff09;、DES&#xff08;数据加密标准&#xff09;、MD5&#xff08;消息摘要算法5&#xff09;和SHA-1&#xff08;安全哈希算法1&#xff09;都是常见的加密算法&#xff0c;用于数据加密和哈希计算。 二、加密算法实现 1、创建…