PyTorch实例:简单线性回归的训练和反向传播解析

news2024/10/7 6:42:38

文章目录

  • 🥦引言
  • 🥦什么是反向传播?
  • 🥦反向传播的实现(代码)
  • 🥦反向传播在深度学习中的应用
  • 🥦链式求导法则
  • 🥦总结

🥦引言

在神经网络中,反向传播算法是一个关键的概念,它在训练神经网络中起着至关重要的作用。本文将深入探讨反向传播算法的原理、实现以及在深度学习中的应用。

🥦什么是反向传播?

反向传播(Backpropagation)是一种用于训练神经网络的监督学习算法。它的基本思想是通过不断调整神经网络中的权重和偏差,使其能够逐渐适应输入数据的特征,从而实现对复杂问题的建模和预测。

反向传播算法的核心思想是通过计算损失函数(Loss Function)的梯度来更新神经网络中的参数,以降低预测值与实际值之间的误差。这个过程涉及到两个关键步骤:前向传播(Forward Propagation)和反向传播。

  • 前向传播(forward):在前向传播过程中,输入数据通过神经网络,每一层都会进行一系列的线性变换和非线性激活函数的应用,最终得到一个预测值。这个预测值会与实际标签进行比较,得到损失函数的值。

  • 反向传播(backward):在反向传播过程中,我们计算损失函数相对于网络中每个参数的梯度。这个梯度告诉我们如何微调每个参数,以减小损失函数的值。梯度下降算法通常用于更新权重和偏差。

🥦反向传播的实现(代码)

要实现反向传播,我们需要选择一个损失函数,通常是均方误差(Mean Squared Error)或交叉熵(Cross-Entropy)。然后,我们计算损失函数相对于每个参数的偏导数(梯度)。这可以使用链式法则来完成,从输出层向后逐层传递。

接下来,我们使用梯度下降或其变种来更新权重和偏差。梯度下降的核心思想是沿着梯度的反方向调整参数,以降低损失函数的值。这个过程不断迭代,直到损失函数收敛到一个较小的值或达到一定的迭代次数。

在代码实现前,我能先了解一下反向传播是怎么个事,下文主要以图文的形式进行输出
这里我们回顾一下梯度,首先假设一个简单的线性模型
在这里插入图片描述
接下来,我们展示一下什么是前向传播(其实就是字面的意思),在神经网络中通常以右面的进行展示,大概意思就是输入x与权重w进行乘法运算,得到了y’
在这里插入图片描述
下图是随机梯度下降的核心公式以及损失函数的导数
在这里插入图片描述
下图是一个两层的神经网络
在这里插入图片描述
如果以图画的形式理解可以从下图进行理解
首先还是和之前的一样,进行输入和权重的矩阵乘法(这里刘二大人推荐一个查询书籍MatrixCookbook)
在这里插入图片描述
之后引入b,不理解的小伙伴可以当做截距
在这里插入图片描述
那么下图框框里面的就是一层神经网络
在这里插入图片描述
那么两层也就可以清晰的得到了,最后得到了y’
在这里插入图片描述

刚刚的描述过于笼统,接下来详细介绍一下前向和后向
在前向传播运算中,f里面进行了z对x和w的偏导求解
在这里插入图片描述
在反向传播里,损失loss对z的偏导,以及经过f后,求得loss对x和w的偏导。按理说我们只用权重w,但是如果x是上一层的输出(多层神经网络)那就需要了,至于loss对x和w的偏导怎么求参考结尾的链式求导法则
在这里插入图片描述

接下来我们可以假设x=2,w=3,手动的求解loss对x和w的偏导,求完就可以对权重的更新了
在这里插入图片描述
也可以从如下的计算图进行清晰的展示前后向传播
在这里插入图片描述
如果x=2,y=4,我写了一下如果错了欢迎指正
在这里插入图片描述
这里粗略的解释一下pytorch中的tensor,大概意思是它重要,其中还有包含了可以存储数值的data和存储梯度的grad
在这里插入图片描述

w.requires_grad = True # 默认是不自动计算梯度,需自行设计

如下是完整的代码(带注释)

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.Tensor([1.0])
w.requires_grad = True

def forward(x):
    return x * w # 这里的权重w是tensor
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

print("predict (before training)",  4, forward(4).item())
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)  # 前向传播
        l.backward()  # 后向传播
        print('\tgrad:', x, y, w.grad.item())  # item是为了防止计算图
        w.data = w.data - 0.01 * w.grad.data  # 注意不要直接取grad,因为这也属于重新创建计算图,只要值就好
        w.grad.data.zero_()  # 注意要清零否者会造成loss对w的导数一直累加,下图说明
    print("progress:", epoch, l.item())
print("predict (after training)", 4, forward(4).item())
  • 循环进行模型训练,这里设置了100个训练周期(epochs)。

  • 在每个周期内,遍历输入数据 x_data 和对应的目标数据 y_data。

  • 对于每个数据点,计算前向传播,然后进行反向传播以计算梯度。

  • 打印出每次反向传播后权重 w 的梯度值。

  • 更新权重 w,使用梯度下降法更新参数,以最小化损失函数。

  • 在更新权重之前,使用 .grad.data.zero_() 来清零梯度,以防止梯度累积。

  • .item() 的作用是将张量中的值提取为Python标量,以便进行打印

在这里插入图片描述
运行结果如下
在这里插入图片描述

🥦反向传播在深度学习中的应用

反向传播算法在深度学习中具有广泛的应用,它使神经网络能够学习复杂的特征和模式,从而在图像分类、自然语言处理、语音识别等各种任务中取得了显著的成就。

以下是反向传播在深度学习中的一些应用:

  • 图像分类:卷积神经网络(CNNs)使用反向传播来学习图像特征,用于图像分类任务。

  • 自然语言处理:循环神经网络(RNNs)和变换器(Transformers)等模型使用反向传播来学习文本数据的语义表示,用于机器翻译、情感分析等任务。

  • 强化学习:在强化学习中,反向传播可以用于训练智能体,使其学会在不同环境中做出合适的决策。

  • 生成对抗网络:生成对抗网络(GANs)使用反向传播来训练生成器和判别器,从而生成逼真的图像、音频或文本。

🥦链式求导法则

在神经网络中,链式求导法则是一个关键的概念,用于计算神经网络中的权重参数的梯度,从而进行反向传播(backpropagation)算法,这是训练神经网络的核心。下面以一个简单的神经网络为例,说明链式求导法则在神经网络中的应用:

假设我们有一个简单的神经网络,包含一个输入层、一个隐藏层和一个输出层。网络的输出可以表示为:

y = f(g(h(x)))

其中:

x 是输入数据。
h(x) 是隐藏层的激活函数。
g(h(x)) 是输出层的激活函数。
f(g(h(x))) 是网络的最终输出。

我们想要计算损失函数关于网络输出 y 的梯度,以便更新网络的权重参数以最小化损失。使用链式求导法则,我们可以将这个问题分解成多个步骤:

  • 首先,计算损失函数关于网络输出 y 的梯度 ∂L/∂y,其中 L 是损失函数。

  • 接下来,计算输出层的激活函数关于其输入的梯度 ∂g(h(x))/∂h(x)。

  • 然后,计算隐藏层的激活函数关于其输入的梯度 ∂h(x)/∂x。

  • 最后,将这些梯度相乘,得到损失函数关于输入数据 x 的梯度 ∂L/∂x,并用它来更新网络的权重参数。

链式求导法则允许我们将整个过程分解为这些步骤,并在每个步骤中计算局部梯度。这是神经网络中反向传播算法的关键,它允许我们有效地更新网络的参数,以便网络能够学习从输入到输出的复杂映射关系。

🥦总结

反向传播是深度学习中的核心算法之一,它使神经网络能够自动学习复杂的特征和模式,从而在各种任务中取得了巨大的成功。理解反向传播的原理和实现对于深度学习从业者非常重要,它是构建和训练神经网络的基础。希望本文对您有所帮助,深入了解反向传播将有助于更好地理解深度学习的工作原理和应用。

本文根据b站刘二大人《PyTorch深度学习实践》完结合集学习后加以整理,文中图文均不属于个人。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1062277.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第八章 排序 四、冒泡排序

目录 一、算法思想 二、例子 三、代码实现 四、验证 五、算法性能分析 注意:要分清楚交换次数和移动次数 六、总结 一、算法思想 从后往前,两两比较相邻元素的值,若为逆序,则交换它们的值,直到全部比较完。 二…

typescript: Builder Pattern

/*** file: CarBuilderts.ts* TypeScript 实体类 Model* Builder Pattern* 生成器是一种创建型设计模式, 使你能够分步骤创建复杂对象。* https://stackoverflow.com/questions/12827266/get-and-set-in-typescript* https://github.com/Microsoft/TypeScript/wiki/…

制作 3 档可调灯程序编写

PWM 0~255 可以将数据映射到0 75 150 225 尽可能均匀电压间隔

Python的NumPy库(一)基础用法

NumPy库并不是Python的标准库,但其在机器学习、大数据等很多领域有非常广泛的应用,NumPy本身就有比较多的内容,全部的学习可能涉及许多的内容,但我们在这里仅学习常见的使用,这些内容对于我们日常使用NumPy是足够的。 …

【Python】datetime 库

# timedelta(days, seconds, microseconds,milliseconds, minutes, hours, weeks) 默认按顺序传递参数 # 主要介绍 datetime.datetime 类 # 引入 from datetime import datetime today datetime.now() # 获取当前时间 2023-10-05 15:58:03.218651 today1 datetime.utcnow() #…

经典算法-----汉诺塔问题

前言 今天我们学习一个老经典的问题-----汉诺塔问题,可能在学习编程之前我们就听说过这个问题,那这里我们如何去通过编程的方式去解决这么一个问题呢?下面接着看。 汉诺塔问题 问题描述 这里是引用汉诺塔问题源自印度一个古老的传说&#x…

Ubuntu 22.04 安装Nvidia显卡驱动、CUDA、cudnn

GPU做深度学习比CPU要快很多倍,用Ubuntu跑也有一定的优势,但是安装Nvidia驱动有很多坑 Ubuntu版本:22.04.3 LTS 分区: /boot分配 1G ,剩下都分给根目录/ 显卡:GTX 1050 Ti 坑1:用Ubuntu自带的 …

ESP32上电到app_main()的过程梳理

前言 (1)如果有嵌入式企业需要招聘校园大使,湖南区域的日常实习,任何区域的暑假Linux驱动实习岗位,可C站直接私聊,或者邮件:zhangyixu02gmail.com,此消息至2025年1月1日前均有效 &am…

【单片机】16-LCD1602和12864和LCD9648显示器

1.LCD显示器相关背景 1.LCD简介 (1)显示器,常见显示器:电视,电脑 (2)LCD(Liquid Crystal Display),液晶显示器,原理介绍 (3&#xff…

哈希表的总结

今天刷了力扣的第一题(1. 两数之和 - 力扣(LeetCode)),是一道用暴力解法就可以完成的题目(两个for循环),但是官方解答给出了用哈希表的解法,用空间换时间,时间复杂度从O(…

Jmeter排查正则表达式提取器未生效问题

今天在使用Jmeter的时候遇到一个很简单的问题,使用正则表达式提取token一直未生效,原因是正则表达式中多了一个空格。虽然问题很简单,但是觉得排查问题的方法很普适,所以记录下,也希望能够给遇到问题的大家一个参考。 …

蓝桥杯每日一题2023.10.5

3420. 括号序列 - AcWing题库 题目描述 题目分析 对于这一我们需要有前缀知识完全背包 完全背包的朴素写法&#xff1a; #include<bits/stdc.h> using namespace std; const int N 1010; int n, m, v[N], w[N], f[N][N]; int main() {cin >> n >> m;fo…

MySQL数据库入门到精通——进阶篇(3)

黑马程序员 MySQL数据库入门到精通——进阶篇&#xff08;3&#xff09; 1. 锁1.1 锁-介绍1.2 锁-全局锁1.3 锁-表级锁1.3.1 表级锁-表锁1.3.2 表级锁元数据锁( meta data lock&#xff0c;MDL)1.3.3 表级锁-意向锁1.3.4 表级锁意向锁测试 1.4 锁-行级锁1.4.1 行级锁-行锁1.4.2…

计算机网络 (中科大郑烇老师)笔记(一)概论

目录 0 引言1 什么是Internet&#xff1f;1.1 网络、计算机网络、互联网1.2 什么是Internet&#xff1f;&#xff1a;从服务角度看 2 什么是协议&#xff1f;3 网络的结构&#xff08;子系统&#xff09;3.1 网络边缘3.2 网络核心&#xff1a;分组交换、线路交换3.3 接入网、物…

【13】c++设计模式——>工厂模式

简单工厂模式的弊端 简单工厂模式虽然简单&#xff0c;但是违反了设计模式中的开放封闭原则&#xff0c;即工厂类在数据增加时需要被修改&#xff0c;而我们在设计时对于已经设计好的类需要避免修改的操作&#xff0c;而选用扩展的方式。 工厂模式设计 简单工厂模式只有一个…

天地无用 - 修改朋友圈的定位: 高德地图 + 爱思助手

1&#xff0c;电脑上打开高德地图网页版 高德地图 (amap.com) 2&#xff0c;网页最下一栏&#xff0c;点击“开放平台” 高德开放平台 | 高德地图API (amap.com) 3&#xff0c;在新网页中&#xff0c;需要登录高德账户才能操作。 可以使用手机号和验证码登录。 4&#xff0c…

探秘前后端开发世界:猫头虎带你穿梭编程的繁忙街区,解锁全栈之路

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

香蕉叶病害数据集

1.数据集 第一个文件夹为数据增强&#xff08;旋转平移裁剪等操作&#xff09;后的数据集 第二个文件夹为原始数据集 2.原始数据集 Cordana文件夹&#xff08;162张照片&#xff09; healthy文件夹&#xff08;129张&#xff09; Pestalotiopsis文件夹&#xff08;173张照片&…

用Python实现一个电影订票系统!

一、整体结构图 二、代码分解 2.1 infos.py 一部电影的详细信息适合用 字典 结构来存储&#xff0c;我们可以给字典里添加多个键值对来保存电影的名称、座位表和宣传时用的字符画&#xff0c;比如电影《泰坦尼克号》的详细信息就可以按下面的形式保存到字典 titanic 中&#…

Tomcat在CentOS上的安装部署

目录 1. Tomcat简介 2. 安装 2.1 安装JDK环境 2.1.1 下载JDK软件 2.1.2 登陆Linux系统&#xff0c;切换到root用户 2.1.3 通过FinalShell&#xff0c;上传下载好的JDK安装包 2.1.4 创建文件夹&#xff0c;用来部署JDK&#xff0c;将JDK和Tomcat都安装部署到&#…