昇思25天学习打卡营第1天|快速入门-Mnist手写数字识别

news2024/10/7 0:31:13

学习目标:熟练掌握MindSpore使用方法

学习心得体会,记录时间
在这里插入图片描述

  • 了解MindSpore总体架构
  • 学会使用MindSpore
  • 简单应用时间-手写数字识别

一、MindSpore总体架构

华为MindSpore为全场景深度学习框架,开发高效,全场景统一部署特点。
在这里插入图片描述


二、学会使用MindSpore

2.1 jupyter云上开发配置

方法一、在昇思大模型平台上有应有的环境,还可以申请使用算力,在自己电脑上需要下载mindspore安装,再安装依赖库down等

  1. 登录官方网网址;
  2. 注册账号
  3. 进入AI实验室
  4. 申请算力。启用算力支持,进入jupyter云上开发,即可开始你的算法设计。
    在这里插入图片描述
    在这里插入图片描述

2.2 本地开发配置Mindspore

方法二、本地搭建mindspore环境,安装相关依赖库,即可开始算法设计
比如在本地电脑anaconda3上配置mindspore框架环境。

  1. anaconda prompt命令窗口创建环境
conda create -n mindspore python=3.9.19```
  1. 切换到该环境
activate mindspore
  1. 安装mindspore
pip install mindspore
  1. 安装依赖库download,加载常用数据集
pip install download

在这里插入图片描述

2.3 制作数据集

1.直接导入MnistDataset
在这里插入图片描述
mindspore和其他成熟的框架,如torch,类似。包含处理深度学习和数据集的方法,如nn,transforms,vision等;以及常用的数据集API,mindspore.dataset可供加载的数据集,如MNIST、CIFAR-10、CIFAR-100、VOC、COCO、ImageNet、CelebA、CLUE等,也支持加载业界标准格式的数据集,包括MindRecord、TFRecord、Manifest等。此外,用户还可以使用此模块定义和加载自己的数据集。

import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision

常用数据集术语说明如下:
Dataset,所有数据集的基类,提供了数据处理方法来帮助预处理数据。
SourceDataset,一个抽象类,表示数据集管道的来源,从文件和数据库等数据源生成数据。
MappableDataset,一个抽象类,表示支持随机访问的源数据集。
Iterator,用于枚举元素的数据集迭代器的基类。


  1. 生成自定义数据集示例如下:
import numpy as np
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms

# 构造图像和标签
data1 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data2 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data3 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data4 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)

label = [1, 2, 3, 4]

# 加载数据集
dataset = ds.NumpySlicesDataset(([data1, data2, data3, data4], label), ["data", "label"])

# 对data数据增强
dataset = dataset.map(operations=vision.RandomCrop(size=(250, 250)), input_columns="data")
dataset = dataset.map(operations=vision.Resize(size=(224, 224)), input_columns="data")
dataset = dataset.map(operations=vision.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                                  std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),
                      input_columns="data")
dataset = dataset.map(operations=vision.HWC2CHW(), input_columns="data")

# 对label变换类型
dataset = dataset.map(operations=transforms.TypeCast(ms.int32), input_columns="label")

# batch操作
dataset = dataset.batch(batch_size=2)

# 创建迭代器
epochs = 2
ds_iter = dataset.create_dict_iterator(output_numpy=True, num_epochs=epochs)
for _ in range(epochs):
    for item in ds_iter:
        print("item: {}".format(item), flush=True)

实验输出结果:
在这里插入图片描述

三、手写数字识别

pycharm IDE工具创建工程项目,搭载前面配置的环境mindspore,编写.py文件。数据处理py,模型训练测试py。

  1. 数据集
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
# 运行过一次,后面就是
# path = download(url, "./", kind="zip", replace=True)
train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')
# print(train_dataset.get_col_names())


# MindSpore的dataset使用数据处理流水线(Data Processing Pipeline)
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset


# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)


if __name__ == "__main__":
    for image, label in test_dataset.create_tuple_iterator():
        print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
        print(f"Shape of label: {label.shape} {label.dtype}")
        break

在这里插入图片描述

  1. 模型
# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28 * 28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits
model = Network()
# print(model)

3.损失函数,优化器,学习率

# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()  # 损失函数
optimizer = nn.SGD(model.trainable_params(), 1e-2)  # 优化器函数,学习率0.01

前向传播函数

# 1. Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

梯度函数

# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

梯度反向优化函数

# 3. Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss
  1. 模型训练
def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

  1. 模型保存
mindspore.save_checkpoint(model, "./saveModels/mnistModel.ckpt")
print("Saved Model to mnistModel.ckpt")

6.模型测试

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  1. 模型加载测试
# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("./saveModels/mnistModel.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
  1. 测试加载模型
model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

9.训练测试结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1868186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何ubuntu安装wine/deep-wine运行exe程序(包括安装QQ/微信/钉钉)

1.失败的方法: ubuntu22.04尝试下面这个链接方法没有成功, ubuntu22.04安装wine9.0_ubuntu 22.04 wine-CSDN博客 上面链接里面也提供了wine官方方法,链接如下:https://wiki.winehq.org/Ubuntu_zhcn 但是运行最后一个命令时候报…

win10 C:\Users\Administrator

win10 C:\Users\Administrator C:\Users\Administrator\Documents\ C:\Users\Administrator\Pictures C:\Users\Administrator\Favorites C:\Users\Administrator\Links C:\Users\Administrator\Videos

【C++】————内存管理

作者主页: 作者主页 本篇博客专栏:C 创作时间 :2024年6月26日 一、C内存分布 我们先来看一串代码: int globalVar 1; static int staticGlobalVar 1; void Test() {static int staticVar 1;int localVar 1;int num1[10] …

电脑怎么设置锁屏密码?这3个方法你知道吗

在日常生活中,电脑已成为我们工作和娱乐的重要工具。为了保护个人信息和数据安全,设置锁屏密码是必不可少的一步。通过设置锁屏密码,您可以有效防止未经授权的访问,确保电脑上的隐私和数据不被泄露。本文将详细介绍电脑怎么设置锁…

STM32_hal_STM32Cude_实现RT—Thread系统

1stm32cude下载系统 1.-2下载显示绿色的为下载成功 2为项目导入系统---点击如下选项 2-1选中如下 意思为 kemel 系统内核 shell shell的实现 device 设备 2-2可以看到项目选项中多了如图选项 3实现led闪烁 3-1 定义两个引脚用于控制led 3-2选择时钟源 3-3更改延迟函数…

Planned independent reguirement can only be maintained via the network

背景:用户上线ps系统,物料用策略70跑需求 但是因为通用料被改了策略,改成其他的了,影响到计划独立需求了。 如果用户不需要了哪个料就会把数量改为0,或者直接删掉物料。之前建议是改成0,这样还有个记录在…

【Python机器学习】交互特征与多项式特征

对于线性模型来说,想要丰富特征,还有一种方法是添加原始数据的交互特征和多项式特征。这种特征工程通常用于统计建模,但也经常用于实际的机器学习应用中。 交互特征 上一篇的例子里,线性模型对wave数据集的的每个箱子都学到一个…

基于稀疏矩阵方法的剪枝压缩模型方案总结

1.简介 1.1目的 在过去的一段时间里,对基于剪枝的模型压缩的算法进行了一系列的实现和实验,特别有引入的稀疏矩阵的方法实现了对模型大小的压缩,以及在部分环节中实现了模型前向算法的加速效果,但是总体上模型加速效果不理想。所…

从零到一打造自己的大模型:模型训练

前言 最近看了很多大模型,也使用了很多大模型。对于大模型理论似乎很了解,但是好像又缺点什么,思来想去决定自己动手实现一个 toy 级别的模型,在实践中加深对大语言模型的理解。 在这个系列的文章中,我将通过亲手实践…

【面试题】Spring面试题

目录 Spring Framework 中有多少个模块,它们分别是什么?Spring框架的设计目标、设计理念?核心是什么?Spring框架中都用到了哪些设计模式?Spring的核心机制是什么?什么是Spring IOC容器?什么是依…

竞赛选题 python区块链实现 - proof of work工作量证明共识算法

文章目录 0 前言1 区块链基础1.1 比特币内部结构1.2 实现的区块链数据结构1.3 注意点1.4 区块链的核心-工作量证明算法1.4.1 拜占庭将军问题1.4.2 解决办法1.4.3 代码实现 2 快速实现一个区块链2.1 什么是区块链2.2 一个完整的快包含什么2.3 什么是挖矿2.4 工作量证明算法&…

鸿蒙面试心得

自疫情过后,java和web前端都进入了冰河时代。年龄、薪资、学历都成了找工作路上躲不开的门槛。 年龄太大pass 薪资要高了pass 学历大专pass 好多好多pass 找工作的路上明明阳关普照,却有一种凄凄惨惨戚戚说不清道不明的“优雅”意境。 如何破局&am…

修复:cannot execute binary file --- ppc64le 系统架构

前言: 修复node_exporter,引用pprof包,对源码编译后在 Linux 系统下执行程序运行时,发生了报错,报错信息:cannot execute binary file: Exec format error。 开始以为编译有问题,检查发现;该l…

正规的外盘期货开户指南避坑!

一:最正规最靠谱的外盘期货开户方式。那就是直开香港账户,需要基本证件、护照、境外卡等。 如果你满足以上条件,可以直接在香港外盘期货公司的营业部或线上官网开户。 优点:安全正规,银期转账。 缺点:保…

Java - 程序员面试笔记记录 实现 - Part1

社招又来学习 Java 啦,这次选了何昊老师的程序员面试笔记作为主要资料,记录一下一些学习过程。 1.1 Java 程序初始化 Java 程序初始化遵循规则:静态变量优于动态变量;父类优于子类;成员变量的定义顺序; …

1. jenkins持续集成交付

jenkins持续集成交付 一、jenkins介绍二、jenkins的安装部署1、下载jenkins2、安装jenkins3、修改插件下载地址4、初始化jenkins 一、jenkins介绍 持续集成交付, CI/CD 偏开发、项目编译、部署、更新 二、jenkins的安装部署 1、下载jenkins [rootjenkins ~]# wge…

LLM 推理:Nvidia TensorRT-LLM 与 Triton Inference Server

随着LLM越来越热门,LLM的推理服务也得到越来越多的关注与探索。在推理框架方面,tensorrt-llm是非常主流的开源框架,在Nvidia GPU上提供了多种优化,加速大语言模型的推理。但是,tensorrt-llm仅是一个推理框架&#xff0…

算法设计与分析--分布式系统作业及答案

分布式系统 作业参考答案2.1 分析在同步和异步模型下,convergecast 算法的时间复杂性。2.2 G 里一结点从 pr 可达当且仅当它曾设置过自己的 parent 变量。2.3 证明 Alg2.3 构造一棵以 Pr 为根的 DFS 树。2.4 证明 Alg2.3 的时间复杂度为 O(m)。2.5 修改 Alg2.3 获得…

限域传质分离膜兼具高渗透性、高选择性特点 未来应用前景广阔

限域传质分离膜兼具高渗透性、高选择性特点 未来应用前景广阔 分离膜是一种具有选择性透过功能的薄层材料。限域传质分离膜是基于限域传质机制的分离膜,兼具高渗透性、高选择性的特点。限域传质是流体分子通过与其运动自由程相当传质空间的过程,流体分子…

网络安全 DVWA通关指南 Cross Site Request Forgery (CSRF)

DVWA Cross Site Request Forgery (CSRF) 文章目录 DVWA Cross Site Request Forgery (CSRF)DVWA Low 级别 CSRFDVWA Medium 级别 CSRFDVWA High 级别 CSRFDVWA Impossible 级别 CSRF CSRF是跨站请求伪造攻击,由客户端发起,是由于没有在执行关键操作时&a…