从0开始深度学习(6)——Pytorch动态图机制(前向传播、反向传播)

news2024/11/17 14:20:51

PyTorch 的动态计算图机制是其核心特性之一,它使得深度学习模型的开发更加灵活和高效。

0 计算图

计算图(Computation Graph)是一种用于表示数学表达式或程序流程的图形结构,可以将复杂的表达式分解成一系列简单的操作,并以节点和边的形式展示这些操作及其之间的关系,能够清晰地展示计算过程中的依赖关系

  • 节点(Nodes): 表示变量或常量,也可以表示操作(如加法、乘法等)。
  • 边(Edges) :表示数据流的方向,即一个操作的结果如何作为输入传递给下一个操作。

举例说明

假设我们有如下的表达式:
x = a + b x=a+b x=a+b
y = c ∗ d y=c*d y=cd
z = x + y z=x+y z=x+y
在这里插入图片描述

其中 a , b , c , d a,b,c,d a,b,c,d是输入变量, z z z是输出变量。

假设令 a = 1 , b = 2 , c = 3 , d = 4 a=1,b=2,c=3,d=4 a=1,b=2,c=3,d=4,则可以根据上面的计算图,计算出 z z z,计算过程如下:
x = a + b = 1 + 2 = 3 x=a+b=1+2=3 x=a+b=1+2=3
y = c ∗ d = 3 ∗ 4 = 12 y=c*d=3*4=12 y=cd=34=12
z = x + y = 12 + 3 = 5 z=x+y=12+3=5 z=x+y=12+3=5

1 前向传播、反向传播

但是在机器学习中,计算图不仅展示了计算路径,还为计算梯度提供了基础,这里涉及到前向传播和反向传播

1.1 前向传播(计算输出)

根据当前的模型参数计算网络的输出,前向传播的输出作用是:

  • 评估模型性能: 通过比较网络的预测值与实际值,可以计算损失函数的值,从而评估模型当前的性能。
  • 准备反向传播: 前向传播计算出的中间结果(例如每个隐藏层的输出)是反向传播中计算梯度所必需的。上面的计算过程就是前向传播。
    在这里插入图片描述

1.2 反向传播(计算梯度)

目的是计算损失函数关于每个模型参数的梯度。这些梯度信息用于更新模型参数,使模型在下一次迭代中表现得更好,关于梯度是什么,在从0开始深度学习(2)——自动微分解释了梯度的概念和计算过程。
在这里插入图片描述

现在依然以上面的例子举例,假设我么们有一个损失函数 L L L,并且 L L L关于 z z z的梯度已知为 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 zL=1(通常都这样设置,且适用于多种函数),目标是计算 a , b , c , d a,b,c,d a,b,c,d关于 L L L的梯度,反着计算。
在这里插入图片描述

1 静态计算图

静态图是通过先定义后运行的方式,先搭建图,然后再输入数据进行计算,典型代表是Tensorflow 1.0 版本,Tensorflow名字的来源就是因为张量Tensor在预先定义的图中流动(Flow)
在这里插入图片描述

2 动态计算图

动态图是指计算图的运算和搭建同时运行,也就是可以先计算前面的节点的值,再根据这些值搭建后面的图,如下图所示:
在这里插入图片描述

2.1核心概念

2.1.1 张量

参考从0开始深度学习(1)—— 代码实现线性代数中的概念解释

2.1.2 自动微分

pytorch框架中借助自动微分机制来实现动态计算图,意味着计算图是在运行时动态构建的,它使得计算复杂模型的梯度变得非常方便。

  • 计算图在每次前向传播时重新构建。这意味着每个操作都会记录下来,并在反向传播时按需计算梯度。
  • 每个张量(Tensor)都有一个 .grad 属性来存储梯度,以及一个 .grad_fn 属性来记录创建该张量的操作。

关键属性:

  1. data:包含tensor的实际数据
  2. grad:存储tensor的梯度。在开始训练前,我们会使用zero_grad()把梯度清零,第一次反向传播后, x.grad会包含当前损失函数关于x的梯度,如果第二次反向传播前不清零,则新的梯度会累加x.grad,即包含两次反向传播的梯度之和
  3. grad_fn:记录了创建tensor的操作,表明该tensor是通过何种前向传播计算出来的,例如x是通过加法计算出来的,则x.grad_fn指向一个加法操作的函数对象

2.2 举例说明

依然以上述例子为例,我先举一个静态图的例子,是在TensorFlow 1.x 中,需要提前定义好整个计算图,然后在会话中执行:

import tensorflow as tf

# 定义输入和参数
a = tf.placeholder(tf.float32, shape=[])
b = tf.placeholder(tf.float32, shape=[])
c = tf.placeholder(tf.float32, shape=[])
d = tf.placeholder(tf.float32, shape=[])

# 前向传播,此处无法查看,只能在Session中查看
x = a + b
y = c * d
z = x + y

# 定义损失函数
L = z

# 定义梯度
grad_a, grad_b, grad_c, grad_d = tf.gradients(L, [a, b, c, d])

# 创建会话
with tf.Session() as sess:
    # 前向传播
    print("Starting forward pass...")
    x_val, y_val, z_val = sess.run([x, y, z], feed_dict={a: 1.0, b: 2.0, c: 3.0, d: 4.0})
    print(f"x: {x_val}, grad_fn: N/A")
    print(f"y: {y_val}, grad_fn: N/A")
    print(f"z: {z_val}, grad_fn: N/A")

    # 打印前向传播的结果
    print(f"x: {x_val}")  # 3
    print(f"y: {y_val}")  # 12
    print(f"z: {z_val}")  # 15

    # 反向传播
    print("Starting backward pass...")
    grad_a_val, grad_b_val, grad_c_val, grad_d_val = sess.run([grad_a, grad_b, grad_c, grad_d], feed_dict={a: 1.0, b: 2.0, c: 3.0, d: 4.0})

    # 打印梯度
    print(f"a.grad: {grad_a_val}")  # 1
    print(f"b.grad: {grad_b_val}")  # 1
    print(f"c.grad: {grad_c_val}")  # 4
    print(f"d.grad: {grad_d_val}")  # 3

接下来举一个动态图的例子
pytorch动态图例子:

import torch

# 定义输入和参数
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([2.0], requires_grad=True)
c = torch.tensor([3.0], requires_grad=True)
d = torch.tensor([4.0], requires_grad=True)

# 前向传播,在定义操作中,就可以查看梯度和操作,不像静态图中,必须在Session中才能查看
print("Starting forward pass...")
x = a + b
print(f"x: {x.item()}, grad_fn: {x.grad_fn}")
# 输出结果:x: 3.0, grad_fn: <AddBackward0 object at 0x000002113E68B790>

y = c * d
print(f"y: {y.item()}, grad_fn: {y.grad_fn}")
# 输出结果:y: 12.0, grad_fn: <MulBackward0 object at 0x000002113D43C070>

z = x + y
print(f"z: {z.item()}, grad_fn: {z.grad_fn}")
# z: 15.0, grad_fn: <AddBackward0 object at 0x000002113E68B790>

# 打印前向传播的结果
print(f"x: {x.item()}")  # 3
print(f"y: {y.item()}")  # 12
print(f"z: {z.item()}")  # 15
'''
输出结果:
x: 3.0
y: 12.0
z: 15.0
'''

# 假设损失函数 L = z
L = z

# 反向传播
print("Starting backward pass...")
L.backward()

# 打印梯度
print(f"a.grad: {a.grad.item()}")  # 1
print(f"b.grad: {b.grad.item()}")  # 1
print(f"c.grad: {c.grad.item()}")  # 4
print(f"d.grad: {d.grad.item()}")  # 3
'''
输出结果:.
a.grad: 1.0
b.grad: 1.0
c.grad: 4.0
d.grad: 3.0
'''

3 利用动态图机制,修改流程

我们这里希望根据 a a a的值来决定计算流程,详情看注释

import torch

# 定义输入和参数
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([2.0], requires_grad=True)

# 采用简洁的代码,可以动态的修改计算流程,控制更灵活
if a.item() > 0:
    x = a + b
    print(f"x: {x.item()}")  # 3
else:
    x = a - b

# 前向传播
y = x * 2
print(f"y: {y.item()}")  # 6

# 反向传播
y.backward()

# 打印梯度
print(f"a.grad: {a.grad.item()}")  # 2
print(f"b.grad: {b.grad.item()}")  # 2

如果是在静态图中需要通过控制流来决定,与动态图相比,操作更复杂,不直观

# 条件分支
x = tf.cond(a > 0, lambda: a + b, lambda: a - b)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2173521.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Materials Studio零基础专题培训重磅来袭

一、软件介绍 Materials Studio是一款由美国Accelrys公司开发的新一代材料计算软件&#xff0c;专为材料科学领域的研究者设计&#xff0c;能够运行在PC上进行各种模拟研究。以下是对它的具体介绍&#xff1a; 1. 软件简介 定义与用途&#xff1a;Materials Studio是一款专门…

基于SpringBoot+Vue的智慧博物馆管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

MapReduce学习与理解

MapReduce为google分布式三驾马车之一。分别为《The Google File System》、《MapReduce: Simplified Data Processing on Large Clusters》、《Bigtable: A Distributed Storage System for Structured Data》。三遍论文奠定了分布式存储和计算的基础。本篇文章来说说mapreduc…

在 commit 里使用 emoji~

在 git commit 上使用 emoji 提供了一种简单的方法&#xff1a;仅通过查看所使用的 emoji 来确定提交的目的或意图&#xff0c;非常好理解&#xff0c;阅读体验很棒。 ‍ 效果 以我的 博客项目 为例&#xff0c;可以看到不少的 emoji &#xff1a; ‍ ‍ 使用方法 直接在…

【Linux】驱动的基本架构和编译

驱动源码 /** Silicon Integrated Co., Ltd haptic sih688x haptic driver file** Copyright (c) 2021 kugua <daokuan.zhusi-in.com>** This program is free software; you can redistribute it and/or modify it* under the terms of the GNU General Public Licen…

python基础库

文章目录 1.研究目的2.platform库介绍3.代码4.结果展示 1.研究目的 最近项目中需要利用python获取计算机硬件的一些基本信息,查阅资料,.于是写下这篇简短的博客,有问题烦请提出,谢谢-_- 2.platform库介绍 platform 库是 Python 的一个内置库&#xff0c;可以让我们轻松地获取…

京东面试:RR隔离mysql如何实现?什么情况RR不能解决幻读?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; 谈谈&#xff1a;mysql 事务隔离的底层原理&#xff1…

Cilium + ebpf 系列文章- (七)Cilium-BGP-自定义定时器-ebgp多跳-优雅重启-MD5加密-传播团体字

一、自定义定时器 这里指的是自定义: Keepalive Interval: 缺省值为30秒。Keepalive用于维护邻居关系&#xff0c;如果在协商的保持时间内没有收到Keepalive消息&#xff0c;则BGP将断开邻居连接。 Hold Time:缺省值是Keepalive时间的3倍&#xff0c;即90秒。这是BGP在关闭连…

spark计算引擎-架构和应用

一Spark 定义&#xff1a;Spark 是一个开源的分布式计算系统&#xff0c;它提供了一个快速且通用的集群计算平台。Spark 被设计用来处理大规模数据集&#xff0c;并且支持多种数据处理任务&#xff0c;包括批处理、交互式查询、机器学习、图形处理和流处理。 核心架构&#x…

c++九月27日

1.顺序表 #ifndef ARRAYLIST_H #define ARRAYLIST_H#include <iostream> #include <stdexcept>template <typename T> class ArrayList { private:T* data; // 存储数据的数组int capacity; // 数组容量int size; // 当前元素数量publ…

【linux】基础IO(下)

8. 理解文件系统 8.1. 认识硬件 --- 磁盘 唯一的机械设备&#xff0c;也是一个外设 注意&#xff1a; 磁头是一面一个&#xff0c;磁头和盘面不接触在软件设计上&#xff0c;设计者会有意识地将相关数据放在一起一般来说&#xff0c;运动越少&#xff0c;效率越高&#xff1…

vue单点登录异步执行请求https://xxx.com获取并处理数据

一、请求一个加密地址获取access_token再拼接字符串再次请求 接口返回数据 异步执行请求该地址获取数据并处理 二、请求代码第二步使用 access_token 获取 auth_key // 第二步&#xff1a;使用 access_token 获取 auth_keyconst access_token tokenData.access_token;const …

什么是NAND Flash?

什么是NAND Flash? NAND闪存是一种非易失性存储器技术&#xff0c;它彻底改变了数字时代的数据存储。它是闪存的一种形式&#xff0c;这意味着它可以被电擦除和重新编程。NAND闪存以NAND&#xff08;NOT-AND&#xff09;逻辑门命名&#xff0c;该逻辑门用于其基本架构。术语“…

服务运营 | 竞价风暴:在线广告交易的实时拍卖与定价艺术

编者按&#xff1a; 在广告交易领域&#xff0c;尤其是谷歌等平台的广告交易中&#xff0c;每一次广告展示——即向特定浏览者展示广告的机会——都是由出版商&#xff08;publisher&#xff09;&#xff0c;例如《纽约时报》网站&#xff0c;通过实时拍卖的方式出售给广告商。…

中航通用飞机社招入职笔试:SHL题库综合能力性格问卷题型分析、高分攻略

中航通用飞机有限责任公司是中国航空工业集团有限公司旗下的大型国有企业&#xff0c;专注于通用航空产品的研制、通航运营与服务、航空零部件制造等业务。公司注册资本133.66亿元人民币&#xff0c;总资产约667亿元&#xff0c;员工人数超过16000人。产品线丰富&#xff0c;包…

8.使用 VSCode 过程中的英语积累 - Help 菜单(每一次重点积累 5 个单词)

前言 学习可以不局限于传统的书籍和课堂&#xff0c;各种生活的元素也都可以做为我们的学习对象&#xff0c;本文将利用 VSCode 页面上的各种英文元素来做英语的积累&#xff0c;如此做有 3 大利 这些软件在我们工作中是时时刻刻接触的&#xff0c;借此做英语积累再合适不过&a…

C# 用Timer控件简单写一个倒计时60s功能

先放界面上一个Label和一个Timer控件&#xff0c;Label用来展示倒计时秒数 添加事件 设置属性&#xff0c;设置每隔一秒执行一次 放代码&#xff1a; //设置时间控件开始运行&#xff0c;具体放在哪里看具体需求 this.timer1.Start();//定义一个全局变量表示秒数 int time…

【手机直连卫星】除了华为Mate 60 Pro,支持卫星通信的手机还有哪些款

2023年底&#xff0c;华为推出的Mate 60 Pro手机&#xff0c;开创了智能手机卫星通信的新纪元。它支持卫星电话通话和短信功能&#xff0c;让用户即使在偏远山野或深海之上也能保持与外界的联系。这一技术的加入&#xff0c;无疑为户外探险者和遥远地区的工作者提供了难以估量的…

影院管理革新:小徐的Spring Boot应用

第二章开发技术介绍 2.1相关技术 小徐影城管理系统是在Java MySQL开发环境的基础上开发的。Java是一种服务器端脚本语言&#xff0c;易于学习&#xff0c;实用且面向用户。全球超过35&#xff05;的Java驱动的互联网站点使用Java。MySQL是一个数据库管理系统&#xff0c;因为它…

港科夜闻 | 香港科大颁授荣誉大学院士予五位杰出人士

关注并星标 每周阅读港科夜闻 建立新视野 开启新思维 1、香港科大颁授荣誉大学院士予五位杰出人士。香港科大9月24日向五位杰出人士颁授荣誉大学院士&#xff0c;他们分别为包弼德教授、简吴秋玉女士、高秉强教授、吴永顺先生及容永祺博士(按姓氏英文字母排序)。荣誉大学院士颁…