详解反向传播(BP)算法

news2024/10/7 0:54:25

文章目录

      • what(是什么)
      • where(用在哪)
      • How(原理&&怎么用)
          • 原理以及推导过程
          • pytorch中的反向传播

what(是什么)

反向传播算法(Backpropagation)是一种用于训练人工神经网络的常见方法。它通过计算网络预测与实际结果之间的误差,然后反向传播这个误差来调整网络中每个权重的值,从而逐步优化网络的学习过程

在这里插入图片描述

where(用在哪)

绝大多数的神经网络都会使用反向传播算法进行网络权重以及阈值的更新,简单列举部分典型的使用场景如下

反向传播算法
前馈神经网络
多层感知机
卷积神经网络
循环神经网络
深度神经网络

How(原理&&怎么用)

原理以及推导过程

下面重点介绍反向传播算法的推导流程

在这里插入图片描述

假设有以上简单的神经网路模型,分为输入层、隐藏层、输出层。其中隐藏层包括4个神经元、输出层包括2个神经元。
假设输出层的两个神经元为 y 1 y_1 y1 y 2 y_2 y2,其激活阈值分别为 β \beta β γ \gamma γ,两个神经元的输入分别为 y 1 i n y_{1in} y1in y 2 i n y_{2in} y2in,输出分别为 y 1 ^ \hat{y_1} y1^ y 2 ^ \hat{y_2} y2^
假设隐藏层四个神经元为 h 1 h_1 h1 h 2 h_2 h2 h 3 h_3 h3 h 4 h_4 h4,其中 h 1 h_1 h1的激活阈值为 δ \delta δ,神经元 h 1 h_1 h1的输入值为 h i n h_{in} hin,输出值为 h o u t h_{out} hout
假设输入层两个神经元为 x 1 x_1 x1 x 2 x_2 x2,其中神经元 x 1 x_1 x1的输出为 x o u t x_{out} xout
假设神经元 x 1 x_1 x1到神经元 h 1 h_1 h1的连接权重为 W 11 W_{11} W11,神经元 h 1 h_1 h1到神经元 y 1 y_1 y1 y 2 y_2 y2的连接权重分别为 W 21 W_{21} W21 W 22 W_{22} W22
假设神经元的激活函数为sigmoid函数,sigmoid激活函数的表达式:
f ( x ) = 1 1 − e − x f(x)=\frac{1}{1-e^{-x}} f(x)=1ex1
该激活函数有一个非常好的性质:
f ′ ( x ) = f ( x ) ( 1 − f ( x ) ) f'(x)=f(x)(1-f(x)) f(x)=f(x)(1f(x))
下面,详细介绍连接权重 W W W以及激活阈值的更新过程。
首先,给出 W 21 W_{21} W21以及 β \beta β的更新公式,其中, W 21 W_{21} W21更新公式为:
W 21 = W 21 + η ∗ Δ W 21 W_{21}=W_{21}+\eta*\Delta W_{21} W21=W21+ηΔW21
同理, β \beta β更新公式为:
β = β + η ∗ Δ β \beta=\beta+\eta*\Delta \beta β=β+ηΔβ

在以上公式中,只有 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ未知,需要计算。而已知的是样本,也就是 ( x , y ) (x,y) (x,y),那么我们将通过样本数据来表达出上述 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ
根据反向传播算法, Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ分别为最终的误差对 W 21 W_{21} W21以及 β \beta β的偏导数。假设采用的损失函数为:
L o s s = 1 2 ( y 1 − y 1 ^ ) 2 + 1 2 ( y 2 − y 2 ^ ) 2 Loss=\frac{1}{2}(y_1-\hat{y_1})^2+\frac{1}{2}(y_2-\hat{y_2})^2 Loss=21(y1y1^)2+21(y2y2^)2
扩展到输出层有k个神经元的情况:
L o s s = 1 2 Σ 1 k ( y i − y i ^ ) 2 Loss=\frac{1}{2}\Sigma_1^k(y_i-\hat{y_i})^2 Loss=21Σ1k(yiyi^)2
而从输出端看,能得到以下表达式:
y 1 ^ = f ( y 1 i n − β ) = f ( W 21 h o u t − β ) \hat{y_1}=f(y_{1in}-\beta)=f(W_{21}h_{out}-\beta) y1^=f(y1inβ)=f(W21houtβ)
y 1 ^ \hat{y_1} y1^带入到损失函数中,也就是:
L o s s = 1 2 ( y 1 − f ( W 21 h o u t − β ) ) 2 + 1 2 ( y 2 − f ( W 22 h o u t − γ ) ) 2 Loss = \frac{1}{2}(y_1-f(W_{21}h_{out}-\beta))^2+\frac{1}{2}(y_2-f(W_{22}h_{out}-\gamma))^2 Loss=21(y1f(W21houtβ))2+21(y2f(W22houtγ))2
如此,便得出损失和 W 21 W_{21} W21之间的代数关系式,接下来只需要对该表达式求导即可得到 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ

首先, ∂ L o s s ∂ W 21 \frac{\partial Loss}{\partial W_{21}} W21Loss的计算公式为:
∂ L o s s ∂ W 21 = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ h o u t = − [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] ∗ h o u t = − ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} \frac{\partial Loss}{\partial W_{21}} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*h_{out} \\ & =- [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))]*h_{out} \\ & = -(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21Loss=[y1f(W21houtβ)][f(W21houtβ)]hout=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]hout=(y1y1^)y1^(1y1^)hout
同样地, ∂ L o s s ∂ β \frac{\partial Loss}{\partial \beta} βLoss的计算公式为:
∂ L o s s ∂ β = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ ( − 1 ) = [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] = ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \frac{\partial Loss}{\partial \beta} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*(-1) \\ & = [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))] \\ & = (y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} βLoss=[y1f(W21houtβ)][f(W21houtβ)](1)=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]=(y1y1^)y1^(1y1^)
由于梯度下降法,需要沿着负梯度方向,所以, Δ W 21 = − ∂ L o s s ∂ W 21 \Delta W_{21}=-\frac{\partial Loss}{\partial W_{21}} ΔW21=W21Loss Δ β = − ∂ L o s s ∂ β \Delta \beta=-\frac{\partial Loss}{\partial \beta} Δβ=βLoss,从而得出 W 21 , β W_{21},\beta W21,β的更新公式为:
W 21 = W 21 + η ∗ Δ W 21 = W 21 − η ∗ ∂ L o s s ∂ W 21 = W 21 + η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} W_{21} &= W_{21} + \eta*\Delta W_{21} \\ & = W_{21}-\eta * \frac{\partial Loss}{\partial W_{21}} \\ & =W_{21}+\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21=W21+ηΔW21=W21ηW21Loss=W21+η(y1y1^)y1^(1y1^)hout

β = β + η ∗ Δ β = β − η ∗ ∂ L o s s ∂ β = β − η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \beta & = \beta+\eta*\Delta \beta \\ & = \beta-\eta* \frac{\partial Loss}{\partial \beta} \\ & = \beta-\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} β=β+ηΔβ=βηβLoss=βη(y1y1^)y1^(1y1^)

使用同样的方式,可以对 W 11 , δ W_{11},\delta W11,δ的梯度公式进行计算和更新。

pytorch中的反向传播

下面举例说明在pytorch中,如何使用反向传播算法来更新权重以及阈值。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义一个复杂的神经网络
class ComplexNet(nn.Module):
    def __init__(self):
        super(ComplexNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)  # 输入大小为10,输出大小为50
        self.fc2 = nn.Linear(50, 20)  # 输入大小为50,输出大小为20
        self.fc3 = nn.Linear(20, 1)   # 输入大小为20,输出大小为1

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 使用ReLU作为激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建网络实例
model = ComplexNet()

# 定义损失函数
criterion = nn.MSELoss()

# 随机生成一些输入和目标输出数据
input_data = torch.randn((32, 10))  # 32个样本,每个样本特征数为10
target_output = torch.randn((32, 1))  # 对应的32个目标输出

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
model.train()	# 设置模型为训练模式
epochs = 1000
for epoch in range(epochs):
    # 梯度清零
    optimizer.zero_grad()

    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_output)

    # 反向传播
    loss.backward()

    # 更新模型参数
    optimizer.step()

    # 每隔一段时间输出一下损失值
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 打印模型结构
print(model)

pythrch中,输入在流经每一个神经元时,会构建一个动态计算图(与tensorflow不同,tensorflow为静态计算图),记录了每个神经元的输入输出信息。在反向传播时, loss.backward()会根据已知的样本数据以及神经元的输入输出信息,计算连接权重以及阈值的梯度,然后optimizer.step()来实现对权重和阈值的更新。需要注意的是,在每一个mini-batch开始前,需要使用optimizer.zero_grad()对梯度置零。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1886895.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么是视频传输用YUV格式,而放弃RGB格式?

😎 作者介绍:我是程序员行者孙,一个热爱分享技术的制能工人。计算机本硕,人工制能研究生。公众号:AI Sun,视频号:AI-行者Sun 🎈 本文专栏:本文收录于《音视频》系列专栏&…

如何寻找一个领域的顶级会议,并且判断这个会议的影响力?

如何寻找一个领域的顶级会议,并且判断这个会议的影响力? 会议之眼 快讯 很多同学都在问:学术会议不是期刊,即使被SCI检索,也无法查询影响因子。那么如何知道各个领域的顶级会议,并对各个会议有初步了解呢…

Qt加载SVG矢量图片,放大缩小图片质量不发生变化。

前言: 首先简单描述下SVG: SVG 意为可缩放矢量图形(Scalable Vector Graphics)。 SVG 使用 XML 格式定义图像。 给界面或者按钮上显示一个图标或背景图片,日常使用.png格式的文件完全够用,但是有些使用场景需要把图…

代码随想录第41天|动态规划

322. 零钱兑换 dp[j] : 最小硬币数量, j 为金额(相当于背包空间)递推公式 : dp[j] min(dp[j - coins[i]] 1, dp[j])初始化: 需要一个最大值, 避免覆盖, dp[0] 0遍历顺序: 钱币有序无序不影响, 因为求解最小个数, 结果相同(先遍历物品后背包, 先背包后物品都可) class Solut…

NSSCTF-Web题目21(文件上传-phar协议、RCE-空格绕过)

目录 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 3、思路 [FSCTF 2023]细狗2.0 4、题目 5、知识点 6、思路 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 文件上传,phar伪协议 3、思路 点击upload,看看 这里提示我们可以上传图片或压缩包&…

Nginx主配置文件---Nginx.conf

nginx主配置文件的模块介绍 全局块: 全局块是配置文件从开始到 events 块之间的部分,其中指令的作用域是 Nginx 服务器全局。主要指令包括: user:指定可以运行 Nginx 服务的用户和用户组,只能在全局块配置。例如&…

Linux多线程【线程互斥】

文章目录 Linux线程互斥进程线程间的互斥相关背景概念互斥量mutex模拟抢票代码 互斥量的接口初始化互斥量销毁互斥量互斥量加锁和解锁改进模拟抢票代码(加锁)小结对锁封装 lockGuard.hpp 互斥量实现原理探究可重入VS线程安全概念常见的线程不安全的情况常…

【开发环境】MacBook M系列芯片环境下搭建完整Python开发环境

文章目录 Anaconda和Python的关系?1. Python2. Anaconda 安装AnacondaPycharm整合Anaconda运行你的Python代码 Anaconda和Python的关系? 如果有简单了解过Python语言的,那么你很容易就会听到有人会叫你安装Anaconda。 那么Anaconda是什么&am…

编译原理2

推导和短语 推导 推导过程中,每一步推导都是对句型的 最右非终结符 进行替换,最右推导(规范推导); 短语 用 β 替换 A,则 β 就是 关于A 的一个短语; 直接短语是短语范围内的一步推导; 直接短语可能不…

Rust学习笔记007:Trait --- Rust的“接口”

Trait 在Rust中,Trait(特质)是一种定义方法集合的机制,类似于其他编程语言中的接口(java)或抽象类(c的虚函数)。 。Trait 告诉 Rust 编译器: 某种类型具有哪些并且可以与其它类型共享的功能Trait:抽象的…

[ROS 系列学习教程] 建模与仿真 - 使用 ros_control 控制差速轮式机器人

ROS 系列学习教程(总目录) 本文目录 一、差速轮式机器人二、差速驱动机器人运动学模型三、对外接口3.1 输入接口3.2 输出接口 四、控制器参数五、配置控制器参数六、编写硬件抽象接口七、控制机器人移动八、源码 ros_control 提供了多种控制器,其中 diff_drive_cont…

Datawhale - 角色要素提取竞赛

文章目录 赛题要求一、赛事背景二、赛事任务三、评审规则1.平台说明2.数据说明3.评估指标4.评测及排行 四、作品提交要求五、 运行BaselineStep1:下载相关库Step2:配置导入Step3:模型测试Step4:数据读取Step5:Prompt设…

工业 web4.0UI 风格品质卓越

工业 web4.0UI 风格品质卓越

【力扣 - 每日一题】3115. 质数的最大距离(一次遍历、头尾遍历、空间换时间、埃式筛、欧拉筛、打表)Golang实现

原题链接 题目描述 给你一个整数数组 nums。 返回两个(不一定不同的)质数在 nums 中 下标 的 最大距离。 示例 1: 输入: nums [4,2,9,5,3] 输出: 3 解释: nums[1]、nums[3] 和 nums[4] 是质数。因此答…

WPF自定义模板--Button

属性&#xff1a; TemplateBinding&#xff1a;用于在ControlTemplate中绑定到控件的属性&#xff0c;例如Background、BorderBrush等。TargetType&#xff1a;指定该模板应用于哪种控件类型。在这个例子中&#xff0c;是Button。 标准的控件模板代码&#xff1a; <Style…

线性代数大题细节。

4.4 方程组解的结构&#xff08;二&#xff09;_哔哩哔哩_bilibili

eNSP中WLAN的配置和使用

一、基础配置 1.拓扑图 2.VLAN和IP配置 a.R1 <Huawei>system-view [Huawei]sysname R1 GigabitEthernet 0/0/0 [R1-GigabitEthernet0/0/0]ip address 200.200.200.200 24 b.S1 <Huawei>system-view [Huawei]sysname S1 [S1]vlan 100 [S1-vlan100]vlan 1…

vue3 window.location 获取正在访问的地址,也可以通过useRoute来获取相关信息。

1、一般我们在开发的vue3项目的时候&#xff0c;地址是这样&#xff1a;http://192.168.1.101:3100/#/login 然后我们在布署完成以后一般是这样https://xxx.yyyyy.com/uusys/#/login 其实xxx可以是www&#xff0c;也可以是一个二级域名 yyyyy.com是域名&#xff0c;uusys一般…

家政小程序的开发:打造现代式便捷家庭服务

随着现代生活节奏的加快&#xff0c;人们越来越注重生活品质与便利性。在这样的背景下&#xff0c;家政服务市场迅速崛起&#xff0c;成为许多家庭日常生活中不可或缺的一部分。然而&#xff0c;传统的家政服务往往存在信息不对称、服务效率低下等问题。为了解决这些问题&#…

Windows编程上

Windows编程[上] 一、Windows API1.控制台大小设置1.1 GetStdHandle1.2 SetConsoleWindowInfo1.3 SetConsoleScreenBufferSize1.4 SetConsoleTitle1.5 封装为Innks 2.控制台字体设置以及光标调整2.1 GetConsoleCursorInfo2.2 SetConsoleCursorPosition2.3 GetCurrentConsoleFon…