【昇思初学入门】第七天打卡-模型训练

news2025/1/14 0:46:30

训练模型

学习心得

  1. 构建数据集。这通常包括训练集、验证集(可选)和测试集。训练集用于训练模型,验证集用于调整超参数和监控过拟合,测试集用于评估模型的泛化能力。
    (mindspore提供数据集https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore.dataset.html)
  2. 定义神经网络模型。这通常涉及到选择适当的网络架构(如卷积神经网络CNN、循环神经网络RNN、全连接网络等)和激活函数。
    创建模型类:使用mindspore.nn.Cell作为基类,创建一个自定义的神经网络模型类。
    义网络层:定义所需的网络,如卷积层、全连接层、激活函数和池化层等
    实现construct方法:在construct方法中,使用定义好的网络层构建前向网络
  3. 定义超参、损失函数和优化器。
    设置超参数:设置超参数,如学习率、批次大小、训练轮数等。
    定义损失函数:选择适当的损失函数,如均方误差(MSE)用于回归问题,交叉熵损失(Cross-Entropy Loss)用于分类问题等。
    设置优化器:选择合适的优化器,如随机梯度下降(SGD)、Adam等,用于根据损失函数的梯度更新模型参数。
  4. 训练和评估。
    循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:
    训练:迭代训练数据集,并尝试收敛到最佳参数。
    验证/测试:迭代测试数据集,以检查模型性能是否提升。

笔记

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)


def datapipe(path, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

train_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()

epochs = 3
batch_size = 64
learning_rate = 1e-2

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

# Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
            
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset)
    test_loop(model, test_dataset, loss_fn)
print("Done!")

结果
训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1862338.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Fusion WAN:企业出海与全球组网的数字网络底座

众多中国企业与品牌正将目光投向海外市场,积极寻求发展新机遇,并且在这一过程中取得了显著的成果。"出海"战略已经成为一些企业转型升级的关键选择。 随着国内市场的竞争日益激烈,越来越多的企业开始寻求海外市场的拓展&#xff0c…

压电风扇的显著特点及其在电子系统中的应用

压电已经存在了一个多世纪,人们发现某些晶体结构在受到机械应力时产生表面电荷。 这种形式的压电传感器是压电传感器的工作方式。与压电传感器(或发电机)类似,压电致动器(或电机)使用补丁[1,3]形式的压电陶…

探索PHP中的魔术常量

PHP中的魔术常量(Magic Constants)是一些特殊的预定义常量,它们在不同的上下文中具有不同的值。这些常量可以帮助开发者获取文件路径、行号、函数名等信息,从而方便调试和日志记录。本文将详细介绍PHP中的魔术常量,帮助…

2024地理信息相关专业大学排名

在开始之前,不得不提一下今年福耀科技大学不能招生的遗憾,不知道明年是否能一切准备就绪开始招生呢? 如果这所大学能招生了,不知道它有没有地理信息相关专业呢? 言归正转,我们现在就基于公开资料&#xf…

力扣随机一题 哈希表 排序 数组

博客主页:誓则盟约系列专栏:IT竞赛 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 2491.划分技能点相等的团队【中等】 题目: 给你一个正整数数组…

Qt添加Dialog对话框

Qt版本:5.12.12 1.添加【模块】 Base class:可以选择QDialog、QWidget、QMainWindow 会自动生成MyDialog.h和MyDialog.cpp文件以及MyDialog.ui文件, 2.添加代码: (1)TestDialog.h #pragma once#include…

三分之一的违规行为未被发现

Gigamon 调查显示,随着漏洞的针对性越来越强、越来越复杂,企业在检测漏洞方面也面临越来越大的困难,超过三分之一的企业表示,现有的安全工具无法在漏洞发生时检测到它们。 随着混合云环境变得越来越复杂,以及恶意行为…

Docker 查看源地址/仓库地址,更改

一、源地址文件配置路径。若有docker文件夹,没有json,可以新增,复制进去内容 /etc/docker/daemon.json {"registry-mirrors": ["https://dockerhub.azk8s.cn","https://hub-mirror.c.163.com",&q…

conda如何修改虚拟环境的python版本

有时候安装虚拟环境的时候,忘记指定python的版本,本文介绍一下如何在虚拟环境创建之后,修改python的版本。 1 如果安装了Anaconda Navigator。 2 终端 参考:conda修改当前环境中的python版本_conda更换python版本-CSDN博客

电机故障检测系统的通用性限制分析

电机故障检测系统因应用环境、功能需求、经济性等多方面差异而难以实现通用。工厂与实验室在环境条件、使用频率、功能需求、成本、维护及数据处理方面有显著不同,此外,LabVIEW软件在两者中的应用和数据处理也存在差异,这进一步限制了系统的通…

初探海龟绘图

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 海龟绘图是Python内置的模块,在使用前需要导入该模块,可以使用以下几种方法导入: l 直接使用import语句导入海龟…

深度学习21-30

1.池化层作用(筛选、过滤、压缩) h和w变为原来的1/2,64是特征图个数保持不变。 每个位置把最大的数字取出来 用滑动窗口把最大的数值拿出来,把44变成22 2.卷积神经网络 (1)conv:卷积进行特征…

JS(JavaScript)的复合类型详解

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

vue3前后端开发:响应式对象不能直接成为前后端数据传输的对象

如图所示:前端控制台打印显示数据是没问题的,后端却显示没有接收到相应数据,但是后端的确接收到了一组空数据 直接说原因:这种情况唯一的原因是没有按正确格式传递参数。每个人写错的格式各有不同,我只是说明一下我在…

大模型应用研发基础环境配置(Miniconda、Python、Jupyter Lab、Ollama等)

老牛同学之前使用的MacBook Pro电脑配置有点旧(2015 年生产),跑大模型感觉有点吃力,操作起来有点卡顿,因此不得已捡起了尘封了快两年的MateBook Pro电脑(老牛同学其实不太喜欢用 Windows 电脑做研发工作&am…

第三方软件连接虚拟机

第三方软件连接虚拟机 1 查看本机VM(VMware)虚拟机网段2 开启虚拟机系统,修改网卡配置3 重新打开网络并测试连通性4 打开VM虚拟机网络开关5 通过第三方软件建立连接6 可能遇到的问题 1 查看本机VM(VMware)虚拟机网段 子…

vite+vue3+ts项目搭建流程 (pnpm, eslint, prettier, stylint, husky,commitlint )

vitevue3ts项目搭建 项目搭建项目目录结构 项目配置自动打开项目eslint①vue3环境代码校验插件②修改.eslintrc.cjs配置文件③.eslintignore忽略文件④运行脚本 prettier①安装依赖包②.prettierrc添加规则③.prettierignore忽略文件④运行脚本 stylint①.stylelintrc.cjs配置文…

EfficientNet-V2论文阅读笔记

目录 EfficientNetV2: Smaller Models and Faster Training摘要Introduction—简介Related work—相关工作EfficientNetV2 Architecture Design—高效EfficientNetV2架构设计Understanding Training Efficiency—了解训练效率Training-Aware NAS and Scaling—训练感知NAS和缩放…

Android开发系列(九)Jetpack Compose之ConstraintLayout

ConstraintLayout是一个用于构建复杂布局的组件。它通过将子视图限制在给定的约束条件下来定位和排列视图。 使用ConstraintLayout,您可以通过定义视图之间的约束关系来指定它们的位置。这些约束可以是水平和垂直的对齐、边距、宽度和高度等。这允许您创建灵活而响…

win10修改远程桌面端口,Windows 10下修改远程桌面端口及服务器关闭445端口的操作指南

Windows 10下修改远程桌面端口及服务器关闭445端口的操作指南 一、修改Windows 10远程桌面端口 在Windows 10系统中,远程桌面连接默认使用3389端口。为了安全起见,建议修改此端口以减少潜在的安全风险。以下是修改远程桌面端口的步骤: 1. 打…