昇思25天学习打卡营第02天 | 快速入门

news2024/10/6 12:21:17

昇思25天学习打卡营第02天 | 快速入门

文章目录

  • 昇思25天学习打卡营第02天 | 快速入门
    • 数据准备
    • 网络构建
    • 模型训练
    • 模型测试
    • 迭代数据集
    • 模型保存
    • 加载模型
    • 总结
    • 打卡

数据准备

MindSpore通过DatasetTransforms实现高效的数据预处理

使用download下载数据,并创建数据集对象:

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

数据集中的数据可以通过create_tuple_iteratorcreate_dict_iterator进行访问:

for image, label in test_dataset.create_tuple_iterator():
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    print(f"Shape of label: {label.shape} {label.dtype}")
    break
	
for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break

原始的数据通常不能满足需求,需要通过数据流水线(Data Processing Pipeline)指定map、batch、shuffle等操作进行处理,并将数据打包为指定大小的batch:

def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset
	
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

网络构建

通过继承nn.Cell类,并重写__init__construct方法来自定义网络结构。

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

__init__方法中的dense_relu_sequential定义了网络的层次结构,由Dense和ReLU组成。
construct方法描述对输入数据的变换。

模型训练

一个模型的训练需要经过三个步骤:

  1. 正向计算:计算模型的预测值,并于正确标签求loss;
  2. 反向传播:利用自动微分机制,求模型参数对loss的梯度值;
    3.参数优化: 根据梯度值更新参数。

MindSpore使用函数式自动微分机制,因此需要实现:

  1. 定义正向计算函数:
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

def forward_fn(data, label):
	logits = model(data)
	loss = loss_fn(logits, label)
	return loss, logits
  1. 使用value_and_grad获得梯度计算函数:
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  1. 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化:
def train_step(data, label):
	(loss, _), grads = grad_fn(data, label)
	optimizer(grads)
	return loss
	
def train(model, dataset):
	size = dataset.get_dataset_size()
	model.set_train()
	for batch, (data, label) in enumerate( dataset.create_tuple_iterator()):
	loss = train_step(data, label)
	
	if batch % 100 == 0:
		loss, current = loss.asnumpy(), batch
		print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

模型测试

通过测试函数来评估模型的性能:

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

迭代数据集

完整的遍历一次数据集成为一个epoch

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

模型保存

通过save_checkpoint保存模型的参数:

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

模型的加载分为两步:

  1. 重新实例化网络模型:
model = Network()
  1. 加载模型参数,并加载至模型上:
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

加载的模型可以直接用于预测推理:

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

总结

通过这一节的内容,对一个网络的诞生有了大概的认识,从原始数据到数据集的处理,简单网络结构的搭建,训练中自动微分机制的使用方法等都有了一定的了解。此外还有模型的保存与加载方法,为之后的深入学习奠定了基础。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1879794.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt项目天气预报(8) - 绘制温度曲线 + 回车搜索(最终篇)

全部内容在专栏: Qt项目 天气预报_mx_jun的博客-CSDN博客 目录 绘制温度曲线 事件过滤器在子控件上绘图 子控件下载事件过滤器 事件过滤器进行绘图 - eventFilter 画初步高温曲线 画初步低温曲线 效果演示 画低温曲线 画高温曲线 效果演示 按下回车搜索: …

LDM论文解读

论文名称:High-Resolution Image Synthesis with Latent Diffusion Models 发表时间:CVPR2022 作者及组织:Robin Rombach, Andreas Blattmann, Dominik Lorenz,Patrick Esser和 Bjorn Ommer, 来自Ludwig Maximilian University of Munich &a…

AI奏响未来乐章:音乐界的革命性变革

AI在创造还是毁掉音乐 引言 随着科技的飞速发展,人工智能(AI)正在逐渐渗透到我们生活的每一个角落,音乐领域也不例外。AI技术的引入,不仅为音乐创作、教育、体验带来了革命性的变革,更为整个音乐产业注入了…

顺序表应用——通讯录

在本篇之前的顺序表专题我们已经学习的顺序表的实现,了解了如何实现顺序表的插入和删除等功能,那么在本篇当中就要学习基于顺序表来实现通讯录,在通讯录当中能实现联系人的增、删、查改等功能,接下来就让我们一起来实现通讯录吧&a…

民用无人机企业招标投标需要资质证书详解

一、基础资质 在民用无人机企业的招标投标过程中,基础资质是首要考虑的因素。这些资质通常包括企业注册资质、税务登记证、组织机构代码证等。 1.1 企业注册资质 企业应具备合法的注册资质,即营业执照。该执照应包含企业名称、注册地址、法定代表人、…

Java [数据结构] Deque与Queue

🤺深入理解 Java 中的 Deque 和 Queue🤺 在现代软件开发中,数据结构是构建高效、可维护代码的基础。 Java 作为一门广泛应用的编程语言,其丰富的集合框架(Collections Framework)为开发者提供了多种强大的…

Zabbix 排坑版 Centos7

systemctl stop firewalld;systemctl disable firewalld;setenforce 0sed -i s/SELINUXenforcing/SELINUXdisabled/ /etc/selinux/configzabbix源地址,可以自己选版本,安装都大差不差 rpm -Uvh https://repo.zabbix.com/zabbix/5.0/rhel/7/x86_64/zabbix-release-5…

【车载开发系列】S32 Design Studio工具安装步骤

【车载开发系列】S32 Design Studio工具安装步骤 S32 Design Studio工具安装步骤 【车载开发系列】S32 Design Studio工具安装步骤※关键字提炼※一. 准备工作二. 下载安装包三. 获取License许可四. 开始预安装五. 开始正式安装六. 启动软件七. 安装插件八. 卸载插件九. 确认安…

【操作系统】进程管理——进程控制和进程通信(个人笔记)

学习日期:2024.6.30 内容摘要:进程控制的概念,进程控制相关的“原语”,进程通信 进程控制 原语 进程控制用“原语”实现。原语是一种特殊的程序,它的执行具有原子性,也就是说,这段程序的执行…

vs code python开发笔记

目录 选择python 解析器 安装插件 不全: 调试启动目录问题: 2.选择python解释器 选择python 解析器 ctrl shift P select interpreter 安装插件 不全: remote ssh python debuger 左下角,点击左右左右箭头,远程…

后端之路第三站(Mybatis)——JDBC跟Mybatis、lombok

一、什么是JDBC JDBC就是sun公司研发的一套通过java来操控数据库的工具,对应不同的数据库系统有不同的JDBC,而他们统称【驱动】,这就是上一篇我们提到创建Mybatis项目时要引入的依赖、以及连接数据库四要素里的第一要素。 JDBC有自己一套原始…

Redis 7.x 系列【8】数据类型之哈希(Hash)

有道无术,术尚可求,有术无道,止于术。 本系列Redis 版本 7.2.5 源码地址:https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 常用命令2.1 HSET2.2 HGET2.3 HDEL2.4 HEXISTS2.5 HGETALL2.6 HKEYS2.7 HLE…

grpc学习golang版( 五、多proto文件示例 )

系列文章目录 第一章 grpc基本概念与安装 第二章 grpc入门示例 第三章 proto文件数据类型 第四章 多服务示例 第五章 多proto文件示例 第六章 服务器流式传输 第七章 客户端流式传输 第八章 双向流示例 文章目录 一、前言二、定义proto文件2.1 公共proto文件2.2 语音唤醒proto文…

探索MySQL核心技术:理解索引和主键的关系

在数据密集型应用中,数据库的性能往往是决定一个应用成败的重要因素之一。其中,MySQL作为一种开源关系型数据库管理系统,以其卓越的性能和丰富的功能被广泛应用。而在MySQL数据库优化的众多技巧中,索引和主键扮演着极其重要的角色…

5、Python之rich:GUI之外,终端呈现也能玩出花

引言 在Python系列文章的上一篇中,我们从print的定义出发,进一步探索了print()函数更多的用法,尤其是一些哪怕是Python老手也可能忽略的用法。没有阅读的或者需要回顾print()及输出格式化的扩展用法,可以查看上一篇文章。 虽然pr…

2024/6/30周报

文章目录 摘要ABSTRACT文献阅读题目问题本文贡献方法LSTMTCN模型总体架构 实验实验结果 深度学习TCN-LSTM代码运行结果 总结 摘要 本周阅读了一篇关于TCN和LSTM进行光伏功率预测的文章,本文提出了一种利用LSTM-TCN预测光伏功率的新模型。它由长短期记忆和时间卷积网…

可编程定时计数器8253/8254 - 8253控制字

8253控制字 概述 图7-45中左下角的是控制字寄存器,其操作端口是0x43,它是8位大小的寄存器 控制字寄存器也称为模式控制器,在控制字寄存器中保存的内容称为控制字,控制字用来设置所指定的计数器(通道)的工作方式、读写格式及数制&#xff0c…

emptyDir + initContainer实现ConfigMap的动态更新(K8s相关)

1. 絮絮叨叨 K8s部署服务时,一般都需要使用ConfigMap定义一些配置文件例如,部署分布式SQL引擎Presto,会在ConfigMap中定义coordinator、worker所需的配置文件以node.properties为例,node.environment和node.data-dir的值将由Helm…

48 - 按日期分组销售产品(高频 SQL 50 题基础版)

48 - 按日期分组销售产品 -- group_concat 分组拼接selectsell_date,count(distinct product) num_sold,group_concat(distinct product order by product separator ,) products fromActivities group bysell_date;

监控电脑的软件有哪些?精选8大监控电脑的软件

根据当前市场反馈和功能评价,以下是八款备受推崇的电脑监控软件推荐,适合不同企业和组织的监控与管理需求: 1.安企神监控软件 特点:全面的局域网监控工具,擅长网络设备监控、网络性能管理和故障诊断。提供员工电脑屏幕…