《PyTorch深度学习实践》第四讲 反向传播

news2025/1/12 18:06:14

b站刘二大人《PyTorch深度学习实践》课程第四讲反向传播笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=4&vd_source=b17f113d28933824d753a0915d5e3a90


image-20230629141914379

对于上述简单的模型可以用解析式来做,但是对于复杂模型而言,如下图每个圆圈中都有一个权重,诶个写解析式求解极其麻烦,几乎不可能

image-20230629142305685

计算图Computational Graph

  • 面对复杂网络,将其看作一个图,在图上来传播梯度,最终根据链式法则将其求出

  • 以一个两层神经网络为例

    image-20230629151310427
  • MM:Matrix Multiplication,矩阵乘法;ADD:向量加法

  • 其中绿色模块就是计算模块

    image-20230629151358261
  • 注意:上面给出的神经网络 y ^ \hat{y} y^可以对其进行展开,展开之后就会发现如果就这样一直线性展开,不管有多少层,最终得到的都能统一为一层,这样的话层数多和层数少都没区别

    image-20230629153104333
    • 解决方法:对每一层的输出加一个非线性函数,使其不能化简

      image-20230629153413641

链式法则(Chain Rule)

image-20230629153810448

步骤:

  1. Create Computational Graph(Forward)

    • 前馈运算
      • 从输入 x x x沿着边向最终的loss计算
    image-20230629155842728
  2. Local Gradient

    • 函数 f f f是用于计算输出 Z Z Z关于输入 x x x和权重 w w w的导数
    • Z = f ( x , w ) Z = f(x,w) Z=f(x,w)
    image-20230629160555055
  3. Given gradient from successive node

    • 对于输出结果 Z Z Z而言,首先要拿到最终的损失函数Loss对它的偏导 ∂ L ∂ z \frac{\partial L}{\partial z} zL
      • 该偏导是从Loss传回来的。先是从最初的输入 x x x通过前馈一步一步计算到最终的损失函数Loss(前馈过程),然后再从Loss开头一步一步往回算(反馈过程)
    image-20230629161338821
  4. **Use chain rule to compute the gradient (Backward) **

    • 拿到 ∂ L ∂ z \frac{\partial L}{\partial z} zL后,经过计算 f f f,我们的目标是得到Loss关于输入 x x x和权重 w w w的偏导,这一计算过程需使用上链式法则
    image-20230629161731489

实例一:

  • f = x ⋅ w f = x · w f=xw,令 x = 2 , w = 3 x = 2, w = 3 x=2,w=3
    • 输出Z关于输入 x x x的偏导: ∂ Z ∂ x = ∂ x ⋅ w ∂ x = w \frac{\partial Z}{\partial x} = \frac{\partial x·w}{\partial x} = w xZ=xxw=w
    • 输出Z关于权重 w w w的偏导: ∂ Z ∂ w = ∂ x ⋅ w ∂ w = x \frac{\partial Z}{\partial w} = \frac{\partial x·w}{\partial w} = x wZ=wxw=x
image-20230629161856196

前馈过程:

  • x x x w w w计算得到 z z z
image-20230629162015774

反馈过程:

  • 假设由前一步得到最终的损失函数Loss对 Z Z Z的偏导为5
image-20230629162559054

线性模型 y ^ = x ∗ w \hat{y} = x * w y^=xw的计算图

  • 模型计算图

    • 假设输入 x = 1 x = 1 x=1,权重 w = 1 w = 1 w=1,那么 y ^ = x ∗ w = 1 \hat{y} = x * w = 1 y^=xw=1
    image-20230629163941595
  • loss计算图, l o s s = ( y ^ − y ) 2 = ( x ⋅ w − y ) 2 loss = (\hat{y} - y)^2 = (x·w - y)^2 loss=(y^y)2=(xwy)2

    • ( y ^ − y ) (\hat{y} - y) (y^y)称为残差项(residual),记为 r = y ^ − y r = \hat{y} - y r=y^y
    image-20230629164743921
    • 假设 y = 2 y = 2 y=2,那么残差项 r = y ^ − y = − 1 r = \hat{y} - y = -1 r=y^y=1,残差对 y ^ \hat{y} y^的导数 ∂ y ^ − y ∂ y ^ = 1 \frac{\partial \hat{y} - y}{\partial \hat{y}} = 1 y^y^y=1

上述即为前馈过程:

image-20230629164931663
  • 不仅能计算出到下一步的输出值,还能够计算出局部的梯度

接下去就是反馈过程:

  • 首先求损失函数loss关于残差项 r r r的偏导, l o s s = r 2 loss = r^2 loss=r2

    image-20230629165221255
  • 再计算损失关于 y ^ \hat{y} y^的偏导

    image-20230629165314090
  • 最终得到损失关于权重 w w w的偏导

    image-20230629165406034

完整的计算图:

image-20230629165438941

PyTorch中如何实现前馈和反馈计算???

image-20230629165755950
import torch    # 导入pytorch库

# 训练集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 权重
w = torch.Tensor([1.0])     # 使用pytorch中的Tensor进行定义赋值
w.requires_grad = True      # 表示需要计算梯度,默认的是False,即不计算梯度


def forward(x):
    # 定义模型:y_hat = x * w,其中w是一个张量Tensor,因此该乘法*被重载了,变成了Tensor之间的数乘
    # x需要为Tensor,如果不是,则会自动转换成Tensor
    return x * w


def loss(x, y):
    # 定义损失函数loss function
    y_pred = forward(x)
    return (y_pred - y) ** 2   # (y_hat - y)^2


print('Predict (before training)', 4, forward(4).item())

for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)  # 计算loss,结果是Tensor(前馈)
        l.backward()    # 反馈,l是Tensor
        print("\tgrad: ", x, y, w.grad.item())  # item是将梯度中的数值取出来作为一个标量
        # w是一个张量,包含data和grad,其中grad也是张量,因此是要取grad的数据data
        w.data = w.data - 0.01 * w.grad.data
        # .data是Tensor操作,.item()是将数值取出来当成标量使用

        w.grad.data.zero_()  # 更新完成后将梯度清零,否则会被累加到下一轮训练

    print("progress: ", epoch, l.item())

print('Predict (after training)', 4, forward(4).item())
image-20230629203721119 image-20230629203721119 image-20230629201909191 image-20230629203806720

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/700936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

详解线程池的作用和实际应用以及拒绝策略

目录 线程池的作用? 线程池的意义: 线程池的参数 ​编辑 线程池任务执行的顺序 线程池拒绝策略 四种策略 应用场景分析 AbortPolicy DiscardPolicy DiscardOldestPolicy CallerRunsPolicy 线程池的作用? 优化系统架构通常包括在时间…

亚马逊平台买家注册流程

在亚马逊平台注册买家号是比较简单的。以下是亚马逊买家注册流程: 1、打开亚马逊网站:访问亚马逊的官方网站,如果要注册美国买家号,那么网址就是www.amazon.com。 2、点击"注册":在亚马逊首页的右上角&…

LLM - 搭建 ProteinGPT 结合蛋白质结构 PDB 知识的行业 ChatGPT 系统

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131403263 论文:ProteinChat: Towards Enabling ChatGPT-Like Capabilities on Protein 3D Structures 工程:ht…

C语言学习(二十九)---内存操作函数

在上一节内容中,我们学习了有关字符串操作的函数,其中分为了限制长度和不限制长度两种方式,虽然上节内容已经在很大程度上有助于程序的实现,但是其有一个致命的缺陷,聪明的你一定已经猜到了吧,对的&#xf…

Linux 网络通信C/S、TCP/IP、Socket 最全详解( 9 ) -【Linux通信架构系列 】

系列文章目录 C技能系列 Linux通信架构系列 C高性能优化编程系列 深入理解软件架构设计系列 高级C并发线程编程 期待你的关注哦!!! 现在的一切都是为将来的梦想编织翅膀,让梦想在现实中展翅高飞。 Now everything is for the…

【算法题】动态规划中级阶段之不同的二叉搜索树、交错字符串

动态规划中级阶段 前言一、不同的二叉搜索树1.1、思路1.2、代码实现 二、不同的二叉搜索树 II2.1、思路2.2、代码实现 三、交错字符串3.1、思路3.2、代码实现 总结 前言 动态规划(Dynamic Programming,简称 DP)是一种解决多阶段决策过程最优…

Pycharm中成功配置PyQt5(External Tools),设计好界面直接生成python代码

1、安装PyQt5和PyQt5-tools 在Pycharm中设置好Python环境,点击File-Settings-Project-Python Interpreter 设置好后退出,点击窗口下的Terminal,输入 # 直接安装输入pip install pyqt5,如果太慢可以用国内镜像源,若出…

【C++实现二叉树的遍历】

目录 一、二叉树的结构二、二叉树的遍历方式三、源码 一、二叉树的结构 二、二叉树的遍历方式 先序遍历: 根–>左–>右中序遍历: 左–>根–>右后序遍历:左–>右–>根层次遍历:顶层–>底层 三、源码 注&am…

SpringBoot04:JSR303数据校验及多环境切换

目录 一、JSR303数据校验 1、如何使用? 2、常见参数 二、多环境切换 1、多配置文件 2、yaml的多文档块 3、配置文件加载位置 一、JSR303数据校验 1、如何使用? SpringBoot中可以用Validated来校验数据,如果数据异常则会统一抛出异常…

python篇---统计列表中每个数字的出现次数

python篇—统计列表中每个数字的出现次数 # -*- coding: utf-8 -*- from collections import Counterlst [1, 2, 3, 3, 4, 1, 2, 5, 5, 5] count Counter(lst) print(每个数字在列表中的出现次数:, count) # 再将collections.Counter格式转换成dict print(dict(c…

C# 难点语法讲解之虚方法(virtual)和隐藏方法的区别---从应用需求开始讲解

这里不单独讲虚方法和隐藏方法是什么&#xff0c;很多文章都有讲&#xff0c;这里只讲他们的区别和应用理解。 另外&#xff1a;如果你不懂MonoBehaviour就别管他&#xff0c; Debug.Log就是Console.WriteLine <一>、隐藏方法 一、隐藏方法的背景故事 从前有个了不起…

C++ | 多线程资源抢占bug解决

多线程资源抢占bug解决 文章目录 多线程资源抢占bug解决bug说明原因排查解决经验>>>>> 欢迎关注公众号【三戒纪元】 <<<<< bug说明 最近调试程序&#xff0c;程序在Release版本下可运行&#xff0c;一直没有问题&#xff0c;在Debug模式下编译后…

Leetcode42 接雨水

给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1] 输出&#xff1a;6 解释&#xff1a;上面是由数组 [0,1,0,2,1,0,1,3,2,1,2,1] 表示的高度图&#xf…

【javascript】防止内容被复制

在JavaScript中&#xff0c;我们可以使用onselectstart事件来防止页面内容被选取。此时无法选取所要的内容。 代码&#xff1a; <!DOCTYPE html> <html><head><meta charset"utf-8"><script>window.onload function() {document.bod…

计算机网络————网络层

文章目录 网络层设计思路IP地址IP地址分类IP地址与硬件地址 协议ARP和RARPIP划分子网和构造超网划分子网构造超网&#xff08;无分类编址CIDR&#xff09; ICMP 虚拟专用网VPN和网络地址转换NATVPNNAT 网络层设计思路 网络层向上只提供简单灵活的、无连接的、尽最大努力交付的数…

mysql表中出现特殊符号(逗号,点号),如何进行查询或操作

mysql表中出现特殊符号&#xff08;逗号&#xff0c;点号&#xff09;&#xff0c;如何进行查询或操作 一、背景说明二、需要把表"引"起来&#xff0c;tab键上面的那个按钮&#xff0c;不是引号 一、背景说明 当mysql表名中出现如点号&#xff08;.&#xff09;&…

安装并使用docker

1、安装docker 1.1、更新现有的包列表&#xff1a; sudo apt update 1.2、用apt安装一些允许通过HTTPS才能使用的软件包&#xff1a; sudo apt install apt-transport-https ca-certificates curl software-properties-common 1.3、将官方Docker存储库的GPG密钥添加到您的系统…

二.Elasticsearch进阶

建议从这里开始看&#xff1a;Elasticsearch快速入门及使用 Elasticsearch进阶 一.Elasticsearch检索方式1.uri 检索参数(不常用)2.uri 请求体(常用&#xff0c;也叫Query DSL) 二.Query DSL语法举例1.match全文匹配2.match_phrase短语匹配3.multi_match多字段匹配4.bool复合…

在Gradio中创建交互式代码编辑器:介绍Code模块和其功能

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

Matlab评价模型--灰色关联度分析

评价模型–灰色关联度分析 灰色关联度分析 基本思想 灰色关联分析的基本思想 是根据序列曲线几何形状的相似程度来判断其联系是否紧密&#xff0c;曲线越接近&#xff0c;相应序列之间的关联度就越大&#xff0c;反之则越小。 此方法可用于 进行系统分析&#xff0c;也可应用…