【笔记】打卡01 | 初学入门

news2024/11/23 1:04:15

1 基本介绍

MindSpore Data(数据处理层)
ModelZoo(模型库)
MindSpore Science(科学计算),包含了业界领先的数据集、基础模型、预置高精度模型和前后处理工具
MindSpore Insight(可视化调试调优工具),能够可视化地查看训练过程、优化模型性能、调试精度问题、解释推理结果

2 快速入门

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

处理数据集

下载Mnist数据集

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

在这里插入图片描述

训练集、测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

在这里插入图片描述
列名:图片 和 对应标签(分类)

数据处理流水线(Data Processing Pipeline)

参数:数据集、batch_size

def datapipe(dataset, batch_size):
    image_transforms = [                    
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

首先,数据变换(Transforms):1、对输入数据(即图片)2、对输出(即标签);
然后,map对图像数据及标签进行变换处理;
最后,将处理好的数据集打包为大小为64的batch

train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

对数据集进行迭代访问

for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break

网络构建

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

mindspore.nn类是构建所有网络的基类,也是网络的基本单元。

  • 自定义网络时,可以继承nn.Cell
  • __init__包含所有网络层的定义
  • construct(类似前向传播??)包含数据(Tensor)的变换过程。

模型训练

定义损失函数、优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

一个完整的训练过程(step)需要实现以下三步:

1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
3. 参数优化:将梯度更新到参数上。

定义正向计算函数。

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

使用value_and_grad通过函数变换获得梯度计算函数。

grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

one-step training

def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

定义测试函数:用来评估模型的性能。

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程需多轮(epoch)训练数据集

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

在这里插入图片描述

保存模型

模型训练完成后,需要保存其参数。

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

加载保存的权重

# 1、重新实例化模型对象,构造模型
model = Network()
# 加载模型参数,并将其加载至模型上。
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

打卡-时间

from datetime import datetime
import pytz
# 设置时区为北京时区
beijing_tz = pytz.timezone('Asia/shanghai')
# 获取当前时间,并转为北京时间
current_beijing_time = datetime.now(beijing_tz)
# 格式化时间输出
formatted_time = current_beijing_time.strftime('%Y-%m-%d %H:%M:%S')
print("当前北京时间:",formatted_time,'your name')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1841183.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【stm32-新建工程-HAL库版本】

stm32-新建工程-HAL库版本 ■ HAL库版本目录■ Drivers■ Middlewares 文件夹, 同寄存器版本一样。■ Output 文件夹, 同寄存器版本一样。■ Projects 文件夹, 同寄存器版本一样。■ User 文件夹 ■ HAL库版本目录 ■ Drivers ① &#xff0c…

Vitis Accelerated Libraries 学习笔记--OpenCV 安装指南

目录 1. 简介 2. 安装过程 2.1 安装准备 2.2 常见错误 2.2.1 核心共享库报错 3. 通过实例测试 4. 总结 1. 简介 使用Vitis Vision Library Vitis 视觉库,为什么要安装opencv库? 在使用Vitis Vision Library时,安装OpenCV库是因为许…

前端 CSS 经典:边框转圈动画效果

前言:首先我们要知道 css 动画只对数值类的 CSS 属性起作用。要实现边框转圈动画效果,实际就是渐变背景的旋转。但是在以前,渐变背景是不支持动画的。现在我们可以利用浏览器新出的 Houdini API 来实现这个动画效果。Houdini API 特别强大&am…

leetcode119 杨辉三角②

给定一个非负索引 rowIndex,返回「杨辉三角」的第 rowIndex 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 示例 1: 输入: rowIndex 3 输出: [1,3,3,1]示例 2: 输入: rowIndex 0 输出: [1]示例 3: 输入: rowIndex 1 输出: [1,1] pub…

centos 7.8 安装sql server 2019

1.系统环境 centos 7.8 2.数据库安装文件准备 下载 SQL Server 2019 (15.x) Red Hat 存储库配置文件 sudo curl -o /etc/yum.repos.d/mssql-server.repo https://packages.microsoft.com/config/rhel/7/mssql-server-2019.repo 采用yum源进行不安装下载,这时yum 会自动检测…

知觉感知:AI深层理解的关键

在人工智能(AI)的广阔领域中,一个核心议题始终萦绕在科学家和哲学家的心头:人工智能是否需要感知能力,以实现对意义的深层理解?这一议题不仅关乎技术的边界,更触及了人类心智的本质。从Stevan H…

[笔记] CCD相机测距相关的一些基础知识

1.35mm胶片相机等效焦距 https://zhuanlan.zhihu.com/p/419616729 拿到摄像头拍摄的数码照片后,我们会看到这样的信息: 这里显示出了两个焦距:一个是实际焦距:5mm,一个是等效焦距:25mm。 实际焦距很容易…

数据结构:4.1.1二叉搜素树及查找

静态查找:要找的集合的元素是不动的,主要是find操作,没有delete操作 动态查找:要查找的集合会经常发生插入删除的操作 静态查找的一个很好的方法就是二分查找 把数据直接放在树上 结点右子树的值>结点的值>结点左子树的…

Flutter-无限循环滚动标签

1. 序章 在现代移动应用开发中,滑动视图是常见的交互模式之一。特别是当你需要展示大量内容时,使用自动滚动的滑动视图可以显著提升用户体验。在这篇文章中,我们将讨论如何使用 Flutter 实现一个自动滚动的列表视图。 2. 效果 3. 实现思路 …

Java安全

Java安全 Java2Sec靶场搭建 靶场地址 https://github.com/bewhale/JavaSec 查看数据库配置文件,mysql,用户名密码根据自己数据库密码更改 使用小皮面板的mysql,新建一个数据名为javasec的数据库 运行javasec.sql文件 下载运行jar包即可 …

STM32单片机I2C通信详解

文章目录 1. I2C通信概述 2. 硬件电路 3. I2C时序基本单元 4. I2C时序 4.1 指定地址写 4.2 当前地址读 4.3 指定地址读 5. I2C外设 6. I2C框图 7. I2C基本结构 8. 主机发送 9. 主机接收 10. 软件和硬件波形对比 11. 代码示例 1. I2C通信概述 I2C(Inter-Integrat…

定制汽车霍尔传感器

磁电效应霍尔传感器、饱和霍尔传感器、非线性霍尔传感器 霍尔传感器原理 霍尔传感器的工作原理基于霍尔效应,即当一块通有电流的金属或半导体薄片垂直地放在磁场中时,薄片的两端会产生电位差。这种现象称为霍尔效应,两端具有的电位差值称为…

YoloV8改进策略:Block篇|即插即用|StarNet,重写星操作,使用Block改进YoloV8(全网首发)

摘要 本文主要集中在介绍和分析一种新兴的学习范式——星操作(Star Operation),这是一种通过元素级乘法融合不同子空间特征的方法,通过元素级乘法(类似于“星”形符号的乘法操作)将不同子空间的特征进行融…

(0撸)Nillion空投教程,现在开启激励测试网

Nillion 是一个去中心化的公共网络,通过启用去中心化信用评分、去中心化可信执行环境、私有 NFT、去中心化安全存储服务等用例来解锁 Web3 中重要的新实用程序 Nillion已经完成2000万美元融资、由Distributed Global领投、同时还有HashKey Capital、GSR等多家机构参…

opencascade AIS_InteractiveContext源码学习1 object display management 对象显示管理

AIS_InteractiveContext 前言 交互上下文(Interactive Context)允许您在一个或多个视图器中管理交互对象的图形行为和选择。类方法使这一操作非常透明。需要记住的是,对于已经被交互上下文识别的交互对象,必须使用上下文方法进行…

visual studio error MSB8008:

新项目编译的时候,可能由于编译器的版本不一致导致的问题。 你的电脑上有两个不同版本的VS,或者你的程序拷贝到别人的电脑上去运行,或者你是从别人那里拷贝来的项目,而你们俩用的VS版本不一样,就会在运行的时候出现这…

代码随想录算法训练营第四十三天|518. 零钱兑换 II ,LCR 103. 零钱兑换,377. 组合总和 Ⅳ

518. 零钱兑换 II - 力扣(LeetCode) class Solution {public int change(int amount, int[] coins) {// 创建dp数组,dp[i][j] 表示使用前i个硬币(下标为0的硬币是前1个)凑成总金额j的硬币组合数int[][] dp new int[coins.length …

AI在创造还是毁掉音乐

目录 1.概述 2.整体介绍 3.人机合作 3.1. AI作为助手 3.2. AI作为创意源泉 3.3. 交互式合作 3.4. AI独立创作与音乐人后期制作 3.5.试点和挑战 4.伦理道德 4.1.AI替代人类的可能性 4.2.伦理道德问题 4.3.平衡技术与创造力 4.4.小结 1.概述 近期音乐大模型的兴起&…

Python Textract库:文本提取

更多Python学习内容:ipengtao.com Textract是一个强大的Python库,用于从各种文件格式中提取文本。无论是PDF、Word文档、Excel电子表格、HTML页面还是图像,Textract都能有效地提取其中的文本内容。Textract通过集成多种开源工具和库&#xff…

Spark Core内核调度机制详解(第5天)

系列文章目录 如何构建DAG执行流程图 (掌握)如何划分Stage阶段 (掌握)Driver底层是如何运转 (掌握)确定需要构建多少分区(线程) (掌握) 文章目录 系列文章目录引言一、Spark内核调度(掌握)1.1、内容概述1.2、RDD的依赖1.3、DAG和Stage1.4、Spark Shuffl…