机器学习 - 训练模型

news2025/1/16 11:30:28

接着这一篇博客做进一步说明:
机器学习 - 选择模型

为了解决测试和预测之间的差距,可以通过更新 internal parameters, the weights set randomly use nn.Parameter() and bias set randomly use torch.randn().
Much of the time you won’t know what the ideal parameters are for a model.
Instead, it is much more fun to write code to see if the model can try and figure them out itself. That is a loss function as well as and optimizer.

FunctionWhat does it do?Where does it live in PyTorch?Common values
Loss functionMeasures how wrong your models predictions (e.g. y_preds) are compared to the truth labels (e.g. y_test). Lower the better.差距越小越好PyTorch has plenty of built-in loss functions in torch.nnMean absolute error (MAE) for regression problems (torch.nn.L1Loss()). Binary cross entropy for binary classification problems (torch.nn.BCELoss())
OptimizerTells your model how to update its internal parameters to best lower the loss.You can find various optimization function implementations in torch.optimStochastic gradient descent (torch.optim.SGD()). Adam.optimizer (torch.optim.Adam())

介绍 MAE: Mean absolute error 也称为平均绝对误差,是一种用于衡量预测值与真实值之间差异的损失函数。MAE计算的是预测值与真实值之间的绝对差值的平均值,即平均误差的绝对值。在 PyTorch 中可以使用 torch.nn.L1Loss 来计算MAE.

介绍Stochastic gradient descent:
这是一种常用的优化算法,用于训练神经网络模型。它是梯度下降算法的变种,在每次更新参数时都使用随机样本的梯度估计来更新参数。SGD的基本思想是通过最小化损失函数来调整模型参数,使得模型的预测结果与真实标签尽可能接近。在每次迭代中,SGD随机选择一小批样本 (mini-batch) 来计算损失函数关于参数的梯度,并使用该梯度来更新参数。由于每次更新只使用了一部分样本,因此SGD通常具有更快的收敛速度和更低的计算成本。在 PyTorch 中,可以使用 torch.optim.SGD(params, lr) 来实现,其中

  • params is the target model parameters you’d like to optimize (e.g. the weights and bias values we randomly set before).
  • lr is the learning rate you’d like the optimizer to update the parameters at, higher means the optimizer will try larger updates (these can sometimes be too large and the optimizer will fail to work), lower means the optimizer will try smaller updates (these can sometimes be too small and the optimizer will take too long to find the ideal values). Common starting values for the learning rate are 0.01, 0.001, 0.0001.

介绍 Adam 优化器:
Adam优化器是一种常用的优化算法,它结合了动量法和自适应学习率调整的特性,能够高效地优化神经网络模型的参数。Adam优化器的基本思想是在梯度下降的基础上引入了动量项和自适应学习率调整项。动量项可以帮助优化器在更新参数时保持方向性,从而加速收敛;而自适应学习率调整项可以根据参数的历史梯度来调整学习率,从而在不同参数上使用不同的学习率,使得参数更新更加稳健。

介绍学习率:
学习率是在训练神经网络时控制参数更新步长的一个超参数。它决定了每次参数更新时,参数沿着梯度方向更新的程度。学习率越大,参数更新的步长越大;学习率越小,参数更新的步长越小。选择合适的学习率通常是训练神经网络时需要调节的一个重要超参数。如果学习率过大,可能导致参数更新过大,导致模型不稳定甚至发散;如果学习率过小,可能导致模型收敛速度过慢,训练时间变长。

代码如下:

import torch

# Create the loss function 
loss_fn = nn.L1Loss()  # MAE loss is same as L1Loss

# Create the optimizer
optimizer = torch.optim.SGD(params = model_0.parameters(),
                            lr = 0.01)


现在创造一个optimization loop
The training loop involves the model going through the training data and learning the relationships between the features and labels.
The testing loop involves going through the testing data and evaluating how good the patterns are that the model learned on the training data (the model never sees the testing data during training).
Each of these is called a “loop” because we want our model to look (loop through) at each sample in each dataset. 所以,得用 for 循环来实现。

PyTorch training loop

NumberStep nameWhat does it do?Code example
1Forward passThe model goes through all of the training data once, performing its forward() function calculations.model(x_train)
2Calculate the lossThe model’s outputs (predictions) are compared to the ground truth and evaluated to see how wrong they are.loss = loss_fn(y_pred, y_train)
3Zero gradientsThe optimizers gradients are set to zero (they are accumulated by default) so they can be recalculated for the specific training step.optimizer.zero_grad()
4Perform backpropagation on the lossComputes the gradient of the loss with respect for every model parameter to be updated (each parameter with requires_grad=True). This is known as backpropagation, hence “backwards”.loss.backward()
5Update the optimizer (gradient descent)Update the parameters with requires_grad=True with respect to the loss gradients in order to improve them.optimizer.step()

PyTorch testing loop
As for the testing loop (evaluating the model), the typical steps include:

NumberStep nameWhat does it do?Code example
1Forward passThe model goes through all of the training data once, performing its forward() function calculations.model(x_test)
2Calculate the lossThe model’s outputs (predictions) are compared to the ground truth and evaluated to see how wrong they are.loss = loss_fn(y_pred, y_test)
3Calculate evaluation metrics (optional)Alongside the loss value you may want to calculate other evaluation metrics such as accuracy on the test set.Custom functions

下面是代码实现

# Create the loss function
# 那你。L1Loss() 是用于计算平均绝对误差 (MAE) 的损失函数。
loss_fn = nn.L1Loss()  # MAE loss is same as L1Loss

# Create the optimizer
# torch.optim.SGD() 是用于创建随机梯度下降优化器的函数。
# parameters() 返回一个包含了模型中所有需要进行梯度更新的参数的迭代器
optimizer = torch.optim.SGD(params = model_0.parameters(),
                            lr = 0.01)

# Set the number of epochs (how many times the model will pass over the training data)
epochs = 200

# Create empty loss lists to track values
train_loss_values = []
test_loss_values = []
epoch_count = []

for epoch in range(epochs):
  ### Training 

  # Put model in training mode (this is the default state of a model)
  # train() 函数通常用于将模型设置为训练模式
  model_0.train()

  # 1. Forward pass on train data using the forward() method inside
  y_pred = model_0(X_train)

  # 2. Calculate the loss (how different are our models predictions to the ground truth)
  loss = loss_fn(y_pred, y_train)

  # 3. Zero grad of the optimizer
  optimizer.zero_grad() 

  # 4. Loss backwards
  loss.backward()

  # 5. Progress the optimizer
  # step() 用于执行一步参数更新操作。
  optimizer.step() 

  ### Testing

  # Put the model in evaluation mode
  model_0.eval() 

  with torch.inference_mode():
    # 1. Forward pass on test data 
    test_pred = model_0(X_test)

    # 2. Calculate loss on test data 
    test_loss = loss_fn(test_pred, y_test.type(torch.float))  # predictions come in torch.float datatype, so comparisons need to be done with tensors of the same type 

    # Print out 
    if epoch % 10 == 0:
      epoch_count.append(epoch)
      train_loss_values.append(loss.detach().numpy())
      test_loss_values.append(test_loss.detach().numpy())
      print(f"Epoch: {epoch} | MAE Train Loss: {loss} | MAE Test Loss: {test_loss}")


plt.plot(epoch_count, train_loss_values, label="Train loss")
plt.plot(epoch_count, test_loss_values, label="Test loss")
plt.title("Training and test loss curves")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()

print("The model learned the following values for weights and bias: ")
print(model_0.state_dict())
print("\nAnd the original values for weights and bias are: ")
print(f"weights: {weight}, bias: {bias}")

# 结果如下:
Epoch: 0 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 10 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 20 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 30 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 40 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 50 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 60 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 70 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 80 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 90 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 100 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 110 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 120 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 130 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 140 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 150 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 160 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 170 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 180 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 190 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
The model learned the following values for weights and bias: 
OrderedDict([('weights', tensor([0.6990])), ('bias', tensor([0.3093]))])

And the original values for weights and bias are: 
weights: 0.7, bias: 0.3

Loss is the measure of how wrong your model is. Loss 的值越低,效果越好。

效果图

都看到这了,点个赞支持下呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1530390.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python内置对象

Python是一种强大的、动态类型的高级编程语言,其内置对象是构成程序的基础元素。Python的内置对象包括数字、字符串、列表、元组、字典、集合、布尔值和None等,每种对象都有特定的类型和用途。 01 什么是内置对象 这些对象是编程语言的基础构建块&…

C语言 指针练习

一、 a、b是两个浮点型变量&#xff0c;给a、b赋值&#xff0c;建立两个指针分别指向a的地址和b的地址&#xff0c;输出两个指针的值。 #include<stdio.h> int main() {float a,b,*p1,*p2;a10.2;b2.3;p1&a;p2&b;printf("a%f,b%f\n",a,b);printf("…

软考高级:类的分类(边界类、控制类、实体类)概念和例题

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

【Java初阶(二)】分支与循环

❣博主主页: 33的博客❣ ▶文章专栏分类: Java从入门到精通◀ &#x1f69a;我的代码仓库: 33的代码仓库&#x1f69a; 目录 1.前言2.顺序结构3.分支循环3.1if语句3.2switch语句 4.循环结构4.1while循环4.2 break和continue4.3 for循环4.4 do while循环 5.输入输出5.1输出5.2输…

记录C++中,子类同名属性并不能完全覆盖父类属性的问题

问题代码&#xff1a; 首先看一段代码&#xff1a;很简单&#xff0c;就是BBB继承自AAA&#xff0c;然后BBB重写定义了同名属性&#xff0c;然后调用父类AAA的打印函数&#xff1a; #include <iostream> using namespace std;class AAA { public:AAA() {}~AAA() {}void …

Django单表数据库操作

单表操作 测试脚本 当你只想测试django某一个py文件的内容,可以不用书写前后端的交互,直接写一个测试脚本即可 单表删除 数据库操作方法: 1.all():查询所有的数据 2.filter():带有过滤条件的查询 3.get():直接拿数据对象,不存在则报错 4.first():拿queryset里面的第一个元素…

【数据结构取经之路】栈

目录 引言 栈的性质 顺序栈 栈的基本操作 初始化 销毁 插入 删除 判空 取栈顶元素 栈的大小 完整代码&#xff1a; 引言 栈(stack)&#xff0c;可以用数组实现&#xff0c;也可以用链表实现。用数组实现的栈叫顺序栈&#xff0c;用链表实现的栈叫链式栈&#…

红外相机和RGB相机标定:实现两种模态数据融合

1. 前期准备 RGB相机&#xff1a;森云智能SG2-IMX390&#xff0c;1个红外相机&#xff1a;艾睿光电IR-Pilot 640X-32G&#xff0c;1个红外标定板&#xff1a;https://item.taobao.com/item.htm?_ujp3fdd12b99&id644506141871&spma1z09.2.0.0.5f822e8dKrxxYI 2.操作步…

C# 方法(函数)

文章目录 C# 方法&#xff08;函数&#xff09;简单示例程序代码运行效果 值传递和引用传递示例程序 运行效果按输出传递参数运行结果 C# 方法&#xff08;函数&#xff09; 简单示例 程序代码 访问的限制符 using System; using System.Collections.Generic; using Syste…

DevEco Studio 项目创建

安装DevEco Studio后开始使用&#xff0c;双击桌面DevEco Studio 快捷方式弹出界面&#xff1a; 选择Application —> Empty Ability&#xff0c;点击Next 项目配置 Project name&#xff1a;工程的名称&#xff0c;可以自定义&#xff0c;由大小写字母、数字和下划线组成。…

JavaSE——面向对象高级二(3/4)-接口(认识接口、接口的好处)以及接口的综合案例(设计班级学生的信息管理模块)

目录 接口 认识接口 接口的好处 接口的综合案例 需求 学生类 班级管理类 学生操作接口 学生操作类 完善班级管理类 测试 接口 认识接口 Java提供了一个关键字 interface&#xff0c;用这个关键字我们可以定义出一个特殊的结构&#xff1a;接口。 public interface…

个人网站制作 Part 13 添加搜索功能[Elasticsearch] | Web开发项目

文章目录 &#x1f469;‍&#x1f4bb; 基础Web开发练手项目系列&#xff1a;个人网站制作&#x1f680; 添加搜索功能&#x1f528;使用Elasticsearch&#x1f527;步骤 1: 安装Elasticsearch&#x1f527;步骤 2: 配置Elasticsearch&#x1f527;步骤 3: 创建索引 &#x1f…

[论文笔记] Dual-Channel Span for Aspect Sentiment Triplet Extraction

一种利用句法依赖和词性相关性信息来过滤噪声&#xff08;无关跨度&#xff09;的基于span方法。 会议EMNLP 2023作者Pan Li, Ping Li, Kai Zhang团队Southwest Petroleum University论文地址https://aclanthology.org/2023.emnlp-main.17/代码地址https://github.com/bert-ply…

长安链Docker Java智能合约引擎的架构、应用与规划

#功能发布 长安链3.0正式版发布了多个重点功能&#xff0c;包括共识算法切换、支持java智能合约引擎、支持后量子密码、web3生态兼容等。我们接下来为大家详细介绍新功能的设计、应用与规划。 在《2022年度长安链开源社区开发者调研报告》中&#xff0c;对Java合约语言支持是开…

ideaSSM 财务凭证管理系统bootstrap开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 idea 开发 SSM 财务凭证管理系统是一套完善的信息管理系统&#xff0c;结合SSM框架和bootstrap完成本系统&#xff0c;对理解JSP java编程开发语言有帮助系统采用SSM框架&#xff08;MVC模式开发&#xff09;&#xff0c;系统具有完整的源代码和数据库&#xff…

[Qt学习笔记]Qt下使用Halcon实现采图时自动对焦的功能(Brenner梯度法)

目录 1、介绍2、实现方法2.1 算法实现过程2.2 模拟采集流程 3、总结4、代码展示 1、介绍 在机器视觉的开发中&#xff0c;现在有很多通过电机去做相机的聚焦调节&#xff0c;对比手工调节&#xff0c;自动调节效果更好&#xff0c;而且其也能满足设备自动的需求&#xff0c;尤…

高效的Gitlab Flow最佳实践

文章目录 一、git flow二、github flow三、gitlab flow四、基于gitlab flow的最佳实践1.语义化版本号2.测试发布3.bug修复 参考 业界包含三种flow&#xff1a; Git flowGithub flowGitlab flow 三种工作流程&#xff0c;有一个共同点&#xff1a;都采用"功能驱动式开发&…

圆弧齿的模型怎么建立?

今天咱们聊一聊圆弧齿的相关内容&#xff0c;有兴趣的小伙伴一起来看看吧。 圆弧齿轮发展历史 早期齿轮传动的速度较低、传递的功率小&#xff0c;长期的生产实践使得人们开始意识到对齿轮单个轮齿形状&#xff08;轮齿齿形&#xff09;的研究的重要性&#xff0c;齿轮的齿形…

5G安全技术新突破!亚信安全5G安全迅龙引擎正式发布

5G专网应用飞速增长&#xff1a;2020年5G专网数量800个&#xff0c;2021年2300个&#xff0c;2022年5325个&#xff0c;2023年已经超过16000个&#xff0c;5G与垂直行业的融合快速加深&#xff0c;5G带来的变革正加速渗透至各行各业。 5G网络出现安全问题&#xff0c;将是异常严…

MNN createFromBuffer(一)

系列文章目录 MNN createFromBuffer&#xff08;一&#xff09; MNN createRuntime&#xff08;二&#xff09; MNN createSession 之 Schedule&#xff08;三&#xff09; MNN createSession 之创建流水线后端&#xff08;四&#xff09; MNN Session::resize 之流水线编码&am…