Mindspore实现手写数字识别

news2025/3/3 18:40:37

废话不多说,首先说一下我使用的环境:

python3.9

mindspore 2.1

使用jupyter notebook

Step1:导入相关依赖的包

import os
from matplotlib import pyplot as plt
import numpy as np
import mindspore as ms
import mindspore.context as context
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.nn.metrics import Accuracy
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor

context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')

 Step2:下载mindspore官方的手写数字识别的数据集:

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

Step3:打印数据的相关信息

DATA_DIR_TRAIN = "MNIST_Data/train" # 训练集信息
DATA_DIR_TEST = "MNIST_Data/test" # 测试集信息
#读取数据
ds_train = ds.MnistDataset(DATA_DIR_TRAIN)
ds_test = ds.MnistDataset(DATA_DIR_TEST )
#显示数据集的相关特性
print('训练数据集数量:',ds_train.get_dataset_size())
print('测试数据集数量:',ds_test.get_dataset_size())
image=ds_train.create_dict_iterator().__next__()
print('图像长/宽/通道数:',image['image'].shape)
print('一张图像的标签样式:',image['label']) #一共 10 类,用 0-9 的数字表达类别


Step4:数据预处理函数,对数据进行归一化、裁剪成指定大小、HWC转换为CHW,最后使用map函数进行映射。设定打乱数据集的操作和设置batch_size的大小。 

def create_dataset(training=True, batch_size=128, resize=(28, 28),
                    rescale=1/255, shift=0, buffer_size=64):
    ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
    # 定义 Map 操作尺寸缩放,归一化和通道变换
    resize_op = CV.Resize(resize)
    rescale_op = CV.Rescale(rescale,shift)
    hwc2chw_op = CV.HWC2CHW()
    # 对数据集进行 map 操作
    ds = ds.map(input_columns="image", operations=[rescale_op,resize_op, hwc2chw_op])
    ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
    #设定打乱操作参数和 batchsize 大小
    ds = ds.shuffle(buffer_size=buffer_size)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

Step5:画出来前10张·图片看看效果

#显示前 10 张图片以及对应标签,检查图片是否是正确的数据集
ds = create_dataset(training=False)
data = ds.create_dict_iterator().__next__()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()
plt.figure(figsize=(15,5))
for i in range(1,11):
    plt.subplot(2, 5, i)
    plt.imshow(np.squeeze(images[i]))
    plt.title('Number: %s' % labels[i])
    plt.xticks([])
plt.show()

Step6:构建网络模型

#创建模型。模型包括 3 个全连接层,最后输出层使用 softmax 进行多分类,共分成(0-9)10 类
class ForwardNN(nn.Cell):
    def __init__(self):
        super(ForwardNN, self).__init__()
        self.conv1 = _conv3x3(1, 64, stride=1)
        self.bn1 = _bn(64)
        self.relu = ops.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
        self.resblock1 = ResidualBlock(64,128,stride=1)
        self.resblock2 = ResidualBlock(128,128,stride=1)
        self.flatten = nn.Flatten()
        self.GAP = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Dense(128,10)
    def construct(self, input_x):
        x = self.conv1(input_x) # 第一层卷积 7X7,步长为 2
        x = self.bn1(x) # 第一层的 Batch Norm
        x = self.relu(x) # Rule 激活层
        x = self.maxpool(x) # 最大池化 3X3,步长为 2
        x = self.resblock1(x)
        x = self.resblock2(x)
#         x = self.fc(self.flatten(self.GAP(x)))
        
        return x
in_ = ms.Tensor(np.random.randn(32,1,28,28).astype(np.float32))
print(in_.shape)
model = ForwardNN()
aa = model(in_)
print(aa.shape)

我这里使用的是自己搭建的有两个resBlock的卷积网络,大家可以自己尝试,也可以使用全连接网络试试。

Step7:设置超参数和相关的指标

#创建网络,损失函数,评估指标 优化器,设定相关超参数
lr = 0.001
num_epoch = 10
momentum = 0.9
net = ForwardNN()
loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
metrics={"Accuracy": Accuracy()}
opt = nn.Adam(net.trainable_params(), lr)

这就开始训练了。

Step8:开始预测

#使用测试集评估模型,打印总体准确率
metrics=model.eval(ds_eval)
print(metrics)

 

准确率97%还不错。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1279703.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringMVC】Spring Web MVC入门(一)

文章目录 前言什么是Spring Web MVC?什么是MVC什么是Spring MVC? Spring Boot 和 Spring MVC 的区别什么是Spring Boot?关系和区别 Spring MVC 学习注解介绍1. SpringBootApplication2. RestController3. RequestMapping3.1 RequestMapping 使…

CSS3 修改滚动条样式

上图: 上代码: /* 修改垂直滚动条 */ .right-list::-webkit-scrollbar {width: 2px; /* 修改宽度 */height: 5px; /* 修改高度 */ } /* 修改滚动条轨道背景色 */ .right-list::-webkit-scrollbar-track {background-color: #f1f1f1; } /* 修改滚动条滑块…

找不到DNS地址的解决方案

找不到DNS地址的解决方案 第一种解决方案:刷新DNS缓存第二种解决方案: 配置Internet协议版本4(TCP/IPv4)配置IP地址配置DNS地址 如何查看本机IPv4地址、子网掩码与默认网关 第一种解决方案:刷新DNS缓存 WINR输入cmd回…

GEE:Sobel算子卷积

作者:CSDN _养乐多_ 本文将深入探讨边缘检测中的一个经典算法,即Sobel算子卷积。我们将介绍该算法的基本原理,并演示如何在Google Earth Engine中应用Sobel算子进行图像卷积操作。并以试验区NDVI为例子,研究区真彩色影像、NDVI图…

python毕业设计论文选题管理系统b615y

毕业论文管理方式效率低下,为了提高效率,特开发了本毕业论文管理系统。本毕业论文管理系统主要实现的功能模块包括学生模块、导师模块和管理员模块三大部分,具体功能分析如下: (1)导师功能模块:…

站群优化工具,站群优化方案策略

站群优化,作为网络推广的一项重要策略,站群的构建和优化对于提升网站在搜索引擎中的排名、吸引目标流量、增加用户粘性等方面有着不可忽视的作用。 站群优化方案 站群优化并非简单的堆积大量网站,更要注重质量和策略。在构建站群时&#xff…

大数据技术之Flume(超级详细)

大数据技术之Flume(超级详细) 第1章 概述 1.1 Flume定义 Flume是Cloudera提供的一个高可用的,高可靠的,分布式的海量日志采集、聚合和传输的系统。Flume基于流式架构,灵活简单。 1.2 Flume组成架构 Flume组成架构如…

WPF绘图---Canvas中Polygon屏幕居中显示

问题描述 在一个Canvas中绘制了多个Polygon&#xff0c;由于坐标可能超出界面显示范围&#xff0c;需要将绘制的Polygon居中显示&#xff0c;并且缩放至界面大小&#xff0c;效果如下&#xff1a; xaml代码 <Borderx:Name"border"Background"#fff"Cli…

三个写法统计整数前导0个数

从键盘输入一个整数(可能有前导0)&#xff0c;编程统计其前导0个数&#xff0c;其法有三。 (笔记模板由python脚本于2023年12月03日 12:32:32创建&#xff0c;本篇笔记适合对python整型int和字符型str熟悉的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;http…

C++17中的结构化绑定

C17中的结构化绑定(structured binding):将指定名称绑定到初始化程序的子对象或元素。简而言之&#xff0c;它们使我们能够从元组或结构中声明多个变量。与引用一样&#xff0c;结构化绑定是现有对象的别名&#xff1b;与引用不同&#xff0c;结构化绑定不必是引用类型(referen…

华天动力-OA8000 MyHttpServlet 文件上传漏洞复现

0x01 产品简介 华天动力OA是一款将先进的管理思想、 管理模式和软件技术、网络技术相结合&#xff0c;为用户提供了低成本、 高效能的协同办公和管理平台。 0x02 漏洞概述 华天动力OA MyHttpServlet 存在任意文件上传漏洞&#xff0c;未经身份认证的攻击者可上传恶意的raq文件…

图片处理OpenCV IMDecode模式说明【生产问题处理】

OpenCV IMDecode模式说明【生产问题处理】 1 前言 今天售后同事反馈说客户使用我们的图片处理&#xff0c;将PNG图片处理为JPG图片之后&#xff0c;变为了白板。 我们图片处理使用的是openCV来进行处理 2 分析 2.1 图片是否损坏&#xff1a;非标准PNG头部 于是&#xff0c;马…

Git中如何按日期进行checkout

Git的checkout命令 在Git中&#xff0c;checkout命令是常用的操作之一。它允许我们切换到不同的分支或指定的提交。通过checkout命令&#xff0c;我们可以在代码库中切换到特定的提交版本&#xff0c;这也意味着我们可以按日期进行checkout。 按日期进行checkout的方法 要按…

SmartSoftHelp8,C#简易编程,测试工具

using System; using System.Data; using System.Drawing; using System.IO; using System.Text; using System.Runtime.InteropServices; using System.Threading; using System.Windows.Forms; /// <summary> /// 编程实验室空间名称 /// </summary> namespa…

边缘与云或边缘加云:前进的方向是什么?

边缘计算使数据处理更接近数据源&#xff0c;以及由此产生的行动或决策的对象。通过设计&#xff0c;它可以改变数十亿物联网和其他设备存储、处理、分析和通信数据的方式。 边缘计算使数据处理更接近数据源&#xff0c;以及由此产生的行动或决策的对象。这与传统的体系结构形成…

局域网传输神器LocalSend

局域网文件传输神器 LocalSend 注意只能在相同局域网用才能使用&#xff08;比如用同一个wifi&#xff09;&#xff0c;通常作为办公用品 安装包下载 在gitHub&#xff0c;最好科学上网一下 LocalSend官网 选择最后更新版本 选择手机或电脑以及自己的系统 安装使用 傻瓜…

Leetcode—1423.可获得的最大点数【中等】

2023每日刷题&#xff08;四十八&#xff09; Leetcode—1423.可获得的最大点数 思路&#xff1a;逆向求长为 n−k 的连续子数组和的最小值 参考灵茶山艾府题解 实现代码 class Solution { public:int maxScore(vector<int>& cardPoints, int k) {int mins 0, …

非标设计之螺纹螺丝选型二

目录 一、螺丝的表面处理工艺&#xff1a;镀锌工艺&#xff1a;渗锌工艺&#xff1a;热浸锌工艺&#xff1a;达克罗工艺&#xff1a;镀镍工艺&#xff1a;氧化&#xff08;发黑&#xff09;工艺&#xff1a;电泳黑工艺&#xff1a;不锈钢螺钉&#xff1a; 二、按照颜色分工艺&a…

TensorRT之LeNet5部署(onnx方式)

文章目录 前言LeNet-5部署1.ONNX文件导出2.TensorRT构建阶段(TensorRT模型文件)&#x1f9c1;创建Builder&#x1f367;创建Network&#x1f36d;使用onnxparser构建网络&#x1f36c;优化网络&#x1f361;序列化模型&#x1f369;释放资源 3.TensorRT运行时阶段(推理)&#x…

SVN下载使用和说明

一、SVN <1>SVN的简介 1、svn是什么&#xff1f; 2、作用 3、基本操作 <2>服务器端的软件下载和安装 1、下载 2、查看环境变量 3、验证安装是否成功 <3>创建项目版本库 1、创建项目版本库&#xff08;svn reponsitory&#xff09; 2、svn版本控制文件说明…