用Pytorch实现线性回归模型

news2025/1/22 23:34:56

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经忘了模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

代码

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

#2. Design Model
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1387329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是TestNG以及如何创建testng.xml文件?

目录 什么是TestNG? 如何创建testng.xml文件 手动创建testng.xml 通过testng.xml运行整个包 通过testng.xml运行类 使用Eclipse创建testng.xml 本文将讨论TestNG以及如何通过执行testng.xml文件在TestNG中运行第一个测试用例。 什么是TestNG? Te…

2.1.2 一个关于y=ax+b的故事

跳转到根目录:知行合一:投资篇 已完成: 1、投资&技术   1.1.1 投资-编程基础-numpy   1.1.2 投资-编程基础-pandas   1.2 金融数据处理   1.3 金融数据可视化 2、投资方法论   2.1.1 预期年化收益率   2.1.2 一个关于yaxb的…

Unity游戏图形学 Shader结构

shader结构 shader语言 openGL:SLG跨平台 >GLSL:openGL shaderlauguge DX:微软开发,性能很好,但是不能跨平台 >HLSL:high level shader language CG:微软和Nvidia公司联合开发&#xff…

2024年腾讯云新用户优惠云服务器价格多少?

腾讯云服务器租用价格表:轻量应用服务器2核2G3M价格62元一年、2核2G4M价格118元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、轻量4核8G12M服务器446元一年、646元15个月,云服务器CVM S5实例2核2G配置280.8元一年…

统计学-R语言-3

文章目录 前言给直方图增加正态曲线的不恰当之处直方图与条形图的区别核密度图时间序列图洛伦茨曲线计算绘制洛伦茨曲线所需的各百分比数值绘制洛伦茨曲线 练习 前言 本篇文章是介绍对数据的部分图形可视化的图型展现。 给直方图增加正态曲线的不恰当之处 需要注意的是&#…

项目解决方案:多个分厂的视频监控汇聚到总厂

目 录 1、概述 2、建设目标及需求 2.1 建设目标 2.2 需求描述 2.3 需求分析 3. 设计依据与设计原则 3.1 设计依据 3.2设计原则 1、先进性与适用性 2、经济性与实用性 3、可靠性与安全性 4、开放性 5、可扩充性 6、追求最优化的系统设备配置…

【数据结构】C语言实现共享栈

共享栈的C语言实现 导言一、共享栈1.1 共享栈的初始化1.2 共享栈的判空1.3 共享栈的入栈1.3.1 空指针1.3.2 满栈1.3.3 入栈空间错误1.3.4 正常入栈1.3.5 小结 1.4 共享栈的查找1.5 共享栈的出栈1.6 共享栈的销毁 二、共享栈的实现演示结语 导言 大家好,很高兴又和大…

JVM-Arthas高效的监控工具

一、arthas介绍 3.选择监控哪个进程 4.进入具体进程 二、arthas的基础命令与基本操作 1.查询包含Java的系统属性: 命令:sysprop |grep java 1.查询不含Java的系统属性: 命令:sysprop | grep -v java 3.打印历史命令 命令&#…

排序算法之八:计数排序

1.计数排序思想 计数排序,顾名思义就是计算数据的个数 计数排序又称非比较排序 思想:计数排序又称为鸽巢原理,是对哈希直接定址法的变形应用。 操作步骤: 统计相同元素出现次数 根据统计的结果将序列回收到原来的序列中 计数…

20240115如何在线识别俄语字幕?

20240115如何在线识别俄语字幕? 2024/1/15 21:25 百度搜索:俄罗斯语 音频 在线识别 字幕 Bilibili:俄语AI字幕识别 音视频转文字 字幕小工具V1.2 BING:音视频转文字 字幕小工具V1.2 https://www.bilibili.com/video/BV1d34y1F7…

嵌入式软件工程师面试题——2025校招社招通用(十八)

说明: 面试群,群号: 228447240面试题来源于网络书籍,公司题目以及博主原创或修改(题目大部分来源于各种公司);文中很多题目,或许大家直接编译器写完,1分钟就出结果了。但…

sqli-labs关卡23(基于get提交的过滤注释符的联合注入)

文章目录 前言一、回顾前几关知识点二、靶场第二十三关通关思路1、判断注入点2、爆数据库名3、爆数据库表4、爆数据库列5、爆数据库关键信息 总结 前言 此文章只用于学习和反思巩固sql注入知识,禁止用于做非法攻击。注意靶场是可以练习的平台,不能随意去…

SQL-用户管理与用户权限

🎉欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克🍹 ✨博客主页:小小恶斯法克的博客 🎈该系列文章专栏:重拾MySQL 🍹文章作者技术和水平很有限,如果文中出现错误&am…

【JupyterLab】在 conda 虚拟环境中 JupyterLab 的安装与使用

【JupyterLab】在 conda 虚拟环境中 JupyterLab 的安装与使用 1 JupyterLab 介绍2 安装2.1 Jupyter Kernel 与 conda 虚拟环境 3 使用3.1 安装中文语言包(Optional)3.2 启动3.3 常用快捷键3.3.1 命令模式下 3.4 远程访问个人计算机3.4.1 局域网下 1 JupyterLab 介绍 官方文档: …

鸿蒙开发笔记(一):ArkTS概述及声明式UI的使用

ArkTS是HarmonyOS优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript(简称TS)生态基础上做了进一步扩展,继承了TS的所有特性,是TS的超集。 ArkTS在TS的基础上主要扩展了如下能力: 基本语法:ArkTS定义…

给 Linux 主机添加 SSH 双因子认证

GitHub:https://github.com/google/google-authenticator-android 在信息时代,服务器安全愈发成为首要任务。Linux 主机通过 ssh 方式连接,当存在弱密码的情况下,通过暴力破解的方式会很容易就被攻破了,本文将向你展示…

一文搞懂系列——Linux C线程池技术

背景 最近在走读诊断项目代码时,发现其用到了线程池技术,感觉耳目一新。以前基本只是听过线程池,但是并没有实际应用。对它有一丝的好奇,于是趁这个机会深入了解一下线程池的实现原理。 线程池的优点 线程池出现的背景&#xf…

Lede(OpenWrt)安装和双宽带叠加

文章目录 一、Lede介绍1. 简介2. 相关网站 二、Lede安装1. 编译环境2. SHELL编译步骤3. 腾讯云自动化助手 三、Lede配置1. 电信接口配置2. 联通接口配置3. 多线多播配置4. 网速测试效果 一、Lede介绍 1. 简介 LEDE是一个专为路由器和嵌入式设备设计的自由和开源的操作系统。 …

HTML--JavaScript--引入方式

啊哈~~~基础三剑看到第三剑,JavaScript HTML用于控制网页结构 CSS用于控制网页的外观 JavaScript用于控制网页的行为 JavaScript引入方式 引入的三种方式: 外部JavaScript 内部JavaScript 元素事件JavaScript 引入外部JavaScript 一般情况下网页最好…