BP神经网络的最简Python实现

news2025/1/25 9:22:52

文章目录

    • 神经元
    • BP原理及实现
    • 测试

BP,就是后向传播(back propagation),说明BP网络要向后传递一个什么东西,这个东西就是误差。

而神经网络,就是由神经元组成的网络,所以在考虑BP之前,还不得不弄清楚神经元是什么。

神经元

泛泛地说,神经元,就是一个函数 w ( x ) w(x) w(x),而且这个函数往往比较友好,可能是一个线性函数,可以表示为

w ( x ) = ∑ i w i x i w(x)=\sum_i w_ix_i w(x)=iwixi

其中 x i x_i xi x x x的诸分量,而且这个分量很可能不是一个标量,而是一个数组,甚至矩阵,即多维数组。

神经网络的目的,就是在给定 x , y x,y x,y的情况下,根据 y = w ( x ) y=w(x) y=w(x)求出 w i w_i wi的值。

如果仅仅是这个程度,那么神经元的地位就和矩阵计算发生了重叠,神经网络就直接退化成最小二乘法了,所以神经元在进行线性代数的计算之后,要引入一个非线性的函数;同时考虑到神经元对中心的偏移,所以网络中还要加一个 B B B参数,故整体可以表示为

w ( x ) = f ( ∑ i w i x i + b i ) w(x)=f\bigg(\sum_i w_ix_i+b_i\bigg) w(x)=f(iwixi+bi)

这里的 f f f是一个非线性函数,一般叫做激活函数,被这个函数一激活,那么平平无奇的矩阵计算,就华华丽丽地变身成了神经元。

这个激活函数并不需要十分复杂,但一般要起到归一化的作用,比如下面将要用到的sigmoid函数

σ ( x ) = 1 1 + exp ⁡ − x \sigma(x)=\frac{1}{1+\exp -x} σ(x)=1+expx1

BP原理及实现

单个神经元虽然组不成网络,但也有一个专门的名称,即单层感知机,可用于数据二分,由于用途单一,所以就不讲了,直接开始多层的BP网络的编写。

那么在书写之前,先声明一下要做的事情。现有一组 x x x和一组 y y y,我们希望建立一个神经网络 W W W,在 x x x经过 W W W之后,得到的结果 Y Y Y y y y的差值小于误差要求。

如前文所述,神经网络是由多个神经元组成的,而神经元最重要的两个参数是 W W W B B B,用以完成对函数 y y y的拟合,这个拟合的过程,即由 x x x得到 Y Y Y的过程,便是前向过程,相应地 W 1 W_1 W1 B 1 B_1 B1构成的就是前向网络。

而后向过程传递的是误差,这也同样需要神经元的帮助,从而需要另一组神经元 W 2 , B 2 W_2, B_2 W2,B2,这些参数的初值可以随机选取。考虑到BP网络实现起来非常轻便,故先将代码列在下面,然后再解读其含义

import numpy as np
sigmoid = lambda x : 1/(1+np.exp(-np.array(xs)))

def bpnn(xs, ys, nIter=101, th=0.005, nHide=10): 
    W1 = np.random.rand(nHide) 
    B1 = np.random.rand(nHide) 
    W2 = np.random.rand(nHide) 
    B2 = np.random.rand()
    for k in range(nIter):
        Y = []
        for xi,yi in zip(xs,ys):
            L1 = xi*W1-B1                # 隐含层输入数据
            L2 = sigmoid(L1)          #隐含层的输出数据
            Y.append(np.sum(W2*L2) - B2) #模型输出
            err = Y[-1] - yi                # 模型误差
            ##反馈,修改参数
            dB2 = -1*th*err
            dW2 = err*th*L2
            dB1 = W2*L2*(1-L2)*(-1)*err*th
            dW1 = W2*L2*(1-L2)*xi*err*th
            W1 = W1 - dW1
            B1 = B1 - dB1
            W2 = W2 - dW2
            B2 = B2 - dB2
        if k%100==0:
            print(k)
    return Y

bpnn的输入除了将要被拟合的xs,ys之外,还有迭代次数nIter, 学习率th以及神经元的权重个数nHide

其中, W 1 , B 1 , W 2 , B 2 W_1,B_1,W_2,B_2 W1,B1,W2,B2便是上文提到的网络参数,前两者用于前向传播,也就是根据 x s xs xs得到 Y Y Y,后两者用于后向传播,就是根据 Y Y Y y s ys ys得到误差。

这段程序的核心过程是for xi,yi in zip(xs,ys)中的内容,表示对每一组 x , y x,y x,y点对进行神经网络的学习,其中 L 1 L_1 L1 L 2 L_2 L2对应 W 1 , B 1 W_1, B_1 W1,B1 W 2 , B 2 W_2, B_2 W2,B2的两个隐藏层。

数据的流动过程为

x->L1->L2->Y->err

首先第一步,根据当前的 W 1 , B 1 W_1,B_1 W1,B1计算得到L1,通过激活函数,生成L2,通过L2得到拟合结果Y,通过比对Yy的值,就可以得到误差err。此为前向传播过程。

接下来就是根据err,来逆推参数的变化量,其基本流程为前向传播的逆过程,第一步得到dB2,其值为

δ B 2 = − θ ∗ σ \delta B_2=-\theta*\sigma δB2=θσ

其中 θ , σ \theta, \sigma θ,σ分别对应代码中的th, err,表示根据学习率和误差,对B2参数进行微调。

然后是dW2,值为

δ W 2 = θ ∗ σ ∗ L 2 \delta W_2=\theta*\sigma*L_2 δW2=θσL2

即除了受到误差和学习率的影响之外,也要考虑L2的影响。那么这个值是怎么来的呢?

σ = w 2 L 2 + b 2 − y ∂ σ ∂ w 2 = L 2 \begin{aligned} \sigma &= w_2L_2+b_2-y\\ \frac{\partial\sigma}{\partial w_2}&=L_2 \end{aligned} σw2σ=w2L2+b2y=L2

∂ σ ∂ w 2 \frac{\partial\sigma}{\partial w_2} w2σ相当于是 w 2 w_2 w2的微小变动对 σ \sigma σ的影响。之所以要引入学习率,乃因直接通过 σ L 2 \sigma L_2 σL2得到的值可能过大,从而跳过真值。

第三步是dB1,值为

δ B 1 = − θ σ w 2 L 2 ( 1 − L 2 ) \delta B_1=-\theta\sigma w_2L_2(1-L_2) δB1=θσw2L2(1L2)

其缘由为

σ = w 2 L 2 + b 2 − y = w 2 f ( w 1 L 1 + b 1 ) + b 2 − y ∂ σ ∂ w 2 = ∂ f ∂ w 1 \begin{aligned} \sigma&=w_2L_2+b_2-y=w_2f(w_1L_1+b1)+b_2-y\\ \frac{\partial\sigma}{\partial w_2}&=\frac{\partial f}{\partial w_1} \end{aligned} σw2σ=w2L2+b2y=w2f(w1L1+b1)+b2y=w1f

考虑到 f f f的表达式 f ( x ) = 1 1 + exp ⁡ − x f(x)=\frac{1}{1+\exp -x} f(x)=1+expx1,则

d f d x = e − x ( 1 + e − x ) = ( 1 − 1 1 + e − x ) ( 1 1 + e − x ) = f ( x ) ( 1 − f ( x ) ) \frac{\text d f}{\text d x}=\frac{e^{-x}}{(1+e^{-x})}=(1-\frac{1}{1+e^{-x}})(\frac{1}{1+e^{-x}})=f(x)(1-f(x)) dxdf=(1+ex)ex=(11+ex1)(1+ex1)=f(x)(1f(x))

考虑到其所谓的 x x x w 1 L 1 w_1L_1 w1L1构成,所以得到的结果就是w_1L_2(1-L_2)

同理,得到最后一组参数W1的更新方案,

测试

最后,又到了喜闻乐见的测试环节

if __name__=="__main__":
    xs = np.arange(1000)*0.01
    ys = np.sin(xs)
    Y = bpnn(xs, ys, nIter=501, th=0.005, nHide=10)
    fig = plt.figure()
    
    plt.plot(xs,ys)
    plt.plot(xs,Y,color='red',linestyle='--')
    plt.show()

得到结果为

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/104973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

endata 电影票房响应数据破解

本文仅供参考学习,如有侵权可联系本人 目标网站 aHR0cHM6Ly93d3cuZW5kYXRhLmNvbS5jbi9Cb3hPZmZpY2UvQk8vWWVhci9pbmRleC5odG1s加密入口分析 在异步请求那里可以看到请求接口,请求参数并未加密只是响应内容进行了加密,暂时也无法判断加密方…

JavaWeb的Servlet学习之Request03

目录 1.Request 1.1Request执行流程 1.2request对象和response对象的原理 1.3 request对象继承体系结构 1.4request功能: 1.3.1获取请求消息数据 1.获取请求行数据 2.获取请求头 3.获取请求体数据 4.其他功能 4.1获取请求参数通用方式:不论get…

开源CA搭建-基于openssl实现数字证书的生成与分发

目录 一、前言 二、openssl介绍 三、openssl的常用用法 (一)单向加密 (二)生成随机数 (三)生成公钥,私钥 1.生成私钥 2.提取公钥 四、搭建CA (一)创建根CA私钥…

Linux的camera驱动 摄像头调试方法

CameraInfo类用来描述相机信息,通过Camera类中getCameraInfo(int cameraId, CameraInfo cameraInfo)方法获得, 主要包括以下两个成员变量facing,facing 代表相机的方向, 它的值只能是CAMERA_FACING_BACK(后置摄像头&am…

Golang 【basic_leaming】1 基本语法

阅读目录Go 语言变量Go 语言 - 声明变量1. 标准格式2. 批量格式Go 语言 - 初始化与定义变量1. 标准格式2. 编译器推导类型格式3. 短变量声明与初始化Go语言 - 多变量同时赋值Go 语言 - 匿名变量参考资料Go 语言整型(整数类型)1 自动匹配平台的 int 和 un…

新项目为什么决定用 JDK 17了

大家好,我是风筝。公众号「古时的风筝」,专注于后端技术,尤其是 Java 及周边生态。文章会收录在 JavaNewBee 中,更有 Java 后端知识图谱,从小白到大牛要走的路都在里面。 最近在调研 JDK 17,并且试着将之前…

阴差阳错,阴阳之变

北京的第一批“杨康”们已经返回到工作岗位,这其中就包括我。简单总结一下我的感染和康复过程,给大家做个样本吧。我属于北京放开的第一波感染者,12.9日当天感觉嗓子干,毫不犹豫,果然是中招了;周末开始发烧…

特朗普发行NFT惹群嘲,上线售罄现“真香定律”

文/章鱼哥出品/陀螺财经特朗普14日在其创建的社交平台truth social上发帖称,“美国需要一个超级英雄”。他还预告自己将于当地时间15日宣布“重大消息”。据《新闻周刊》报道,特朗普当日在其社交平台上发了一段十几秒的视频,里面有一个他站在…

Windows实时运动控制软核(三):LOCAL高速接口测试之C++

今天,正运动小助手给大家分享一下MotionRT7的安装和使用,以及使用C对MotionRT7开发的前期准备。 01 MotionRT7简介 MotionRT7是深圳市正运动技术推出的跨平台运动控制实时内核,也是国内首家完全自主自研,自主可控的Windows运动控…

Linux搭建测试环境详细步骤

本文讲解如何在Linux CentOS下部署Java Web项目的步骤 环境准备 (1)Linux系统(2)JDK(3)Tomcat (4)MySQL工具下载 一、Linux系统 本文主要是Linux CentOS7为例 自己在家练习小项…

[拆轮子] PaddleDetection 中的 COCODataSet 是怎么写的

今日,辗转反侧,该💩的代码就是跑不成功,来看看 COCODataSet 到底是怎么写的,本文只参考当前版本的代码,当前版本 PaddleDetection2.5 COCODataSet 源码见附录 COCODataSet 类内部就三个函数: …

词义和词义消歧

Synsets(“synonym sets”, effectively senses) are the basic unit of organization in WordNet.同义词集 对于许多应用程序,我们希望消除歧义 • 我们可能只对一种含义感兴趣 • 在网络上搜索chemical plant 化工厂,我们不想搜到香蕉中的化学物质 所以…

【SpringBoot扩展点】 容器刷新前回调ApplicationContextInitializer

本文将作为Spring系列教程中源码版块的第一篇,整个源码系列将分为两部分进行介绍;单纯的源码解析,大概率是个吃力没人看的事情,因此我们将结合源码解析,一个是学习下别人的优秀设计,一个是站在源码的角度看…

【MySQL】索引和事务重点知识汇总

目录1.索引:1.1 索引的使用:1.2 索引背后的核心数据结构:1.2.1 先认识 B 树(N叉搜索树):1.2.2 再认识 B 树(N叉搜索树):2.事务:2.1 隔离性:2.1.1 脏读问题:2.1.2 不可重复读问题:2.1.3 幻读问题:2.1.4 总结:2.1.5 隔离级别:1.索引: 索引存在的意义就是为了提高查询到效率.索引…

【AI理论学习】Python机器学习中的特征选择

Python机器学习中的特征选择特征选择方法特征选择的Python库使用Scikit-learn实现特征选择方差卡方检验ANOVALasso正则化递归特征消除使用Feature-engine进行特征选择单变量特征选择相关性Python 中的更多特性选择方法参考资料任何数据科学项目的一个重要步骤是选择最具预测性的…

vue实现文件下载

引言 最近在自己做项目的需求的过程中,需要vuespringboot实现文件的下载功能(导出博客文件)。 问题重现 在我后端文件下载接口开发完成后,使用vue前端去进行对接时出现了问题。 我是直接使用的axios去进行请求接口&#xff0c…

Python 炫技操作:条件语句的七种写法

原代码 这是一段非常简单的通过年龄判断一个人是否成年的代码,由于代码行数过多,有些人就不太愿意这样写,因为这体现不出自己多年的 Python 功力。 if age > 18:return "已成年" else:return "未成年"下面我列举了六…

SwiftUI 中创建谷歌字体浏览器

Google Fonts是设计用户界面时使用的免费字体的转到站点。本教程将展示如何编写一个简单的工具来预览这些字体,而无需在系统中注册每种字体。 该应用程序包含一个拆分视图,该视图在左侧面板中包含字体列表。右侧面板将显示字体样式选项的预览。 项目设置 创建一个名为 Googl…

Vue2之webpack篇(一)

目录 前言 1、什么是webpack? 2、传统开发模式 一、传统开发模式 1、场景 2、问题 3、原因 4、解决方案 二、ES6模块化 1、ES6的解决方案 3、拓展 4、取别名 5、*搭配取别名 6、导出default{} 三、CommonJS规范 1、推荐文档 2、使用CommonJS规范解决方…

十二、DockerFile构建过程解析

1、概述 Dockerfile是用来构建Docker镜像的文本文件,是由一条条构建镜像所需的指令和参数构成的脚本。 在Docker 常用命令篇中,我们已经知道了2中构建镜像的方式 export\import 和 commit方式。这两种方式都需要先运行并创建容器,然后在容器…