5 使用pytorch实现线性回归

news2025/1/11 3:58:40

文章目录

    • 前提的python的知识补充
    • 基本流程
    • 准备数据
    • 构造计算图
    • loss以及oprimizer
    • 循环训练
    • 课程代码

课程来源: 链接
课程内容参考: 链接
以及(强烈推荐)Birandaの

前提的python的知识补充

pytorch 之 call, init,forward
pytorch系列nn.Modlue中call的进一步解释

即可以了解调用mode实际上就是在调用其中的forward函数。

基本流程

  1. 准备数据集
  2. 设计用于计算最终结果的模型
  3. 构造损失函数及优化器
  4. 设计循环周期——前馈、反馈、更新

准备数据

在原先的题设中, x , y ^ ∈ R x,\widehat y \in R x,y R
在pytorch中,若使用mini-batch的算法,一次性求出一个批量的 y ^ \widehat y y ,则需要 x x x以及 y ^ \widehat y y 作为矩阵参与计算,此时利用其广播机制,可以将原标量参数 ω \omega ω扩写为同维度的矩阵 [ w ] [w] [w],参与运算而不改变其Tensor的性质。
对于矩阵,行表示样本,列表示特征
在这里插入图片描述
代码部分:

import torch
#数据作为矩阵参与Tensor计算
x_data = torch.Tensor([1.0],[2.0],[3.0])
y_data = torch.Tensor([2.0],[4.0],[6.0])

构造计算图

在如下线性模型的计算图中,红框区域为线性单元,其中的 ω \omega ω以及 b b b是需要反复训练确定的,在设计时,需要率先设计出此二者的维度。
而由于公式
y ^ = ω x + b \widehat y = \omega x +b y =ωx+b
因此,只要确定了$\widehat y 以及 以及 以及x$的维度,就可以确定上述两个量的维度大小。
在这里插入图片描述

代码部分:

# 固定继承于Module
class LinearModel(torch.nn.Module):
    # 构造函数初始化
    def __init__(self):
        # 调用父类的init
        super(LinearModel, self).__init__()
        # Linear对象包括weight(w)以及bias(b)两个成员张量
        self.linear = torch.nn.Linear(1,1)

    #前馈函数forward,对父类函数中的overwrite
    def forward(self, x):
        #调用linear中的call(),以利用父类forward()计算wx+b
        y_pred = self.linear(x)
        return y_pred
    #反馈函数backward由module自动根据计算图生成
model = LinearModel()

在这里插入图片描述

loss以及oprimizer

按照上述理论,由于前边的计算过程都是针对矩阵的,因此最后的 l o s s loss loss也是矩阵,但由于要进行反向传播调整参数,因此 l o s s loss loss应当是个标量,因此要对矩阵 [ l o s s ] [loss] [loss]内的每个量求和求均值(MSE)。
l o s s = 1 N Σ [ l o s s 1 ⋮ l o s s n ] loss = \frac{1}{N}\Sigma \begin{bmatrix} {loss_1}\\ {\vdots}\\ {loss_n}\\ \end{bmatrix} loss=N1Σ loss1lossn

优化器并不构建计算图,生成的优化器对象可以直接对整个模型进行优化

代码:

criterion = torch.nn.MSELoss(size_average=False)
#model.parameters()用于检查模型中所能进行优化的张量
#learningrate(lr)表学习率,可以统一也可以不统一
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

在这里插入图片描述
关于model.parameters(),即将model当中所有权重参数传入到优化器当中,相当于进行更新。

循环训练

  1. 前馈计算预测值与损失函数
  2. forward前馈计算预测值即损失loss
  3. 梯度或前文清零并进行backward
  4. 更新参数
for epoch in range(100):
    #前馈计算y_pred
    y_pred = model(x_data)
    #前馈计算损失loss
    loss = criterion(y_pred,y_data)
    #打印调用loss时,会自动调用内部__str__()函数,避免产生计算图
    print(epoch,loss)
    #梯度清零
    optimizer.zero_grad()
    #梯度反向传播,计算图清除
    loss.backward()
    #根据传播的梯度以及学习率更新参数
    optimizer.step()

课程代码

import torch
#数据作为矩阵参与Tensor计算
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

#固定继承于Module
class LinearModel(torch.nn.Module):
    #构造函数初始化
    def __init__(self):
        #调用父类的init
        super(LinearModel, self).__init__()
        #Linear对象包括weight(w)以及bias(b)两个成员张量
        self.linear = torch.nn.Linear(1,1)

    #前馈函数forward,对父类函数中的overwrite
    def forward(self, x):
        #调用linear中的call(),以利用父类forward()计算wx+b
        y_pred = self.linear(x)
        return y_pred
    #反馈函数backward由module自动根据计算图生成
model = LinearModel()

#构造的criterion对象所接受的参数为(y',y)
criterion = torch.nn.MSELoss(size_average=False)
#model.parameters()用于检查模型中所能进行优化的张量
#learningrate(lr)表学习率,可以统一也可以不统一
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1000):
    #前馈计算y_pred
    y_pred = model(x_data)
    #前馈计算损失loss
    loss = criterion(y_pred,y_data)
    #打印调用loss时,会自动调用内部__str__()函数,避免产生计算图
    print(epoch,loss)
    #梯度清零
    optimizer.zero_grad()
    #梯度反向传播,计算图清除
    loss.backward()
    #根据传播的梯度以及学习率更新参数
    optimizer.step()

 #Output
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

#TestModel
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)

print('y_pred = ',y_test.data)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/182473.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python进行因子分析

1 因子分析 1.1 定义 因子分析法(Factor Analysis)是一种利用降维的思想,从研究原始变量相关矩阵内部的依赖关系出发,把一些具有错综复杂关系的变量归结为少数几个综合因子的一种多变量统计分析方法。其优势在于不仅可以在减少大量指标分析的工作量的同…

尚硅谷hello scala-配置idea2022.1.2版本创建scala2.11.8版本maven文件

0_前置说明 软件版本 idea2022.1.2 scala2.11.8 java1.8.0_144 尚硅谷资源下载 关注b站尚硅谷 idea资源 百度网盘:https://pan.baidu.com/s/1Gbavx34OfF29LZqJ8dc85g?pwdyyds 提取码: yyds B站直达:​https://www.bilibili.com/video/BV1CK411d7a…

LabVIEW NI Linux Real-Time深层解析

LabVIEW NI Linux Real-Time深层解析NI LabVIEW Real-Time模块支持NI Linux Real-Time操作系统,在选定的NI硬件上提供。本文介绍了具体的新特性和高级功能,可让您为应用充分利用NI Linux Real-Time。Linux Shell支持NI Linux Real-Time操作系统提供了全面…

《Linux0.11源码趣读》学习笔记day6

到上次记录,整个操作系统的全部代码就已经从硬盘加载到内存中了,然后这些代码又通过jmpi跳转到0x90200处,即硬盘第二个扇区开始处的内容 这些内容就是第二个操作系统源代码文件setup.s 不过现在先来看一下操作系统的编译过程 操作系统的编译…

后端学习 - Docker

文章目录基本概念三个核心概念:镜像、容器、仓库联合文件系统 UnionFS常用命令Docker File基本概念 一次配置,处处使用运行在同一宿主机上的容器是相互隔离的,各自拥有独立的文件系统容器模型和虚拟机模型的主要区别 相较于虚拟机而言&#…

【Pytorch项目实战】之生成式网络:编码器-解码器、自编码器AE、变分自编码器VAE、生成式对抗网络GAN

文章目录生成式网络 - 生成合成图像算法一:编码器-解码器算法二:自编码器(Auto-Encoder,AE)算法三:变分自编码器(Variational Auto Encoder,VAE)算法四:生成式…

九型人格是什么?

九型人格是什么? 九型人格学(Enneagram/Ninehouse)是一个有2000多年历史的古老学问,它按照人们习惯性的思维模式,情绪反应和行为习惯等性格特质,将人的性格分为九种,又被称为九柱图,起源于中亚西亚地区,和中国的八卦图有点像,近代的九型是由六十年代智利的一位心理学…

计算机组成原理 | 第四章:存储器 | 存储器与CPU连接 | 存储器的校验 | Cache容量计算

文章目录📚概述🐇存储器分类🐇存储器的层次结构🥕原理🥕主存速度慢的原因🥕存储器三个主要特征的关系🥕缓存-主存层次和主存-辅存层次⭐️📚主存储器🐇概述🥕…

【opencv】Haar分类器及Adaboost算法人脸识别理论讲解

提到opencv,就不得不提其图像识别能力,最近旷世开源的YoloX项目兴起,作为目前Yolo系列中的最强者,本人对其也很感兴趣,但是完全没用机器学习和计算机视觉的基础,知其然,不知其所以然,于是想稍稍入坑一下opencv图像识别,了解一下相关算法,(说不定以后毕设会用到呢)。…

磨金石教育影视干货分享|朋友亲身经历—给新人剪辑师的三个建议

大学的时候有一个同学很喜欢视频剪辑。平时没事就蹲在电脑前,下载一些素材,自学剪辑软件,慢慢的搞一些创意剪辑。那时候自媒体短视频已经很火爆,这位同学剪辑的视频,不管质量如何就往上面发。一开始我们对于新事物的认…

Java---微服务---分布式搜索引擎elasticsearch(2)

分布式搜索引擎elasticsearch(2)1.DSL查询文档1.1.DSL查询分类1.2.全文检索查询1.2.1.使用场景1.2.2.基本语法1.2.3.示例1.2.4.总结1.3.精准查询1.3.1.term查询1.3.2.range查询1.3.3.总结1.4.地理坐标查询1.4.1.矩形范围查询1.4.2.附近查询1.5.复合查询1…

SpringBoot+Vue项目学生读书笔记共享平台

文末获取源码 开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7/8.0 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven3.3.9 浏…

一起自学SLAM算法:10.3 机器学习与SLAM

连载文章,长期更新,欢迎关注: 前面已经分析过的8种SLAM算法案例(Gmapping、Cartographer、LOAM、ORB-SLAM2、LSD-SLAM、SVO、RTABMAP和VINS)都可以称为传统方法,因为这些算法都是在人为精心设计的特定规则下…

电子技术——MOS放大器基础

电子技术——MOS放大器基础 我们已经学过MOS可以当做一个压控流源,使用栅极电压 vGSv_{GS}vGS​ 控制漏极电流 iDi_DiD​ 。尽管两个量的关系不是线性的,稍后我们将会介绍偏置在线性区的工作方法。 构建压控压源放大器 现在,我们有了一个压…

【Java|golang】1664. 生成平衡数组的方案数---奇数前缀和 + 偶数前缀和

给你一个整数数组 nums 。你需要选择 恰好 一个下标(下标从 0 开始)并删除对应的元素。请注意剩下元素的下标可能会因为删除操作而发生改变。 比方说,如果 nums [6,1,7,4,1] ,那么: 选择删除下标 1 ,剩下…

[JAVA安全]JACKSON反序列化

前言 ackson是一个开源的Java序列化和反序列化工具&#xff0c;可以将Java对象序列化为XML或JSON格式的字符串&#xff0c;以及将XML或JSON格式的字符串反序列化为Java对象。 依赖 <dependency><groupId>com.fasterxml.jackson.core</groupId><artifact…

中国省际铁路通行时间数据

一、数据简介 本数据来自南京大学长江产业经济研究院《全国统一大市场下的省际铁路交通研究报告》的附录部分。中国的铁路&#xff08;高铁&#xff09;建设取得了辉煌成果。但受铁路时刻众多、历史数据不容易搜集整理的限制&#xff0c;学术与政策研究者一直无法对铁路建设的时…

三、JDBC详解

教程相关资料&#xff1a;https://www.aliyundrive.com/s/wMiqbd4Zws6 1&#xff0c;JDBC概述 在开发中我们使用的是java语言&#xff0c;那么势必要通过java语言操作数据库中的数据。这就是接下来要学习的JDBC。 1.1 JDBC概念 JDBC 就是使用Java语言操作关系型数据库的一套…

Nacos 配置中心 服务端推送变更源码讲解

目录 1. 配置引起变更的两种方式 1.1 后台管理直接操作 1.2 NacosClient 调用 RPC 接口 2. 变更事件处理 AsyncNotifyService 2.1 HTTP 任务 2.2 RPC任务 2.3 NacosServer 其他节点接收到消息后如何处理 3. 客户端推送实现&#xff1a;DumpService.dump 接着上一篇 Nac…

1601_读Dennis M. Ritchie and Ken Thompson的The UNIX TimeSharing System_Unix分时复用系统

全部学习汇总&#xff1a; GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 很久之前就闻听了UNIX的大名&#xff0c;也看过很多相关的故事类文章。其中最让我印象深刻的莫过于Ken发明UNIX的故事以及这个系统对于Linux以及GNU的OS的影响&…