深入探索BP神经网络【简单原理、实际应用和Python示例】

news2025/1/12 0:00:42

人工神经网络(Artificial Neural Networks)是一种受到生物神经网络启发的机器学习模型,它的应用范围广泛,包括图像识别、语音识别、自然语言处理等领域。其中,BP神经网络(Backpropagation Neural Network)是最常见和基础的神经网络之一。

背景

BP神经网络是一种 前馈神经网络,最早由David Rumelhart、Geoffrey Hinton和Ron Williams于1986年提出。它的设计灵感来自于 人脑中神经元之间的信息传递机制,通过神经元之间的连接和权重来模拟信息处理 。BP神经网络由输入层、隐含层(可以有多个)、输出层组成,每个神经元都与上一层的所有神经元相连,每个连接都有一个权重。

它就像一条信息流水管,把信息从一头输送到另一头。想象一下,你把一些问题的答案输入到这个水管的一端,然后它通过一系列的阀门和管道,最终给你正确的答案。这些阀门和管道就是我们的神经元,而它们之间的强弱连接就是权重。

与前馈神经网络不同,后馈神经网络 在信息传递方面更加复杂。前馈神经网络只向前传递信息,而后馈神经网络可以在信息传递的同时,也允许信息反馈,形成循环。这使得后馈神经网络在处理某些复杂问题时更加强大,但也更加复杂。虽然后馈神经网络在某些情况下非常有用。

解决的问题

BP神经网络主要用于解决分类和回归问题。它可以应用于诸如图像分类、文本情感分析、手写数字识别等任务。通过学习和调整神经元之间的权重,BP神经网络可以自动地发现输入数据中的模式特征,从而进行准确的预测和分类。

实现原理

梯度下降算法

BP神经网络的训练基于梯度下降算法。该算法的核心思想是最小化损失函数,使预测值与实际值之间的误差尽可能小。下面是梯度下降算法的关键步骤:

  1. 初始化权重和偏置: 刚开始,我们给网络的每个"连接"(类似神经元之间的道路)随机设置了一些权重,就像是在城市的交叉口上设置了一些标志。

  2. 前向传播: 我们把数据传递给网络,就像是车辆在城市中行驶,经过每个交叉口。每个神经元计算出一些东西,就像是每个交叉口告诉车辆往哪里走。

  3. 计算损失: 我们想要知道网络的预测是否准确,所以我们计算出它的错误,就像是检查车辆是否按照地图行驶。

  4. 反向传播: 然后,我们回头看每个连接的错误,就像是回头看每个交叉口上的标志是否设置得合适。如果错了,我们会微调标志,就像是微调道路的方向

  5. 重复迭代: 我们一遍又一遍地重复这个过程,每次都试图让预测更准确。就像是一辆车一步一步地走向正确的目的地,最终我们希望网络的预测变得非常准确,损失变得很小。

代码示例

以下将演示如何使用BP神经网络解决二分类问题(解决XOR问题)

import numpy as np
import matplotlib.pyplot as plt

# 定义激活函数
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# 初始化权重和偏置
input_size = 2
hidden_size = 3
output_size = 1
learning_rate = 0.1

w1 = np.random.randn(input_size, hidden_size)
b1 = np.zeros((1, hidden_size))
w2 = np.random.randn(hidden_size, output_size)
b2 = np.zeros((1, output_size))

# 训练数据
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])

# 存储损失值
losses = []

# 训练神经网络
for epoch in range(10000):
    # 前向传播
    z1 = np.dot(X, w1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, w2) + b2
    a2 = sigmoid(z2)

    # 计算损失
    loss = np.mean((a2 - y) ** 2)
    losses.append(loss)

    # 反向传播
    dloss_da2 = 2 * (a2 - y)
    da2_dz2 = a2 * (1 - a2)
    dz2_dw2 = a1
    dz2_db2 = 1

    dloss_dw2 = np.dot(dz2_dw2.T, dloss_da2 * da2_dz2)
    dloss_db2 = np.sum(dloss_da2 * da2_dz2, axis=0, keepdims=True)

    da2_dz1 = w2
    dz1_dw1 = X
    dz1_db1 = 1

    dloss_da1 = np.dot(dloss_da2 * da2_dz2, da2_dz1.T)
    dloss_dw1 = np.dot(dz1_dw1.T, dloss_da1 * sigmoid(z1) * (1 - sigmoid(z1)))
    dloss_db1 = np.sum(dloss_da1 * sigmoid(z1) * (1 - sigmoid(z1)), axis=0, keepdims=True)

    # 更新权重和偏置
    w1 -= learning_rate * dloss_dw1
    b1 -= learning_rate * dloss_db1
    w2 -= learning_rate * dloss_dw2
    b2 -= learning_rate * dloss_db2

    if epoch % 1000 == 0:
        print(f'Epoch {epoch}, Loss: {loss}')

# 使用训练好的模型进行预测
def predict(X):
    z1 = np.dot(X, w1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, w2) + b2
    a2 = sigmoid(z2)
    return a2

# 测试
X_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
predictions = predict(X_test)
print("Predictions:")
print(predictions)

# 绘制损失曲线
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

在上述代码中,各个变量的意义如下:

  1. w1:第一层权重矩阵,它包含了输入数据到隐含层神经元的连接权重。其形状为 (input_size, hidden_size),其中 input_size 是输入特征的数量,hidden_size 是隐含层神经元的数量。

  2. b1:第一层偏置,它包含了隐含层每个神经元的偏置。其形状为 (1, hidden_size),用于调整神经元的激活阈值——如果偏置值很大,神经元更容易被激活;如果偏置值很小,神经元更难被激活

  3. w2:第二层权重矩阵,它包含了隐含层到输出层神经元的连接权重。其形状为 (hidden_size, output_size),其中 output_size 是输出层神经元的数量。

  4. b2:第二层偏置,它包含了输出层每个神经元的偏置。其形状为 (1, output_size),用于调整输出层神经元的激活阈值。

  5. X:训练数据集,包含输入数据的矩阵。其形状为 (样本数, input_size),其中 样本数 表示训练数据的数量。

  6. y:目标输出数据,包含每个样本的目标输出。其形状为 (样本数, output_size)

  7. losses:用于存储每个训练周期(epoch)中的损失值。损失是模型的预测与实际目标的误差。

  8. input_size:输入特征的数量,通常等于 X 矩阵的列数。

  9. hidden_size:隐含层神经元的数量,是神经网络结构中的一个超参数。

  10. output_size:输出层神经元的数量,通常等于 y 矩阵的列数。

  11. learning_rate:学习率,用于控制权重和偏置的更新步长。这是一个超参数,影响训练过程的速度和稳定性。

如何理解超参数?
“超” 表示它们在神经网络模型中处于更高的层次,是一种 控制参数的参数,是我们手动设置的。它们影响着训练的方式和结果,但不是从数据中学习得到的,两个常见的超参数是:

  1. 隐含层神经元数量(hidden_size): 这就好像是一个神秘的房间,我们在这个房间里进行了一些数学运算,以便理解和解决问题。但我们需要决定房间里有多少人。如果人太多,可能会变得复杂而慢,如果人太少,可能无法解决复杂的问题。所以,hidden_size 是一个数字,帮助我们控制神经网络中这个房间里有多少“工作人员”。
  2. 学习率(learning_rate): 学习率就像是我们在学校学习的速度。如果我们学习得太快,可能会错过重要的东西;如果我们学习得太慢,可能会浪费时间。所以,learning_rate 是一个数字,它帮助我们控制在训练神经网络时,我们每次更新权重和偏置的速度。如果它太大,可能导致不稳定,如果太小,训练可能会很慢。

这两个超参数,hidden_sizelearning_rate,是我们在构建神经网络时需要做的决策,它们会影响我们的模型是如何工作的。所以,选择适合问题的隐含层神经元数量和学习率非常重要。学习率要选择得当,使得模型在训练过程中能够迅速学习,同时又保持稳定,而隐含层神经元数量要足够适应问题的复杂度,但也不能过多,以免增加计算负担。

在上述代码中,用到的三个函数的简单解释:

  1. np.random.randn:这是一个用于生成随机数的函数。它创建一个包含随机数的数组,这些随机数遵循标准 正态分布 (均值为0,标准差为1)。在神经网络中,我们通常使用这些随机数来初始化神经网络的权重,使它们具有随机的初始值。

  2. np.zeros:这个函数用于创建一个数组,其中所有的元素都是零。在神经网络中,我们通常使用它来初始化偏置,以确保它们的初始值为零。

  3. np.dot:这是一个用于进行矩阵相乘的函数。在神经网络中,我们使用这个函数来计算不同层之间神经元的连接以及权重与输入数据之间的乘积。这有助于执行神经网络的前向传播和反向传播计算。

这些函数在神经网络的实现中非常有用,因为它们可以帮助我们进行随机初始化、初始化和数学运算,这些都是神经网络训练过程中的基本操作。

如何理解激活函数Sigmod?
它是一种用来处理神经网络中的数值的特殊函数。它的作用就像是一个开关,可以把输入的数值变成0到1之间的输出
如果我们把一个大的数值输入 Sigmoid 函数,它会输出接近于1的值,就好像是 开关被打开 一样。而如果我们输入一个负数或者很小的数,它会输出接近于0的值,就好像是开关被关闭一样。

所以,Sigmoid 函数的作用是把输入的数值 转换成一个在0到1之间的数 ,这在神经网络中常用于控制神经元的激活状态。这个函数的特点是输出总是在0到1之间,有助于神经网络学习非常复杂的模式和信息。

为什么需要设置b1、b2偏置调整神经元激活阈值呢?
简单来说,偏置就像是神经元的调节器,它帮助神经元更好地理解和处理各种不同的数据。

有些数据可能需要更小的刺激才能激活神经元,而有些数据可能需要更大的刺激。偏置就像是用来调整神经元敏感度的工具,使神经网络可以更好地学习和理解各种不同类型的数据。所以,偏置是神经网络中非常重要的部分,它增加了网络的适应能力,帮助网络更好地解决各种问题。

效果展示

经过训练的BP神经网络可以用于进行二分类任务,例如逻辑门操作,如AND、OR、XOR。在上面的代码示例(XOR)中,网络在经过足够的训练迭代后,可以准确地进行预测。
在这里插入图片描述

总结

BP神经网络是一种重要的神经网络模型,通过梯度下降算法来训练和优化模型,以解决分类和回归问题。它在机器学习和深度学习中具有广泛的应用,是许多人工智能应用的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1092687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记-MongoDB(复制集,分片集集群搭建)

复制集群搭建 基本介绍 什么是复制集? 复制集是由一组拥有相同数据集的MongoDB实例做组成的集群。 复制集是一个集群,它是2台及2台以上的服务器组成,以及复制集成员包括Primary主节点,Secondary从节点和投票节点。 复制集提供了…

花2个月时间学习,面华为测开岗要30k,面试官竟说:你不是在搞笑。。。

背景介绍 计算机专业,代码能力一般,之前有过两段实习以及一个学校项目经历。第一份实习是大二暑期在深圳的一家互联网公司做前端开发,第二份实习由于大三暑假回国的时间比较短(小于两个月),于是找的实习是…

通用考勤后台管理系统

考勤后台系统,包括待办事项、人员管理、任务中心、任务详情、我的任务、客户管理、考勤功能几大功能,本后台系统以考勤打卡为主要功能,采用分屏布局的方式,简洁大方,使用方便

mysqlbinlog 日用记录

我是同步覆盖了两张表,现在想用日志恢复。 先说结论,没有恢复,因为我的日志不完整,设置了定时清理。 如果你truncate了表或者数据库,如果没有完整的日志是恢复不了数据的。 第一、mysqlbinlog 可能没开启 第二、开…

C++入门 第一篇(C++关键字, 命名空间,C++输入输出)

目录 1. C关键字 2. 命名空间 2.1 命名空间定义 2.2命名空间的使用 命名空间的使用有三种方式: 1.加命名空间名称及作用域限定符 2.使用using将命名空间中某个成员引入 3.使用using namespace 命名空间名称 引入 3. C输入&输出 4.缺省函数 4.1 缺省参…

微信开发者工具下载

一、微信开发者工具下载官网 微信开发者工具下载地址与更新日志 | 微信开放文档 (qq.com) 二、微信开发者工具界面 下载安装好后,软件图标如下图所示。 运行软件如下图所示,这时候就需要使用你的管理员账号扫码登录。 登陆后的界面,如下图…

为知笔记一个日记模板

<!DOCTYPE HTML><html><head> <meta http-equiv"Content-Type" content"text/html; charsetunicode"> <title>日记&#xff1a;</title><style id"wiz_custom_css">html, .wiz-editor-body {font-siz…

Lua调用C#类

先创建一个Main脚本作为主入口&#xff0c;挂载到摄像机上 public class Main : MonoBehaviour {// Start is called before the first frame updatevoid Start(){LuaMgr.GetInstance().Init();LuaMgr.GetInstance().DoLuaFile("Main");}// Update is called once p…

Stm32_标准库_11_ADC_光敏热敏传感器_测数值

在测量光敏传感器数值得基础上手动将通道改成热敏传感器通道即可 由于温度传感器的测量范围是-20 ~ 105摄氏度&#xff0c;所以输出温度得考虑带上符号这就需要在原有输出光照强度代码的基础上新添加几个函数 函数1&#xff1a; uint16_t AD_Getvailue(uint8_t ADC_Channel){…

C# PortraitModeFilter (人物图片)背景模糊

效果 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.Drawing; using System.Drawing.Imaging; using System.Linq; using System.Windows.Forms; us…

python文本转语音

概述 目前有文本转语音的技术&#xff0c;可以用在配音领域&#xff0c;我个人因为一些需求&#xff0c;所以开始寻找这方面的资源&#xff0c;目前各大平台&#xff0c;比如腾讯&#xff0c;讯飞&#xff0c;阿里&#xff0c;百度等都有这样的API服务&#xff0c;我个人是是使…

Multi Scale Supervised 3D U-Net for Kidney and Tumor Segmentation

目录 摘要1 引言2 方法2.1 预处理和数据增强2.2 网络的体系结构2.3 训练过程2.4 推理与后处理 3 实验与结果4 结论与讨论 摘要 U-Net在各种医学图像分割挑战中取得了巨大成功。一些新的、带有花里胡哨功能的架构可能在某些数据集中在使用最佳超参数时取得成功&#xff0c;但它们…

力扣-463.岛屿的周长

Idea 注意观察&#xff0c;每一个完整的方块&#xff0c;边长都是加4&#xff0c;一旦这个方块有其他的方块相邻的话&#xff0c;那么这两个方块总边长就要减少2. 因此我们遍历二维数组的时候&#xff0c;判断岛屿方块的上面还有左方是否有相邻即可 class Solution { public:in…

Linux 64位 C++协程池原理分析及代码实现

导语 本文介绍了协程的作用、结构、原理&#xff0c;并使用C和汇编实现了64位系统下的协程池。文章内容避免了协程晦涩难懂的部分&#xff0c;用大量图文来分析原理&#xff0c;适合新手阅读学习。 GitHub源码 1. Web服务器问题 现代分布式Web后台服务逻辑通常由一系列RPC请…

算法村开篇

大家好我是苏麟从今天开始我将带来算法的一些习题和心得体会等等...... 算法村介绍 我们一步步地学习算法本专栏会以闯关的方式来学习算法 循序渐进地系统的学习算法并掌握大部分面试知识 , 期待和大家一起进步 . 索大祝大家学有所成 , 前程似锦.

PyCharm运行Nosetests并导出测试报告

1. Pycharm运行Nosetests PyCharm可以使用两种方法&#xff0c;运行Nosetests测试文件&#xff1a; 1) 图形用户界面GUI a) 在PyCharm中&#xff0c;选中测试文件&#xff0c;如Tests/test_demo.py b) 鼠标右键选择Run Nosetests in test_demo.py即可执行测试 注1&#xff…

【大数据Hive】hive select 语法使用详解

目录 一、前言 二、Hive select 完整语法树 三、Hive select 操作演示 3.1 数据准备 3.1.1 创建一张表 3.1.2 将数据load加载到t_usa_covid19表 3.1.3 再创建一张分区表 3.1.4 使用动态分区插入数据 3.2 select 常用语法 3.2.1 查询所有字段或者指定字段 3.2.2 查询…

【数据库系统概论】第七章数据库设计

7.1数据库设计概述 数据库设计定义是什么&#xff1f; 数据库设计(database design)&#xff1a;数据库设计是指对于一个给定的应用环境&#xff0c;构造(设计)优化的数据库逻辑模式和物理结构&#xff0c;并据此建立数据库及其应用系统&#xff0c;使之能够有效地存储和管理…

【排序算法】详解冒泡排序及其多种优化稳定性分析

文章目录 算法原理细节分析优化1优化2算法复杂度分析稳定性分析总结 算法原理 冒泡排序(Bubble Sort) 就是从序列中的第一个元素开始&#xff0c;依次对相邻的两个元素进行比较&#xff0c;如果前一个元素大于后一个元素则交换它们的位置。如果前一个元素小于或等于后一个元素…

RootSIFT---SIFT图像特征的扩展

RootSIFT是论文 Three things everyone should know to improve object retrieval - 2012所提出的 A Comparative Analysis of RootSIFT and SIFT Methods for Drowsy Features Extraction - 2020 当比较直方图时&#xff0c;使用欧氏距离通常比卡方距离或Hellinger核时的性能…