重新回顾反向传播与梯度下降:训练神经网络的基石

news2024/11/5 11:44:57

有关反向传播与梯度下降:流程与公式推导

  • 背景
  • 前向传播
  • 反向传播

背景

  反向传播则是一种训练神经网络的算法,目前我们使用的深度学习模型大都是通过这种方法训练的。它的核心思想是通过计算损失函数相对于每个参数的导数,来更新神经网络中的权重和偏置。反向传播负责计算梯度,梯度下降负责利用这些梯度来更新参数。两者的区别就是它们的目的不同:梯度下降是为了更新模型的权重和偏置;反向传播是为了在神经网络中获取损失函数关于每一层参数的导数(即每一层神经网络权重的梯度),为梯度下降提供依据。

前向传播

  我们首先回顾一个简单的神经网络,使用两个输入,隐藏层有两个神经元(带有ReLU激活函数),以及一个预测(输出层):
在这里插入图片描述
  每个隐藏神经元都在执行以下过程:
在这里插入图片描述
  其中,input是我们输入的特征。weights是我们用来乘以输入的系数,我们算法的目标就是找到最优权重。Linear Weighted Sum将输入和权重的乘积相加,并加上一个偏置项b。之后经过一个激活函数增加非线性,ReLU是最常用的激活函数。
  在我们上面提到的那个简单的神经网络中,hidden layer存储了多个神经元以学习数据模式。一个神经网络有可能是多层的,即包含多个hidden layers。
  要训练一个神经网络,首先要让它生成预测,这被称为前向传播,即数据从第一层到最后一层(也称为输出层)遍历所有神经元的过程。以我们提到的简单的神经网络作为例子,我们首先为其创建一些任意的权重、偏置和输入:Input: [ 0.9 , 1.0 ] [0.9, 1.0] [0.9,1.0]、输入到隐藏层的权重 W 1 W_1 W1:神经元1: W 1 , 1 = [ 0.2 , 0.3 ] W_{1,1}= [0.2, 0.3] W1,1=[0.2,0.3]
神经元2: W 1 , 2 = [ 0.4 , 0.5 ] W_{1,2}=[0.4, 0.5] W1,2=[0.4,0.5]、隐藏层偏置 b 1 : [ 0.1 , 0.2 ] b_1:[0.1, 0.2] b1:[0.1,0.2]、隐藏层到输出层的权重 W 2 : [ 0.5 , 0.6 ] W_2:[0.5, 0.6] W2:[0.5,0.6]、输出层偏置 b 2 : [ 0.4 ] b_2:[0.4] b2:[0.4]。target/label设置为 [ 2.0 ] [2.0] [2.0]
  初始化之后,我们现在可以进行前向传播,过程如下:从输入到隐藏层的线性加权和 z 1 1 z¹_{1} z11 z 2 1 z¹_{2} z21为:
z 1 1 = W 1 , 1 ⋅ I n p u t + b 1 , 1 = [ 0.2 , 0.3 ] ⋅ [ 0.9 , 1.0 ] + 0.1 = 0.58 z¹_{1}=W_{1,1}\cdot Input+b_{1,1}=[0.2,0.3]\cdot[0.9,1.0]+0.1=0.58 z11=W1,1Input+b1,1=[0.2,0.3][0.9,1.0]+0.1=0.58 z 1 1 = W 1 , 2 ⋅ I n p u t + b 1 , 2 = [ 0.4 , 0.5 ] ⋅ [ 0.9 , 1.0 ] + 0.2 = 1.06 z¹_{1}=W_{1,2}\cdot Input+b_{1,2}=[0.4,0.5]\cdot[0.9,1.0]+0.2=1.06 z11=W1,2Input+b1,2=[0.4,0.5][0.9,1.0]+0.2=1.06  然后,我们再隐藏层执行ReLU激活函数,得到 a 1 , 1 1 a^1_{1,1} a1,11 a 1 , 2 1 a^1_{1,2} a1,21,然后生成整个网络的输出,这一步不涉及激活函数: z 1 2 = W 2 ⋅ [ a 1 , 1 1 , a 1 , 2 1 ] + b 2 = [ 0.5 , 0.6 ] ⋅ [ 0.58 , 1.06 ] + 0.4 = 1.326 z^{2}_1=W_2\cdot[a^1_{1,1},a^1_{1,2}]+b_2=[0.5,0.6]\cdot[0.58,1.06]+0.4=1.326 z12=W2[a1,11,a1,21]+b2=[0.5,0.6][0.58,1.06]+0.4=1.326  现在,我们就完成了第一次前向传播。这个过程可以直观地展示出来:
在这里插入图片描述

反向传播

  完成前向传播后,我们拿到了网络的预测,我们希望通过预测和真实值的误差更新网络权重和偏置,以最小化网络预测结果的误差,这一步是通过反向传播算法实现的。
  接下来,让我们深入了解这一算法的原理,反向传播旨在计算每个权重和偏置相对于误差(损失)的偏导数。然后,使用梯度下降法更新每个参数,从而最小化每个参数引起的误差(损失)。我们通过一个使用计算图的“简单”例子来说明。考虑以下函数: f ( x , y , z ) = z ( x − y ) f(x,y,z)=z(x-y) f(x,y,z)=z(xy)  我们将其绘制为计算图:
在这里插入图片描述
  这是一个关于如何计算 f ( x , y , z ) f(x,y,z) f(x,y,z)的流程图,将 p p p表示为 x − y x-y xy,现在,让我们代入一些数值:
在这里插入图片描述
  计算 f(x,y,z) 的最小值需要使用微积分,特别是,我们需要知道 f ( x , y , z ) f(x,y,z) f(x,y,z)关于其三个变量 x 、 y 、 z x、y、z xyz的偏导数。我们可以从计算 p = x − y p=x-y p=xy f = p z f=pz f=pz的偏导数开始: p = x − y        ∂ p ∂ x = 1 , ∂ p ∂ y = 1 p=x-y\ \ \ \ \ \ \frac{\partial p}{\partial x}=1,\frac{\partial p}{\partial y}=1 p=xy      xp=1,yp=1 f = p z        ∂ f ∂ p = z , ∂ f ∂ z = p f=pz\ \ \ \ \ \ \frac{\partial f}{\partial p}=z,\frac{\partial f}{\partial z}=p f=pz      pf=z,zf=p  进一步地,我们使用链式法则来求解这些偏导: ∂ f ∂ x \frac{\partial f}{\partial x} xf ∂ f ∂ y \frac{\partial f}{\partial y} yf ∂ f ∂ z \frac{\partial f}{\partial z} zf,以x为例: ∂ f ∂ x = ∂ f ∂ p ⋅ ∂ p ∂ x = z \frac{\partial f}{\partial x}=\frac{\partial f}{\partial p}\cdot \frac{\partial p}{\partial x}=z xf=pfxp=z  通过这种方式,我们对 y , z y,z yz进行同样的操作,求得 ∂ f ∂ y = z \frac{\partial f}{\partial y}=z yf=z ∂ f ∂ z = x − y \frac{\partial f}{\partial z}=x-y zf=xy,现在,我们可以在计算图上写下这些梯度及其对应的值:
在这里插入图片描述
  请注意,在训练网络的过程中,我们的期望是最小化损失函数,那么这里的 f ( x , y , z ) f(x,y,z) f(x,y,z)就相当于我们的损失函数,梯度下降法通过沿梯度的反方向更新值 ( x , y , z ) (x,y,z) (x,y,z)的一个小量来工作。例如,对于 x: x : = x − h ∂ f ∂ x x:=x-h\frac{\partial f}{\partial x} x:=xhxf  那么这里的 h h h指的就是学习率,它决定了我们优化速度地大小,当我们将学习率设置为0.1,那么x被更新为3.7,我们再来观察现在的输出:
在这里插入图片描述
  输出变小了,也就是说,它正在被最小化!现在,让我们将这个过程应用到上面提到的简单神经网络示例中。请记住,我们的预测值是 1.326,假设目标值是 2.0。以均方误差作为损失函数: l o s s = 0.5 ( z 1 2 − t a r g e t ) 2 = 0.5 ( 1.326 − 2 ) 2 = 0.37826 loss = 0.5(z_1^{2} - target)^2 = 0.5(1.326 - 2)^2 = 0.37826 loss=0.5(z12target)2=0.5(1.3262)2=0.37826  请注意,在我们网络中使用的符号,上标表示层数。接下来,我们要计算损失相对于预测值的梯度: ∂ l o s s ∂ z 1 2 = z 1 2 − t a r g e t = − 0.674 \frac{\partial loss}{\partial z_1^{2}}=z_1^{2}-target=-0.674 z12loss=z12target=0.674  现在,我们需要计算损失相对于输出层偏置和权重(W_2 和 b_2)的梯度: ∂ l o s s ∂ W 2 , 1 = ∂ l o s s ∂ z 1 2 ⋅ ∂ z 1 2 ∂ W 2 , 1 = − 0.674 × 0.58 = − 0.39092 \frac{\partial loss}{\partial W_{2,1}} = \frac{\partial loss}{\partial z_1^{2}} \cdot \frac{\partial z_1^{2}}{\partial W_{2,1}} = -0.674 \times 0.58 = -0.39092 W2,1loss=z12lossW2,1z12=0.674×0.58=0.39092 ∂ l o s s ∂ W 2 , 2 = ∂ l o s s ∂ z 1 2 ⋅ ∂ z 1 2 ∂ W 2 , 2 = − 0.674 × 1.06 = − 0.71444 \frac{\partial loss}{\partial W_{2,2}} = \frac{\partial loss}{\partial z_1^{2}} \cdot \frac{\partial z_1^{2}}{\partial W_{2,2}} = -0.674 \times 1.06 = -0.71444 W2,2loss=z12lossW2,2z12=0.674×1.06=0.71444 ∂ l o s s ∂ b 2 = ∂ l o s s ∂ z 1 2 ⋅ ∂ z 1 2 ∂ b 2 = − 0.674 × 1 = − 0.674 \frac{\partial loss}{\partial b_2} = \frac{\partial loss}{\partial z_1^{2}} \cdot \frac{\partial z_1^{2}}{\partial b_2} = -0.674 \times 1 = -0.674 b2loss=z12lossb2z12=0.674×1=0.674  这些表达式看起来可能很复杂,但我们其实只是进行了部分微分,并多次应用了链式法则,根据这一公式: z 1 2 = W 2 , 1 × a 1 1 + W 2 , 2 × a 2 1 + b 2 z_1^{2}=W_{2,1}\times a_1^{1}+W_{2,2}\times a_2^{1}+b_2 z12=W2,1×a11+W2,2×a21+b2  最后一步是使用梯度下降法更新参数,学习率设置为 h = 0.1 h=0.1 h=0.1 W 2 , 1 : = W 2 , 1 − h ∂ ( loss ) ∂ ( W 2 , 1 ) = 0.5 − ( 0.1 ) ( − 0.39092 ) = 0.539092 W_{2,1} :=W_{2,1} - h \frac{\partial (\text{loss})}{\partial (W_{2,1})}=0.5 - (0.1)(-0.39092)=0.539092 W2,1:=W2,1h(W2,1)(loss)=0.5(0.1)(0.39092)=0.539092 W 2 , 2 : = W 2 , 2 − h ∂ ( loss ) ∂ ( W 2 , 2 ) = 0.6 − ( 0.1 ) ( − 0.71444 ) = 0.671444 W_{2,2} :=W_{2,2} - h \frac{\partial (\text{loss})}{\partial (W_{2,2})}=0.6 - (0.1)(-0.71444)=0.671444 W2,2:=W2,2h(W2,2)(loss)=0.6(0.1)(0.71444)=0.671444 b 2 : = b 2 − h ∂ ( loss ) ∂ ( b 2 ) = 0.4 − ( 0.1 ) ( − 0.674 ) = 0.4674 b_2 :=b_2 - h \frac{\partial (\text{loss})}{\partial (b_2)}=0.4 - (0.1)(-0.674)=0.4674 b2:=b2h(b2)(loss)=0.4(0.1)(0.674)=0.4674  我们已经更新了输出层的权重和偏置!接下来,我们要对隐藏层的权重和偏置重复这个过程,并使用梯度下降法更新这些权重: w 1 , 1 : = w 1 , 1 − h ∂ l o s s ∂ w 1 , 1 = 0.23033 w_{1,1}:=w_{1,1}-h\frac{\partial loss}{\partial w_{1,1}}=0.23033 w1,1:=w1,1hw1,1loss=0.23033 w 1 , 2 : = w 1 , 2 − h ∂ l o s s ∂ w 1 , 2 = 0.3337 w_{1,2}:=w_{1,2}-h\frac{\partial loss}{\partial w_{1,2}}=0.3337 w1,2:=w1,2hw1,2loss=0.3337 b 1 : = b 1 − h ∂ l o s s ∂ b 1 = 0.1337 b_1:=b_1-h\frac{\partial loss}{\partial b_1}=0.1337 b1:=b1hb1loss=0.1337 w 2 , 1 : = w 2 , 1 − h ∂ l o s s ∂ w 2 , 1 = 0.4364 w_{2,1}:=w_{2,1}-h\frac{\partial loss}{\partial w_{2,1}}=0.4364 w2,1:=w2,1hw2,1loss=0.4364 w 2 , 2 : = w 2 , 2 − h ∂ l o s s ∂ w 2 , 2 = 0.54044 w_{2,2}:=w_{2,2}-h\frac{\partial loss}{\partial w_{2,2}}=0.54044 w2,2:=w2,2hw2,2loss=0.54044 b 2 : = b 2 − h ∂ l o s s ∂ b 2 = 0.24044 b_2:=b_2-h\frac{\partial loss}{\partial b_2}=0.24044 b2:=b2hb2loss=0.24044  这样,我们就进行了一次完整的反向传播迭代!对于训练集中所有样本的一次前向传播和一次反向传播构成一个训练周期(epoch)。下一步是使用更新后的权重和偏置进行另一次前向传播,使用新权重和偏置的前向传播结果为1.618296,这更接近目标值 2,因此网络已经“学习”到了更好的权重和偏置,这就是机器学习的实际应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2230842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第524题通过删除字母匹配到字典里最长单词

题目&#xff1a; 题解&#xff1a; class Solution {public String findLongestWord(String s, List<String> dictionary) {int m s.length();int[][] f new int[m 1][26];Arrays.fill(f[m], m);for (int i m - 1; i > 0; --i) {for (int j 0; j < 26; j) {…

PHP合成图片,生成海报图,poster-editor使用说明

之前写过一篇使用Grafika插件生成海报图的文章&#xff0c;但是当我再次使用时&#xff0c;却发生了错误&#xff0c;回看Grafika文档&#xff0c;发现很久没更新了&#xff0c;不兼容新版的GD&#xff0c;所以改用了intervention/image插件来生成海报图。 但是后来需要对海报…

机器人领域中的scaling law:通过复现斯坦福机器人UMI——探讨数据规模化定律(含UMI的复现关键)

前言 在24年10.26/10.27两天&#xff0c;我司七月在线举办的七月大模型机器人线下营时&#xff0c;我们带着大家一步步复现UMI「关于什么是UMI&#xff0c;详见此文&#xff1a;UMI——斯坦福刷盘机器人&#xff1a;从手持夹持器到动作预测Diffusion Policy(含代码解读)」&…

丝杆支撑座的更换与细节注意事项

丝杆支撑座是支撑连接丝杆和电机的轴承支撑座&#xff0c;分固定侧和支撑侧&#xff0c;它们都有用预压调整的JIS5级的交界处球轴承。在自动化设备中是常用的传动装置&#xff0c;作为核心部件&#xff0c;对设备精度、稳定性和生产效率产生直接影响。在长时间运行中&#xff0…

行业深耕+全球拓展双轮驱动,用友U9 cloud加速中国制造全球布局

竞争加剧、供应链动荡、出海挑战……在日益激烈的市场竞争和新的全球化格局中&#xff0c;中国制造业的数智化转型已经步入深水区。 作为面向中型和中大型制造业的云ERP&#xff0c;用友U9 cloud一直是中国制造业转型升级的参与者和见证者。自2021年发布以来&#xff0c;用友U…

C#实现word和pdf格式互转

1、word转pdf 使用nuget&#xff1a; Microsoft.Office.Interop.Word winform页面&#xff1a; 后端代码&#xff1a; //using Spire.Doc; //using Spire.Pdf; using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using Sy…

Spring MVC 完整生命周期和异常处理流程图

先要明白 // 1. 用户发来请求: localhost:8080/user/1// 2. 处理器映射器(HandlerMapping)的工作 // 它会找到对应的Controller和方法 GetMapping("/user/{id}") public User getUser(PathVariable Long id) {return userService.getById(id); }// 3. 处理器适配…

wps宏代码学习

推荐学习视频&#xff1a;https://space.bilibili.com/363834767/channel/collectiondetail?sid1139008&spm_id_from333.788.0.0 打开宏编辑器和JS代码调试 工具-》开发工具-》WPS宏编辑器 左边是工程区&#xff0c;当打开多个excel时会有多个&#xff0c;要注意不要把…

vscode | 开发神器vscode快捷键删除和恢复

目录 快捷键不好使了删除快捷键恢复删除的快捷键 在vscode使用的过程中&#xff0c;随着我们自身需求的不断变化&#xff0c;安装的插件将会持续增长&#xff0c;那么随之而来的就会带来一个问题&#xff1a;插件的快捷键重复。快捷键重复导致的问题就是快捷键不好使了&#xf…

Java-02

笔试算法&#xff1a; 41. 回文串 我们称一个字符串为回文串&#xff0c;当且仅当这个串从左往右和从右往左读是一样的。例如&#xff0c;aabbaa、a、abcba 是回文串&#xff0c;而 ab、ba、abc 不是回文串。注意单个字符也算是回文串。 现在&#xff0c;给你一个长度为n的…

《数字图像处理基础》学习05-数字图像的灰度直方图

目录 一&#xff0c;数字图像的数值描述 &#xff11;&#xff0c;二值图像 &#xff12;&#xff0c;灰度图像 3&#xff0c;彩色图像 二&#xff0c;数字图像的灰度直方图 一&#xff0c;数字图像的数值描述 在之前的学习中&#xff0c;我知道了图像都是二维信息&…

6.1、实验一:静态路由

源文件获取&#xff1a;6.1_实验一&#xff1a;静态路由.pkt: https://url02.ctfile.com/f/61945102-1420248902-c5a99e?p2707 (访问密码: 2707) 一、目的 理解路由表的概念 会使用基础命令 根据需求正确配置静态路由 二、准备实验 1.实验要求 让PC0、PC1、PC2三台电脑…

集成ruoyi-it管理系统,遇到代码Bug

前言&#xff1a;这次ruoyi框架开发it管理系统&#xff0c;出现很多问题&#xff0c;也有学到很多东西&#xff0c;出现几个问题&#xff0c;希望下次项目不会出现或者少出现问题&#xff1b;其中还是有很多基础知识有些忘记&#xff0c;得多多复习 1&#xff1a;当写的代码没…

解决Redis缓存穿透(缓存空对象、布隆过滤器)

文章目录 背景代码实现前置实体类常量类工具类结果返回类控制层 缓存空对象布隆过滤器结合两种方法 背景 缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在&#xff0c;这样缓存永远不会生效&#xff0c;这些请求都会打到数据库 常见的解决方案有两种&#xff0c;分别…

基于微信小程序的校园失物招领系统的研究与实现(V4.0)

博主介绍&#xff1a;✌stormjun、8年大厂程序员经历。全网粉丝15w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…

vscode 创建 vue 项目时,配置文件为什么收缩到一起展示了?

一、前言 今天用 vue 官方脚手架创建工程&#xff0c;然后通过 vscode 打开项目发现&#xff0c;配置文件都被收缩在一起了。就像下面这样 这有点反直觉&#xff0c;他们应该是在同一层级下的&#xff0c;怎么会这样&#xff0c;有点好奇&#xff0c;但是打开资源管理查看&…

LInux系统编程(二)操作系统和进程

目录 一、前言&#xff1a;冯诺依曼体系结构 1、图中各个单元的介绍 2、值得注意的几点 二、操作系统 1、操作系统分层图 2、小总结 三、 进程&#xff08;重点&#xff09; 1、进程的基本概念 2、存放进程信息的数据结构——PCB&#xff08;Linux 下称作 task_struct…

HNU-小学期-专业综合设计

写在前面 选题&#xff1a;大数据技术-智慧交通预测系统 项目github地址&#xff08;如果有用麻烦点个star与follow&#xff09;&#xff1a;https://github.com/wolfvoid/HNU-ITPS &#xff08;全部代码以及如何部署参见README&#xff09; 项目报告&#xff1a;如下&…

Linux特种文件系统--tmpfs文件系统

tmpfs类似于RamDisk&#xff08;只能使用物理内存&#xff09;&#xff0c;使用虚拟内存&#xff08;简称VM&#xff09;子系统的页面存储文件。tmpfs完全依赖VM&#xff0c;遵循子系统的整体调度策略。说白了tmpfs跟普通进程差不多&#xff0c;使用的都是某种形式的虚拟内存&a…

PLC会被卡脖子吗?冗余技术才是中型和大型PLC的门槛

美方称北京天圣华参与高超音速武器的研发和空对空导弹的生产&#xff0c;因此把北京天圣华列入实体制裁清单。据说因为天圣华向和中国军方相关研究机构出售了西门子的建模软件&#xff0c;并为军工项目的也就做出了积极贡献&#xff0c;因此美方对西门子施压。 西门子是全球最大…