机器学习之模型训练

news2024/11/22 15:43:01

前言

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

有了数据集和模型后,可以进行模型的训练与评估。

构建数据集

定义神经网络模型

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()

从网络构建中加载代码,构建一个神经网络模型。

定义超参、损失函数和优化器

超参

超参数是可以调整的参数,可以控制深度学习模型训练优化的过程,包括训练轮次、批次大小和学习率等。这些超参数的取值会影响模型的训练和收敛速度,其中学习率在迭代过程中控制模型的学习进度。

损失函数

损失函数用于评估模型预测值和目标值之间的误差,帮助模型降低误差并提高预测准确性。常见的损失函数包括均方误差和负对数似然,用于回归和分类任务。nn.CrossEntropyLoss结合了多种损失函数的功能,对模型的预测结果进行归一化并计算误差。

优化器

模型优化是通过调整模型参数来减少模型误差的过程,MindSpore提供了多种优化算法的实现,称之为优化器。优化器内部定义了模型参数优化过程,所有优化逻辑都封装在优化器对象中。在这里,使用了SGD(随机梯度下降)优化器。

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:训练和验证/测试。在训练阶段,模型通过迭代训练数据集来调整参数,以尝试收敛到最佳参数。而在验证/测试阶段,模型通过迭代测试数据集来评估模型的性能是否提升。这种流程的循环迭代可以帮助模型不断学习和优化,以达到更好的性能和准确度。

# Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset)
    test_loop(model, test_dataset, loss_fn)
print("Done!")

总结

模型训练一般包括构建数据集、定义神经网络模型、定义超参数、损失函数和优化器,以及输入数据集进行训练和评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1898983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第一篇——导论:数学通识课的体系和学习攻略

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么? 四、总结五、升华 一、背景介绍 数学的认知大厦;之前听的时候就觉得很重要,本次又…

性能压测 -优化 Nginx的动静分离

两件事情 1.以后将所有的项目的静态资源都应该放在nginx里面 2.nginx 规则:/static/***所有请求都有nginx直接返回 nginx 配置一下配置文件,然后把html 的静态资源,绑定好是Nginx优先级高的静态资源路径,就去交给nginx静态资源…

h5 video 标签播放经过 java 使用 ws.schild( jave、ffmpeg ) 压缩后的 mp4 视频只有声音无画面的问题排查记录

1. 引入 ws.schild MAVEN 依赖&#xff1a; <dependency><groupId>ws.schild</groupId><artifactId>jave-all-deps</artifactId><version>3.5.0</version></dependency><dependency><groupId>ws.schild</grou…

基于Spring Boot框架的EAM系统设计与实现

摘 要&#xff1a;文章设计并实现一个基于Spring Boot框架的EAM系统&#xff0c;以应对传统人工管理模式存在的低效与信息管理难题。系统利用Java语言、JSP技术、MySQL数据库等技术栈&#xff0c;构建了一个B/S架构的高效管理平台&#xff0c;提升了资产管理的信息化水平。该系…

分班查询系统怎么制作?

新学年的临近&#xff0c;教师们的工作清单再次膨胀&#xff0c;充满各种任务。开学前的准备总是让人忙碌不已&#xff0c;从课程规划到教室布置&#xff0c;再到与家长的沟通&#xff0c;每一个环节都至关重要。尤其是分班结果的公布&#xff0c;这项工作虽然看起来简单&#…

Qwen1.5-1.8b部署

仿照ChatGLM3部署&#xff0c;参考了Qwen模型的文档&#xff0c;模型地址https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary http接口 服务端代码api.py from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModelForCausalLM, …

强化学习后的数学原理:随机近似与梯度下降

概述 这节课的作用&#xff1a; 本节课大纲如下&#xff1a; Motivating examples 先回顾一下 mean estimation &#xff1a; 为什么总数反复提到这个 mean estimation&#xff0c;就是因为 RL 当中有非常多的 expectation&#xff0c;后面就会知道除了 state value 这些定义…

PySide6 实现资源的加载:深入解析与实战案例

目录 1. 引言 2. 加载内置资源 3. 使用自定义资源文件&#xff08;.qrc&#xff09; 创建.qrc文件 编译.qrc文件 加载资源 4. 动态加载UI文件 使用Qt Designer设计UI 加载UI文件 5. 注意事项与最佳实践 6. 结论 在开发基于PySide6的桌面应用程序时&…

博途通讯笔记1:1200与1200之间S7通讯

目录 一、添加子网连接二、创建PUT GET三、各个参数的意义 一、添加子网连接 二、创建PUT GET 三、各个参数的意义

新手高效指南:电子元器件BOM表创建/制作及配单全教程

在科技日新月异的今天&#xff0c;电子产品设计与制造不仅是创新精神的展现&#xff0c;更是对精确度与效率的不懈追求。在这个过程中&#xff0c;一份精细且全面的BOM&#xff08;物料清单&#xff09;犹如一座桥梁&#xff0c;连接着创意与现实世界。BOM不仅细致记录了产品所…

如何优化圆柱晶振32.768KHz的外壳接地?

圆柱晶振32.768KHz在电子设备中扮演着重要的角色&#xff0c;其精确的时钟信号对于许多应用至关重要。为了确保晶振的稳定性和准确性&#xff0c;外壳接地是一个关键步骤。 一、外壳接地的目的 外壳接地的主要目的是为了防止信号干扰。当晶振的外壳接地后&#xff0c;它相当于…

16-JS封装:extend方法

目录 一、封装需求 二、实现1&#xff1a;jQuery.extend 三、实现2&#xff1a;通过原型jQuery.fn.extend 四、优化 一、封装需求 封装需求&#xff1a; $.extend&#xff1a; var obj{ name:"xxx",age:18} var obj3{ gender:"女"} var obj2{}; 将obj、…

S7-1200PLC学习记录

文章目录 前言一、S7-12001.数字量输入模块2. PNP接法和NPN接法 二、博图软件1. 位逻辑运算Part1. 添加新设备&#xff08;添加PLC&#xff09;Part2. 添加信号模块Part3. 添加信号板中模块Part4. 添加新块Part5. Main编程文件案例1案例2 -( S )- 和 -( R )-完整操作过程&#…

ERROR | Web server failed to start. Port 8080 was already in use.

错误提示&#xff1a; *************************** APPLICATION FAILED TO START ***************************Description:Web server failed to start. Port 8080 was already in use.Action:Identify and stop the process thats listening on port 8080 or configure thi…

C++——模板详解(下篇)

一、非类型模板参数 模板参数分为类型形参与非类型形参。 类型形参即&#xff1a;出现在模板参数列表中&#xff0c;跟在class或者typename之后的参数类型名称。 非类型形参&#xff0c;就是用一个常量作为类&#xff08;函数&#xff09;模板的一个参数&#xff0c;在类&#…

Swift 中的方法调用机制

Swift 方法调用详解&#xff1a;与 Objective-C 的对比、V-Table 机制、Witness Table 机制 在 iOS 开发中&#xff0c;Swift 和 Objective-C 是两种常用的编程语言。尽管它们都能用于开发应用程序&#xff0c;但在方法调用的底层机制上存在显著差异。本文将详细介绍 Swift 的…

maven项目使用netty,前端是vue2,实现通讯

引入的java包 <!-- 以下是即时通讯--><!-- Netty core modules --><dependency><groupId>io.netty</groupId><artifactId>netty-all</artifactId><version>4.1.76.Final</version> <!-- 使用最新的稳定版本…

C++中的引用——引用的注意事项

1.引用必须初始化 2.引用在初始化后不可以改变 示例&#xff1a; 运行结果&#xff1a;

03:EDA的进阶使用

使用EDA设计一个38译码器电路和245放大电路 1、38译码器1.1、查看74HC138芯片数据1.2、电路设计 2、245放大电路2.1、查看数据手册2.2、设计电路 3、绘制PCB3.1、导入3.2、放置3.3、飞线3.4、特殊方式连接GND3.5、泪滴3.6、配置丝印和划分区域3.7、添加typc接口供电 1、38译码器…

案例精选 | 聚铭网络助力南京市玄武区教育局构建内网日志审计合规体系

南京市玄武区教育局作为江苏省教育领域的先锋机构&#xff0c;其工作重点涵盖了教育政策的实施、教育现代化与信息化的融合、教育资源的优化、教育质量的提升以及教育公平的促进。在这一背景下&#xff0c;网络安全管理成为了确保教育信息化顺利推进的关键环节之一。 根据玄武…