Lecture5 实现线性回归(Linear Regression with PyTorch)

news2025/1/22 15:58:39

目录

1 Pytorch实现线性回归

1.1 实现思路

1.2 完整代码

2 各部分代码逐行详解

2.1 准备数据集

2.2 设计模型

2.2.1 代码

2.2.2 代码逐行详解

2.2.3 疑难点解答

2.3 构建损失函数和优化器

2.4 训练周期

2.5 测试结果

3 线性回归中常用优化器


1 Pytorch实现线性回归

1.1 实现思路

图1 实现线性回归主要过程

图2 线性回归计算图

1.2 完整代码

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(500):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

2 各部分代码逐行详解

2.1 准备数据集

在PyTorch中,一般需要采取mini-batch形式构建数据集,也就是把数据集定义成张量(Tensor)形式,以方便后续计算。

在下面这段代码中,x_data是个二维张量,它有3个样本,每个样本有1个特征值,即维度是 (3, 1);y_data同理。不清楚的同学可以使用 x.dim() 方法和 x.shape 属性来获取张量的维度和尺寸,自行调试。简言之,在minibatch中,行表示样本,列表示feature

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2.2 设计模型

图3 目标计算图

主要目标:构建计算图

2.2.1 代码

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
model = LinearModel()

2.2.2 代码逐行详解

class LinearModel(torch.nn.Module):

一般我们需要一个类,并继承自PyTorch的Module类,这是因为torch.nn.Module提供了很多有用的功能,使得我们可以更方便地定义、训练和使用神经网络模型。

接下来至少需要实现两个函数,即initforward

__init__方法

    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

该方法对模型的参数进行初始化

super(LinearModel, self).__init__() 中,第一个参数 LinearModel 指定了查找的起点,即在 LinearModel 类的父类中查找;第二个参数 self 指定了当前对象,即调用该方法的对象。该语句的作用是调用 LinearModel 的父类 torch.nn.Module__init__ 方法,并对父类的属性进行初始化。这是初始化模型的一个必要语句。

接下来将一个torch.nn.Linear对象实例化并赋值给self.linear属性。torch.nn.Linear 的构造函数接收三个参数:in_features 、 out_features、bias,分别代表输入特征的数量、输出特征的数量和偏置量。

图4 Linear类构造函数参数介绍

forward方法

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

forward()方法作用是进行前馈运算,相当于计算\hat{y}=\omega x + b

注意这里相当于是重写了torch.nn.Linear 类中的forward方法。在我们重写forward后,函数将会执行的过程如下:

图5 forward前馈运算

y_pred = self.linear(x) 的作用是将输入 x 传入全连接层进行线性变换,得到输出 y_pred

最后通过实例化LinearModel类来调用模型

model = LinearModel()

2.2.3 疑难点解答

1、可能你会有疑问,代码中的backward过程体现在哪呢?

答:torch.nn.Module类构造出的对象会自动完成backward过程。Module 类及其子类在前向传递时会自动构建计算图,并在反向传播(backward)时自动进行梯度计算和参数更新。比如self.linear=torch.nn.Linear(1, 1),

这里的linear属性得到Linear类的实例后,相当于继承自Module,所以它也会自动进行backward,就无须我们再手动求导了。

2、y_pred = self.linear(x) 中,linear为什么后面可以直接跟括号呢?

这里涉及到了python语法中的可调用对象(Callable Object)知识点。在self.linear后面加括号,相当于直接在对象上加括号,相当于实现了一个可调用对象

self.linear = torch.nn.Linear(1, 1)中,相当于我们创建了一个Module对象,因为nn.Linear类继承自nn.Module类。

接着我们执行了y_pred = self.linear(x)这段代码,相当于我们调用了Moudle 类的 __call__ 方法。

于是nn.Module类的__call__方法又会进一步去自动调用模块的forward方法。

举个例子:

class Adder:
    def __init__(self, n):
        self.n = n

    def __call__(self, x):
        return self.n + x

add5 = Adder(5)
print(add5(3))  # 输出 8

在这个例子中,我们定义了一个 Adder 类,它接受一个参数 n,并且实现了 __call__ 方法。当我们创建 add5 对象时,实际上是创建了一个 Adder 对象,并且把参数 n 设置为 5。当我们调用 add5 对象时,实际上是调用了 Adder 对象的 __call__ 方法,

通过实现 __call__ 方法,我们可以让对象像函数一样被调用,这在一些场景下很有用,例如,我们可以用它来实现一个状态机、一个闭包或者一个装饰器等。

3、权重体现在哪?forward里面好像没涉及到权重值的传入?

这里 self.linear 实际上是一个 PyTorch 模块(Module),包含了权重矩阵和偏置向量,于是我们便可以用这个对象来完成下图所示计算

图6 模块成员关系图

图7 nn.Linear包含两个成员

那么权重是怎么传入forward中的呢?

torch.nn.Linear类的构造函数__init__中,它会自动创建一个nn.Parameter对象,用于存储权重,并将其注册为模型的可学习参数(Learnable Parameter)

这个nn.Parameter对象的创建代码位于nn.Linear类的__init__函数中的这一行:

图8 Linear类中的weight接收器

因此,self.linear中的weight属性实际上是从nn.Parameter对象中获取的。在forward方法中,self.linear会自动获取到它的weight属性,并用它来完成矩阵乘法的操作。

2.3 构建损失函数和优化器

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

图9 MSE损失函数公式

torch.nn.MSELoss 是一个均方误差损失函数,用于计算模型输出与真实值之间的差异,即MSE。其中,size_average 参数指定是否对损失求均值,默认为 True,即求平均值。在这个例子中,size_average=False 意味着我们希望得到所有样本的平方误差之和。

图10 SGD随机梯度下降公式

torch.optim.SGD 是随机梯度下降优化器,用于更新神经网络中的参数。其中,model.parameters() 对神经网络中的参数进行优化,它会检查所有成员,告诉优化器需要更新哪些参数。在反向传播时,优化器会通过这些参数计算梯度并对其进行更新。lr 参数表示学习率,即每次参数更新的步长。在这个例子中,我们使用随机梯度下降作为优化器,学习率为 0.01。最后我们得到了一个优化器对象optimizer

2.4 训练周期

for epoch in range(500): # 训练500轮
    y_pred = model(x_data)  # 前向计算
    loss = criterion(y_pred, y_data)  # 计算损失
    print(epoch, loss.item())  # 打印损失值
    optimizer.zero_grad() # 梯度清零,不清零梯度的结果就变成这次的梯度+原来的梯度
    loss.backward()  # 反向传播
    optimizer.step()  # 更新权重

2.5 测试结果

循环迭代进行训练500轮。

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果部分截图:

0 23.694297790527344
1 10.621758460998535
2 4.801174163818359
3 2.208972215652466
4 1.0539695024490356
5 0.5387794971466064
6 0.3084312379360199
7 0.20490160584449768
8 0.1578415036201477
9 0.13593381643295288
10 0.12523764371871948
11 0.1195460706949234
12 0.11609543859958649

···
494 0.00010695526725612581
495 0.00010541956726228818
496 0.00010390445095254108
497 0.00010240855044685304
498 0.00010094392928294837
499 9.949218656402081e-05
w =  1.993359923362732
b =  0.015094676986336708
y_pred =  tensor([[7.9885]])

Process finished with exit code 0
 

总之,求yhat,求loss,然后backward,最后更新权重

3 线性回归中常用优化器

• torch.optim.Adagrad
• torch.optim.Adam
• torch.optim.Adamax
• torch.optim.ASGD
• torch.optim.LBFGS
• torch.optim.RMSprop
• torch.optim.Rprop
• torch.optim.SGD

阅读官方教程的更多示例:

Learning PyTorch with Examples — PyTorch Tutorials 1.13.1+cu117 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/350979.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络协议(七):传输层-UDP

网络协议系列文章 网络协议(一):基本概念、计算机之间的连接方式 网络协议(二):MAC地址、IP地址、子网掩码、子网和超网 网络协议(三):路由器原理及数据包传输过程 网络协议(四):网络分类、ISP、上网方式、公网私网、NAT 网络…

06- 信用卡反欺诈 (机器学习集成算法) (项目六)

本项目为 kaggle 项目 项目难点在于: 盗刷的比例占总数据量的比例较低, 直接预测为非盗刷也有 99.8% 的准确率.data.info() # 查看所有信息msno.matrix(data) # 查看缺失值axis1 时 # 删除列显示颜色种类 from matplotlib import colors plt.colormaps() # mag…

关于知识图谱TransR

论文题目 Learning Entity and Relation Embeddings for Knowledge Graph Completion 论文链接 TransR 文中指出,不管是TransE还是TransH都是将实体和关系映射同一空间,但是,一个实体可能具有多个层面的信息,不同的关系可能关注…

ray简单介绍

ray使用也有一段时间了, 这篇文章总结下ray的使用场景和用法 ray可以做什么? 总结就两点: 可以将其视为一个进程池(当然不仅限于此), 可以用于开发并发应用还可以将应用改造为分布式 基于以上两点, 有人称之为:Modern Parallel and Distributed Python 构成 Ray AI Runtim…

Redis多级缓存

文章目录一. 什么是多级缓存二. JVM进程缓存一. 什么是多级缓存 传统的缓存策略一般是请求到达Tomcat后,先查询Redis,如果未命中则查询数据库,如图: 存在下面的问题: 请求要经过Tomcat处理,Tomcat的性能…

Linux高级IO

文章目录一、五种 IO 模型1.阻塞 IO2.非阻塞 IO3.信号驱动 IO4. IO 多路转接5.异步 IO二、高级 IO 重要概念1.同步通信和异步通信2.阻塞和非阻塞fcntl 系统调用3.其他高级 IO三、I/O 多路转接之 select1.函数原型socket 就绪的条件2.理解 select 的执行过程3.使用示例4. select…

新手小白如何入门黑客技术?

你是否对黑客技术感兴趣呢?感觉成为黑客是一件很酷的事。那么作为新手小白,我们该如何入门黑客技术,黑客技术又是学什么呢? 其实不管你想在哪个新的领域里有所收获,你需要考虑以下几个问题: 首先&#xff…

Springboot扩展点之FactoryBean

前言FactoryBean是一个有意思,且非常重要的扩展点,之所以说是有意思,是因为它老是被拿来与另一个名字比较类似的BeanFactory来比较,特别是在面试当中,动不动就问你:你了解Beanfactory和FactoryBean的区别吗…

spring cloud gateway网关和链路监控

文章目录 目录 文章目录 前言 一、网关 1.1 gateway介绍 1.2 如何使用gateway 1.3 网关优化 1.4自定义断言和过滤器 1.4.1 自定义断言 二、Sleuth--链路追踪 2.1 链路追踪介绍 2.2 Sleuth介绍 2.3 使用 2.4 Zipkin的集成 2.5 使用可视化zipkin来监控微服务 总结 前言 一、网关…

ubuntu wordpress建站

nginx 安装测试 https://blog.csdn.net/leon_zeng0/article/details/113578143 ubuntu 基于apache2安装wordpress https://ubuntu.com/tutorials/install-and-configure-wordpress#7-configure-wordpress 报错403的话,是权限没搞对,解决参考https://ww…

空间误差分析:统一的应用导向处理(Matlab代码实现)

👨‍🎓个人主页:研学社的博客💥💥💞💞欢迎来到本博客❤️❤️💥💥🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密…

深度学习算法面试常问问题(二)

X86和ARM架构在深度学习侧的区别? X86和ARM架构分别应用于PC端和低功耗嵌入式设备,X86指令集很复杂,一条很长的指令就可以完成很多功能;而ARM指令集很精简,需要几条精简的短指令完成很多功能。 影响模型推理速度的因…

mysql分库分表概念及原理、ShardingSphere实现mysql集群分库分表读写分离

一:分库分表概念 1.1 为什么要对数据库进行分表 索引的极限:单表数据量达到几十万或上百万以上,使用索引性能提升也不明显。 分表使用门槛:单表行数超过 500 万行或者单表容量超过 2GB,才推荐进行分库分表。 分表适用…

MIT 6.S965 韩松课程 04

Lecture 04: Pruning and Sparsity (Part II) 文章目录Lecture 04: Pruning and Sparsity (Part II)剪枝率分析每层的敏感度自动剪枝微调和训练稀疏网络彩票假说稀疏度的系统支持不均衡负载M:N 稀疏度本讲座提纲章节 1:剪枝率分析每层的敏感度AMC: AutoML for Model…

C#:Krypton控件使用方法详解(第四讲) ——kryptonLabel

今天介绍的Krypton控件中的kryptonLabel,下面开始介绍这个控件的属性:首先介绍控件中的外观属性:Cursor属性:表示功能为鼠标移动过这个控件的时候显示光标的类型。Text属性:表示显示的文本内容。其他属性不做过多的介绍…

编写 Cypher 代码续

编写 Cypher 代码 过滤查询 查看图中的唯一性约束索引 SHOW CONSTRAINTS查看图中关系的属性类型 CALL db.schema.relTypeProperties()查看图中节点的属性类型 CALL db.schema.nodeTypeProperties()查看数据模型 CALL db.schema.visualization()用 WHERE 子句添加过滤条件 查询执…

28k入职腾讯测试岗那天,我哭了,这5个月付出的一切总算没有白费~

先说一下自己的个人情况,计算机专业,16年普通二本学校毕业,经历过一些失败的工作经历后,经推荐就进入了华为的测试岗,进去才知道是接了个外包项目,不太稳定的样子,可是刚毕业谁知道什么外包不外…

jsp营养配餐管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 jsp营养配餐管理系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql,使…

企业带宽控制管理

在企业中保持稳定的网络性能可能具有挑战性,因为采用数字化的网络可扩展性和敏捷性应该与组织的发展同步。随着基础设施的扩展、新应用和新技术的引入,网络的带宽容量也在增加。 停机和带宽过度使用是任何组织都无法避免的两个问题,为了解决…

最新版海豚调度dolphinscheduler-3.1.3配置windows本地开发环境

0 说明 本文基于最新版海豚调度dolphinscheduler-3.1.3配置windows本地开发环境,并在windows本地进行调试和开发 1 准备 1.1 安装mysql 可以指定为windows本地mysql,也可以指定为其他环境mysql,若指定为其他环境mysql则可跳过此步。 我这…