昇思25天学习打卡营第1天|MindSpore 全流程操作指南

news2024/10/7 9:26:52

目录

MindSpore 库相关操作的导入指南

处理数据集

网络构建

模型训练

保存模型

加载模型


MindSpore 库相关操作的导入指南


        首先,我们导入了 MindSpore 这个库的整个模块。然后,从 MindSpore 库中引入了 nn 模块,一般来说,它是和神经网络有关的。接着,从 MindSpore 的数据集模块里导入了 vision 模块(可能和图像处理有关系)以及 transforms 模块(可能会用于进行数据的转换操作)。最后,从 MindSpore 的数据集模块中导入了 MnistDataset ,通常情况下,这是用来处理 MNIST 数据集的类或者模块。总的来讲,这些导入语句为后面使用 MindSpore 库中的相关功能去处理数据以及构建神经网络模型提前做好了准备。

        代码如下:

import mindspore  
from mindspore import nn  
from mindspore.dataset import vision, transforms  
from mindspore.dataset import MnistDataset  

 处理数据集


        获取 MNIST 数据集的压缩文件,并把它保存到本地。首先,从名为“download”的模块中导入了“download”函数,这个函数是用于下载文件的。然后,定义要下载的文件的 URL,即“https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/ notebook/datasets/MNIST_Data.zip”。接着,调用“download”函数,把文件从上述指定的 URL 下载到当前目录(./),同时将文件类型设定为“zip”。而且,如果本地已经存在该文件,会将其覆盖。

        代码如下:

# Download data from open datasets  
from download import download  
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \  
     "notebook/datasets/MNIST_Data.zip"  
path = download(url, "./", kind="zip", replace=True)  

        运行如下:

        创建了一个叫做“train_dataset”的对象,它把“MNIST_Data/train”这个路径下的数据当作训练数据集来使用。另外还创建了一个叫做“test_dataset”的对象,它以“MNIST_Data/test”这个路径下的数据作为测试数据集。

        代码如下:

train_dataset = MnistDataset('MNIST_Data/train')  
test_dataset = MnistDataset('MNIST_Data/test')  

       打印出名为 train_dataset 的数据集对象的列名。get_col_names() 这个方法通常用于获取数据集的列名信息。

    代码如下:

print(train_dataset.get_col_names())  

        定义了一个叫做“datapipe”的函数,其作用是对输入的数据集“dataset”开展一系列的数据处理与批处理操作。此函数有两个参数,分别是“dataset”(数据集)和“batch_size”(批大小)。

        在“image_transforms”这个列表里,设定了一系列针对图像数据的转换操作。而“label_transform”则明确了对标签数据的类型转换,将其转为“mindspore.int32”类型。

        关于数据的映射和批处理:把图像的转换操作运用到数据集中名为“image”的列上;把标签的转换操作应用到名为“label”的列上。并且按照规定的“batch_size”对数据集进行批处理。

        函数的返回结果:最终,这个函数会返回经过处理和批处理之后的数据集。

        总的来讲,这个函数旨在对输入的数据集进行预先处理和批处理,为后续的模型训练或者其他操作做好准备。

        代码如下:

def datapipe(dataset, batch_size):  
    image_transforms = [  
        vision.Rescale(1.0 / 255.0, 0),  
       vision.Normalize(mean=(0.1307,), std=(0.3081,)),  
       vision.HWC2CHW()  
    ]  
    label_transform = transforms.TypeCast(mindspore.int32)  
  
    datasetdataset = dataset.map(image_transforms, 'image')  
   datasetdataset = dataset.map(label_transform, 'label')  
    datasetdataset = dataset.batch(batch_size)  
    return dataset  

        通过调用 test_dataset 的 create_tuple_iterator 方法来创建一个迭代器,接着利用 for 循环来对其中的图像和标签数据进行遍历。

        在每一轮循环时,会把图像的形状、数据的类型,还有标签的形状、数据的类型都打印出来。并且在完成第一轮循环之后,使用 break 语句来结束循环,也就是说,只处理并打印一组数据。

        代码如下:

# Map vision transforms and batch dataset  
train_dataset = datapipe(train_dataset, 64)  
test_dataset = datapipe(test_dataset, 64) 

        使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。

for image, label in test_dataset.create_tuple_iterator():  
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")  
    print(f"Shape of label: {label.shape} {label.dtype}")  
    break
for data in test_dataset.create_dict_iterator():  
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")  
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")  
    break 

网络构建


        在 Python 的 PyTorch 框架中,定义了一个叫做 Network 的神经网络模型类,接着创建了这个模型的实例 model,最后把模型的信息给打印出来。这个叫做 Network 的类是继承自 nn.Cell 的。其中 def init(self): 这部分是类的初始化方法。在这个方法里,创建了一个展平层,还构建了一个由多层全连接层以及 ReLU 激活函数组成的序列。另外,定义了模型的前向传播计算流程,先对输入的 x 进行展平操作,然后通过之前构建好的序列层得到预测结果 logits。之后创建了 Network 类的实例,这就相当于得到了一个具体的神经网络模型。最后打印出来的是创建好的模型对象的相关信息,通常会涵盖模型的结构以及参数等方面。

        代码如下:

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

模型训练


        首先,将损失函数 loss_fn 实例化为交叉熵损失,同时把优化器 optimizer 设定为随机梯度下降(SGD),并把学习率设成了 1e-2。
接着:定义了前向传播函数 forward_fn,这个函数接收数据和标签作为输入,通过模型得出预测结果 logits,然后计算损失。获取了梯度函数 grad_fn,它是用来计算前向传播函数的梯度的。定义了单步训练函数 train_step,进行梯度的计算和优化器的更新操作,并且返回损失值。
最后,定义了训练函数 train:通过 dataset.get_dataset_size() 来获取数据集的规模大小。将模型设置为训练模式。对数据集中的批次进行遍历,针对每个批次的数据和标签执行 train_step 函数来开展训练。每经过 100 个批次,就打印一次当前的损失值,以及当前批次和总批次的信息。
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]") 这行代码的作用是按照特定格式输出当前的损失值、当前批次以及总批次的信息,以这样的方式展示,能方便我们观察训练过程中损失的变化状况。

        代码如下:

# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 1. Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 3. Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

        除训练外,我们定义测试函数,用来评估模型的性能。

        代码如下:

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

        训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。

        代码如下:

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

保存模型


        借助 MindSpore 库的相应功能来保存模型的检查点,把模型对象 model 存储为名为 model.ckpt 的文件。接着打印出“Saved Model to model.ckpt”,这意味着“已经把模型保存至 model.ckpt”。

        代码如下:

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型


        首先,创建一个随机初始化的模型对象 model 。接下来,加载名为 model.ckpt 的检查点文件,并把其中包含的参数加载到这个模型之中。在此之后,获取那些未能成功加载的参数,并将其打印出来。最终,把模型设定为非训练的模式。

        在后续的操作中,针对测试数据集中的数据以及对应的标签,利用模型来进行预测。从预测的结果中获取最大值所在的索引,将其作为预测的类别。然后打印出前 10 个预测的结果以及与之相对应的实际标签。在完成对第一个批次数据的处理之后,停止操作。

     代码如下:

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)


model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

        最后效果如下截图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1883391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaEE—什么是服务器?以及Tomcat安装到如何集成到IDEA中?

目录 ▐ 前言 ▐ JavaEE是指什么? ▐ 什么是服务器? ▐ Tomcat安装教程 * 修改服务端口号 ▐ 将Tomcat集成到IDEA中 ▐ 测试 ▐ 结语 ▐ 前言 至此,这半年来我已经完成了JavaSE,Mysql数据库,以及Web前端知识的学习了&am…

ROS2在rviz2中实时显示轨迹和点

本文是将《ROS在rviz中实时显示轨迹和点》博客中rviz轨迹显示转为ROS2环境中的rviz2显示。 ros2的工作空间创建这里就不展示了。 包的创建 ros2 pkg create --build-type ament_cmake showpath --dependencies rclcpp nav_msgs geometry_msgs tf2_geometry_msgsshowpath.cpp…

【微服务网关——Websocket代理】

1.Websocket协议与原理 1.1 连接建立协议 1.1.1 客户端发起连接请求 客户端通过 HTTP 请求发起 WebSocket 连接。以下是一个 WebSocket 握手请求的例子: GET /chat HTTP/1.1 Host: server.example.com Upgrade: websocket Connection: Upgrade Sec-WebSocket-Key…

python 中的 下划线_ 是啥意思

在 Python 中,_(下划线)通常用作占位符,表示一个变量名,但程序中不会实际使用这个变量的值。 目录 忽略循环变量:忽略函数返回值:在解释器中使用:举例子1. 忽略循环变量2. 忽略不需…

APP逆向 day8 JAVA基础3

一.前言 昨天我们讲了点java基础2.0,发现是又臭又长,今天就是java基础的最后一章,也就是最难的,面向对象。上一末尾也是提到了面向对象,但是面向对象是最重要的,怎么可能只有这么短呢?所以今天…

怎样将word默认Microsoft Office,而不是WPS

设置——>应用——>默认应用——>选择"word"——>将doc和docx都选择Microsoft Word即可

【嵌入式DIY实例】- LCD ST7735显示DHT11传感器数据

LCD ST7735显示DHT11传感器数据 文章目录 LCD ST7735显示DHT11传感器数据1、硬件准备与接线2、代码实现本文介绍如何将 ESP8266 NodeMCU 板 (ESP-12E) 与 DHT11 (RHT01) 数字湿度和温度传感器连接。 NodeMCU 从 DHT11 传感器读取温度(以 C 为单位)和湿度(以 rH% 为单位)值,…

连锁品牌如何做宣传?短视频矩阵工具助轻松千万流量曝光!

今天给大家分享一家烘焙行业连锁品牌(可可同学),通过小魔推获得了1031万的话题曝光,旗下的连锁门店登顶同城人气榜单第一名,​让自己的流量和销量获得双增长 01 品牌连锁店如何赋能旗下门店? 作为一家全国…

昇思25天学习打卡营第13天|基于MobileNetV2的垃圾分类

MobileNetv2模型原理介绍 相比于传统的卷积神经网络,MobileNet网络使用深度可分离卷积(Depthwise Separable Convolution)的思想在准确率小幅度降低的前提下,大大减小了模型参数与运算量。并引入宽度系数α和分辨率系数β使模型满…

【面试题】IPS(入侵防御系统)和IDS(入侵检测系统)的区别

IPS(入侵防御系统)和IDS(入侵检测系统)在网络安全领域扮演着不同的角色,它们之间的主要区别可以归纳如下: 功能差异: IPS:这是一种主动防护设备,不仅具备检测攻击的能力&…

利用pyecharts制作2023全国GDP分布图

完整代码: from pyecharts import options as opts from pyecharts.charts import Map import pandas as pddf pd.read_excel(各省份GDP.xlsx) # print(df.head())year 2023 info df[[省份,year]] # print(info)info_list info.values.tolist() print(info_lis…

YTM32的HA系列微控制器启动过程详解

YTM32的HA系列微控制器启动过程详解 文章目录 YTM32的HA系列微控制器启动过程详解IntroductionPricinple & MachenismHA01的内存地址空间BOOT ROM简介安全启动Security Boot快速从Powerdown模式下唤醒对内核进行例行自检(Structural Core Self-Test,…

Python容器 之 字符串--下标和切片

1.下标(索引) 一次获取容器中的一个数据 1, 下标(索引), 是数据在容器(字符串, 列表, 元组)中的位置, 编号 2, 一般来说,使用的是正数下标, 从 0 开始 3, 作用: 可以通过下标来获取具体位置的数据. 4, 语法: 容器[下标] 5, Python 中是支持…

SuperMap GIS基础产品FAQ集锦(20240701)

一、SuperMap iDesktopX 问题1:对于数据提供方提供的osgb格式的数据,如何只让他生成一个s3mb文件呢?我用倾斜入库的方式会生成好多个s3mb缓存文件 11.1.1 【解决办法】不能控制入库后只生成一个s3mb文件;可以在倾斜入库的时候设…

基于Java实现图像浏览器的设计与实现

图像浏览器的设计与实现 前言一、需求分析选题意义应用意义功能需求关键技术系统用例图设计JPG系统用例图图片查看系统用例图 二、概要设计JPG.javaPicture.java 三、详细设计类图JPG.java UML类图picture.java UML类图 界面设计JPG.javapicture.java 四、源代码JPG.javapictur…

Leetcode.1735 生成乘积数组的方案数

题目链接 Leetcode.1735 生成乘积数组的方案数 rating : 2500 题目描述 给你一个二维整数数组 q u e r i e s queries queries ,其中 q u e r i e s [ i ] [ n i , k i ] queries[i] [n_i, k_i] queries[i][ni​,ki​] 。第 i i i 个查询 q u e r i e s [ i …

Esxi硬件日志告警

原创作者:运维工程师 谢晋 Esxi硬件日志告警 故障描述故障处理 故障描述 主机报错硬件对象状态告警 在Esxi监控硬件内发现Systemctl Manager Module 1 Event log 0报警,该报警是Esxi事件日志保存空间满了,需要清理空间。 故障处理 开启…

整除分块的题目

链接 思路: 求1到n中的因数个数和等价于求,设x为因子,就是求x在1到n里出现了几次,求1到n里是x的倍数的数有几个,即n/x。需要用整除分块,n/i的值是分块分部的,右端点是n/(n/i)。 代…

相机网线RJ45连接器双端带线5米8芯绿色网线注塑成型

相机网线RJ45连接器双端带线5米8芯绿色网线注塑成型,这款网线采用了环保的绿色材质,线长5米,足够满足大多数拍摄场景的需求。更重要的是,它采用了8芯设计,保证了数据传输的稳定性和高速性。在接口方面,它采…