昇思第7天

news2024/10/6 4:06:10

模型训练

模型训练一般分为四个步骤:

构建数据集。
定义神经网络模型。
定义超参、损失函数及优化器。
输入数据集进行训练与评估。

  1. 数据集加载
import mindspore
from mindspore import nn
# 从 MindSpore 数据集包中导入 vision 和 transforms 模块。
# vision:包含处理图像数据的工具。
# transforms:包含数据转换的工具。
from mindspore.dataset import vision, transforms
# 从 MindSpore 数据集包中导入 MnistDataset 类,用于加载 MNIST 数据集。
from mindspore.dataset import MnistDataset
# 从 download 模块中导入 download 函数,用于下载数据集。
from download import download

# 指定数据集的 URL 地址。
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"

# 使用 download 函数下载数据集并解压到当前目录。
path = download(url, "./", kind="zip", replace=True)

# 定义一个数据管道函数,接收数据集路径和批量大小作为参数。
def datapipe(path, batch_size):
    # 定义图像数据的转换操作列表。
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),       # 缩放图像像素值到 [0, 1] 范围。
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),  # 标准化图像数据。
        vision.HWC2CHW()                      # 转换图像格式从 HWC(高度、宽度、通道)到 CHW(通道、高度、宽度)。
    ]
    # 定义标签数据的转换操作,将标签转换为 int32 类型。
    label_transform = transforms.TypeCast(mindspore.int32)

    # 加载指定路径的数据集。
    dataset = MnistDataset(path)
    # 对数据集的图像应用转换操作。
    dataset = dataset.map(image_transforms, 'image')
    # 对数据集的标签应用转换操作。
    dataset = dataset.map(label_transform, 'label')
    # 将数据集分批,每批包含指定数量的样本。
    dataset = dataset.batch(batch_size)
    # 返回处理后的数据集。
    return dataset

# 创建训练数据集,批量大小为 64。
train_dataset = datapipe('MNIST_Data/train', batch_size=64)

# 创建测试数据集,批量大小为 64。
test_dataset = datapipe('MNIST_Data/test', batch_size=64)
  1. 构建神经网络
 # 定义一个神经网络类 Network,继承自 nn.Cell。
class Network(nn.Cell):
    # 在初始化方法中定义网络的结构。
    def __init__(self):
        # 调用父类的初始化方法。
        super().__init__()
        # 定义一个平坦化层,用于将输入的多维数据展开为一维。
        self.flatten = nn.Flatten()
        # 定义一个顺序容器 SequentialCell,其中包含多个层顺序连接。
        self.dense_relu_sequential = nn.SequentialCell(
            # 全连接层,将输入数据的尺寸从 28*28(即 784)转换为 512。
            nn.Dense(28*28, 512),
            # ReLU 激活函数。
            nn.ReLU(),
            # 全连接层,将输入数据的尺寸从 512 转换为 512。
            nn.Dense(512, 512),
            # ReLU 激活函数。
            nn.ReLU(),
            # 全连接层,将输入数据的尺寸从 512 转换为 10(对应于 10 个类别)。
            nn.Dense(512, 10)
        )

    # 定义前向传播方法,用于计算网络的输出。
    def construct(self, x):
        # 将输入数据平坦化。
        x = self.flatten(x)
        # 依次通过顺序容器中的各层,得到最终的输出 logits。
        logits = self.dense_relu_sequential(x)
        # 返回计算得到的 logits。
        return logits

# 创建一个 Network 类的实例,表示定义好的神经网络模型。
model = Network()

3.定义超参、损失函数及优化器。

# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2
# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2

# 定义损失函数,用于计算预测结果与实际标签之间的差异。
# 使用交叉熵损失函数(CrossEntropyLoss),这是分类问题中常用的损失函数。
loss_fn = nn.CrossEntropyLoss()

# 定义优化器,用于更新模型的参数。

# 使用随机梯度下降(SGD)优化器。
# model.trainable_params() 获取模型中所有需要训练的参数。
# learning_rate 指定优化器的学习率。
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)


4.训练与评估
训练

# 定义前向函数,用于计算模型输出和损失。
def forward_fn(data, label):
    # 使用模型计算预测值(logits)。
    logits = model(data)
    # 计算预测值与真实标签之间的损失。
    loss = loss_fn(logits, label)
    # 返回损失值和预测值。
    return loss, logits

# 获取梯度函数,用于计算损失相对于模型参数的梯度。
# mindspore.value_and_grad 会计算前向函数的值和梯度。
# forward_fn: 计算损失的前向函数。
# None: 不需要计算的额外输出。
# optimizer.parameters: 需要计算梯度的参数。
# has_aux=True: 表示前向函数返回多个值。
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 定义单步训练函数。
def train_step(data, label):
    # 计算损失和梯度。
    (loss, _), grads = grad_fn(data, label)
    # 使用优化器更新模型参数。
    optimizer(grads)
    # 返回当前步的损失值。
    return loss

# 定义训练循环函数。
def train_loop(model, dataset):
    # 获取数据集的大小(即批次的数量)。
    size = dataset.get_dataset_size()
    # 设置模型为训练模式。
    model.set_train()
    # 枚举数据集的每个批次。
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        # 执行单步训练,获取当前批次的损失值。
        loss = train_step(data, label)

        # 每 100 个批次打印一次损失值和当前批次编号。
        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

测试函数

# 定义测试循环函数,用于在测试集上评估模型的性能。
def test_loop(model, dataset, loss_fn):
    # 获取数据集的批次数量。
    num_batches = dataset.get_dataset_size()
    # 设置模型为评估模式。
    model.set_train(False)
    
    # 初始化总样本数、测试损失和正确预测数。
    total, test_loss, correct = 0, 0, 0
    
    # 枚举数据集的每个批次。
    for data, label in dataset.create_tuple_iterator():
        # 使用模型进行预测。
        pred = model(data)
        # 累加总样本数。
        total += len(data)
        # 累加测试损失。
        test_loss += loss_fn(pred, label).asnumpy()
        # 累加正确预测数。
        correct += (pred.argmax(1) == label).asnumpy().sum()
    
    # 计算平均损失。
    test_loss /= num_batches
    # 计算准确率。
    correct /= total
    
    # 打印测试结果,包括准确率和平均损失。
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

运行

# 定义损失函数和优化器。
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

# 执行多个 epoch 的训练循环。
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    # 执行训练循环。
    train_loop(model, train_dataset)
    # 在测试集上进行评估。
    test_loop(model, test_dataset, loss_fn)

print("Done!")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1887832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

肝癌-图像分类数据集

肝癌-图像分类数据集 数据集: 链接:https://pan.baidu.com/s/18r-JS1FIv6BiyvlqDpUE0w?pwdrw5w 提取码:rw5w 数据集信息介绍: 文件夹 恶性 中的图片数量: 1008 文件夹 良性 中的图片数量: 882 所有子文件夹中的图片总数量: 1…

微软账户和本地账户有什么区别?如何切换登录账户?

Windows 操作系统是目前世界上比较流行的操作系统之一,在使用 Windows 系统的时候都需要我们进行登录,其中我们可以使用微软账户或者本地账户进行登录,那本地账户和微软账户有什么区别?下面就带大家了解一下微软账户和本地账户。 …

请不要在 Vue 中滥用“watch”功能,拜托了!

随着 Vue 3 的 Composition API 风格的普及,使用 watch 的成本越来越低。 现在,我们可以在任何地方使用 watch 来监听响应式数据。随着业务的推进,你可能会在代码中看到大量的 watch。 当你接手修改这些充满 watch 代码时,我相信…

电梯修理升级,安装【电梯节能】能量回馈设备

电梯修理升级,安装【电梯节能】能量回馈设备 1、节能率评估 15%—45% 2、降低机房环境温度,改善电梯控制系统的运行环境; 3、延长电梯使用寿命; 4、机房可以不需要使用空调等散热设备的耗电,间接节省电能。 欢迎私询哦…

使用PID算法实现DAC模拟量输出的快速调节

目录 概述 1 系统框架和算法 1.1 框架结构介绍 1.2 PID算法实现 1.2.1 理论介绍 1.2.2 离散化位置式PID 1.2.3 位置式PID算法 2 STM32Cube 配置项目 2.1 配置参数 2.2 GENERATE项目 3 功能实现 3.1 ADC采样数据功能 3.2 DAC数据转换 3.3 PID相关的调制函数 4 …

黄子韬vs徐艺洋卫生间风波

【热搜爆点】黄子韬VS徐艺洋:卫生间风波背后的职场与友情界限探讨在这个充满欢笑与意外的综艺时代,《跟我出游吧》再次以它独有的魅力,引爆了一个既尴尬又引人深思的话题——“黄子韬要上徐艺洋的卫生间?”这不仅仅是一句简单的调…

CSS|03 尺寸样式属性文本与字体属性

尺寸样式属性 height:元素高度height的值:auto 自动length 使用px定义高度% 基于包含它的块级对象的百分比高度 width:元素的宽度width的值与height一样span标签可以设置宽度、高度吗? 答:不可以,因为span标签是一个行…

mysql-sql-第十四周

学习目标: sql 学习内容: 40.查询学过「哈哈」老师授课的同学的信息 Select * from students left join score on students.stunmscore.stunm where counm (select counm from teacher left join course on teacher.teanmcourse.teanm where teache…

DCU整体硬件架构

DCU整体硬件架构 DCU整体硬件架构 首先,DCU通过PCI-E总线与CPU处理器相连,它是CPU主机系统的一个硬件扩展,其存在的目的是为了对程序某些模块或者函数进行加速。虽然DCU是原硬件系统的一个扩展,接受CPU调度指挥,但是在…

西部智慧健身小程序+华为运动健康服务

1、 应用介绍 西部智慧健身小程序为用户提供一站式全流程科学健身综合服务。用户通过登录微信小程序,可享用健康筛查、运动风险评估、体质检测评估、运动处方推送、个人运动数据监控与评估等公益服务。 2、 体验介绍西部智慧健身小程序华为运动健康服务核心体验如…

认识流式处理框架Apache Flink

目录 一、Apache Flink 的基础概念 1.1 Apache Flink是什么? 1.2 Flink的定义 二、Apache Flink 的发展史 2.1 Flink前身Stratosphere 2.2 Flink发展时间线及重大变更 三、Flink核心特性 3.1 批流一体化 3.2 同时支持高吞吐、低延迟、高性能 3.3 支持事件时…

前端接入chatgpt,实现流式文字的显示

前端接入chatgpt,实现流式文字的显示 业务需求: 项目需要接入chatgpt提供的api,后端返回流式的字符,前端接收并实时显示。 相关技术原理: 1. JS中的Stream流: 在JavaScript中,使用Stream流通常指的是处理数据流的…

react native中使用@react-navigation/native进行自定义头部

react native中使用react-navigation/native进行自定义头部 效果示例图实例代码 效果示例图 实例代码 /* eslint-disable react-native/no-inline-styles */ /* eslint-disable react/no-unstable-nested-components */ import React, { useLayoutEffect } from react; import…

ripro子主题eeesucai-child集成后台美化包(适用于设计素材站+资源下载站等)

模板介绍 最新RiPro子主题模板,Eeesucai-child模板后台美化包,使用该子主题前需要安装WordPress程序和RiPro模板。 安装教程 第一种,在wordpress后台上传主题,上传之后点启动 第二种,上传到wordpress主题目录/wp-con…

MatLab 二维图像绘制基础

MatLab 二维图像绘制基础 plot 描点绘图 %% % 二维绘图 ,plot进行描点,步长越小,越平滑 x [1:9]; y [0.1:0.2:1.7]; X x y*i; % 复数 plot(X)plot绘制矩阵 %% % 当X Y 为矩阵时,对应矩阵中的元素依次绘制 t 0:0.01:2*pi; …

将多个Excel工作表合并成一个工作表,1分钟轻松搞定!

1. 案例展示 2. 视频详解 多个工作表合并成一个工作表 3. 图文详解 第一步:相同格式(表头)的表格,并将所有表格都放在一个文件夹内“将多个工作表合并成一个工作表”(自己定义文件名) 第二步:新…

Linux 【线程池】【单例模式】【读者写者问题】

💓博主CSDN主页:麻辣韭菜💓   ⏩专栏分类:Linux初窥门径⏪   🚚代码仓库:Linux代码练习🚚   🌹关注我🫵带你学习更多Linux知识   🔝 目录 🏳️‍🌈前言 …

VSCode打开其它IDE项目注释显示乱码的解决方法

问题描述:VSCode打开Visual Studio(或其它IDE)工程,注释乱码,如下图所示: 解决方法:点击VSCode右下角的UTF-8,根据提示点击“通过编码重新打开”,再选择GB2312&#xff0…

JDBC链接kerberos认证的impala数据库报错问题解决

先上代码 public static Connection connectToImpala() {try {log.info("ketTabPath:" ketTabPath);log.info("krb5Path:" krb5Path);System.setProperty("java.security.krb5.conf", krb5Path);System.setProperty("sun.security.krb5.…

python如何输出list

直接输出list_a中的元素三种方法: list_a [1,2,3,313,1] 第一种 for i in range(len(list_a)):print(list_a[i]) 1 2 3 313 1 第二种 for i in list_a:print(i) 1 2 3 313 1 第三种,使用enumerate输出list_a方法: for i,j in enum…