pytorch深度学习基础 8 (使用PyTorch的内置功能和默认参数来构建和训练一个简单的线性模型)

news2024/9/22 9:46:57

co

上面几节都是自定义了很多东西,比如模型的权重,偏置的大小,学习率,损失函数等等,但是实际上pytorch有很多内置的函数以及默认的参数可以对我们的模型部分进行替换,效果也是非常好的,今天我们就来试试使用PyTorch的内置功能和默认参数来构建和训练一个简单的线性模型

  1. 模型定义:使用nn.Linear(1, 1)定义了一个简单的线性模型,这意味着输入特征数量和输出特征数量都是1。这是一个非常基本的模型,非常适合入门级的线性回归任务。

  2. 优化器:使用了optim.SGD作为优化器,这是随机梯度下降的一个实现,适用于大多数优化问题。设置学习率为1e-2,这是一个相对较大的学习率,但在这种情况下可能仍然适用,因数据量和模型都很简单。

  3. 损失函数:选择了nn.MSELoss()作为损失函数,即均方误差损失,这是回归问题中最常用的损失函数之一。

  4. 内存(RAM)使用

    • PyTorch在默认情况下会管理内存使用,包括自动梯度计算所需的内存。当调用.backward()时,PyTorch会自动计算所有需要梯度的张量的梯度,并存储在内存中。
    • 在训练循环中,内存使用可能会随着数据批次的大小和模型复杂度的增加而增加。然而,由于数据和模型都非常简单,内存使用应该保持在较低水平。
  5. 默认参数

    • PyTorch中的许多函数和类都有默认参数,这些参数在大多数情况下都是合理的起点。例如,在nn.Linear中,默认不使用偏置项(但可以通过设置bias=True来启用),在optim.SGD中,默认学习率、动量等参数可以根据需要进行调整。
    • 在我的代码中,已经明确设置了学习率,但没有修改其他默认参数,这在大多数情况下是可行的。
  6. 训练循环

    • 训练循环,包括前向传播、计算损失、梯度归零、反向传播和参数更新。
    • 注意,在每次迭代时,对整个训练集进行了训练,这被称为批量梯度下降(Batch Gradient Descent)。在大数据集上,这可能会导致训练速度非常慢,因为每次迭代都需要计算整个数据集的梯度。然而,在数据集很小的情况下,这是可以接受的

大致的思想有了小编说一下这次的模型和之前模型的不同点

linear_model = nn.Linear(1, 1) 

nn.Linear(1, 1)是PyTorch中的一个线性层(全连接层),用于执行线性变换。具体来说,它将输入数据从一维空间映射到另一维空间。

参数解释

  • 1, 1:这两个参数分别表示输入特征的数量和输出特征的数量。在这个例子中,输入和输出特征的数量都是1。

功能

  • 线性变换nn.Linear(1, 1)执行的线性变换可以表示为 y=wx+b,其中 w 是权重,b 是偏置,x 是输入,y 是输出。
  • 权重和偏置:这个层会自动初始化权重 w 和偏置 b,并在训练过程中通过反向传播进行调整。

应用场景

  • 简单回归问题:在简单的线性回归问题中,可以使用nn.Linear(1, 1)来拟合一个输入和一个输出之间的关系。
  • 神经网络的一部分:在更复杂的神经网络中,nn.Linear(1, 1)可以作为网络的一个层,用于处理输入
    optimizer = optim.SGD( # 构建随机梯度下降(SGD)优化器
        linear_model.parameters(),
        lr=1e-2)   # linear_model.parameters(): # 这是一个生成器,返回线性模型中所有需要优化的参数
    

    定义了一个优化器,具体来说是使用随机梯度下降(SGD)优化器来优化一个线性模型的参数。linear_model.parameters(): 这是一个生成器,返回线性模型中所有需要优化的参数。这些参数在训练过程中会根据损失函数的梯度进行更新

  • 如果大家好奇这个默认的params的参数到底是什么呢,不妨运行一下下面的代码试试

  • print(list(linear_model.parameters()))
    

    最后的改变就是原来我们的损失函数是自己定义的函数,现在我们使用自带的函数进行替换,计算的结果是一样的

  • loss_fn = nn.MSELoss()

    综上所述,我们整合一下代码

  • import numpy as np
    import torch
    import torch.optim as optim
    
    torch.set_printoptions(edgeitems=2, linewidth=75)
    
    t_c = [0.5,  14.0, 15.0, 28.0, 11.0,  8.0,  3.0, -4.0,  6.0, 13.0, 21.0]
    t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]
    print(torch.tensor(t_c).shape)
    t_c = torch.tensor(t_c).unsqueeze(1) # 为每个张量增加一个维度
    t_u = torch.tensor(t_u).unsqueeze(1)
    
    # print(t_u)
    n_samples = t_u.shape[0]
    n_val = int(0.2 * n_samples)
    
    shuffled_indices = torch.randperm(n_samples)
    
    train_indices = shuffled_indices[:-n_val]
    val_indices = shuffled_indices[-n_val:]
    
    
    t_u_train = t_u[train_indices]
    t_c_train = t_c[train_indices]
    
    t_u_val = t_u[val_indices]
    t_c_val = t_c[val_indices]
    
    t_un_train = 0.1 * t_u_train
    t_un_val = 0.1 * t_u_val
    
    import torch.nn as nn
    
    linear_model = nn.Linear(1, 1) # 定义了一个简单的线性模型。nn.Linear是一个全连接层,用于实现线性变换
    linear_model(t_un_val)
    x = torch.ones(10, 1)
    linear_model(x)
    optimizer = optim.SGD( # 构建随机梯度下降(SGD)优化器
        linear_model.parameters(),
        lr=1e-2)   # linear_model.parameters(): 这是一个生成器,返回线性模型中所有需要优化的参数
    print(list(linear_model.parameters()))
    
    
    def training_loop(n_epochs, optimizer, model, loss_fn, t_u_train, t_u_val,
                      t_c_train, t_c_val):
        for epoch in range(1, n_epochs + 1):
            t_p_train = model(t_u_train)
            loss_train = loss_fn(t_p_train, t_c_train)
    
            t_p_val = model(t_u_val)
            loss_val = loss_fn(t_p_val, t_c_val)
    
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step() # 根据梯度和学习率等参数来更新模型的参数
    
            if epoch == 1 or epoch % 1000 == 0:
                print(f"Epoch {epoch}, Training loss {loss_train.item():.4f},"
                      f" Validation loss {loss_val.item():.4f}")
    
    
    
    training_loop(
        n_epochs = 3000,
        optimizer = optimizer,
        model = linear_model,
        loss_fn = nn.MSELoss(), # PyTorch 中用于计算均方误差(Mean Squared Error, MSE)的损失函数
        t_u_train = t_un_train,
        t_u_val = t_un_val,
        t_c_train = t_c_train,
        t_c_val = t_c_val)
    
    

从结果来看,效果还是很好的 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2074936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

客户信任的秘密武器:为什么每个网站都需要SSL证书?

SSL证书,是网络安全的一把钥匙,它不仅能够锁住数据的安全,还能够建立起用户与网站之间的信任桥梁。在这个数字化日益发展的时代,每个网站都需要配备SSL证书,其背后的原因是多方面的,涉及到技术、安全、信任…

一文掌握数据要素、数据资源、数据资产、数字资产、数据管理、数据治理、数字资产入表是什么?以及关系

数据要素、数据资源、数据资产、数字资产、数据管理、数据治理、数字资产入表到底是什么呢?他们之间是什么关系呢? 数据要素是构建块,数据资源是这些构建块的集合,而数据资产则是具有价值的资源。数据管理和数据治理则确保这些数据…

Lesson 87 A car crash

Lesson 87 A car crash 词汇 attendant n. 接待员,随从 构成:attend v. 出席,参加    -ant / -ent 人 例如:student 学生    assistant 助理 相关:attendance n. 出勤率 例句:Conan以前是一个好接待…

【已解决】我可以再docker里面装Nginx,然后再Nginx下装java吗?

我可以再docker里面装Nginx,然后再Nginx下装java吗? Docker 是一个开源的应用容器引擎,它允许开发者打包他们的应用以及应用的运行环境到一个可移植的容器中,然后发布到任何流行的 Linux 机器上,也可以实现虚拟化。Docker 容器通常…

私域流量池|家政小程序开发,便捷服务新模式

随着时代的进步和需求的日益增长,家政服务行业也迎来了显著的发展提升。随着科技的不断发展,数字化已经成为各行各业的重要趋势。家政小程序因此而应运而生,成为提高家政服务效率的智能化工具。不仅满足了用户对服务的灵活性需求,…

2024年开发者必备的一款服务端组件

最新技术资源(建议收藏) https://www.grapecity.com.cn/resources/ 前言 在现代工作环境中,信息的处理和管理是至关重要的。表格是一种常见的数据呈现和整理工具,被广泛应用于各行各业。然而,随着技术的不断发展&…

NC设计LRU缓存结构

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 描述 设计LRU(最近…

leetcode215. 数组中的第K个最大元素,小根堆/快排思想

leetcode215. 数组中的第K个最大元素 给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。 请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。 你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。…

vue面试集合

缓存 浏览器缓存和http缓存 浏览器缓存&#xff1a; 1&#xff0c;简单的缓存方式有cookie&#xff0c;localStorage和sessionStorage。 2&#xff0c;vue中keep-alive缓存动态组件&#xff1a; 全部缓存&#xff1a;使用<keep-alive>标签包裹缓存路由&#xff0c;ro…

JAVA电子器件制造行业生产管理系统计算机毕设计算机毕业设计

项目开发意义 目前小型企业基本上是采用人工完成生产及物料的车间计划,由于企业运作是以订单驱动而非计划生产,人工手段无法及时随新订单的到来更新计划,造成计划偏离实际;各个生产单位(车间)各自为战,分别提出物料、设备、专用工具的需求,在整个企业层面上很难较精确地控制物料…

机器学习:集成学习之随机森林

目录 前言 一、集成学习 1.集成学习的含义 2.集成学习的代表 3.集成学习的应用 二、随机森林 1.随机森林的特点 2.随机森林生成步骤 3.随机森林优点 4.随机森林的缺点 三、代码实现 1.完整代码 2.数据预处理 3.创建并训练模型 4.测试模型 总结 前言 随机森林是…

集合及数据结构第十二节(下)————哈希表、字符串常量池和练习题

系列文章目录 集合及数据结构第十二节&#xff08;下&#xff09;————哈希表、字符串常量池和练习题 哈希表、字符串常量池和练习题 哈希表的概念冲突-概念冲突-避免冲突-解决冲突严重时的解决办法冲突严重时的解决办法的实现性能分析和 java 类集的关系Hashmap的使用案…

R8RS标准之重要特性及用法实例(四十)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列…

LDR6500Type-C pd OTGi协议芯片讲解

LDR6500是一款由乐得瑞科技推出的USB-C DRP&#xff08;Dual Role Port&#xff0c;双角色端口&#xff09;接口USB PD&#xff08;Power Delivery&#xff0c;功率传输&#xff09;通信芯片。这款芯片具备一系列先进的功能和特点&#xff0c;特别适合于手机音频转接器、USB Ty…

QT中引入SQLITE3数据库

1、把sqlite3.dll、.h、.lib这三个文件拷贝到工程目录下 2、在pro文件中配置一下即可 LIBS $$PWD/sqlite3.lib 3、保存一下pro文件 4、引入sqlite3.h头文件 5、验证 先新建一个文件夹data&#xff0c;若没有user.db&#xff0c;则会自动新建&#xff1b;有就直接使用 运行成…

UTONMOS:探索未来游戏的元宇宙纪元新篇章

元宇宙游戏&#xff0c;作为融合了虚拟现实&#xff08;VR&#xff09;、增强现实&#xff08;AR&#xff09;、区块链、人工智能&#xff08;AI&#xff09;等前沿技术的综合性数字世界&#xff0c;元宇宙游戏不仅重新定义了游戏的边界&#xff0c;更预示着一个沉浸式、交互性…

YOlOV5入门教程

前言 因项目需求&#xff0c;所以要使用yolo进行操作&#xff0c;现在对yolov5进行教程&#xff0c;代码可以在这下载&#xff1a;https://github.com/ultralytics/yolov5 项目结构 下载完成后可以看到资源如图所示。 1.1.github文件夹 ISSUE_TEMPLATE 目录 这个目录下的文件…

Cesium 展示——绘制水面动态升高

文章目录 需求分析需求 如图,绘制水面动态升高,作为洪水淹没的效果 分析 我们首先需要绘制一个面然后给这个面一个高度,在回调函数中进行动态设置值【这里有两种,一种是到达水面一定高度停止升高,一种是水面重新升高】/*** @description :洪水淹没* @author : Hukang*…

关闭IDEA启动画面

新版IDEA启动时启动画面居中且无法最小化&#xff0c;所以想把它给隐藏掉。&#xff08;此操作不会加快启动速度&#xff09; 在快捷方式后加入参数 nosplash&#xff0c;记得有个空格。