mindspore实现自定义CNN图像分类模型

news2024/12/22 9:26:34

一、数据集定义

         使用mindspore.dataset中的ImageFolderDataset接口加载图像分类数据集,ImageFolderDataset接口传入数据集文件上层目录,每个子目录分别放入不同类别的图像。使用python定义一个create_dataset函数用于创建数据集,在函数中使用mindspore.dataset.vision接口中的Decode、Resize、Normalize、HWC2CHW对图像进行解码、调整尺寸、归一化和通道变换预处理,其中Resize根据模型需要的图像大小进行设置,归一化操作可以通过设置mean和std约束范围。如将mean设置为[127.5,127.5,127.5],std设置为[255,255,255],可以将数据归一化到[0.5~0.5]范围内。

数据集加载:

import mindspore
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import context, save_checkpoint, ops, Tensor
import mindspore.dataset as ds
import mindspore.dataset.vision as CV
import mindspore.dataset.transforms as C
from mindspore import dtype as mstype


def create_dataset(data_path, batch_size=24, repeat_num=1):
    """定义数据集"""
    data_set = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=True)
    image_size = [100, 100]
    mean = [127.5, 127.5, 127.5]
    std = [255., 255., 255.]
    trans = [
        CV.Decode(),
        CV.Resize(image_size),
        CV.Normalize(mean=mean, std=std),
        CV.HWC2CHW()
    ]
    # 实现数据的map映射、批量处理和数据重复的操作
    type_cast_op = C.TypeCast(mstype.int32)
    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
    data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    data_set = data_set.batch(batch_size, drop_remainder=True)
    data_set = data_set.repeat(repeat_num)
    return data_set

二、定义网络结构

         定义神经网络需要使用mindspore.nn模块,使用python创建一个cnn_net类并继承nn.Cell,在init中初始化模型需要用到的各种算子,该卷积神经网络需要用到的算子分别为卷积层nn.Conv2d、激活函数nn.Relu、池化层nn.Maxpool2d、打平操作nn.Flatten、全连接层nn.Dense。这里用的自定义卷积神经网络由4层卷积+2层全连接组成,每个卷积层后接一个激活函数和最大池化层,每个池化层通过设置步长为2对特征图进行尺寸减半,因此在经过四层卷积后特征图变为输入的1/16,也就是6*6。在卷积层后接一个打平操作,将特征图从二维转换为一维,特征图打平以才能后进入全连接层,最后一层全连接层输出通道数与分类类别数一致。模型中每层输入输出通道定义如下:

卷积层1:输入通道3,输出通道8,卷积核3*3

卷积层2:输入通道8,输出通道16,卷积核3*3

卷积层3:输入通道16,输出通道32,卷积核3*3

卷积层4:输入通道32,输出通道64,卷积核3*3

全连接层1:输入288,输出128

全连接层2:输入128,输出分类数

网络实现:

class cnn_net(nn.Cell):
    """
    网络结构
    """
    def __init__(self, num_class=10, num_channel=3):
        super(cnn_net, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(in_channels=num_channel, out_channels=8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, 3)
        self.conv3 = nn.Conv2d(16, 32, 3)
        self.conv4 = nn.Conv2d(32, 64, 3)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(2304, 128, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(128, num_class, weight_init=Normal(0.02))

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

三、定义损失函数计算流程

         由于网络中没有带损失函数,需要单独定义一个类NetWithLoss用于计算损失,在计算损失前,需要将标签进行onehot编码,如分类标签为[0,1,2],那么标签1经过onehot转换为[0,1,0],之后将数据送入模型进行前向计算,得到logits,使用交叉熵损失函数对logits和label计算损失。

class NetWithLoss(nn.Cell):
    def __init__(self, backbone, loss_func, classes):
        super(NetWithLoss, self).__init__()
        self.backbone = backbone
        self.loss_func = loss_func
        self.classes = classes

    def construct(self, inputs, labels):
        labels = ops.one_hot(labels, self.classes,
                        Tensor(1, dtype=mindspore.float32),
                        Tensor(0, dtype=mindspore.float32))
        logits = self.backbone(inputs)
        loss = self.loss_func(logits, labels)
        return ops.mean(loss, axis=0)

四、定义训练流程

         定义一个train函数进行训练,在训练函数中首先定义迭代次数,学习率,批大小,分类数量、输入通道、训练集、验证集、模型、损失函数、优化器等,这里使用for循环进行训练迭代,在数据集迭代过程中使用nn.TrainOneStepCell进行模型训练。在每一轮训练结束后对模型进行验证,计算模型推理准确率。

         在开启训练之前可以通过设置运行环境来觉得模型在什么设备上运行。mindspore支持CPU、GPU、以及Ascend(昇腾训练加速卡),当然,不同设备需要安装对应版本的mindspore。

def train():
    # 数据路径
    epochs = 10
    lr = 0.001
    batch_size = 32
    num_classes = 2
    input_channel = 3
    ckpt_file = 'best.ckpt'
    train_data_path = "./datasets/dogs/train"
    eval_data_path = "./datasets/dogs/val"
    train_ds = create_dataset(train_data_path, batch_size)
    eval_ds = create_dataset(eval_data_path, 1)
    eval_ds_size = eval_ds.get_dataset_size()
    net = cnn_net(num_classes, input_channel)
    opt = nn.Adam(params=net.trainable_params(), learning_rate=lr)
    loss_func = nn.SoftmaxCrossEntropyWithLogits()
    loss_net = NetWithLoss(net, loss_func, num_classes)
    train_net = nn.TrainOneStepCell(loss_net, opt)
    train_net.set_train()
    argmax = ops.Argmax(axis=0)
    best_acc = 0
    best_epoch = 0
    for epoch in range(epochs):
        train_loss = 0
        # 训练
        for data in train_ds.create_tuple_iterator():
            images = data[0]
            lables = data[1]
            loss = train_net(images, lables)
            train_loss += loss

        # 评估
        total = 0
        for data in eval_ds.create_tuple_iterator():
            images = data[0]
            lables = data[1].squeeze()
            logits = net(images)
            pred = argmax(logits.squeeze())
            if pred == lables:
                total += 1

        acc = total / eval_ds_size
        # 保存ckpt
        if acc > best_acc:
            best_acc = acc
            best_epoch = epoch + 1
            save_checkpoint(net, ckpt_file)
        ckpt_file = f'epoch{epoch+1}.ckpt'
        save_checkpoint(net, ckpt_file)
        print(f'epoch:{epoch+1}, loss:{train_loss}, acc:{acc}')
    print(f'train success, best epoch is {best_epoch}, best acc is {best_acc}')


if __name__ == '__main__':
    train()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[C++基础]-初识模板

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 目录 一、泛型编…

【正点原子STM32连载】第五十五章 T9拼音输入法实验 摘自【正点原子】MiniPro STM32H750 开发指南_V1.1

1)实验平台:正点原子MiniPro H750开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id677017430560 3)全套实验源码手册视频下载地址:http://www.openedv.com/thread-336836-1-1.html 4&#xff…

deepwalknode2vec 代码实战

提示:笔记内容来自于B站up主同济子豪兄 文章目录1. Embedding嵌入的艺术2. deepwalk2.1. 什么是图嵌入?2.2. deepwalk的步骤1、生成graph;2、利用random walk生成多个路径;3、训练表示向量的学习;4、为了解决分类个数过…

航拍遥感数据集

一、Roundabout Aerial Images for Vehicle Detection 本数据集是从无人机拍摄的西班牙环形交叉口航空图像数据集,使用PASCAL VOC XML文件进行注释,指出车辆在其中的位置。此外,还附带一个CSV文件,其中包含与捕获的环形交叉口的位…

深度学习 神经网络(2)前向传播

深度学习 神经网络(2)前向传播一、前言二、神经网络结构三、前向传播四、参考资料一、前言 前面介绍了《感知器》,类似于单个神经元细胞,现在我们用多个感知器组合成更加复杂的神经网络。本文介绍了多层神经网络通过前向传播方法…

超市营业额数据分析

文章目录1:查看单日交易额最小的3天的交易数据,并查看这3天是周几1.1:导入模块1.2:数据处理1.3:输出结果完整代码2:把所有员工的工号前面增加一位数字,增加的数字和原工号最后一位相同&#xff…

FBAR滤波器的工作原理及制备方法

近年来,随着无线通信技术朝着高频率和高速度方向迅猛发展,以及电子元器件朝着微型化和低功耗的方向发展,基于薄膜体声波谐振器(Film Bulk Acoustic Resonator,FBAR)的滤波器的研究与开发越来越受到人们的关…

酒楼拓客营销流程,酒楼宣传推广方案

随著网络时代的发展,许多行业受到了大大的冲击,其中也涵盖酒楼,在目前的情况下,对于酒楼来说,无论是互联网还是线下,引流都是最重要的。那么酒楼如何做好营销推广工作,从而提升业绩?…

乘风而起!企业级应用软件市场迅猛发展,有哪些机会可以把握?

数字化转型战略的深入,使我国企业级软件市场得到了迅速的发展,据统计,2021年我国企业级应用软件市场规模超过了600亿元,其中商业智能(BI)市场规模超过了50亿元。 得益于中国企业对于数据系统的本地化部署需…

Hadoop3 - MapReduce DB 操作

一、MapReduce DB 操作 对于本专栏的前面几篇文章的操作,基本都是读取本地或 HDFS 中的文件,如果有的数据是存在 DB 中的我们要怎么处理呢? Hadoop 为我们提供了 DBInputFormat 和 DBOutputFormat 两个类。顾名思义 DBInputFormat 负责从数…

MODBUS通信浮点数存储解析常用算法

MODBUS通信相关的基础知识,各种PLC通信程序的写法。可以参看专栏的其它文章这里不赘述。MODBUS通信时,数据帧都是以字节为单位发送和接收的,接收到的字节,如何存放和解析。就需要我们具备数据处理类的知识了,这里需要大家简单了解下有关数据结构的基础知识,这方面比较薄弱…

AcWing 蓝桥杯AB组辅导课 05、树状数组与线段树

文章目录前言一、树状数组1.1、树状数组知识点1.2、树状数组代码模板模板题:AcWing 1264. 动态求连续区间和例题例题1、AcWing 1265. 数星星【中等,信息学奥赛一本通】习题习题1:1215. 小朋友排队【中等,蓝桥杯】二、 线段树知识点…

27.5 Java集合之Set学习(基本概念,存储原理,性能测试)

文章目录1.Set接口1.1 Set的特性是什么?2.具体实现2.1 HashSet2.1.1 存储原理2.1.2 性能测试2.2 TreeSet2.2.1 存储原理2.2.2 性能测试2.3 EnumSet(了解即可)2.3.1 存储原理2.4 LinkedHashSet2.4.1 存储原理2.4.2 性能测试2.4.3 代码地址1.Se…

【Gitee】上传本地项目到 Gitee 仓库(入门篇)

本文主要介绍上传本地项目到 Gitee 仓库的过程,可以说是一个比较傻瓜的教材吧,从0开始,祝大家都能一次成功~~~ 一、前期准备 1. 配置 Gitte 创建 Gitte 账号,绑定好邮箱,并创建一个空仓库 。创建账号绑定邮箱过程这部…

【信号检测】基于小波变换的信号趋势检测和分离研究附matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

双十一好物推荐:2022年好用的数码好物分享

一年一度的双十一尽在眼前,因为双十一的优惠力度是一年中最大的一次,所以许多人都想着直接一年屯一次,一次屯一年的理念,那么作为资深剁手党的我来说,对比于选购双十一好物来说我还是比较有心得的,下面让我…

机器视觉之工业摄像机知识点(一)

本文主要记录一些基础的工业摄像机的一些简要知识点。我也是根据我觉得比较重要的来记录。作为一位算法工程师,其实是有两条路来走,即技术专家以及技术经理。这两个实际是不同的职业方向。如果你不擅于与外部沟通交流,并且具备非常强的科研和…

基于OpenHarmony的ArkUI框架进阶对于高性能容器类和持久化和原子化的运用

文章目录高性能容器类Badge原子化服务代码简析表达式持久化高性能容器类 顾名思义,容器类是一个存储类,用于存储各种数据类型的元素,并提供一系列处理数据元素的方法。ArkUI开发框架提供了两种类型的容器类,线性和非线性。这些容…

【机器学习】求矩阵的-1/2次方的方法

目录 一、背景描述 二、D^(-1/2)的理论基础 三、代码实现 四、总结 一、背景描述 今天在看如下论文的时候: 态势感知图卷积网络在电力系统连锁故障中的应用-机器学习文档类资源-CSDN文库https://download.csdn.net/download/mzy20010420/86745616?spm1001.20…

Rust之常用集合(一):向量(vector)

开发环境 Windows 10Rust 1.64.0VS Code 1.72.2 项目工程 这里继续沿用上次工程rust-demo 常用集合 Rust的标准库包括许多非常有用的数据结构,称为集合。大多数其他数据类型表示一个特定的值,但是集合可以包含多个值。与内置数组和元组类型不同&…