【youcans的深度学习 D01】PyTorch例程:从极简线性模型开始

news2025/3/1 1:28:57

欢迎关注『youcans的深度学习』系列


【youcans的深度学习 D01】PyTorch 例程:从极简线性模型开始

    • 1. PyTorch 建模的基本步骤
    • 2. 线性模型的结构
    • 3. 建立 PyTorch 线性模型
      • 3.1 准备数据集
      • 3.2 定义线性模型类
      • 3.3 建立一个线性模型
      • 3.4 模型训练
      • 3.5 模型推断
    • 4. PyTorch 线性模型例程
    • 5. 小结


在前面的章节我们已经介绍了 Pytorch 中的数据加载和模型建立。

本节在此基础上,实现一个最简单的案例:创建一个线性网络模型。该模型虽然极简,但已经完整包括了 PyTorch 建立模型、模型训练和模型预测的关键步骤,有助于读者理解 PyTorch 深度学习的基本方法。


1. PyTorch 建模的基本步骤

使用 PyTorch 建立、训练和使用神经网络模型的基本步骤如下。

  1. 准备数据集(Prepare dataset)。
  2. 创建网络模型(Design model using Class)。
  3. 定义损失函数和优化器(Construct loss and optimizer)。
  4. 模型训练(Trainning the model)。
  5. 模型测试(Testing the model)。
  6. 模型保存与加载(Saving and loading a model)。
  7. 模型推理(Inferring model)。

2. 线性模型的结构

对于线性模型
y = w i ∗ x i + b ,   i = 1 , . . . n y = w_i * x_i + b, \space i=1,...n y=wixi+b, i=1,...n

在这里插入图片描述


3. 建立 PyTorch 线性模型

3.1 准备数据集

简单地,考虑 n=1 时的单输入单输出系统:
y = w ∗ x + b y = w * x + b y=wx+b

随机生成一组样本输入值,试验得到对应的样本输出值,构造训练样本数据集 {Xtrain, Ytrain}。

PyTorch 中的基本数据结构是张量 Tensor,训练样本数据集的输入 Xtrain 和输出 Ytrain 的类型都是 Tensor。

# (1) 建立样本数据集
Ntrain = 100  # 训练样本数目
Xtrain = torch.rand(Ntrain, 1)  # 输入值,均匀分布
noise = torch.normal(0.0, 0.1, Xtrain.shape)  # 高斯分布,加性噪音
Ytrain = 3.6 + 2.5 * Xtrain + noise  # 模拟试验的输出值
print("Xtrain.shape: {}, Ytrain.shape: {}".format(Xtrain.shape, Ytrain.shape))

3.2 定义线性模型类

使用 PyTorch 构造神经网络模型,需要运用__call__()__init__()方法定义模型类 Class。__init__()方法是类的初始化函数,类似于C++的构造函数。__call__()方法使类对象具有类似函数的功能。

# (2) 定义线性模型类
class LinearModel(torch.nn.Module):
    def __init__(self):  # 构造函数
        super(LinearModel, self).__init__()
        # 构造 linear 对象,并说明输入输出的维数,第三个参数默认为 true
        self.linear = torch.nn.Linear(1, 1)  # 包括 2 个参数 weight 和 bias

    def forward(self, x):  # 重写 forward 函数
        y_pred = self.linear(x)  # 可调用对象,计算 y=wx+b
        return y_pred

nn.Module 是所有神经网络单元(neural network modules)的基类。PyTorch在nn.Module中实现了__call__()方法,在 __call__() 方法中调用 forward 函数。

nn.Linear 定义一个神经网络的线性层 ,方法如下:

torch.nn.Linear(in_features, # 输入的神经元个数
            out_features, # 输出神经元个数
            bias=True # 是否包含偏置
            )

表示对输入 X n ∗ i X_{n*i} Xni 进行线性加权求和:
Y n ∗ o = X n ∗ i W i ∗ o + b Y_{n*o} = X_{n*i} W{i*o} + b Yno=XniWio+b

本例中输入 x 和输出 y 都是一维,nn.Linear(1, 1) 中的参数 (1, 1) 表示 i=o=1 。


3.3 建立一个线性模型

建立线性模型,包括三个步骤:

  • 实例化线性模型对象
  • 设置损失函数 Loss
  • 设置优化器 optim
    # (3) 实例化线性模型对象
    model = LinearModel()  # 实例化 LinearModel
    # (4) 构造损失函数 Loss
    criterion = torch.nn.MSELoss(reduction='sum')
    # (5) 构造优化器 optim
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # lr为学习率
    # 优化器对象创建时需要传入模型参数 model.parameters(),将扫描 module中的所有成员

torch.nn.functional 模块包含内置损失函数,均方误差损失为 mse_loss。

torch.optim.SGD 表示随机梯度下降优化器,注意要将 model 的参数 model.parameters() 传给优化器对象,以便 SGD 优化器扫描需要优化的参数。


3.4 模型训练

模型训练的基本步骤是:

  1. 前馈计算模型的输出值;
  2. 计算损失函数值;
  3. 计算权重 weight 和偏差 bias 的梯度;
  4. 根据梯度值调整模型参数;
  5. 将梯度重置为 0(用于下一循环)。
    # (6) 模型训练
    epoch_list = []
    loss_list = []
    for epoch in range(100):
        Ypred = model(Xtrain)  # 前馈计算模型输出值
        loss = criterion(Ypred, Ytrain)  # 计算损失函数
        optimizer.zero_grad()  # 梯度归零
        loss.backward()  # 误差反向传播
        optimizer.step()  # 权重更新
        epoch_list.append(epoch)  # 记录迭代次数
        loss_list.append(loss.item())  # 记录损失函数

        if epoch%10==0:
            print("Epoch {}: loss={:.4f}".format(epoch, loss.item()))

3.5 模型推断

    # (7) 模型预测
    Ypred = model(Xtrain)

4. PyTorch 线性模型例程

# DNNdemo01_v1.py
# DNNdemo of PyTroch: Linear regression
# PyTorch 例程: 01 线性回归
# Copyright: youcans@qq.com
# Crated: Huang Shan, 2023/03/11

import torch
import matplotlib.pyplot as plt

# (2) 定义线性模型类
class LinearModel(torch.nn.Module):
    def __init__(self):  # 构造函数
        super(LinearModel, self).__init__()
        # 构造 linear 对象,并说明输入输出的维数,第三个参数默认为 true
        self.linear = torch.nn.Linear(1, 1)  # 包括 2 个参数 weight 和 bias

    def forward(self, x):  # 重写 forward 函数
        ypred = self.linear(x)  # 可调用对象,计算 y=wx+b
        return ypred


if __name__ == "__main__":
    # (1) 建立样本数据集
    Ntrain = 50  # 训练样本数目
    Xtrain = torch.rand(Ntrain, 1)  # 输入值,均匀分布
    noise = torch.normal(0.0, 0.1, Xtrain.shape)  # 高斯分布,加性噪音
    Ytrain = 2.0 * Xtrain + 3.6 + noise  # 模拟试验的输出值
    print("Xtrain.shape: {}, Ytrain.shape: {}".format(Xtrain.shape, Ytrain.shape))

    # (3) 实例化线性模型对象
    model = LinearModel()  # 实例化 LinearModel

    # (4) 设置损失函数 Loss 和 优化器 optim
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # lr为学习率
    # 优化器对象创建时需要传入模型参数 model.parameters(),将扫描 module中的所有成员

    # (5) 模型训练
    epoch_list = []
    loss_list = []
    for epoch in range(10):  # 训练轮次
        Ypred = model(Xtrain)  # 前馈计算模型输出值
        loss = criterion(Ypred, Ytrain)  # 计算损失函数
        optimizer.zero_grad()  # 梯度归零
        loss.backward()  # 误差反向传播
        optimizer.step()  # 权重更新
        epoch_list.append(epoch)  # 记录迭代次数
        loss_list.append(loss.item())  # 记录损失函数

        if epoch%10==0:
            print("Epoch {}: loss={:.4f}".format(epoch, loss.item()))

    # 迭代训练后的模型参数 weight 和 bias
    print("Linear Model: y = w * x + b")
    print('w = ', model.linear.weight.item())
    print('b = ', model.linear.bias.item())

    # (6) 模型预测
    Ypred = model(Xtrain)
    print("Xtest.shape: {}, Ytest.shape: {}".format(Xtrain.shape, Ypred.shape))

    # (7) 绘图
    plt.figure(figsize=(9, 4))
    plt.suptitle("Response curve of activation function")
    plt.subplot(121)
    plt.plot(Xtrain.numpy(), Ytrain.numpy(), 'ro', label="train")
    plt.plot(Xtrain.numpy(), Ypred.detach().numpy(), 'bx', label="model")
    plt.xlabel('x'), plt.ylabel('y'), plt.legend()
    plt.title("Prediction of linear model")
    plt.subplot(122)
    plt.plot(epoch_list, loss_list)
    plt.xlabel('times'), plt.ylabel('loss')
    plt.title("Loss of Linear Model")
    plt.show()

运行结果:

Xtrain.shape: torch.Size([20, 1]), Ytrain.shape: torch.Size([20, 1])
Epoch 0: loss=309.5900
Epoch 10: loss=0.3274
Epoch 20: loss=0.2995
Epoch 30: loss=0.2807
Epoch 40: loss=0.2678
Epoch 50: loss=0.2590
Epoch 60: loss=0.2530
Epoch 70: loss=0.2489
Epoch 80: loss=0.2461
Epoch 90: loss=0.2442

Linear Model: y = w * x + b
w = 1.989990234375
b = 3.5692780017852783


在这里插入图片描述


5. 小结

本节实现了一个最简单的案例:创建一个线性网络模型。

该模型虽然极简,但已经完整包括了 PyTorch 建立模型、模型训练和模型预测的基本步骤。

本文中简化了很多环节,是为了便于读者理解和掌握 PyTorch 深度学习的核心步骤。

【本节完】


版权声明:
欢迎关注『youcans的深度学习』系列,转发请注明原文链接:
【youcans的深度学习 D01】PyTorch例程:极简线性模型
Copyright 2023 youcans, XUPT
Crated:2023-04-18


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/435889.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java-处理xml格式数据

处理xml格式数据 前言一、java处理xml格式数据1、 生成XML格式数据2、 解析XML格式数据 二、问题三、常用类及方法介绍 前言 dom4j是java中的XML API,性能优异、功能强大、开放源代码。 也是所有解析XML文件方法中最常用的! 一、java处理xml格式数据 …

榜单发布 新能源乘用车OBC赛道进入转型升级周期

新能源汽车尤其是纯电动汽车市场的快速普及,也带动一批相关核心零部件厂商做大做强。比如,以车载充电机OBC及集成电源行业为例,威迈斯、富特科技等数家公司正在冲刺IPO。 目前,车载电源领域产品主要分为三种:一是单一…

步入AIGC时代,展望人工智能发展

步入AIGC时代,展望人工智能发展 0. 前言1. 步入 AIGC 时代1.1 人工智能简介1.2 AIGC 简介1.3 AIGC 发展与应用 2. CSIG 企业行——走进合合信息2.1 活动介绍2.2 走进合合信息 3. 文档图像处理中的底层视觉技术3.1 什么是底层视觉3.2 智能图像处理技术3.3 智能图像处…

消息中间件RabbitMQ---概述和概念 【一】

1、概述 1、大多应用中,可通过消息服务中间件来提升系统异步通信、扩展解耦能力 2、消息服务中两个重要概念: 消息代理(message broker)和目的地(destination) 当消息发送者发送消息以后,将由…

C语言中数据结构——顺序表

🐶博主主页:ᰔᩚ. 一怀明月ꦿ ❤️‍🔥专栏系列:线性代数,C初学者入门训练,题解C,C的使用文章,「初学」C 🔥座右铭:“不要等到什么都没有了,才下…

java多线程详细讲解 线程的创建、线程的状态、synchronized锁、Volatile关键字、和cas锁(自旋锁 乐观锁 无锁)

java多线程详细讲解 线程的创建、线程的状态、synchronized锁、Volatile关键字、和cas锁(自旋锁 乐观锁 无锁) 一、线程的概念二、创建线程的三种方式三、线程方法Sleep、Yield、Join四、线程的执行状态五、synchronized关键字1.为什么要上锁?2.锁定的内…

SDL初识(1)

简介 SDL(Simple DirectMedia Layer) 是一个跨平台开发库,旨在通过 OpenGL 和 Direct3D 提供对音频、键盘、鼠标、操纵杆和图形硬件的低级访问。 SDL 支持 Windows、Mac OS X、Linux、iOS 和 Android。可以在源代码中找到对其他平台的支持。SDL 是用 C 语言编写的…

JavaScript【六】JavaScript中的字符串(String)

文章目录 🌟前言🌟字符串(String)🌟单引号和双引号的区别🌟属性🌟 length :字符串的长度 🌟 方法🌟 str.charAt(index);🌟 str.charCodeAt(index);🌟 String.fromCharCod…

死磕“增长”:火山引擎的实用主义

作者 | 曾响铃 文 | 响铃说 在刘慈欣的科幻小说《三体》中,地外文明为了封锁地球科技,在天文台向地球科学家展现了「宇宙闪烁」这一奇观,试图颠覆人类的认知,从而影响科技进步,促使地球科技发展陷入停滞。 如今&…

给你们讲个笑话——低代码会取代程序员

今天是正经男,我们严肃讨论一下一直以来争吵不休的取代问题。 低代码开发平台,低代码技术会取代开发人员么? 一、背景 低代码开发平台的普及,让很多公司对快速生成应用抱有很大期望。甚至有人认为,低代码开发平台未来…

MTLAB绘图

这里写目录标题 一、图例1、散点图 二、绘图1、总体图形参数2、坐标、图框、网格图框去上右边框小刻度网格坐标范围和刻度控制旋转 坐标、刻度 3、图例图例位置和方向 Location和Orientation图例加标题 、分多列 4、文本 字、字体、字号5、线型 符号6、颜色栏 colorbar7、颜色8…

【技能分享】CAD转SHP最好的方法

1、利用 ArcToolsbox 工具先将 DWG 文件转为 MDB 通过 CASS 软件生成的 DWG 文件,字段中包含有很多属性内容,所以我们先将 DWG 格式 的文件转换为 MDB 格式,再通过 MDB 转换为 SHP 格式数据进行整理。具体步骤如下: 通过 ArcTool…

2023Mathorcup高校数学建模挑战赛ABCD选题建议

提示&#xff1a;本科同学尽量选择C、D题进行作答&#xff0c;获奖率相对会高。C君认为的难度&#xff1a;AD<C<B&#xff0c;开放度&#xff1a;B<C<A<D 。 A题 量子计算机在信用评分卡组合优化中的应用 这道题目是传统的运筹学题目。需要建立客户信用等级的…

阿里ARouter 路由框架解析

一、简介 众所周知&#xff0c;在日常开发中&#xff0c;随着项目业务越来越复杂&#xff0c;项目中的代码量也越来越多&#xff0c;如果维护、扩展、解耦等成了一个非常头疼问题&#xff0c;随之孕育而生的诸如插件化、组件化、模块化等热门技术。 而其中组件化中一项的难点&…

Spring Cloud 之五:Feign使用Hystrix

系列目录&#xff08;持续更新。。。&#xff09; Spring Cloud之一&#xff1a;注册与发现-Eureka工程的创建 Spring Cloud之二&#xff1a;服务提供者注册到Eureka Server Spring Cloud之三&#xff1a;Eureka Server添加认证 Spring Cloud之四&#xff1a;使用Feign实现…

camunda如何监控流程执行

在 Camunda 中&#xff0c;可以使用 Camunda 提供的用户界面和 API 来监控流程的执行情况。以下是几种常用的监控流程执行的方式&#xff1a; 1、使用 Camunda Cockpit&#xff1a;Camunda Cockpit 是 Camunda 官方提供的流程监控和管理工具&#xff0c;可以在浏览器中访问 Co…

【百面成神】消息中间件基础7问,你能撑到第几问

前 言 &#x1f349; 作者简介&#xff1a;半旧518&#xff0c;长跑型选手&#xff0c;立志坚持写10年博客&#xff0c;专注于java后端 ☕专栏简介&#xff1a;纯手打总结面试题&#xff0c;自用备用 &#x1f330; 文章简介&#xff1a;消息中间件最基础、重要的9道面试题 文章…

Android中的MVVM架构:使用Jetpack组件实现现代化的应用架构

Android中的MVVM架构&#xff1a;使用Jetpack组件实现现代化的应用架构 Jetpack组件是构建现代Android应用的绝佳利器&#xff0c;组件化设计让构建App如此简单。 引言 随着移动应用的日益复杂和功能的不断增加&#xff0c;构建稳健、可扩展和易维护的Android应用变得越来越重…

[考研数据结构] 第3章之队列的基本知识与操作

文章目录 队列的基本概念 队列的顺序存储 顺序队列 存储类型 基本操作 循序队列 存储类型 基本操作 循环队列判空与判满的三种解决方案 方法一&#xff1a;牺牲一个存储单元 方法二&#xff1a;类型增设记录型变量size 方法三&#xff1a;类型增设标志型变量tag 队…

嵌入式【协议篇】CAN协议原理

一、CAN协议介绍 1、简介 CAN是控制器局域网络(Controller Area Network, CAN)的简称,是一种能够实现分布式实时控制的串行通信网络。 其实可以简单把CAN通信理解成开一场电话会议,当一个人讲话时其他人就听(广播),当多个人同时讲话时则根据一定规则来决定谁先讲话谁后讲…