昇思25天学习打卡营第1天|基本介绍与快速入门

news2024/12/30 3:55:45

先贴上打卡截图
在这里插入图片描述

基本介绍

首先来看基本介绍,昇思MindSpore是华为的一个全场景深度学习框架,属于昇腾AI全栈的一部分。
总体架构如下图所示(来自官方学习材料)
MindSpore-arch
从对底层多样性硬件适用的Runtime到应用层面的Model Zoo、科学计算、API等一应俱全。为用户提供不同Level的Python API,使不同层次的用户都可以很好地利用框架。MindSpore还支持全场景统一部署、分布式训练原生,总体来看是值得学习的一个框架。

快速入门

下面来看快速入门板块,还是经典任务手写体数字识别,数据集是Mnist。
导入必要的包:

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

数据集

下载数据集:

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

初始化数据集对象:

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

定义数据处理流水线,指定图像和标签的transform,map和batch:

def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

这一点挺有意思,至少是本人第一次看到以流水线的思想来处理数据。
对数据两种不同的迭代访问方式:

for image, label in test_dataset.create_tuple_iterator():
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    print(f"Shape of label: {label.shape} {label.dtype}")
    break

for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break

模型

网络定义如下:

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

一个基本的MLP,与PyTorch基本一致,不同的是基类叫Cell

训练

损失函数采用交叉熵,优化器使用SGD,定义如下:

# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

一个经典的bp神经网络的三步:

  1. 前向计算
  2. 反向传播
  3. 优化参数

定义如下:

# 1. 前向计算,就是调用模型并计算loss
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 2. 自动计算梯度的函数,即反向传播,传入前向计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 3. 定义一步训练,最后调用优化器来优化参数
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

定义测试函数来评估模型性能:

def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

一个epoch训练:

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

训练3个epoch:

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")

训练结果:
在这里插入图片描述

保存、加载模型和推理

保存模型:

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型:

# 首先实例化一个模型对象
model = Network()
# 读取checkpoint并加载参数到模型
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

推理:

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

推理结果:
Predicted: "[6 5 5 6 3 0 7 0 3 2]", Actual: "[6 3 5 6 3 0 7 0 3 2]"

学习心得

昇思的API入门还是挺简单的,今天主要是从经典的手写体数字识别任务熟悉昇思的API,构建了一个三层的MLP,使用ReLU激活函数,损失函数使用交叉熵,batch size取64,优化器使用SGD,训练3个epoch后得到了测试集94%的准确率,在训练完成后保存了参数,初始化了一个新的模型对象,并加载训练完成的参数进行了推理,基本完成任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1864781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows bat 提取多个目录下的文件,到一个目录

批处理命令 echo off setlocalrem 设置源目录和目标目录 set "sourceDirE:\motrix" set "targetDirE:\新建文件夹"rem 创建目标目录,如果不存在 if not exist "%targetDir%" mkdir "%targetDir%"rem 循环遍历源目录中的所…

微信小程序服务器从腾讯云迁移到阿里云出现的坑

微信小程序服务器从腾讯云迁移到阿里云出现的坑 背景 原先小程序后台服务器到期,因为之前买的是腾讯云新用户,便宜,到期后续费金额懂的都懂。就在阿里云用新用户买了个新的,遂把服务全转到了阿里云服务器上。 此时,域…

[SAP ABAP] 汇总内表数据

在加入新数据记录时&#xff0c;将非数值字段具有相同内容记录的数值字段汇总 语法格式 COLLECT <wa> INTO <itab>. <wa>&#xff1a;代表工作区 <itab>&#xff1a;代表内表 示例1 结果显示&#xff1a;

解决google chrome helper 内存占用较高!

导语&#xff1a;mac 后台有很多 google chrome helper 线程并且内存占用较高。一直怀疑是IDEA插件的锅&#xff0c;并不是&#xff01; 查看是哪个网页&#xff0c;哪个插件占用内存。 chrome 更多工具 -> 任务管理器&#xff1a; 找到罪魁祸首&#xff0c;关闭标签页或者…

基于前馈神经网络的姓氏分类任务(基础)

1、认识前馈神经网络 What is it 图1-1 前馈神经网络结构 人们大多使用多层感知机&#xff08;英语&#xff1a;Multilayer Perceptron&#xff0c;缩写&#xff1a;MLP&#xff09;作为前馈神经网络的代名词&#xff0c;但是除了MLP之外&#xff0c;卷积神经网络&#xff08…

Linux通用LInux高危漏洞(CVE-2024-1086)修复案例

一、漏洞描述 2024年3月28日&#xff0c;监 Linux kernel权限提升漏洞&#xff08;CVE-2024-1086&#xff09;的PoC/EXP在互联网上公开&#xff0c;该漏洞的CVSS评分为7.8&#xff0c;目前漏洞细节已经公开披露&#xff0c;美国网络安全与基础设施安全局&#xff08;CISA&…

[leetcode]search-insert-position 搜索插入位置

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int searchInsert(vector<int>& nums, int target) {int left 0, right nums.size()-1;while(left <right) {int mid left (right-left)/2;if(nums[mid] target){return mid;} else if(nu…

视觉与味蕾的交响:红酒与艺术的无界狂欢,震撼你的感官世界

在浩瀚的艺术海洋中&#xff0c;红酒以其不同的魅力&#xff0c;成为了一种跨界整合的媒介。当雷盛红酒与艺术相遇&#xff0c;它们共同呈现出一场特别的视觉盛宴&#xff0c;让人沉醉在色彩与光影的交织中&#xff0c;感受红酒与艺术的无界碰撞。 雷盛红酒&#xff0c;宛如一件…

DNS入门指南:企业DNS系统架构趋势解读

诞生于1987年的DNS是互联网和IT基础设施中发生的几乎所有事情的起点。从最初的简单域名解析到现在的智能解析、安全解析&#xff0c;伴随技术的变化与演进&#xff0c;DNS系统也在发生着诸多的变化。总体来说DNS系统的发展有着五大趋势&#xff0c;本文将会逐一进行解读。   …

DataX数据迁移

DataX数据迁移 访问DataX Web管理页面&#xff1a; http://ip:9527/index.html 用户名&#xff1a;admin&#xff0c;密码&#xff1a;123456 本文中示例将SqlServer数据增量同步到MySql中。 增量同步同步时&#xff0c;MySql中的新字段设置默认值 1. 查看执行器是否注册成…

【吊打面试官系列-Mysql面试题】说说对 SQL 语句优化有哪些方法?

大家好&#xff0c;我是锋哥。今天分享关于 【说说对 SQL 语句优化有哪些方法&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; 说说对 SQL 语句优化有哪些方法&#xff1f; 1、Where 子句中&#xff1a;where 表之间的连接必须写在其他 Where 条件之前&#xff…

品牌电商维权:应对渠道低价与假货的有力举措

在品牌治理渠道的过程中&#xff0c;会遭遇各种各样的问题&#xff0c;其中低价现象尤为突出。低价往往导致经销商被迫跟价&#xff0c;而未授权的店铺则更加不受管控&#xff0c;更容易出现低价情况。然而&#xff0c;低价本身不能直接作为品牌管控渠道的充分理由&#xff0c;…

聚类距离度量(保姆级讲解,包学会~)

在机器学习的聚类中&#xff0c;我们通常需要使用距离来进行类的划分&#xff0c;或者比较不同类之间的各种距离&#xff0c;这里我们介绍西瓜书上所提出的一些距离计算方式。 首先介绍一下距离的一些性质&#xff1a; 西瓜书上给出了四条性质&#xff0c;第一个是非负性&#…

MATLAB-DBO-CNN-SVM,基于DBO蜣螂优化算法优化卷积神经网络CNN结合支持向量机SVM数据分类(多特征输入多分类)

DBO-CNN-SVM&#xff0c;基于DBO蜣螂优化算法优化卷积神经网络CNN结合支持向量机SVM数据分类(多特征输入多分类) 1.数据均为Excel数据&#xff0c;直接替换数据就可以运行程序。 2.所有程序都经过验证&#xff0c;保证程序可以运行。 3.具有良好的编程习惯&#xff0c;程序均…

你的企业“赚钱能力”,银行怎么看?聊聊税贷与票贷背后的门道

大家都听过“税贷”和“票贷”吧&#xff1f;特别是这两年&#xff0c;国家扶持中小微企业&#xff0c;这些名词更是火得不行。但你知道吗&#xff0c;税贷和票贷并不是只看税和票那么简单。今天&#xff0c;咱就来聊聊这背后的门道&#xff08;最后附上&#xff1a;企业信用贷…

四川赤橙宏海商务信息咨询有限公司一站式抖音电商服务

在数字化浪潮汹涌的当下&#xff0c;电商行业正以前所未有的速度发展&#xff0c;而抖音电商作为其中的佼佼者&#xff0c;更是吸引了无数商家和消费者的目光。在这个充满机遇与挑战的市场中&#xff0c;四川赤橙宏海商务信息咨询有限公司凭借其专业的服务和丰富的经验&#xf…

QCC51XX---开启手机log日志

QCC51XX---系统学习目录_trbi200软件-CSDN博客 目录 1.Vivo 2.华为 3.小米 4.三星 5.oppo 1.Vivo *#*#112#*#* 输入命令后会进入log日志系统(由于版本原因,界面可能不同),打开log开关,log就会在后台自动录制。 点击设置,则可进入图1(右边)的界面,可以导出log,导出…

使用IPXProxy动态住宅代理进行数据采集有哪些优势?

​随着全球化进程的加速&#xff0c;企业在网络信息时代必须利用先进的工具来扩展业务边界。动态住宅代理作为一种高效的网络代理服务&#xff0c;在许多关键业务场景中展现了其独特价值。本文将深入探讨动态住宅代理在数据分析、SEO优化、网络营销等领域的广泛应用&#xff0c…

【权威主办|检索稳定】2024年法律、教育与社会发展国际会议 (LESD 2024)

2024年法律、教育与社会发展国际会议 (LESD 2024) International Conference on Law, Education and Social Development in 2024 【重要信息】 大会地点&#xff1a;成都 官网地址&#xff1a;http://www.iclesd.com 投稿邮箱&#xff1a;iclesdsub-conf.com 【注意&#xff1…

STM32烧写hex及bin文件的五种方法

一.STVP 1.概述 STVP是ST早期的一款下载编程工具&#xff0c;支持早期的ST早期的芯片&#xff08;比如ST7系列&#xff09;&#xff0c;也支持STM8、 STM32。 该工具虽然相对ST-LINK utility、STM32CubeProg比较老&#xff0c;但该工具官方在2017年还进行了维护&#xff0c;现…