昇思25天学习打卡营第2天|MindSpore快速入门

news2024/9/22 21:22:57

打卡

目录

打卡

快速入门案例:minist图像数据识别任务

案例任务说明

流程

1 加载并处理数据集

2 模型网络构建与定义

3 模型约束定义

4 模型训练

5 模型保存

6 模型推理

相关参考文档入门理解

MindSpore数据处理引擎

模型网络参数初始化

模型优化器

损失函数

代码

安装

从模型训练到预测推理

self_main_train_and_save.py

self_dataprocess.py

self_network.py

self_modeltrain.py

self_modeltest.py

self_predict.py


快速入门案例:minist图像数据识别任务

案例任务说明

MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。

目的是训练一个可以高效识别手写阿拉伯数字的模型。

流程

1 加载并处理数据集

涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。

2 模型网络构建与定义

涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。

3 模型约束定义

包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)

4 模型训练

- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

- 定义测试函数,用来评估模型的性能。

5 模型保存

- 两种保存方式:

1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")

2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

6 模型推理

- 两种加载方式:

1)模型参数加载: 

> model = network()

> param_dict = mindspore.load_checkpoint("model.ckpt");  

param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

2)统一的中间表示(Intermediate Representation,IR)的加载:

> mindspore.set_context(mode=mindspore.GRAPH_MODE)
> graph = mindspore.load("model.mindir")
> model = nn.GraphCell(graph)  ## nn.GraphCell 仅支持图模式。
> outputs = model(inputs)

保存与加载 — MindSpore master 文档

相关参考文档入门理解

MindSpore数据处理引擎

MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。

数据集 Dataset — MindSpore master 文档

数据变换 Transforms — MindSpore master 文档

模型网络参数初始化

Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_initbias_init等入参,可以直接使用实例化的Initializer进行参数初始化。

参数初始化 — MindSpore master 文档

模型优化器

优化器 — MindSpore master 文档

损失函数

损失函数 — MindSpore master 文档

代码

安装

pip/conda均可:

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

从模型训练到预测推理

训练:

python self_main_train_and_save.py

推理:

python self_predict.py

self_main_train_and_save.py

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件 
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)    
 
## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names())   # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']


## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)  
test_dataset = datapipe(test_dataset, 64)  

## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    print(f"Shape of label: {label.shape} {label.dtype}")
    break
    “”“
    Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32
    Shape of label: (64,) Int32
    ”“”
for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break


## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn 
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")


## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

self_dataprocess.py

from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

self_network.py

# Define model
from mindspore import nn

class Network(nn.Cell): 
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )
    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits


def check_network():
    model = Network()
    print(model)

self_modeltrain.py

# Instantiate loss function and optimizer
from mindspore import nn

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 1. Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 3. Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss


def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()     ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认True
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)
        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

self_modeltest.py

from mindspore import nn 

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

self_predict.py

## 加载模型
from self_network import Network

# Instantiate a random initialized model
model = Network()

# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)   ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

优化路由,优化请求url

1、使用父子关系调整下使其更加整洁 2、比如说我修改了下url,那所有的页面都要更改 优化:把这个url抽出来,新建一个Api文件夹用于存放所有接口的url,在业务里只需要关注业务就可以 使用时 导包 发请求 如果想要更改路径,在这里…

ReAct Agent 分享回顾

在人工智能的迅速发展中,ReAct Agent作为一项前沿技术,受到越来越多的关注。本文结合ReAct Agent 提出者的访谈内容,探讨ReAct Agent的研究背景、技术挑战、未来展望,以及它与大模型的紧密联系,分析其科研成果与商业化…

迅捷PDF编辑器合并PDF

迅捷PDF编辑器是一款专业的PDF编辑软件,不仅支持任意添加文本,而且可以任意编辑PDF原有内容,软件上方的工具栏中还有丰富的PDF标注、编辑功能,包括高亮、删除线、下划线这些基础的,还有规则或不规则框选、箭头、便利贴…

使用Docker、Docker-compose部署单机版达梦数据库(DM8)

安装前准备 Linux Centos7安装:https://blog.csdn.net/andyLyysh/article/details/127248551?spm1001.2014.3001.5502 Docker、Docker-compose安装:https://blog.csdn.net/andyLyysh/article/details/126738190?spm1001.2014.3001.5502 下载DM8镜像 …

动态颤抖的眼睛效果404页面源码

动态颤抖的眼睛效果404页面源码, 源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 动态颤抖的眼睛效果404页面源码

【密码学】密码学五要素

密码学五要素是密码系统的基本组成部分,这五个要素共同构成了密码系统的框架。在实际应用中,密码系统的安全性依赖于密钥的安全管理以及算法的强度。 如果任何一方被泄露或破解,那么整个密码系统都将面临风险。因此,在设计和使用密…

关于多人开发下git pull报错代码冲突问题的解决方案

关于多人开发下git pull报错代码冲突问题的解决方案 问题描述 最近多人开发项目习惯性先 git pull 来更新代码的时候,遇到了下面的问题:error: Your local changes to the following files would be overwritten by merge: Please, commit your change…

医疗器械FDA | FDA如何对医疗器械网络安全认证进行审查?

FDA医械网络安全文件出具​https://link.zhihu.com/?targethttps%3A//www.wanyun.cn/Support%3Fshare%3D24315_ea8a0e47-b38d-4cd6-8ed1-9e7711a8ad5e FDA对医疗器械的网络安全认证进行审查时,主要关注以下几个方面,以确保医疗器械在网络环境中的安全性…

vulhub靶场之DEVGURU:1

1 信息收集 1.1 主机发现 arp-scan -l 发现主机IP地址为“192.168.1.11 1.2 端口发现 nmap -sS -sV -A -T5 -p- 192.168.1.11 发现端口为:22,80,8585 1.3 目录扫描 dirsearch -u 192.168.1.11 发现存在git泄露 2 文件和端口访问 2…

力扣5----最长回文子串

给你一个字符串 s,找到 s 中最长的回文子串 示例 1: 输入:s "babad" 输出:"bab" 解释:"aba" 同样是符合题意的答案。示例 2: 输入:s "cbbd" 输出…

嵌入式通信协议全解析:SPI、I²C、UART详解(附带面试题)

目录 一、什么是通信 二、 通信的分类 同步通信(Synchronous Communication) 异步通信(Asynchronous Communication) 不同协议标准区分图: UART UART的特点: UART的通信过程: UART的配置…

Linux多进程和多线程(四)进程间通讯-定时器信号和子进程退出信号

多进程(四) 定时器信号alarm()函数示例alarm()函数的限制定时器信号的实现原理setitimer()函数setitimer()和alarm()函数的区别 setitimer() old_value参数的示例 对比alarm()区别总结: 子进程退出信号 示例: 多进程(四) 定时器信号 SIGALRM 信号是用来通知进程…

ctfshow web 36d 练手赛

不知所措.jpg 没啥用然后测试了网站可以使用php伪达到目的 ?filephp://filter/convert.base64-encode/resourcetest/../index.<?php error_reporting(0); $file$_GET[file]; $file$file.php; echo $file."<br />"; if(preg_match(/test/is,$file)){inclu…

统一视频接入平台LntonCVS视频监控平台具体功能介绍

LntonCVS视频监控平台是一款基于H5技术开发的安防视频监控解决方案&#xff0c;专为全球范围内不同品牌、协议及设备类型的监控产品设计。该平台提供了统一接入管理&#xff0c;支持标准的H5播放接口&#xff0c;使其他应用平台能够快速集成视频功能。无论开发环境、操作系统或…

24-7-6-读书笔记(八)-《蒙田随笔集》[法]蒙田 [译]潘丽珍

文章目录 《蒙田随笔集》阅读笔记记录总结 《蒙田随笔集》 《蒙田随笔集》蒙田&#xff08;1533-1592&#xff09;&#xff0c;是个大神人&#xff0c;这本书就是250页的样子&#xff0c;但是却看了好长好长时间&#xff0c;体会还是挺深的&#xff0c;但看的也是不大仔细&…

《第一行代码》小结

文章目录 一. Android总览1. 系统架构2. 开发环境3. 在红米手机上运行4. 项目资源详解4.1 整体结构4.2 res文件4.3 build.gradle文件 二. Activity0. 常用方法小结1. 创建一个Activity 一. Android总览 1. 系统架构 应用层&#xff1a;所有安装在手机上的应用程序 应用框架层&…

vb.netcad二开自学笔记3:启动与销毁

Imports Autodesk.AutoCAD.ApplicationServicesImports Autodesk.AutoCAD.EditorInputImports Autodesk.AutoCAD.RuntimePublic Class WellcomCADImplements IExtensionApplicationPublic Sub Initialize() Implements IExtensionApplication.InitializeMsgBox("net程序已…

字节跳动与南开联合开源 StoryDiffusion:一键生成漫画和视频故事的神器!完全免费!

大家好&#xff0c;我是程序员X小鹿&#xff0c;前互联网大厂程序员&#xff0c;自由职业2年&#xff0c;也一名 AIGC 爱好者&#xff0c;持续分享更多前沿的「AI 工具」和「AI副业玩法」&#xff0c;欢迎一起交流~ 漫画&#xff0c;是多少人童年的回忆啊&#xff01; 记得小学…

Sahi+Yolov10

一、前言 了解到Sahi&#xff0c;是通过切图&#xff0c;实现提高小目标的检测效果。sahi 目前支持yolo5\yolo8\mmdet\detection2 等等算法&#xff0c;本篇主要通过实验onnx加载模型的方式使sahi支持yolov10。 二、代码 &#xff08;1&#xff09;转换模型 首先使用 conda创…

EtherCAT转Profinet网关配置说明第一讲:配置软件安装及介绍

网关XD-ECPNS20为EtherCAT转Profinet协议网关&#xff0c;使EtherCAT协议和Profinet协议两种工业实时以太网网络之间双向传输 IO 数据。适用于具有EtherCAT协议网络与Profinet协议网络跨越网络界限进行数据交换的解决方案。 本网关通过上位机来进行配置。 首先安装上位机软件 一…