Linear Regression with PyTorch 用PyTorch实现线性回归

news2024/9/21 4:23:57

文章目录

    • 4、Linear Regression with PyTorch 用PyTorch实现线性回归
      • 4.1 Prepare dataset 准备数据集
      • 4.2 Design Model 设计模型
        • 4.2.1 __call__() 作用
      • 4.3 Construct Loss and Optimizer 构造损失和优化器
      • 4.4 Training Cycle 训练周期
      • 4.5 Test Model 测试模型
      • 4.6 Different Optimizer
        • 4.6.1 Adagrad
        • 4.6.2 Adam
        • 4.6.3 Adamax
        • 4.6.4 ASGD
        • 4.6.5 LBFGS
        • 4.6.6 RMSprop
        • 4.6.7 Rprop
        • 4.6.8 SGD
      • 4.7 More Example

4、Linear Regression with PyTorch 用PyTorch实现线性回归

4.1 Prepare dataset 准备数据集

在PyTorch中,计算图是以小批量的方式进行的,所以 X 和 Y 是 3×1 的张量:

import torch
from torch import nn

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

注意:根据广播机制,所以 w 和 b 也都为 3 * 1(3行1列)。参考:什么是广播机制

我们来复习一下梯度下降算法

4.2 Design Model 设计模型

class Liang(nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = Liang()

说明:

  • 我们的模型类应该继承自 nn.Module 模块,它是所有神经网络模块的基类。
  • 必须实现成员方法 __init__()forward()
  • 构造对象:nn.Linear() 就是上图中的 Linear Unit,包含 weightbias
  • nn.Linear 类已经实现了神奇的方法__call__(),它使类的实例可以被调用(就像一个函数一样);通常情况下 forward() 将被调用。

参考文档:Linear

4.2.1 call() 作用

当我们不清楚会传入多少变量时(或传入变量过多时):

def func(a, b, c, x, y):
    pass


func(1, 2, 3, x=4, y=5)

将变量替换成 *args,将其打印会输出一个元组;将变量替换成 **kwargs,将其打印会输出一个字典

def func(*args, **kwargs):
    print(args)  # (1, 2, 3)
    print(kwargs)  # {'x': 4, 'y': 5}


func(1, 2, 3, x=4, y=5)

实例:

class Liang:
    def __init__(self):
        pass

    def __call__(self, *args, **kwargs):
        print('Hello' + str(args[0]))  # Hello1


liang = Liang()
liang(1, 2, 3)

4.3 Construct Loss and Optimizer 构造损失和优化器

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

报错:UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.

即:size_averagereduce args 将被弃用,请使用 reduction='sum' 代替。

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

参考文档:MSELoss

参考文档:SGD

4.4 Training Cycle 训练周期

for epoch in range(100):
    y_pred = model(x_data)  # Forward: Predict
    loss = criterion(y_pred, y_data)  # Forward: Loss
    print(epoch, loss)

    optimizer.zero_grad()  # The grad computed by .backward() will be accumulated. So before backward, remember set the grad to ZERO!!!
    loss.backward()  # Backward: Autograd
    optimizer.step()  # Update

4.5 Test Model 测试模型

import torch
from torch import nn

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])


class Liang(nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = Liang()

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1000):
    y_pred = model(x_data)  # Forward: Predict
    loss = criterion(y_pred, y_data)  # Forward: Loss
    print(epoch, loss)

    optimizer.zero_grad()  # The grad computed by .backward() will be accumulated. So before backward, remember set the grad to ZERO!!!
    loss.backward()  # Backward: Autograd
    optimizer.step()  # Update

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

4.6 Different Optimizer

如果想要直观的看出每个优化器的效果,那我们可以借助 matplotlib 画图来展现:

1、导包
import matplotlib.pyplot as plt

2、创建空列表(存放:迭代次数 + 损失值)
epoch_list = []
loss_list = []

3、向列表中添加元素
epoch_list.append(epoch)
loss_list.append(loss.item())

4、画图
plt.plot(epoch_list, loss_list) # 横纵坐标值
plt.xlabel('Epoch') # x轴名称
plt.ylabel('Loss') # y轴名称
plt.title('SGD') # 图标题
plt.show() # 展示

4.6.1 Adagrad

参考文档:Adagrad

4.6.2 Adam

参考文档:Adam

4.6.3 Adamax

参考文档:Adamax

4.6.4 ASGD

参考文档:ASGD

4.6.5 LBFGS

参考文档:LBFGS

TypeError: step() missing 1 required positional argument: 'closure'

LBFGS要传递闭包,暂未解决!

4.6.6 RMSprop

参考文档:RMSprop

4.6.7 Rprop

参考文档:Rprop

4.6.8 SGD

参考文档:SGD

4.7 More Example

https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/133191.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

redis缓存淘汰策略

定时删除 Redis不可能时时刻刻遍历所有被设置了生存时间的key,来检测数据是否已经到达过期时间,然后对它进行删除。 立即删除能保证内存中数据的最大新鲜度,因为它保证过期键值会在过期后马上被删除,其所占用的内存也会随之释放。…

zookeeper学习笔记2(小D课堂)

zookeeper数据模型: 我们的zookeeper是以节点的形式存在的,这样的形式和数据结构中的树的形式很像。同时也很像我们的linux的结构,例如linux的/user/local目录下可以有我们的/usr/local/tomcat目录。这样的节点形式。 我们的zookeeper中的每…

算法练习-常用查找算法复现

一个不知名大学生,江湖人称菜狗 original author: jacky Li Email : 3435673055qq.com Time of completion:2023.1.1 Last edited: 2023.1.1 目录 算法练习-常用查找算法复现(PS:1 -- 3自己写的,4、5懒得写了&#xf…

PHP开发者之路

我们经常会发现,历时四年软件专业的大学生毕业居然找不到工作,即便找到了工作也只能是做一些简单的辅助性工作。 那么我们不禁要问,究竟是什么原因让我们可爱的大学生们学而无用,或者用而不学呢? 我认为主要是因为现…

三角形年份aabb3n+1近似计算阶乘之和数据统计水仙花数韩信点兵倒三角形子序列的和分数化小数排列蛇形填数sprintf竖式问题

目录 P16_习题1-6_三角形 P16_习题1-7_年份 P20_eg2-1_aabb 为什么是int n a*1100 b*11 为什么要将向下取整? P22_eg2-2_3n1问题 P24_eg2-3_近似计算 P25_eg2-4_阶乘之和 P27_eg2-5_数据统计 P34_习题2-1_水仙花数 P34_习题2-2_韩信点兵 P34_习题2-3_倒…

Fragment全文详解(由浅入深_源码分析)

相信android开发者们一定或多或少的用过Fragment,但是对于其更深层次的原理我猜可能大部分应该都没有了解过,今天这里就由浅入深,整体对Fragment做一个全面解析。 基础介绍 Fragment是什么以及为什么要有Fragment呢? Fragment直…

长沙烟火气回来了,颐而康客流回暖为什么这么快?

随着一大批阳康的人们走出家门,长沙这座消费之城也逐步恢复了往日的活力。车多起来了、路堵起来了、线下店铺恢复营业了、长沙的烟火气息又回来了。 在颐而康万家丽西子店的大厅里,等候休息区已经坐满了顾客,他们有的在等待,有的…

Centos6从零开始安装mysql和tomcat后台环境,并成功部署Tomcat项目

最近因为搞定了一些环境的搭建因为项目过于老旧的缘故我从centosstream9一直改换7一直到6都没有成功一直到改成6.5的32位版本才算是成功搭建完成所以特地来写一篇文章记录一下。 首先我的liunx使用版本是 centos6.5 32位 java版本:jdkCentos6从零开始安装mysql和tom…

7-6 整除光棍

这里所谓的“光棍”,并不是指单身汪啦~ 说的是全部由1组成的数字,比如1、11、111、1111等。传说任何一个光棍都能被一个不以5结尾的奇数整除。比如,111111就可以被13整除。 现在,你的程序要读入一个整数x,这个整数一定…

【Kuangbin数论】阿拉丁和飞毯

4577. 阿拉丁和飞毯 - AcWing题库 题意&#xff1a; 思路&#xff1a; 就是去求x和y 使得 1.x!y 2.x*ya 3.min(x,y)b 一开始想的是去根号n地枚举a的约数 &#xff0c;然后直接统计 但是这样肯定T&#xff0c;所以换成dfs枚举约数去了 但是也T了 首先a*a<b的话直接特…

前端 | 手把手教你装饰你的github profile(github 首页)

1.创建存储库 您可以创建一个与您的 github 帐户名同名的存储库 添加README文件 2.编辑README.md 现在&#xff0c;可以根据自己的喜好修改 repo 中的自述文件&#xff0c;但我在考虑包含哪些信息时查看了其他开发人员的资料。通常包括简短的介绍、使用的技术堆栈和联系方式…

Buildroot编译hisi平台根文件系统

Buildroot编译hisi平台根文件系统 文章目录1. 下载Buildroot源码2. Menuconfig配置3. 编译Buildroot3.1 手动下载软件包3.2 kernel header 报错3.3 arm-hisiv300-linux-gcc-ar&#xff1a;cannot find plugin liblto_plugin.so3.4 /media/data/hisi/buildroot-2022.02.8/output…

C++类的多种构造函数

目录默认构造函数普通构造函数拷贝构造函数转换构造函数移动构造函数举例两个场景下面以Complex 复数类来学习C类中的各种构造函数; #include <iostream> using namespace std;//复数类 class Complex{friend ostream & operator<<(ostream &out, Complex…

2022年终结——人生中最美好的一站

文章目录前言回顾2022工作上学习上投资上生活上展望2023工作学习投资生活总结有一种责任与压力&#xff0c;叫做上有老下有小&#xff0c;但有一种幸福也叫做上有老下有小&#xff0c;当你遭遇挫折与困难时&#xff0c;这些“老小”以及那个同龄的“她”是你坚实的后盾&#xf…

Redisson中的“琐事”

文章目录前言锁分类Redisson可重入锁&#xff08;Reentrant Lock&#xff09;公平锁&#xff08;Fair Lock&#xff09;联锁&#xff08;MultiLock&#xff09;红锁&#xff08;RedLock&#xff09;读写锁&#xff08;ReadWriteLock&#xff09;信号量&#xff08;Semaphore&am…

【C++】左值、右值、语义移动和完美转发

右值引入的目的是为了对象移动&#xff1a; 因为在很多情况下&#xff0c;对象拷贝会经常发生&#xff0c;但是很多对象在拷贝后就直接被销毁了。这对性能是一个很大损耗。在重新分配内存的时候&#xff0c;从旧的内存将元素拷贝到新的内存中是不必要的。更好的方法是移动元素。…

论文投稿指南——中文核心期刊推荐(天文、测绘学)

【前言】 &#x1f680; 想发论文怎么办&#xff1f;手把手教你论文如何投稿&#xff01;那么&#xff0c;首先要搞懂投稿目标——论文期刊 &#x1f384; 在期刊论文的分布中&#xff0c;存在一种普遍现象&#xff1a;即对于某一特定的学科或专业来说&#xff0c;少数期刊所含…

使用Kalibr问题汇总:ModuleNotFoundError: No module named ‘wx‘

问题1&#xff1a; 报错&#xff1a;/kalibr_ws/src/Kalibr/Schweizer-Messer/sm_python/python/sm/PlotCollection.py", line 4, in import wx ModuleNotFoundError: No module named ‘wx’ 解决&#xff1a; sudo apt-get install python3-wxgtk4.0问题2&#xff1…

MySQL补齐函数LPAD和RPAD之SQLite解决方案

工作中经常需要对数据进行清洗&#xff0c;并对个别字段进行格式化处理&#xff0c;像 字符串左右补齐。MySQL数据库自带有LPAD()、RPAD()&#xff0c;而SQLite数据库没有的相应函数&#xff0c;需要自己转换。 目录 1、MySQL数据库 1.1、MySQL左右补全函数 1.2、实践验证 …

阶段性回顾(5)与一些题目实例(数组合并,有序判断,删除元素,进制问题等)

tips 1. 内存栈区的使用习惯是先使用高地址&#xff0c;再使用低地址。并且你还要清楚&#xff1a;随着数组下标的增大&#xff0c;其元素的地址也是在不断变高&#xff1b;对于一个占多个内存单元的变量进行取地址&#xff0c;取出来的是其所占内存空间最低地址的内存单元的地…