bp(back propagation)

news2025/1/18 6:52:59

文章目录

  • 定义
  • 过程
    • 前向传播计算过程
    • 计算损失函数(采用均方误差MSE)
    • 反向传播误差(链式法则)
    • 计算梯度
    • 更新参数
  • 简单实例

定义

反向传播全名是反向传播误差算法(Backpropagation),是一种监督学习方法,用于调整神经网络中权重参数,以最小化模型的预测误差。反向传播通过计算损失函数相对于网络参数的梯度,然后使用梯度下降等优化算法来更新参数,从而提高模型的性能

核心思想是从网络的输出层向输入层传播误差信号,然后根据这些误差信号来更新网络中的权重和偏差,以使模型的输出更接近实际目标

过程

前向传播:将输入数据输入到网络中,通过神经网络的前向传播计算出每个神经元的输出结果,直到输出层输出最终的结果。
计算损失函数:将神经网络输出的结果与真实标签进行比较,计算出损失函数的值。
反向传播误差:从输出层开始,计算每个神经元的误差,然后向前计算每个神经元的误差,直到计算出输入层每个神经元的误差。具体来说,我们首先计算输出层的误差,然后反向传播到前一层隐藏层,计算隐藏层的误差,并将误差反向传播到更早的层,直到计算出输入层的误差。这个过程可以使用链式法则来计算。
计算梯度:根据误差计算每个神经元的梯度,即损失函数对每个神经元权重和偏差的偏导数。
更新参数:使用梯度下降法或其他优化方法来更新每个神经元的权重和偏差,使得损失函数的值最小化。

设置初始值如图:
 目标:给出输入数据i1,i2(0.05和0.10),使输出尽可能与原始输出o1,o2(0.01和0.99)接近

神经网络中信息传递的基本原理:
在神经网络中,前一层神经元的输出通常作为下一层神经元的输入

前向传播计算过程

y=wx+b
例如上图中h1的输入input1 = w1i1+w2i2+b1
假设采用的sigmoid激活函数,则输出output1 = 1/(1+e^-input1)
而这个输出output1又会作为下一层神经元o1、o2的输入
由此我们可以轻松计算出最后的输出

计算损失函数(采用均方误差MSE)

将前向传播计算得到的输出与原始输出比较
MSE = Σ(预测值 - 实际观测值)^2 / 观测值的数量

反向传播误差(链式法则)

从输出层开始,计算每个神经元的误差,然后向前计算每个神经元的误差,直到计算出输入层每个神经元的误差

计算梯度

当前层的梯度 = 当前层输出的梯度 × 当前层输入的梯度
以b站某up主讲解中的一个片段为例:
其中绿色字体代表的是输入值,红色字体为反向传播的梯度

我们以最左上角w0,x0那为例,假设得到的是y0,(再次提醒,上面绿色字体为前向传播计算的值)
则x0的梯度为:y0对x0求导再乘以0.2(当前层输出的梯度),最后得到0.40(至于0.39可能是笔误)
同理w0的梯度:y0对w0求导再乘以0.2,即-1.0*0.2 = 0.20

更新参数

设置一个学习率,根据之前计算的梯度就可以进行参数权重更新了

简单实例

import numpy as np

# 定义 Sigmoid 激活函数
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# 定义 Sigmoid 激活函数的导数
def sigmoid_derivative(x):
    return x * (1 - x)

# 创建神经网络类
class NeuralNetwork:
    def __init__(self, input_size, hidden_size, output_size):
        # 初始化权重
        self.weights_input_hidden = np.random.rand(input_size, hidden_size)
        self.weights_hidden_output = np.random.rand(hidden_size, output_size)

    def forward(self, X):
        # 前向传播
        self.hidden_input = np.dot(X, self.weights_input_hidden)
        self.hidden_output = sigmoid(self.hidden_input)
        self.output = np.dot(self.hidden_output, self.weights_hidden_output)
        return self.output

    def backward(self, X, y, output, learning_rate):
        # 反向传播
        error = y - output
        delta_output = error
        error_hidden = delta_output.dot(self.weights_hidden_output.T)
        delta_hidden = error_hidden * sigmoid_derivative(self.hidden_output)

        # 更新权重
        self.weights_hidden_output += self.hidden_output.T.dot(delta_output) * learning_rate
        self.weights_input_hidden += X.T.dot(delta_hidden) * learning_rate

    def train(self, X, y, learning_rate, epochs):
        for i in range(epochs):
            output = self.forward(X)
            self.backward(X, y, output, learning_rate)
            if i % 1000 == 0:
                loss = np.mean(np.square(y - output))
                print(f"Epoch {i}, Loss: {loss}")

# 示例数据
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])

# 创建神经网络
input_size = 2
hidden_size = 4
output_size = 1
learning_rate = 0.1
epochs = 10000

nn = NeuralNetwork(input_size, hidden_size, output_size)

# 训练神经网络
nn.train(X, y, learning_rate, epochs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1114684.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HZOJ-271: 滑动窗口

题目描述 ​ 给出一个长度为 N� 的数组,一个长为 K� 的滑动窗口从最左移动到最右,每次窗口移动,如下图: 找出窗口在各个位置时的极大值和极小值。 输入 ​ 第一行两个数 N,K�,�。 …

win yolov5.7 tensorRT推理

安装TensorRT 下载tensorrt8.xx版本,适用于Windows的cuda11.x的版本 官方下载地址https://developer.nvidia.com/nvidia-tensorrt-8x-download 把tensorRT里面的bin、include、lib添加到本机CUDA中,CUDA需要加入环境变量中 配置虚拟环境 torch的版本…

Confluence最新版本(8.6)安装

软件获取 Confluence 历史版本下载地址:Confluence Server 下载存档 | Atlassian Atlassian-agent.jar https://github.com/haxqer/confluence/releases/download/v1.3.3/atlassian-agent.jar MySQL 驱动包 MySQL :: Download MySQL Connector/J (Archived Ve…

举个栗子!Alteryx 技巧(6):从 API 中提取数据

你听说过从 API 中提取数据吗?API 是指应用编程接口,是计算机之间或计算机程序之间的连接,它是一种软件接口,让不同的软件进行信息共享。对于很多数据分析师来说,他们常常需要从 API 中提取数据,那么如何快…

ASEMI肖特基二极管MBR10200CT在电子电路中起什么作用

编辑-Z 肖特基二极管MBR10200CT是一种特殊类型的二极管,可以在电子电路中起到多种作用。 首先,MBR10200CT具有非常低的正向电压丢失和快速的开关速度。这使得它非常适合用作整流器。在直流电源中,MBR10200CT可以将交流电转换为直流电&#x…

漏洞复现--蓝凌EIS智慧协同平台任意文件上传

免责声明: 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

【LCR 159. 库存管理 III】

目录 一、题目描述二、算法原理三、代码实现 一、题目描述 二、算法原理 三、代码实现 class Solution { public:int getrandom(int left,int right,vector<int>& stock){return stock[rand()%(right-left1)left];}void qsort(int l,int r,vector<int>& s…

【uni-app+vue3】分享功能,App端调用手机的系统分享,可分享到微信、QQ、朋友圈等,已实现

前言&#xff1a; uniapp分享功能&#xff0c;之前考虑的思路一直都是如何使用uni.share()实现 需要下载插件&#xff0c;需要申请微信appid&#xff0c;插件已下载&#xff0c;主要就是没有appid&#xff0c;所以只好换方向 实现效果&#xff1a; 在需要分享的页面&#xff0c…

嵌入式学习笔记(59)内存管理之栈

1.6.1.什么是栈&#xff08;Stack&#xff09; 栈是一种数据结构&#xff0c;C语言中使用栈来存放局部变量。 1.6.2.栈管理内存的特点&#xff08;小内存、自动化&#xff09; 先进后出 FILO&#xff08;First In Last Out&#xff09; 栈 先进先出 FIFO&#xff08;First …

聚类方法总结及code

参考链接 聚类方法的分类&#xff1a; 类别包括的主要算法划分方法K-Means算法&#xff08;均值&#xff09;、K-medoids算法&#xff08;中心点&#xff09;、K-modes算法&#xff08;众数&#xff09;、k-prototypes算法、CLARANS&#xff08;基于选择&#xff09;&#xf…

jmeter接口测试避坑指南

接口测试看着很简单&#xff0c;但是操作过程中还是出现很多问题&#xff0c;现总结如下&#xff1a; 一、jmeter中乱码问题 可在jmeter.properties 这个文件里面找到sampleresult.default.encodingxx&#xff0c;后面xx改成utf-8&#xff0c;然后取消注释。 解决jmeter的bod…

编译和连接

前言&#xff1a;哈喽小伙伴们&#xff0c;从我们开始学习C语言到实现如今的成果&#xff0c;可以说我们对C语言的掌握已经算是精通级别了&#xff0c;但是我们只学习了怎么写代码&#xff0c;却没怎么了解过代码的背后是怎么工作的。 那么今天这篇文章我们一起来学习C语言的最…

spark获取hadoop服务token

spark 作业一直卡在accepted 问题现象问题排查1.查看yarn app日志2.问题分析与原因 问题现象 通过yarn-cluster模式提交spark作业&#xff0c;客户端日志一直卡在submit app&#xff0c;没有运行 问题排查 1.查看yarn app日志 appid已生成&#xff0c;通过yarn查看app状态为…

Linux的权限管理操作

1、权限的基本概念 在多用户计算机系统的管理中&#xff0c;权限是指某个特定的用户具有特定的系统资源使用权利。 在Linux 中分别有读、写、执行权限&#xff1a; 权限针对文件权限针对目录读r表示可以查看文件内容&#xff1b;cat表示可以(ls)查看目录中存在的文件名称写w…

css钟表数字样式

如图&#xff1a; 代码 font-size: 28px;font-family: Yourname;font-weight: 500;color: #00e8ff;

YOLOv5/v7/v8改进实验(七)之使用timm更换YOLOv8模型主干网络Backbone篇

🚀🚀 前言 🚀🚀 timm 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。更难能可贵的是它还在不断地更新迭代新的训练方法,新的视觉模型和优化代码。本章主要介绍如何使用timm替…

企业实施MES管理系统时,是否需要一物一码

随着科技的不断进步&#xff0c;企业对于生产管理和效率提升的需求日益增长。MES管理系统作为一种先进的生产管理工具&#xff0c;受到了越来越多企业的关注和应用。然而&#xff0c;在部署MES管理系统时&#xff0c;是否需要实施一物一码的条码进行管理&#xff0c;往往成为企…

“人间烟火”背后,长沙招商引资再出圈

连续多年&#xff0c;长沙荣膺全国最具幸福感城市。同时&#xff0c;长沙也被誉为“中部崛起的引擎城市”。长沙不仅有网红城市的人间烟火气&#xff0c;更以创新的精神&#xff0c;优质的营商环境&#xff0c;高效的政府服务&#xff0c;丰富的人才资源和深厚的产业基础&#…

徐宗本院士将在2023年CCF中国软件大会上作特邀报告

2023年CCF中国软件大会&#xff08;CCF ChinaSoft 2023&#xff09;邀请徐宗本院士作大会特邀报告。 特邀嘉宾 徐宗本 中国科学院院士 数学家、中国科学院院士、西安交通大学教授。 主要从事应用数学、智能信息处理、机器学习基础理论研究。曾提出稀疏信息处理的L(1/2)正则化理…

uni-app:实现当前时间的获取,并且根据当前时间判断所在时间段为早上,下午还是晚上

效果图 核心代码 获取当前时间 toString()方法将数字转换为字符串 padStart(2, 0)&#xff1a;padStart()方法用于在字符串头部填充指定的字符&#xff0c;使其达到指定的长度。该方法接受两个参数&#xff1a;第一个参数为期望得到的字符串长度&#xff0c;第二个参数为要填充…