《PyTorch深度学习实践》第五讲 用PyTorch实现线性回归

news2025/1/11 14:01:19

b站刘二大人《PyTorch深度学习实践》课程第五讲用PyTorch实现线性回归笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5&vd_source=b17f113d28933824d753a0915d5e3a90


PyTorch官网教程:https://pytorch.org/tutorials/beginner/pytorch_with_examples.html


PyTorch Fashion

  1. 准备数据集
  2. 设计模型,写成类的形式(nn.Module)
    • 前向传播,计算 y ^ \hat{y} y^
  3. 构造损失函数loss和优化器(使用PyTorch的API)
    • 构造loss用于反向传播;优化器用于更新梯度
  4. 写训练周期(前馈 -> 反馈 -> 更新)

线性回归第一步:准备数据集

  • 在PyTorch中,计算图是采用的mini-batch形式计算
import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

线性回归第二步:设计模型

  • 线性单元

    • 要确定权重 w w w的维度,则需要知道输入 x x x和输出 y ^ \hat{y} y^的维度;
    image-20230630102424959
  • 将模型定义成一个类

"""
Our model class should be inherit from nn.Module, which is Base class for all neural network modules
模型类都是从nn.Module继承,nn.Module是所有神经网络模型的基类
成员方法至少包含__init__()和forward()
"""
class LinearModel(torch.nn.Module):
    def __int__(self):
        # 构造函数,用于初始化对象
        super(LinearModel, self).__int__()  # super是调用父类的构造,第一个参数LinearModel是类名称
        self.linear = torch.nn.Linear(1, 1) # 构造对象。nn.Linear包含两个张量成员:权重w和偏置b
        
    def forward(self, x):
        # 前馈计算
        y_pred = self.linear(x)	# y_hat,在一个对象(linear)后面加括号,表明实现了一个可调用的对象
        return y_pred


model = LinearModel()  # 实例化,model是可调用的,如model(x),x会传入forward中
image-20230629213425516
  • in_features:输入样本的维度(特征)

  • out_features:输出样本的维度

    image-20230630104714599
  • *args:表示可变参数,会存放所有未命名的变量参数,在函数调用的时候自动组装为一个元组

  • **kwargs:表示关键字参数,在函数内部自动组装成一个字典

    # 例子:假设定义一个func函数,并定义了形参
    def func(a, b, c, x, y):
        pass
    
    # 在调用的时候,传入的实参必须要和形参对应
    func(1, 2, 3, x=3, y=5)
    
    # 问题是如果调用的时候参数更多该怎么办?
    func(1, 2, 4, 3, x=3, y=5) # 比上面多一个值,这样调用就会出错
    
    ---
    # 对func进行修改,将a,b,c换成*args,那么在调用func的时候所有没有命名的实参都会传到args中
    def func(*args, x, y):
        pass
    
    ---
    # 对于x和y这种命名的参数可以写成**kwargs,在调用func的时候命名的实参都会传到kwargs中
    def func(*args, **kwargs):
        pass
    
    image-20230630113754366
# 定义一个可调用的类
class Foobar:
    def __init__(self):
        # 先定义__init__,因为没起作用就写个pass
        pass
    
    # 要想对象可调用,则需要定义一个__call__函数。pycharm中会自动提示如下形式
    #  *args:表示可变参数,会存放所有未命名的变量参数,在函数调用的时候自动组装为一个元组
    #  **kwargs:表示关键字参数,在函数内部自动组装成一个字典
    def __call__(self, *args, **kwargs):
        print("Hello" + str(args[0]))  # 假设就接受args的第一个参数
        
        
foobar = Foobar()  # 定义一个Foobar类的变量foobar
# 由于类中定义了__call__()函数,所以可以进行如下操作,给foobar传入参数
foobar(1, 2, 3)
image-20230630114332162
  • PyTorch中的Module的call函数里面有一条语句是要调用forward(),因此在我们自己写的module类中必须要实现forward()来覆盖掉父类中的forward()

线性回归第三步:构造loss函数和优化器

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  • 损失函数使用MSE

    • MSELoss继承自nn.Module,参与计算图的构建
    image-20230630120312661 image-20230630120353109
    • size_average:是否要求均值(可求可不求)
    • reduce:是否要降维(一般只考虑size_average)
  • 优化器使用SGD

    • torch.optim.SGD()是一个类,与nn.Module无关,不参与计算图的构建

      image-20230630120854394
    • model.parameters()是权重

      • model中并没有定义相应的权重,但里面的成员函数linear有权重
      • 方法parameters是继承自Module,它会检查model中的所有成员函数,如果成员中有相应的权重,那就将其都加到最终的训练结果上
    • lr:learning rate,一般都设定一个固定的学习率


线性回归第四步:训练过程

三个步骤:

  • 前馈
  • 反馈
    • 开始反馈前要先将梯度归零
  • 更新
for epoch in range(100):
    y_pred = model(x_data)              # 前馈:计算y_hat
    loss = criterion(y_pred, y_data)    # 前馈:计算损失
    print(epoch, loss.item())

    optimizer.zero_grad()   # 反馈:在反向传播开始将上一轮的梯度归零
    loss.backward()         # 反馈:反向传播(计算梯度)
    optimizer.step()        # 更新权重w和偏置b
image-20230630125150399

完整的代码(包含模型测试和loss曲线绘制)

import torch
import matplotlib.pyplot as plt

# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# 用于绘图
epoch_list = []
loss_list =[]

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = LinearModel()

# criterion = torch.nn.MSELoss(size_average=False) pytorch更新后被弃用了
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练过程
for epoch in range(100):
    y_pred = model(x_data)              # 前馈:计算y_hat
    loss = criterion(y_pred, y_data)    # 前馈:计算损失
    print(epoch, loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())

    optimizer.zero_grad()   # 反馈:在反向传播开始将上一轮的梯度归零
    loss.backward()         # 反馈:反向传播(计算梯度)
    optimizer.step()        # 更新权重w和偏置b

# 输出权重和偏置
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

# 绘制loss曲线
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
  • 训练100轮:
image-20230630125329678
  • 训练1000轮:
image-20230630125424128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/704168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3 elementplus table根据某id相同合并单元格

根据表格中id相同的合并单元格 1.标签上加入合并方法 <el-table:data"tableData.data"selection-change"handleSelectionChange":span-method"arraySpanMethod">/*** 合并行*/ interface SpanMethodProps {row: ListPageType;column: Tabl…

Edge浏览器提示您开启了窗口拦截程序解决方法

最近在使用edge浏览器兼容性的时候&#xff0c;发现登录窗口弹出后&#xff0c;经常被拦截&#xff0c;后面经过在网上上和自己实际测试&#xff0c;终于解决了这个问题。 操作步骤如下&#xff1a; 第一步&#xff0c;找到右上角三个点的图标&#xff0c;点击一哈 第二步&am…

基于Java校园教务系统设计实现(源码+lw+部署文档+讲解等)

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

无权限复制时怎样获取内容

在body中contenteditable"true"&#xff0c;然后直接在html文档中复制

【Spring Boot统一功能处理】用户登录权限校验与拦截器,拦截器与传统的校验方式想比有什么好处呢? ? ?我们一起去探索其中的奥秘吧! ! !

前言: 大家好,我是良辰丫,今天我们要学习Spring Boot统一功能处理,什么叫统一功能呢?我们在javaEE初阶学习过前后端交互,约定交互时的统一格式,其中这种约定就是一个统一功能.&#x1f48c;&#x1f48c;&#x1f48c; &#x1f9d1;个人主页&#xff1a;良辰针不戳 &#x1f…

VS Code报错 No module named ‘torch‘ (但已经安装了pytorch)

一、复现错误程序 创建一个python文件 test.py&#xff0c;其内容为&#xff1a; import torch print(torch.__version__)使用VS Code打开并运行该程序时&#xff0c;会出现以下错误&#xff1a; ModuleNotFoundError: No module named ‘torch’ 二、解决方案 选择适合的Pyt…

Python测试应用与工具

文章目录 前言环境准备unittestpytestpytest插件 mock最后 前言 例如&#xff1a;随着人工智能的不断发展&#xff0c;机器学习这门技术也越来越重要&#xff0c;很多人都开启了学习机器学习&#xff0c;本文就介绍了机器学习的基础内容。 Python测试应用与公具 今天跟大家分享…

MVTEC 3D dataset

官网&#xff1a;https://www.mvtec.com/company/research/datasets/mvtec-3d-ad/downloads https://www.mvtec.com/company/research/datasets/mvtec-3d-adhttps://www.mvtec.com/company/research/datasets/mvtec-3d-ad 数据大小&#xff1a;13个G 1. 介绍 MVTec 3D异常检测…

【lora模块调试:亿百特lora-型号E22-400T30D-V=初步调试踩坑-认识模块-了解协议(1)】

【lora模块调试&#xff1a;亿百特lora-型号E22-400T30D-V初步调试踩坑-认识模块-了解协议&#xff08;1&#xff09;】 1、概述2、实验环境3-1&#xff1a;先行了解3-2&#xff1a;经验总结4、硬件线路连接方式1、厂家提供的ttl转usb的模块2、使用开发板上的串口3、自己弄个转…

基于Java+Vue前后端分离网上拍卖系统设计实现(源码+lw+部署文档+讲解等)

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

智能批量剪辑系统源码开发者如何减少服务器成本?

一、智能混剪批量剪辑自研与接入第三方“如阿里云”接口的差别 智能混剪批量剪辑自研和接入第三方"如阿里云"接口的差别主要在于技术实现和功能定制。自研混剪系统需要团队投入大量时间和资源来研发和维护&#xff0c;并且能够根据用户需求定制和优化功能&#xff0…

6 中断概览(STM32HAL库)

目录 中断概览 STM32异常和中断介绍 STM32的异常一览 STM32的中断表一览 中断的优先级 中断的优先级分组 优先级分组 嵌套向量中断控制器(NVIC)功能 中断概览 什么是中断&#xff1f; 中断是指计算机运行过程中&#xff0c;出现某些意外情况需主机干预时&#xff0c;机器…

如何理解Spring Bean?

文章目录 一、什么是 Spring Bean&#xff1f;二、定义Spring Bean 有哪些方式&#xff1f;三、Spring 容器是如何加载 Bean 的&#xff1f; 我一共分三段来介绍&#xff0c;首先&#xff0c;介绍什么是 Spring Bean&#xff1f;然后&#xff0c;定义Spring Bean 有哪些方式&am…

typescript Constructor Set requires ‘new‘

使用typescript的class继承时报错 “构造函数集需要’new’” ts代码 class MySet extends Set {constructor() {super();}let myset new MySet();控制台错误 只需要在tsconfig.json文件中添加以下配置即可 "compilerOptions": {"target": "es6…

面试常问 什么是回表?为什么需要回表?

小伙伴们在面试的时候&#xff0c;有一个特别常见的问题&#xff0c;那就是数据库的回表。什么是回表&#xff1f;为什么需要回表&#xff1f; 索引结构 要搞明白这个问题&#xff0c;需要大家首先明白 MySQL 中索引存储的数据结构。这个其实很多小伙伴可能也都听说过&#xf…

SQL方言:传统关系型数据库下的方言对比

前言&#xff1a; 技术多元化是一个趋势&#xff0c;多语言并存&#xff0c;多数据库适配&#xff0c;多环境兼容>< 场景&#xff1a; 当从SQL Server数据库迁移到MySql数据库或者Oracle数据库&#xff0c;甚至国产化数据库&#xff0c;不同数据库之间可以自定义切换&…

实现firebase FCM和Analytics

前提&#xff1a;1.需要vpn 2.带有google 服务的手机 注意&#xff01;&#xff01;&#xff01; 这个在2023年6月30日时还是测试版&#xff0c;所以手机有概率接收不到消息 编写代码前需要在https://console.firebase.google.com/ 配置好参数 这里的token值需要填写代码内的i…

macOS 系统 安装 Kafka 快速入门

博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&#x1f466;&#x1f3fb; 《java 面试题大全》 &#x1f369;惟余辈才疏学浅&#xff0c;临摹之作或有不妥之处&#xff0c;还请读者海涵指正。☕&#x1f36d; 《MYSQL从入门到精通》数据库是开发者必会基础之…

神策(Android)- 集成基础埋点的整个过程

记得最早以前都是用友盟全家桶&#xff0c;埋点是用友盟&#xff0c;推送也是用友盟&#xff1b;但是近俩年我参与开发的app&#xff0c;埋点都是用神策、推送都是用极光私服&#xff0c;分享都是去对应集成对应平台的SDK 神策篇 神策&#xff08;Android&#xff09;- 集成基…

2023-6-30-第十二式组合模式

&#x1f37f;*★,*:.☆(&#xffe3;▽&#xffe3;)/$:*.★* &#x1f37f; &#x1f4a5;&#x1f4a5;&#x1f4a5;欢迎来到&#x1f91e;汤姆&#x1f91e;的csdn博文&#x1f4a5;&#x1f4a5;&#x1f4a5; &#x1f49f;&#x1f49f;喜欢的朋友可以关注一下&#xf…