【深度学习】基于华为MindSpore的手写体图像识别实验

news2025/1/11 22:40:50

1 实验介绍

1.1 简介

Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。为简单起见,每个图像都被平展并转换为784(28*28)个特征的一维numpy数组。

1.2 实验目的

  1. 学会如何搭建全连接神经网络。
  2. 掌握搭建网络过程中的关键点。
  3. 掌握分类任务的整体流程。

2.2 实验环境要求 

推荐在华为云ModelArts实验平台完成实验,也可在本地搭建python3.7.5和MindSpore1.0.0环境完成实验。

2.3 实验总体设计

d6b6f298ac344c8f9d3d3396292e6dd5.png

 

创建实验环境:在本地搭建MindSpore环境。

导入实验所需模块:该步骤通常都是程序编辑的第一步,将实验代码所需要用到的模块包用import命令进行导入。

导入数据集并预处理:神经网络的训练离不开数据,这里对数据进行导入。同时,因为全连接网络只能接收固定维度的输入数据,所以,要对数据集进行预处理,以符合网络的输入维度要求。同时,设定好每一次训练的Batch的大小,以Batch Size为单位进行输入。

模型搭建:利用mindspore.nn的cell模块搭建全连接网络,包含输入层,隐藏层,输出层。同时,配置好网络需要的优化器,损失函数和评价指标。传入数据,并开始训练模型。

模型评估:利用测试集进行模型的评估。

2.4 实验过程

2.4.1 搭建实验环境

Windows下MindSpore实验环境搭建并配置Pycharm请参考【机器学习】Windows下MindSpore实验环境搭建并配置Pycharm_在pycharm上安装mindspore_弓长纟隹为的博客-CSDN博客

官网下载MNIST数据集 MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在MNIST文件夹下建立train和test两个文件夹,train中存放train-labels-idx1-ubyte和train-images-idx3-ubyte文件,test中存放t10k-labels-idx1-ubyte和t10k-images-idx3-ubyte文件。

2.4.2  模型训练、测试及评估

#导入相关依赖库
import  os
import numpy as np
from matplotlib import pyplot as plt
import mindspore as ms
#context模块用于设置实验环境和实验设备
import mindspore.context as context
#dataset模块用于处理数据形成数据集
import mindspore.dataset as ds
#c_transforms模块用于转换数据类型
import mindspore.dataset.transforms as C
#vision.c_transforms模块用于转换图像,这是一个基于opencv的高级API
import mindspore.dataset.vision as CV
#导入Accuracy作为评价指标
from mindspore.nn.metrics import Accuracy
#nn中有各种神经网络层如:Dense,ReLu
from mindspore import nn
#Model用于创建模型对象,完成网络搭建和编译,并用于训练和评估
from mindspore.train import Model
#LossMonitor可以在训练过程中返回LOSS值作为监控指标
from mindspore.train.callback import  LossMonitor
#设定运行模式为动态图模式,并且运行设备为昇腾芯片
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
#MindSpore内置方法读取MNIST数据集
ds_train = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "train"))
ds_test = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "test"))

print('训练数据集数量:',ds_train.get_dataset_size())
print('测试数据集数量:',ds_test.get_dataset_size())
#该数据集可以通过create_dict_iterator()转换为迭代器形式,然后通过get_next()一个个输出样本
image=ds_train.create_dict_iterator().get_next()
#print(type(image))
print('图像长/宽/通道数:',image['image'].shape)
#一共10类,用0-9的数字表达类别。
print('一张图像的标签样式:',image['label'])
DATA_DIR_TRAIN = "D:/Dataset/MNIST/train" # 训练集信息
DATA_DIR_TEST = "D:/Dataset/MNIST/test" # 测试集信息

def create_dataset(training=True, batch_size=128, resize=(28, 28), rescale=1 / 255, shift=-0.5, buffer_size=64):
    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)

    # 定义改变形状、归一化和更改图片维度的操作。
    # 改为(28,28)的形状
    resize_op = CV.Resize(resize)
    # rescale方法可以对数据集进行归一化和标准化操作,这里就是将像素值归一到0和1之间,shift参数可以让值域偏移至-0.5和0.5之间
    rescale_op = CV.Rescale(rescale, shift)
    # 由高度、宽度、深度改为深度、高度、宽度
    hwc2chw_op = CV.HWC2CHW()

    # 利用map操作对原数据集进行调整
    ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
    ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
    # 设定洗牌缓冲区的大小,从一定程度上控制打乱操作的混乱程度
    ds = ds.shuffle(buffer_size=buffer_size)
    # 设定数据集的batch_size大小,并丢弃剩余的样本
    ds = ds.batch(batch_size, drop_remainder=True)

    return ds
#显示前10张图片以及对应标签,检查图片是否是正确的数据集
dataset_show = create_dataset(training=False)
data = dataset_show.create_dict_iterator().get_next()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()

for i in range(1,11):
    plt.subplot(2, 5, i)
    #利用squeeze方法去掉多余的一个维度
    plt.imshow(np.squeeze(images[i]))
    plt.title('Number: %s' % labels[i])
    plt.xticks([])
plt.show()

# 利用定义类的方式生成网络,Mindspore中定义网络需要继承nn.cell。在init方法中定义该网络需要的神经网络层
# 在construct方法中梳理神经网络层与层之间的关系。
class ForwardNN(nn.Cell):
    def __init__(self):
        super(ForwardNN, self).__init__()
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.fc1 = nn.Dense(784, 512, activation='relu')
        self.fc2 = nn.Dense(512, 256, activation='relu')
        self.fc3 = nn.Dense(256, 128, activation='relu')
        self.fc4 = nn.Dense(128, 64, activation='relu')
        self.fc5 = nn.Dense(64, 32, activation='relu')
        self.fc6 = nn.Dense(32, 10, activation='softmax')

    def construct(self, input_x):
        output = self.flatten(input_x)
        output = self.fc1(output)
        output = self.fc2(output)
        output = self.fc3(output)
        output = self.fc4(output)
        output = self.fc5(output)
        output = self.fc6(output)
        return output

lr = 0.001
num_epoch = 10
momentum = 0.9

net = ForwardNN()
#定义loss函数,改函数不需要求导,可以给离散的标签值,且loss值为均值
loss = nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean')
#定义准确率为评价指标,用于评价模型
metrics={"Accuracy": Accuracy()}
#定义优化器为Adam优化器,并设定学习率
opt = nn.Adam(net.trainable_params(), lr)


#生成验证集,验证机不需要训练,所以不需要repeat
ds_eval = create_dataset(False, batch_size=32)
#模型编译过程,将定义好的网络、loss函数、评价指标、优化器编译
model = Model(net, loss, opt, metrics)

#生成训练集
ds_train = create_dataset(True, batch_size=32)
print("============== Starting Training ==============")
#训练模型,用loss作为监控指标,并利用昇腾芯片的数据下沉特性进行训练
model.train(num_epoch, ds_train,callbacks=[LossMonitor()],dataset_sink_mode=True)

#使用测试集评估模型,打印总体准确率
metrics_result=model.eval(ds_eval)
print(metrics_result)

20dd127b552846908ef8ff2628cafaed.png

d07ee4e815a5405b97b8b709bf63a160.png

备注:

若报错 AttributeError: ‘DictIterator’ object has no attribute ‘get_next’ ,这是说MindSpore数据类中缺少 “get_next”这个方法,但是在MNIST图像识别的官方代码中却使用了这个方法,这就说明MindSpore官方把这个变成私密方法。

只需要在源码iterators.py中找到DictIterator这个类,将私有方法变成公有方法就行了(即去掉最前面的下划线)。

参考mindspore 报错 AttributeError: ‘DictIterator‘ object has no attribute ‘get_next‘_create_dict_iterator_TNiuB的博客-CSDN博客

MindSpore:前馈神经网络时报错‘DictIterator‘ has no attribute ‘get_next‘_skytier的博客-CSDN博客

ef3de1aeb10343b8a2ca38887ff8a3c0.png

更多问题请参考Window10 上MindSpore(CPU)用LeNet网络训练MNIST - 知乎 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/457484.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

看完这篇文章你就彻底懂啦{保姆级讲解}-----(LeetCode刷题142环形链表II) 2023.4.24

目录 前言算法题(LeetCode刷题142环形链表II)—(保姆级别讲解)分析题目:算法思想环形链表II代码:补充 结束语 前言 本文章一部分内容参考于《代码随想录》----如有侵权请联系作者删除即可,撰写…

ESP32设备驱动-LIS3MDL磁场传感器驱动

LIS3MDL磁场传感器驱动 文章目录 LIS3MDL磁场传感器驱动1、LIS3MDL介绍2、硬件准备3、软件准备4、驱动实现1、LIS3MDL介绍 LIS3MDL 具有4/8/12/16 高斯的用户可选满量程。自检功能允许用户在最终应用中检查传感器的功能。该设备可以被配置为生成用于磁场检测的中断信号。 LIS…

Vue 3 第十四章:组件五(内置组件-transitiontransition-group)

文章目录 1. transition组件1.1. 基本用法1.2. css过渡class介绍1.3. 过渡效果命名1.3.1. 基本用法 1.4. 配合自定义动画&#xff08;animation&#xff09;使用1.5. 自定义过渡 class1.6. <Transition>组件生命周期1.7. transition 常用场景 2. transition-group组件2.1…

Java基础(十一)日期时间API

1 JDK8之前&#xff1a;日期时间API 1.1 java.lang.System类的方法 System类提供的public static long currentTimeMillis()&#xff1a;用来返回当前时间与1970年1月1日0时0分0秒之间以毫秒为单位的时间差。 此方法适用于计算时间差。 计算世界时间的主要标准有&#xff1a;…

SCAU 统计学 实验6

要确定不同培训方式对产品组装时间是否有显著影响&#xff0c;我们可以使用单因素方差分析&#xff08;One-way ANOVA&#xff09;。我们将使用以下数据&#xff1a; 培训方式 A 的样本数据 培训方式 B 的样本数据 培训方式 C 的样本数据 显著性水平&#xff08;α&#xff09…

windows下springboot集成ELK

ELK ElasticSearch Logstash Kibana的集合。ELK主要用于日志的集中管理、快速查询和分析。主要是通过 Logstash 将应用系统的日志通过 input 收集&#xff0c;然后通过内部整理&#xff0c;通过 output 输出到 Elasticsearch 中&#xff0c;其实就是建立了一个 index&#x…

【利刃出鞘】链式思维利用ChatGPT,让其成为工作中的利剑?附带初学者扫盲SpringBoot

【利刃出鞘】链式思维利用ChatGPT&#xff0c;让其成为工作中的利剑 一、一点思考二、技术学习——链式思维2.1 springboot注册bean的几种方式&#xff1f;2.2 springboot Component 注册的原理&#xff1f;2.3 springboot引用注册的Bean原理&#xff1f;2.4 private final MyB…

26-第一个Servlet项目

目录 1.Servlet是什么&#xff1f; 2.第一个Servlet项目 2.1.创建Maven项目 2.2.引入Servlet依赖&#xff08;将Maven项目改为Servlet项目(尚不完整)&#xff09; 2.3.完善Servlet项目目录——源代码目录&单元测试目录&#xff08;非必须&#xff09; 2.4.编写代码 …

4月24日作业

作业1 #include <iostream> using namespace std; template <typename T> class Node { private: T* p; //指针指向栈的首地址 int maxsize; //栈最大容量 int top-1; //栈顶 public: Node(){} //无参构造 Node(int max):maxsize(max)//有参构造 填最大容…

2022 ICPC Gran Premio de Mexico Repechaje 题解

目录 A. Average Walk&#xff08;签到&#xff09; 题意&#xff1a; 思路&#xff1a; 代码&#xff1a; C. Company Layoffs&#xff08;签到&#xff09; 题意&#xff1a; 思路&#xff1a; 代码&#xff1a; D. Denji1&#xff08;模拟/二分&#xff09; 思路&am…

Bsah shell的操作环境

文章目录 Bsah shell的操作环境路径与命令查找顺序使用案例 bash的登录与欢迎信息&#xff1a;/etc/issue、/etc/motdbash的环境配置文件如下login与non-login shell/etc/profile(login shell 才会读)~/.bash_profile(login shell 才会读)source&#xff1a;读入环境配置文件的…

简单介绍一下什么是“工作内存”和“主内存”(JMM中的概念)

在学习Java多线程编程里&#xff0c; volatile 关键字保证内存可见性的要点时&#xff0c;看到网上有些资料是这么说的&#xff1a;线程修改一个变量&#xff0c;会把这个变量先从主内存读取到工作内存&#xff1b;然后修改工作内存中的值&#xff0c;最后再写回到主内存。 对…

【基于gcc】手把手教你移植RT-Thread到STM32

前言 网上大多数移植RT-Thread系统的教程都是基于Keil的&#xff0c;下面将带来基于gcc版本的移植教程&#xff0c;若你还没有基于gcc的环境&#xff0c;可以查看我的这篇文章&#xff1a;VSCode搭建STM32开发环境 1、下载RT-Thread源码 RT-Thread有好几个版本&#xff0c;我…

小程序路由跳转

小程序中的路由只是单纯页面地址的跳转&#xff0c;一般在页面中使用 navigator 组件来实现&#xff0c;也有很多场景需要在 js 中根据逻辑的执行结果跳转到某个页面&#xff0c;比如&#xff1a;如果检测到用户尚未登录就需要给他跳转到登录页面 1.1navigate navigate 跳转到…

软件工程开发文档写作教程(04)—开发文档的编制策略

本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl本文参考资料&#xff1a;电子工业出版社《软件文档写作教程》 马平&#xff0c;黄冬梅编著 开发文档编制策略 文档策略是由上级(资深)管理者制订的&#xff0c;对下级开发单位或开发人…

银河麒麟 Server V10 离线源建立+部署

前言 这国产操作系统真神奇&#xff0c;docker CentOS7&#xff0c; MySQL CentOS8 简直了&#xff0c;这缝合技术真是绝了&#xff01; docker CentOS7 能装最新版 23 很顺利的&#xff01; MySQL CentOS8 也是最新版8.0.33的&#xff0c;也很顺利&#xff01; 系统版本 …

权威解析,软件测试的当下分析现状

Parasoft是一家专门提供软件测试解决方案的公司&#xff0c;Parasoft通过其经过市场验证的自动化软件测试工具集成套件&#xff0c;帮助企业持续交付高质量的软件。Parasoft的技术支持嵌入式、企业和物联网市场&#xff0c;通过将静态代码分析和单元测试、Web UI和API测试等所有…

详解树与二叉树的概念,结构,及实现(下篇)

目录 一&#xff0c; 二叉树链式实现 1. 前置说明 2. 二叉树遍历&#xff08;主打的就是一个分治思想&#xff09; 2. 1 前序遍历 2. 2 中序遍历 2. 3 后序遍历 2. 4 层序遍历 3. 二叉树结点个数及高度 3. 1 二叉树节点个数 3. 2 二叉树叶子节点个数 3. 3 二叉树第…

数字信号处理技术(三)自适应噪声完备集合经验模态分解(CEEMDAN)-Python代码

本文仅对自适应噪声完备集合经验模态分解&#xff08;CEEMDAN&#xff09;的原理简单介绍和重点介绍模型的应用。 1. CEEMDAN原理 CEEMDAN&#xff08;Complete Ensemble Empirical Mode Decomposition with Adaptive Noise&#xff09;的中文名称是自适应噪声完备集合经验模…

一文全解经典机器学习算法之支持向量机SVM(关键词:SVM,对偶、间隔、支持向量、核函数、特征空间、分类)

文章目录 一&#xff1a;概述二&#xff1a;间隔与支持向量三&#xff1a;对偶问题&#xff08;1&#xff09;什么是对偶问题&#xff08;2&#xff09;SVM对偶问题&#xff08;3&#xff09;SMO算法 四&#xff1a;核函数&#xff08;1&#xff09;核函数的概述和作用&#xf…