【机器学习300问】59、计算图是如何帮助人们理解反向传播的?

news2024/10/6 4:01:04

        在学习神经网络的时候,势必会学到误差反向传播,它对于神经网络的意义极其重大,它是训练多层前馈神经网络的核心算法,也是机器学习和深度学习领域中最为重要的算法之一。要正确理解误差反向传播,不妨借助一个工具——计算图。

一、计算图是什么?

(1)计算图的定义

        计算图是一种用来表示数学表达式和计算流程的数据结构(有向图),是一个将复杂计算过程可视化的图形模型。计算图中有节点(vertices)和边(edges)的概念,节点通常表示计算操作或者函数,而边表示数据或参数的流动,一个节点的输出可以成为另一个节点的输入。

(2)举例说明

        举例说明一下,假设有这样一个数据表达式 v=a+u 它的计算图可以画成这样:

        现在,在这个操作的前后分别再加上个操作 u=bc 和 J=3v,那么这个新的复杂一点的计算图可以画成这样:

(3)计算图的主要用途和优点

① 直观的表现计算关系

计算图中的节点代表数值或者变量,而边则表示节点之间的依赖关系,例如加法、乘法或其他复杂操作。这种图形化的表示方式直观展示了计算步骤及其顺序。

② 进行局部计算

无论全局计算是多么的复杂,都可以通过局部计算使各个节点致力于简单的计算,从而简化问题。

③ 内存优化和提升计算效率

计算图可以缓存中间计算结果,避免重复计算,提高内存使用效率和减少运行时间。

④ 自动微分和计算梯度

        在训练神经网络时,需要计算损失函数相对于所有权重的梯度来进行优化。计算图能够轻松实现自动微分,因为每一步计算的梯度可以通过链式法则从输出反向传播至输入,这一过程被称为反向传播。

二、计算图怎么实现正向传播?

(1)正向传播的过程

  1. 初始化节点值:在计算图中,输入层的每一个节点都会被赋予一个初始值,这通常是实际输入数据,比如一组特征向量。
  2. 应用权重和偏置:在计算到达隐藏层之前,输入层的节点值与对应的权重(weights)相乘。此外,每个隐藏层节点都会加上一个偏置值(bias)。这个操作对应计算图中的一个节点,其上一层的所有输出都会汇聚到这个点。
  3. 激活函数:在隐藏层节点加上权重和偏置之后,通过一个激活函数(如ReLU或Sigmoid)来引入非线性。激活函数的选择对网络的性能有着直接影响。此操作通常在计算图中表示为另外一个节点。
  4. 传递到下一层:经过激活函数处理后的值被传递到下一层,如果有多个隐藏层,则这一过程会重复进行,直到达到输出层。
  5. 输出层的计算:最后,当信息流到达输出层,相同的过程(加权、求和、添加偏置以及激活)会得到一个输出值。在分类任务中,这个值通常会通过softmax函数被转换成概率分布。
  6. 得出预测结果:这个最终经过激活处理的输出层的值就是网络的预测结果。对于不同的任务,比如分类、回归或其他,输出层会以不同的方式处理这些值来适应特定任务的需求。

(2)正向传播举例

        以逻辑回归任务为例,来介绍一下上述过程的具体实现。先把逻辑回归所用到的公式说明一下:

z=w^Tx+b=w_1x_1+w_2x_2+b

        假设数据样本只有两个特征x_1x_2,为了计算z,我们需要输入参数w_1,w_2,b,因此z的公式如上所示。逻辑回归模型的公式定义如下,其中\hat y是预测值,\sigma(z)是sigmoid函数:

\hat y = a = \sigma (z)

\sigma(z)=\frac{1}{1+e^{-z}}

        逻辑回归的损失函数定义如下,为了简单理解我给出的是单个样本的代价函数:

L(a,y)=-[ylog(a)+(1-y)log(1-a)]

        其中,a是逻辑回归的输出,y是样本的标签值(真实值)。现在我们画出这个计算图。

(3)人话理解正向传播

        正向传播很好理解,就是你把数值带入上面这个图,从左向右正向一次计算每一个节点中的数学表达式,就得到了最终的输出。 虽然很好理解,但我在这里特意说明一下,在正向传播过程中,“传播”的实际上是信号数据(就是你通过节点式子算出来的值)。

        中间怎么算的你可以忽略,正向传播就是在干这样一件事情,输入一组样本加上对应权重和偏置,得到一个最终的损失函数值。进一步理解就是说,不同的一组样本+权重+偏置会得到不同的损失函数值。因为神经网络学习的目的就是得到最优的权重和偏置(w,b合称参数),所以我把每次输入的训练样本都设成一样,但是允许更换许多许多组不同的权重+偏置的组合,就会得到许多许多不同的损失函数值。

        我想要做的事情就是想得知:当损失函数值最小的时候,对应的权重+偏置是多少,此时的参数就是最优参数。要想达成这个目的就需要误差反向传播。

三、计算图怎么实现反向传播?

(1)反向传播的过程

        计算图在实现反向传播的过程中扮演了关键角色,它提供了一种结构化的表示,使得能够高效且准确地计算神经网络中所有参数相对于损失函数的梯度。下面是利用计算图进行反向传播的基本步骤:

  1. 计算损失函数值:反向传播的起点是损失函数,使用预测结果和实际标签计算损失函数的值。损失函数衡量模型预测结果与真实值之间的差距。
  2. 初始化梯度:在损失节点处,对自身的梯度进行初始化,通常这个梯度是1,因为损失函数关于自身的导数是1。初始化所有参数(权重和偏置)的梯度为零。
  3. 利用链式法则计算每个节点梯度:从输出层开始,根据损失函数的梯度(即损失函数对输出层节点的导数),利用链式法则,将计算得到的梯度沿着计算图反向传播到每个权重和偏置。每个节点的梯度是其下游节点梯度与其本地导数的乘积累积而成。
  4. 梯度聚合:在计算图中,一个节点可能有多个输入,每个权重可能会受到多条路径上的梯度影响,需要将所有这些影响累积起来,得到最终的梯度值。
  5. 参数更新:当所有参数的梯度都通过反向传播计算出来后,根据优化算法(如梯度下降法或其变种)使用这些梯度来更新相应的权重和偏置。

(2)反向传播举例

        以一个输入节点的输入层、一个sigmoid作为激活函数的隐藏层,一个输出节点的输出层为例。先给出他的正向传播计算图:

        除了乘法和加法节点之外,还出现了自然数e为底的幂运算和除法节点。除法节点处的导数公式是:

y=\frac{1}{x}

\frac{\partial y}{\partial x}=-\frac{1}{x^2}=-y^2

反向传播时,会将上游的值乘上-y^2后再传给下游。计算图如下:

加法节点会原封不动的传给下游,计算图如下图所示:

exp节点的y=e^x它的导数就是它自己,所以上游的值要乘上这个exp(-x)后再传给下游,计算图如下所示:

乘法节点将正向传播时的值翻转后做乘法,因此这里需要乘以-1,计算图如下所示:

(3)人话理解反向传播

        既然神经网络学习的目的是学到最优权重w和偏置b(合称参数),这个“最优”怎么定义?就是指能让损失函数值最小的参数,就是最优参数。要想得到最优参数,那么我就得回答这样一个问题,即怎么根据现在误差值来调整w和b才能让误差值变小呢?这个问题可以拆解成另外两个小问题:

① 误差和参数之间它们是怎么互相影响对方的?

        建立参数与误差值之间的公式,也就是损失函数。通过损失函数就找到了他们之间的桥梁,就能知道误差想要表小,参数该如何变化。

② 如果误差变化了参数该如何变化?

        为了使误差减小,我们需要根据损失函数关于参数的梯度来调整参数。梯度给出了损失函数在当前参数值处的斜率,指向损失函数值下降最快的方向。这里有必要重提一下导数的定义:“一函数在某一点的导数描述了这个函数在这一点附近的变化率”,也就是自变量变,函数值跟着变,还不是简单的跟着变多少,还知道变的快慢。假设我有w_1,w_2,b三个参数,对损失函数求偏导就能知道每一个参数,它对应要变化多快或多慢,将这种变化率\partial w_1,\partial w_2,\partial d反向传给原来的w_1,w_2,b,就能指导参数变化的具体值:

w_{new}=w_{old}-a\partial w

b_{new}=b_{old}-a\partial b

        所以反向传播,“传播”的是误差信号在神经网络中的梯度(就是指导参数该怎么变的变化率)。具体来说,它并不是传播实际的输入数值或预测结果,而是传播从输出层到输入层的损失函数相对于网络中每个参数的梯度(梯度就是偏导数的向量表示)。这些梯度信息表明了在当前参数设置下,损失函数增大或减小的趋势。也就是说,告诉我们应当如何调整权重和偏置才能降低损失函数的值。

③ 举个小孩儿做题的例子

        反向传播就好比你正在教一个小孩做一道复杂的数学题。小孩做完题后,你会检查答案是否正确,如果不对,你会告诉他哪里错了,让他改正。在这个过程中:

  • 正向传播:就像是小孩按照步骤一步步计算题目。比如说他要计算 (a+b) × c,他先算出 a+b 的结果,然后再把这个结果乘以 c 得到最后的答案。

  • 损失函数:相当于你用来判断小孩答案对错的标准。如果他的答案离正确答案差很多,你就有一个衡量错误程度的“分数”。

  • 反向传播:是你指导小孩如何改正错误的过程。假设他最后的答案错了,你会告诉他:“你计算的最后一步有问题,你需要知道是因为 c 值没乘对还是前面 a+b 的结果就不对。”于是你从最后一个步骤开始,告诉小孩每一步对他最后答案的影响有多大(也就是计算梯度),这样他才能有针对性地调整自己的计算步骤,以便下次做得更好。

        这里的关键就在于,你没有告诉他正确答案(也就是损失函数最小的时候对应的参数,有时候正确答案你自己都不知道),而是通过提示让他自己一步步去找到正确答案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1561349.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件心学格物致知篇(5)愿望清单上篇

愿望清单 前言 最近发现愿望清单是一个很有意思的词,结合自己的一些过往经验得到一点点启发。 我发现在众多领域都有东西想伪装成它。 比如一些企业的企业战略,比如客户提出的一些软件需求,比如一些系统的架构设计指标,比如一…

k8s练习-创建一个Deployment

创建Deployment 创建一个nginx deployment [rootk8s-master home]# cat nginx-deployment.yaml apiVersion: apps/v1 kind: Deployment metadata:name: nginx-deployment spec:selector:matchLabels:app: nginx # 配置pod的labelsreplicas: 2 # 声明2个副本template:metada…

2024.3.26 ARM

SPI相关理论 概述 SPI,是Serial Peripheral interface的缩写,是串行外围设备接口,是一种高速的,全双工,同步的通信总线,并且在芯片的管脚上只占用四根线,节约了芯片的管脚,同时为P…

数字孪生|山海鲸数据管家简介及安装步骤

哈喽,大家好啊,我是雷工! 最近在学习数字孪生相关的软件山海鲸,了解到采集Modbus协议需要先安装山海鲸数据管家,本节先学习数据管家及安装步骤,以下为学习笔记: 1、简介 数据管家是帮用户进行数据管理与转发的软件,能够解决山海鲸可视化等软件对数据接入过程中的许多…

实体机双系统安装

实体机双系统安装 第一步:下载openKylin镜像 前往官网下载x86_64的镜像(https://www.openkylin.top/downloads/628-cn.html) tips:下载完镜像文件后,请先检查文件MD5值是否和官网上的一致,如果不一致请重…

【python基础教程】2. 算法的基本要素与特性

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:python基础教程 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、…

Java中常见的锁策略

目录 乐观锁 vs 悲观锁 悲观锁: 乐观锁: 重量级锁 vs 轻量级锁 ⾃旋锁(Spin Lock) 公平锁 vs 非公平锁 可重⼊锁 vs 不可重入锁 读写锁 乐观锁 vs 悲观锁 悲观锁: 总是假设最坏的情况,每次去拿数据的时候都认为别…

蓝桥杯第1593题——二进制问题

题目描述 小蓝最近在学习二进制。他想知道 1 到 N 中有多少个数满足其二进制表示中恰好有 K 个 1。你能帮助他吗? 输入描述 输入一行包含两个整数 N 和 K。 输出描述 输出一个整数表示答案。 输入输出样例 示例 输入 7 2输出 3评测用例规模与约定 对于 30% …

YOLOv9改进策略 :主干优化 | 极简的神经网络VanillaBlock 实现涨点 |华为诺亚 VanillaNet

💡💡💡本文改进内容: VanillaNet,是一种设计优雅的神经网络架构, 通过避免高深度、shortcuts和自注意力等复杂操作,VanillaNet 简洁明了但功能强大。 💡💡💡引入VanillaBlock GFLOPs从原始的238.9降低至 165.0 ,保持轻量级的同时在多个数据集验证能够高效涨点…

跑spark的yarn模式时RM连不上的情况

在linux控制台跑spark on yarn一个测试案例,日志中总显示RM连yarn服务的时候是:0.0.0.0:8032 具体情况如下图: 我问题出现的原因,总结如下: 1.防火墙没关闭,关闭 2.spark-env.sh这个文件的YARN_CONF_DIR…

NRF52832修改OTA升级时的bootloader蓝牙MAC

NRF52832在OTA升级时,修改了APP的蓝牙MAC会导致无法升级,原因是OTA程序的蓝牙MAC没有被修改所以手机扫描蓝牙时无法连接 解决办法 在bootloader的程序里面加入修改蓝牙mac地址的代码实现原理: 在bootloader蓝牙广播开启之前修改蓝牙mac 通…

Java 处理Mysql获取树形的数据

Mysql数据&#xff1a; 代码如下&#xff1a; Entity&#xff1a; Data Accessors(chain true) public class Region {private BigInteger id;//名称private String name;//父idprivate BigInteger parentId;private List<Region> children;private Integer createTim…

基于SSM+Jsp+Mysql的固定资产管理系统

开发语言&#xff1a;Java框架&#xff1a;ssm技术&#xff1a;JSPJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包…

linux系统中启动MyCat问题总结

最近在深入学习mysql的底层原理和应用&#xff0c;今天在学习分库分表操作的时候&#xff0c;由于一些配置问题&#xff0c;导致无法正常启动MyCat。最终通过GPT和其他博客将问题解决了&#xff0c;以下是一篇关于MyCat启动异常的解决方案。 1.思路&#xff1a; 配置问题其实…

python练习二

# Demo85def pai_xu(ls_test):#创建一个列表排序函数命名为pai_xu# 对创建的函数进行注释"""这是一个关于列表正序/倒序排列的函数:param ls_test: 需要排序的列表:return:"""ls1 [int(ls_test[i]) for i in range(len(ls_test))]#对input输入的…

求组合数I(acwing)

题目描述&#xff1a; 给定n组询问&#xff0c;每组询问给定两个整数a&#xff0c;b&#xff0c;请你输出Ca^b mod(1e97)的值。 输入格式: 第一行包含整数n。 接下来n行&#xff0c;每行包含一组a和b。 输出格式: 共n行&#xff0c;每行输出一个询问的解。 …

C++刷题篇——05静态扫描

一、题目 二、解题思路 注意&#xff1a;注意理解题目&#xff0c;缓存的前提是先扫描一次 1、使用两个map&#xff0c;两个map的key相同&#xff0c;map1&#xff1a;key为文件标识&#xff0c;value为文件出现的次数&#xff1b;map2&#xff1a;key为文件标识&#xff0c;va…

Navicat设置mysql权限

新建用户&#xff1a; 注意&#xff1a;如果不生效执行刷新命令:FLUSH PRIVILEGES; 执行后再重新打开查看&#xff1b; 查询权限命令&#xff1a;1234为新建的用户名&#xff0c;localhost为访问的地址 SHOW GRANTS FOR 1234localhost;如果服务器设置服务器权限后可能会出现权…

【Docker】搭建安全可控的自定义通知推送服务 - Bark

【Docker】搭建安全可控的自定义通知推送服务 - Bark 前言 本教程基于绿联的NAS设备DX4600 Pro的docker功能进行搭建。 简介 Bark是一款为Apple设备用户设计的开源推送服务应用&#xff0c;它允许开发者、程序员以及一般用户将信息快速推送到他们自己的iPhone、iPad等设备上…

副业赚钱攻略:给工资低的你6个实用建议,闷声致富不是梦

经常有朋友向我咨询&#xff0c;哪些副业比较靠谱且能赚钱。实际上&#xff0c;对于大多数打工族而言&#xff0c;副业不仅是增加收入的途径&#xff0c;更是利用业余时间提升自我、实现价值的重要方式。 鉴于此&#xff0c;今天我想和大家分享六个值得尝试的副业&#xff0c;…