理解前向传播、反向传播和计算图

news2024/12/28 3:46:42

1. 什么是前向传播?

前向传播(Forward Propagation)是神经网络的推理过程。它将输入数据逐层传递,通过每一层的神经元计算,最终生成输出。

前向传播的公式

假设我们有一个简单的三层神经网络(输入层、一个隐藏层和输出层),网络的每一层计算如下:

z ( 1 ) = W ( 1 ) ⋅ X + b ( 1 ) z^{(1)} = W^{(1)} \cdot X + b^{(1)} z(1)=W(1)X+b(1)
a ( 1 ) = σ ( z ( 1 ) ) a^{(1)} = \sigma(z^{(1)}) a(1)=σ(z(1))
z ( 2 ) = W ( 2 ) ⋅ a ( 1 ) + b ( 2 ) z^{(2)} = W^{(2)} \cdot a^{(1)} + b^{(2)} z(2)=W(2)a(1)+b(2)
y ^ = a ( 2 ) = softmax ( z ( 2 ) ) \hat{y} = a^{(2)} = \text{softmax}(z^{(2)}) y^=a(2)=softmax(z(2))

其中,(W) 和 (b) 分别是权重矩阵和偏置,(\sigma) 是激活函数,(\hat{y}) 是网络的输出。

代码示例

我们用 Python 和 NumPy 来实现前向传播的过程:

import numpy as np

# 输入数据和网络参数
X = np.array([[0.5, 1.2]])
W1 = np.array([[0.4, 0.3], [0.2, 0.7]])
b1 = np.array([[0.1, 0.2]])
W2 = np.array([[0.6], [0.8]])
b2 = np.array([[0.3]])

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 前向传播
z1 = np.dot(X, W1) + b1
a1 = sigmoid(z1)
z2 = np.dot(a1, W2) + b2
output = sigmoid(z2)

print("网络输出:", output)

可视化前向传播

在这里,我们可以使用计算图来表示前向传播的过程。计算图中的每个节点代表一个操作或变量,箭头表示数据的流动。

在这里插入图片描述

2. 什么是反向传播?

反向传播(Backpropagation)是神经网络的训练过程,它通过计算损失函数的梯度来更新权重,从而最小化损失。

反向传播的原理

反向传播通过链式法则计算梯度。假设我们的损失函数是均方误差(MSE):

L = 1 2 ( y ^ − y ) 2 L = \frac{1}{2} (\hat{y} - y)^2 L=21(y^y)2

对于每个权重 (W),梯度更新规则为:

∂ L ∂ W ( 2 ) = δ ( 2 ) ⋅ a ( 1 ) \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} \cdot a^{(1)} W(2)L=δ(2)a(1)

其中,(\delta^{(2)}) 是输出层的误差:

δ ( 2 ) = y ^ − y \delta^{(2)} = \hat{y} - y δ(2)=y^y

代码示例

我们可以用 Python 代码实现一个简单的反向传播:

# 假设真实标签
y = np.array([[1]])

# 计算损失
loss = 0.5 * (output - y) ** 2

# 反向传播
d_output = output - y
d_z2 = d_output * output * (1 - output)
d_W2 = np.dot(a1.T, d_z2)

d_a1 = np.dot(d_z2, W2.T)
d_z1 = d_a1 * a1 * (1 - a1)
d_W1 = np.dot(X.T, d_z1)

# 更新权重
learning_rate = 0.1
W2 -= learning_rate * d_W2
W1 -= learning_rate * d_W1

print("更新后的 W1:", W1)
print("更新后的 W2:", W2)

反向传播的计算图

我们同样可以用计算图来表示反向传播的过程,这有助于理解梯度如何在网络中流动。

Backpropagation

3. 理解计算图

计算图(Computational Graph)是表示神经网络中计算过程的图形结构。通过计算图,我们可以直观地理解前向传播和反向传播。

在计算图中,节点表示操作(如加法、乘法、激活函数)或变量(如输入、权重),边表示数据流动。通过前向传播,计算图中的数据从输入流向输出;通过反向传播,梯度从输出反向传播到输入。

计算图的例子

下图展示了一个简单的计算图,表示前向传播和反向传播的过程。

Computational Graph

4. 实战案例:手写数字识别

让我们用一个简单的手写数字识别的案例来巩固这些概念。我们将使用一个小型的神经网络,通过前向传播预测数字,通过反向传播调整权重。

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# 加载数据
digits = load_digits()
X = digits.data / 16.0
y = digits.target.reshape(-1, 1)

# 独热编码标签
encoder = OneHotEncoder(sparse=False)
y_onehot = encoder.fit_transform(y)

# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.3)

# 初始化网络参数
input_size = X_train.shape[1]
hidden_size = 64
output_size = y_train.shape[1]

W1 = np.random.randn(input_size, hidden_size)
b1 = np.zeros((1, hidden_size))
W2 = np.random.randn(hidden_size, output_size)
b2 = np.zeros((1, output_size))

# 训练过程
epochs = 1000
learning_rate = 0.01

for epoch in range(epochs):
    # 前向传播
    z1 = np.dot(X_train, W1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, W2) + b2
    output = sigmoid(z2)
    
    # 计算损失
    loss = np.mean(0.5 * (output - y_train) ** 2)
    
    # 反向传播
    d_output = output - y_train
    d_z2 = d_output * output * (1 - output)
    d_W2 = np.dot(a1.T, d_z2)
    
    d_a1 = np.dot(d_z2, W2.T)
    d_z1 = d_a1 * a1 * (1 - a1)
    d_W1 = np.dot(X_train.T, d_z1)
    
    # 更新权重
    W2 -= learning_rate * d_W2
    W1 -= learning_rate * d_W1
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

通过上述代码,我们可以训练一个简单的神经网络来识别手写数字。在每个训练周期,网络通过前向传播计算输出,通过反向传播调整权重。

5. 总结

在本文中,我们深入探讨了前向传播、反向传播和计算图的概念,并通过代码示例和图示帮助理解这些复杂的过程。希望这些内容能帮助你更好地理解神经网络的工作原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2135431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一种简单的过某宝验证码的方式(仅做学习使用)

开篇 今天介绍一种简单的过某宝验证码的方式,用的是自动化,这样对不会js逆向的小白非常友好,只需要用到selenium框架就能轻松过某宝验证码,即模拟人的操作对滑块进行滑动。 但是首先还是需要训练验证码和标题 训练前&#xff1a…

各个大厂软件测试面试题,面试经验分享

前言 一、华为测试岗电话面试 一面 1)自我介绍 2)项目流程 >讲下H模型 3)业务流程 >项目讲解、可从贷款流程讲起 4)做过自动化吗? 5)做过接口测试吗? 可从postman和jmeter做手工接口测…

数业智能心大陆探索生成式AIGC创新前沿

近日,数业智能心大陆参与了第九届“创客中国”生成式人工智能(AIGC)中小企业创新创业大赛。在这场汇聚了众多创新力量的研讨过程中,广东数业智能科技有限公司基于多智能体的心理健康技术探索与应用成果,从众多参赛者中…

KTM580030bit 绝对角度细分器支持最多 4096 对极与一键非线性自校准集成双 16bit 2M SAR ADC

KTM5800 是一款 30bit 绝对角度细分 4096 对极编码细分器,可以与磁电阻传感器( AM R/TMR )搭配,构成一个高速高精度的非接触磁性编码器模块。它具有以非常高的采样速率 读取传感器上的差分模拟正弦和余弦信号的能力&#xf…

vue3提交按钮限制重复点击

下载lodash npm install lodash 引入并使用 <template><div click"submit()">提交</div> </template><script setup>import { debounce } from lodash;const submit debounce(() > {//业务代码},2000,{leading: true,trailing:…

ETL数据集成丨建设BI的关键前提是ETL数据集成?

背景 很多企业都购买了商业智能&#xff08;BI&#xff09;来加速数字化转型&#xff0c;但是发现仅仅依赖BI效果往往不太好。虽然通过BI&#xff0c;企业能够快速分析和可视化数据&#xff0c;然而&#xff0c;BI并不是一个万能工具&#xff0c;它虽然能帮助企业解读数据&…

rancker 图形化界面

rancker 图形化界面 图形化界面进行k8s集群的管理 rancher自带监控————普罗米修斯 #在master和两个node上都操作 [rootmaster01 opt]# rz -E rz waiting to receive. [rootmaster01 opt]# docker load -i rancher.tar ​ #在master上操作 [rootmaster01 opt]# docker pul…

90v转5v500MA内置mos芯片方案

在设计一个90V转5V500mA的DC/DC转换器方案时&#xff0c;可以考虑使用AH7550这款150KHz固定频率PWM降压&#xff08;降压&#xff09;DC/DC转换器。AH7550能够以高效率、低纹波和出色的线路和负载调节驱动0.4A负载&#xff0c;且需要最少数量的外部组件&#xff0c;使用简单&am…

【物联网技术大作业】设计一个智能家居的应用场景

前言&#xff1a; 本人的物联网技术的期末大作业&#xff0c;希望对你有帮助。 目录 大作业设计题 &#xff08;1&#xff09;智能家居的概述。 &#xff08;2&#xff09;介绍智能家居应用。要求至少5个方面的应用&#xff0c;包括每个应用所采用的设备&#xff0c;性能&am…

CAPL_构建基于UDS的刷写学习—01 Hex文件的解析

前言&#xff1a; 打算写一个系列&#xff1a;CAPL_构建基于UDS的刷写学习&#xff0c;大致写一下写作的思路 1&#xff1a;本文是第1篇首先讲解基础。首先搞清楚&#xff0c;各种不同文件&#xff08;常见的S19,hex,bin,以及汽车行业主机厂自己的各种文件CBF(奇瑞特有),VBF&…

SpringCloud Alibaba之Nacos服务注册和配置中心

&#xff08;学习笔记&#xff09;nacos-server版本&#xff1a;2.2.3 总体介绍&#xff1a; 1、Nacos介绍 官网&#xff1a;Nacos官网| Nacos 配置中心 | Nacos 下载| Nacos 官方社区 | Nacos 官网 Nacos /nɑ:kəʊs/ 是 Dynamic Naming and Configuration Service的首字…

8路模拟量采集模块,4~20mA 0~10V电流电压高速采集——DAM-3054P

阿尔泰科技 DAM-3054P为8路差分模拟量采集模块&#xff0c;高速采集&#xff0c;每通道采集速率为500sps&#xff0c;16位AD&#xff0c;支持RS485通讯接口&#xff0c;带有标准ModbusRTU协议。配备良好的人机交互界面&#xff0c;使用方便&#xff0c;性能稳定。 指标参数&…

基于图像的端到端方案实现小车在模拟城市场景中的自主导航

基于图像的端到端方案实现小车在模拟城市场景中的自主导航 FSD&#xff08;Full Self-Driving&#xff09;是特斯拉公司推出的一种自动驾驶技术&#xff0c;旨在实现完全自主的驾驶体验。FSD系统依靠大量的数据和高级的机器学习算法&#xff0c;结合车载传感器&#xff08;如摄…

docker--刚开始学不知道如何操作拉取,或拉取失败(cmd)

报 unauthorized: incorrect username or password.&#xff08;未授权&#xff09; 进行授权 在docker desktop注册账号登录好docker desktop 在cmd中进行docker登录&#xff0c;输入账号密码&#xff0c;提示Login Succeeded&#xff0c;即登录成功 docker login -u xxx(x…

yjs04——matplotlib的使用(多个坐标图)

1.多个坐标图与一个图的折线对比 1.引入包&#xff1b;字体&#xff08;同&#xff09; import matplotlib.pyplot as plt import random plt.rcParams[font.family] [SimHei] plt.rcParams[axes.unicode_minus] False 2.创建幕布 2.1建立图层幕布 一个图&#xff1a;plt.fig…

Artec Leo协助定制维修管道,让石油和天然气炼油厂不停产

以下文章来源于Artec3D埃太科三维 &#xff0c;作者小埃 挑战 在高温、狭窄的炼油厂中&#xff0c;准确测量结构复杂的受损管道和设备&#xff0c;以便设计、制造、安装定制维修解决方案&#xff0c;从而尽快完成修复。 解决方案 Artec Leo, Artec Studio, Geomagic Design X…

关于Vue2里 v-for和v-if一起用的时候会出现的问题

关于Vue2里 v-for和v-if一起用的时候会出现的问题 &#x1f389;&#x1f389;&#x1f389;欢迎来到我的博客,我是一名自学了2年半前端的大一学生,熟悉的技术是JavaScript与Vue.目前正在往全栈方向前进, 如果我的博客给您带来了帮助欢迎您关注我,我将会持续不断的更新文章!!!&…

roboguide将tp程序转化为LS文本格式的方法

不同的软件版本可能操作不同&#xff0c;但是仍然可以参考文章中的办法。 我使用的版本如图所示&#xff1a; 1.首先&#xff0c;打开任意一个工程&#xff0c;如果没有&#xff0c;可以打开自带的示例。 如图&#xff0c;我打开了自带的示例&#xff0c;在帮助文档中可以找到…

ubuntu中QT+opencv在QLable上显示摄像头

ubuntu中QTopencv在QLable上显示摄像头 饭前的一篇文章吧&#xff0c;写完吃饭走 图像在机器视觉中的重要性是不可忽视的。机器视觉是指计算机利用图像处理技术进行图像识别、分析和理解的科学与技术领域。图像是机器视觉的输入数据&#xff0c;通过分析和处理图像&#xff0…

【鸿蒙】HarmonyOS NEXT星河入门到实战7-ArkTS语法进阶

目录 1、Class类 1.1 Class类 实例属性 1.2 Class类 构造函数 1.3 Class类 定义方法 1.4 静态属性 和 静态方法 1.5 继承 extends 和 super 关键字 1.6 instanceof 检测是否实例 1.7.修饰符(readonly、private、protected 、public) 1.7.1 readonly 1.7.2 Private …