昇思25天学习打卡营第08天 | 模型训练

news2024/10/5 14:39:39

昇思25天学习打卡营第08天 | 模型训练

文章目录

  • 昇思25天学习打卡营第08天 | 模型训练
    • 超参数
    • 损失函数
    • 优化器
      • 优化过程
    • 训练与评估
    • 总结
    • 打卡

模型训练一般遵循四个步骤:

  1. 构建数据集
  2. 定义神经网络模型
  3. 定义超参数、损失函数和优化器
  4. 输入数据集进行训练和评估

构建数据集和网络模型在之前的内容在已经涉及,不再赘述。

超参数

超参数(Hyperparameters)是可以调整的参数,可以控制模型训练的过程。

深度学习模型多采用随机梯度下降算法SGD进行优化:
w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_t- \eta\frac1n\sum_{x\in B}\nabla l(x,w_t) wt+1=wtηn1xBl(x,wt)
其中, η \eta η是学习率, n n n是batch大小,都是超参数,这两个参数是直接影响模型性能收敛的重要参数。
一般会定义三个超参数:

  • epoch:遍历数据集的次数
  • batch size:每个批次数据的大小。size 过小导致花费时间多,梯度震荡严重,不利于收敛;size 过大容易陷入局部极小值。
  • learning rate:学习率国小会导致收敛速度变慢;过大则可能会导致训练不收敛。

损失函数

损失函数用于评估模型预测值和目标值之间的误差。
常见的损失函数包括:

  • nn.MSELoss:均方误差,用于回归
  • nn.NLLLoss:负对数似然,用于分类
  • nn.CrossEntropyLoss:结合了nn.LogSoftmaxnn.NLLLoss,可以对logits进行归一化并计算预测误差
loss_fn = nn.CrossEntropyLoss()

优化器

优化器内部定义了模型参数的优化过程,所有的优化逻辑都封装在优化器对象中。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

优化过程

通过自动微分获得的微分函数,计算参数对应的梯度,并传入优化器中,即可实现参数优化。

grads = grad_fn(inputs)
optimizer(grads)

训练与评估

遍历一次数据集被称为一轮(epoch),每轮执行训练时包含两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,检查模型性能是否提升。
# Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程一般为:

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset)
    test_loop(model, test_dataset, loss_fn)

总结

这一节的内容对深度学习模型训练的一般过程进行了详细的介绍,从数据集构建到模型定义,接着定义超参数并选择合适的值,创建损失函数和优化器对象完成训练前的准备。通过封装一个模型调用和loss计算的前向计算函数并自动微分,在每个epoch中计算loss并优化参数,从而完成模型的训练。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901181.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git 运用小知识

1.Git添加未完善代码的解决方法 1.1 Git只是提交未推送 把未完善的代码提交到本地仓库 只需点击撤销提交,提交的未完善代码会被撤回 代码显示未提交状态 1.2 Git提交并推送 把未完善的代码提交并推送到远程仓库 点击【未完善提交并推送】的结点选择还原提交&#x…

前端面试题20(防抖函数)

在前端开发中,防抖(debounce)函数是一种常见的优化技术,用于控制函数的执行频率,避免在短时间内重复调用同一函数。这在处理如用户输入、窗口尺寸变化或鼠标移动等高频事件时特别有用,可以显著提升应用程序…

最小权顶点覆盖问题-优先队列分支限界法-C++

问题描述: 给定一个赋权无向图 G(V,E),每个顶点 v∈V 都有一个权值 w(v)。如果 U⊆V,U⊆V,且对任意(u,v)∈E 有 u∈U 或 v∈U,就称 U 为图 G 的一个顶点覆盖。G 的最小权顶点覆盖是指 G 中所含顶点权之和最小的顶点覆盖。对于给定…

AttackGen:一款基于LLM的网络安全事件响应测试工具

关于AttackGen AttackGen是一款功能强大的网络安全事件响应测试工具,该工具利用了大语言模型和MITRE ATT&CK框架的强大功能,并且能够根据研究人员选择的威胁行为组织以及自己组织的详细信息生成定制化的事件响应场景。 功能介绍 1、根据所选的威胁行…

springboot项目多模块工程==1搭建

1、新建父工程 采用springboot工程作为父工程搭建方便依赖选择,在这个基础上进行maven的pom父子模块结构调整。该工程选择mave进行依赖管理 2、springboot 版本及相关依赖选择 3、删除工程目录src,并修改pom 由于该父工程只作为依赖的统一管理,因此将…

Python实战训练(方程与拟合曲线)

1.方程 求e^x-派(3.14)的解 用二分法来求解,先简单算出解所在的区间,然后用迭代法求逼近解,一般不能得到精准的解,所以设置一个能满足自己进度的标准来判断解是否满足 这里打印出解x0是因为在递归过程中…

CentOS 7安装Elasticsearch7.7.0和Kibana

一. 准备安装包 elasticsearch和kibana:官网历史版本找到并下载(https://www.elastic.co/cn/downloads/past-releases#elasticsearch)ik分词器:GitHub下载(https://github.com/infinilabs/analysis-ik/releases/tag/v…

3.js - 裁剪平面(clipIntersection:交集、并集)

看图 代码 // ts-nocheck// 引入three.js import * as THREE from three// 导入轨道控制器 import { OrbitControls } from three/examples/jsm/controls/OrbitControls// 导入lil.gui import { GUI } from three/examples/jsm/libs/lil-gui.module.min.js// 导入tween import …

Interpretability 与 Explainability 机器学习

「AI秘籍」系列课程: 人工智能应用数学基础人工智能Python基础人工智能基础核心知识人工智能BI核心知识人工智能CV核心知识 Interpretability 模型和 Explainability 模型之间的区别以及为什么它可能不那么重要 当你第一次深入可解释机器学习领域时,你会…

WEB编程-了解Tomcat服务器

第⼀章⽹络编程 1.1 概述 计算机⽹络:是指将地理位置不同的具有独⽴功能的多台计算机及其外部设备,通过通信线路连接起来,在⽹络 操作系统、⽹络管理软件及⽹络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。 …

cs224n作业3 代码及运行结果

代码里要求用pytorch1.0.0版本,其实不用也可以的。 【删掉run.py里的assert(torch.version “1.0.0”)即可】 代码里面也有提示让你实现什么,弄懂代码什么意思基本就可以了,看多了感觉大框架都大差不差。多看多练慢慢来,加油&am…

前端位置布局汇总

1、位置:绝对位置和相对位置 绝对位置 style"position: absolute;left: 218px;top: 0%;" style"position: absolute;bottom:5px;right:5px ;" 相对位置 :margin外边距 padding内边距 style"border:1px solid black;width:200px;text-ali…

vue事件处理v-on或@

事件处理v-on或 我们可以使用v-on指令(简写)来监听DOM事件,并在事件触发时执行对应的Javascript。用法:v-on:click"methodName"或click"hander" 事件处理器的值可以是: 内敛事件处理器&#xff1…

Yolo v7网络实现细节(一)

Yolo v7网络实现细节 YOLO v7网络架构的整体介绍 不同GPU和对应模型: ​​​​​​​边缘GPU:YOLOv7-tiny普通GPU:YOLOv7​​​​​​​云GPU的基本模型: YOLOv7-W6 激活函数: YOLOv7 tiny: leaky ReLU其…

南方健康2024米思会:科普患教赋能医药增长闭环,千亿蓝海市场大爆发!

2024年6月25日-28日,在中国•南太湖举办的2024米思会如约而至,顺利落下帷幕,本次大会以“韧进启新局”为主题,以不懈进取的“韧劲”,立身破局,迎变启新。通过4天3夜的思想碰撞和互动交流,引领行…

使用shell脚本实现DM8开机自动启动

编写shell脚本 #!/bin/bashsu -dmdba >>EOF cd /home/dmdba/dmdbms/bin ./DmServiceDMTEST start echo "dm start ... " EOF注意:DmServiceDMTEST每个服务器设置的不一样,根据实际进行更换 授权脚本可执行权限 chmod -x /dmdata/dmse…

策略为王股票软件源代码-----如何修改为自己软件61----资讯菜单修改-----举例---------调用同花顺网页------

http://stock.sina.com.cn 将原来的新浪行情,修改为同花顺, 搜索 stock.sina.com.cn... StkUI\View\InfoView.cpp(58):char

【C++:默认成员函数初始化列表】

构造函数 特点 没有返回值支持函数重载对象实例化时,编译器自动调用作用不是构造,而是初始化函数名与类名相同无参函数和全缺省的函数,不用传参就能调用的函数叫做默认构造函数 构造函数是一个特殊的成员函数 注:无参构造函数在实…

Lock4j简单的支持不同方案的高性能分布式锁实现及源码解析

文章目录 1.Lock4j是什么?1.1简介1.2项目地址1.3 我之前手写的分布式锁和限流的实现 2.特性3.如何使用3.1引入相关依赖3.2 配置redis或zookeeper3.3 使用方式3.3.1 注解式自动式3.3.2 手动式 4.源码解析4.1项目目录4.2实现思路 5.总结 1.Lock4j是什么? 1.1简介 lock4j是苞米…

平均102天 Accept的国产医学顶刊,影响因子4连涨,还免版面费!

《Asian Journal of Pharmaceutical Sciences》 (亚洲药物制剂科学) 是由沈阳药科大学主办、Elsevier合作出版的全英文药剂学学术期刊,是“中国科技期刊卓越行动计划”资助期刊,现已被SCIE、PubMed Central、Scopus和DOAJ等国际著名检索系统收录&#xf…