从零开始学习线性回归:理论、实践与PyTorch实现

news2025/1/21 22:03:00

文章目录

  • 🥦介绍
  • 🥦基本知识
  • 🥦代码实现
  • 🥦完整代码
  • 🥦总结

🥦介绍

线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。

🥦基本知识

线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
y=b0+b1x1+b2x2+…+bpxp
y=b0​+b1​x1​+b2​x2​+…+bp​xp​

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。

损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
在这里插入图片描述
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。

梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
在这里插入图片描述
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE bjMSE 是损失函数对参数 b j b_j bj的偏导数。

🥦代码实现

如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
在这里插入图片描述

  • 准备数据
  • 设计模型(计算) y ^ i \hat{y}_i y^i
  • 构造损失和优化器
  • 训练周期(前向,反向 ,更新)

本节还是以刘二大人的视频讲解为例,结尾会设置传送门

class LinearModel(torch.nn.Module):
	def __init__(self):
		super(LinearModel, self).__init__() # 调用父类的构造函数
		self.linear = torch.nn.Linear(1, 1)  # 参数详情下图展示
	def forward(self, x):
		y_pred = self.linear(x)   # x代表输入样本的张量
		return y_pred
model = LinearModel()

所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算

好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
在这里插入图片描述

  • 第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。

  • 第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。

  • 第三个参数:意思是要不要偏置量。默认true

通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立

criterion = torch.nn.MSELoss(size_average=False)  # 使用均方误差损失 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

在这里插入图片描述
在这里插入图片描述
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。

优化器的选择有许多大家可以都试试看看
在这里插入图片描述

之后就进行训练了

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data) 
    print(epoch, loss.item())
    optimizer.zero_grad()   # 归零
    loss.backward()  # 反向
    optimizer.step()  # 更新
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

🥦完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]]) 
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__() 
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x) 
        return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data) 
    print(epoch, loss.item())
    optimizer.zero_grad() 
    loss.backward()
    optimizer.step()
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()

运行结果如下
在这里插入图片描述

在这里插入图片描述

🥦总结

在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1063886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Texifier 专业打造,让你的LaTeX编辑更高效!

作为LaTeX用户,你一定知道寻找一款优秀的编辑工具是多么重要。而Texifier(原Texpad)就是你在Mac上寻找的完美解决方案!它是一款专业的LaTeX编辑工具,为你带来高效、便捷的编辑体验。 Texifier拥有丰富的功能&#xff…

Linux网络编程系列之UDP协议编程

一、什么是UDP协议 UPD协议(User Datagram Protocol,用户数据报协议)是Internet协议族中的一个无连接协议,属于传输层,它不保证数据传输的可靠性或完整性,只是把应用程序发给网络层的数据封装成数据包进行传…

VL53L5CX驱动开发(1)----驱动TOF进行区域检测

VL53L5CX驱动开发----1.驱动TOF进行区域检测 闪烁定义视频教学样品申请源码下载主要特点硬件准备技术规格系统框图应用示意图区域映射生成STM32CUBEMX选择MCU 串口配置IIC配置X-CUBE-TOF1串口重定向代码配置Tera Term配置演示结果 闪烁定义 VL53L5CX是一款先进的飞行感应&…

【C语言】利用数组处理批量数据(字符数组)

前言:前面已经介绍了,字符数据是以字符的ASCII代码存储在存储单元中的,一般占一个字节。由于ASCII代码也属于整数形式,因此在C99标准中,把字符类型归纳为整型类型中的一种。 💖 博主CSDN主页:卫卫卫的个人主页 &#x…

使用ebpf 监控linux内核中的nat转换

1.简介 Linux NAT(Network Address Translation)转换是一种网络技术,用于将一个或多个私有网络内的IP地址转换为一个公共的IP地址,以便与互联网通信。 在k8s业务场景中,业务组件之间的关系十分复杂. 由于 Kubernete…

浅谈在操控器类中,为何要通过osgGA::CameraManipulator的逆矩阵改变视点位置

在osg代码目录下的include\osgGA目录存放了很多osg自带的操控器类,这些操控器类都派生自osgGA::CameraManipulator,而这个CameraManipulator又派生自osgGA::GUIEventHandler,可见其本质上是个事件处理类。因此它首先会接收事件,比…

月薪20k的软件测试工程师都要具备什么能力?你跟大佬的差距在哪?

第一,强大的业务能力:很熟悉业务流程,熟悉业务模块、数据、架构,测试所需资源。了解测试所需时间。 第二,发现bug能力:一般问题发现的能力,隐性问题发现能力,连带问题发现能力&…

专为实现最高性能和效率而设计,SQN3242UCKGTA、SQN3220SC、SQN3220 LTE-A Cat 6 模块【SKY85735-11射频前端】

一、SQN3242UCKGTA、SQN3220SC、SQN3220 LTE-A Cat 6 模块 1、简介 Sequans 的 Cassiopeia 是 Cat4 和 Cat6 LTE-Advanced 平台系列,包括集成了高性能网络和应用 CPU 的 SQN3220 Cat6 基带 SoC 和 SQN3220SC Cat4 基带 SoC、Sequans 的 SQN3242 LTE 优化收发器、经…

Pyhon-每日一练(1)

🌈write in front🌈 🧸大家好,我是Aileen🧸.希望你看完之后,能对你有所帮助,不足请指正!共同学习交流. 🆔本文由Aileen_0v0🧸 原创 CSDN首发🐒 如…

大华智慧园区前台任意文件上传(1day)

声明 本文仅用于技术交流,请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,文章作者不为此承担任何责任。 漏洞简介 大华智慧园区综合管理平台是一个集智能化、信息化、网络化、…

3D人脸生成的论文

一、TECA 1、论文信息 2、开源情况:comming soon TECA: Text-Guided Generation and Editing of Compositional 3D AvatarsGiven a text description, our method produces a compositional 3D avatar consisting of a mesh-based face and body and NeRF-based ha…

总结三:计算机网络面经

文章目录 1、简述静态路由和动态路由?2、说说有哪些路由协议,都是如何更新的?3、简述域名解析过程,本机如何干预域名解析?4、简述 DNS 查询服务器的基本流程是什么?DNS 劫持是什么?5、简述网关的…

CCS安装和运行TMS320F28004x第一个程序

1. CCS安装 TI 的MCU或者DSP,官方的集成开发环境是 Code Composer Studio™ ,要开发TI的芯片,首先需要安装 CCS 环境。 CCS 软件可以到下面的 TI 官网下载: https://www.ti.com.cn/tool/cn/CCSTUDIO 下载完之后,点击…

Pandas vs SQL全面对比

前一段时间给大家详解过 Pandas 的用法,今天再来分享下 Pandas 与 SQL 的对比。 Pandas 和 SQL 有很多相似之处,都是对二维表的数据进行查询、处理,都是数据分析中常用的工具。 对于只会 Pandas 或只会 SQL 的朋友,可以通过今天…

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】 1、前言2、实验环境3、自我总结1、基础了解仪器控制-熟悉仪器2、连接SCPI协议3、选择控制方式-程控方式-RS2324、代码编写 4、熟悉协议-SCPI协议5、测试实验-测试指令(1)硬件连接(…

学习记忆——图像篇——记忆古诗词

《长歌行》 青青园中葵,朝露待日晞。 阳春布德泽,万物生光辉。 常恐秋节至,焜黄华叶衰。 百川东到海,何时复西归? 少壮不努力,老大徒伤悲!

wisemodel 始智AI - 小记

文章目录 关于 wisemodel 始智AI 关于 wisemodel 始智AI https://www.wisemodel.cn/home 旨在打造中国版 “HuggingFace” 该社区汇聚了清华 / 智谱 chatglm2-6B、Stable Diffusion V1.5、alphafold2、seamless m4t large 等模型,以及 shareGPT、ultrachat、moss-…

80%测试员被骗,关于jmeter 的一个弥天大谎!

jmeter是目前大家都喜欢用的一款性能测试工具,因为它小巧、简单易上手,所以很多人都愿意用它来做接口测试或者性能测试,因此,在目前企业中,使用各个jmeter的版本都有,其中以jmeter3.x、4.x的应该居多。 但是…

网络安全行业真的内卷了吗?网络安全就业必看

前言 有一个特别流行的词语叫做“内卷”: 城市内卷太严重了,年轻人不好找工作;教育内卷;考研内卷;当然还有计算机行业内卷…… 这里的内卷当然不是这个词原本的意思,而是“过剩”“饱和”的替代词。 按照…

c++ 学习 之 运算符重载 知识要点

我们要好好分清楚一些运算符的结果为 左值还是 右值 赋值与调用