【代码】python实现一个BP神经网络-原理讲解与代码展示

news2024/11/26 20:30:37

本文来自《老饼讲解-BP神经网络》https://www.bbbdata.com/

目录

  • 一、BP神经网络原理回顾
    • 1.1 BP神经网络的结构简单回顾
    • 1.2.BP神经网络的训练算法流程
  • 二、python实现BP神经网络代码
    • 2.1.数据介绍
    • 2.2.pytorch实现BP神经网络代码

在python中要如何使用代码实现一个BP神经网络呢?
在python中可以利用pytorch来实现BP神经网络,这是最简洁也是最常用的方法。
通过本文可以详细掌握怎么使用python的pytorch来实现一个BP神经网络。

一、BP神经网络原理回顾

1.1 BP神经网络的结构简单回顾

BP神经网络的结构如下:
BP神经网络结构图
BP神经网络由输入层、隐层、输出层组成,其中隐层可以是有多层的,整个网络以前馈式进行计算,也就是每层的输出作为下层的输入,不断套娃,直到输出层

每层的计算公式如下:
y = T ( W X + B ) y=T(WX+B) y=T(WX+B)
其中,
X:该层的输入
W:该层的权重
B:该层的阈值
T:该层的激活函数

1.2.BP神经网络的训练算法流程

梯度下降算法求解BP神经网络的流程如下:
梯度下降算法求解BP神经网络

一、先初始化一个解                                                 
二、迭代                                                                  
1. 计算所有w,b在当前处的梯度dw,db           
2. 将w,b往负梯度方向更新:                       
   w = w-lr*dw                       
   b = b-lr*db       
3. 判断是否满足退出条件,如果满足,则退出迭代

二、python实现BP神经网络代码

在python中只需要使用pytorch就可以简单实现BP神经网络,而且提供了丰富的训练算法。

2.1.数据介绍

为方便理解,不妨采用以下的简单数据:
在这里插入图片描述
上述即为sin函数在[-5,5]之间的20个采样数据

2.2.pytorch实现BP神经网络代码

下面展示在pytorch中实现BP神经网络的代码
特别说明:需要先安装pytorch包

import torch
import matplotlib.pyplot as plt 
torch.manual_seed(99)

# -----------计算网络输出:前馈式计算---------------
def forward(w1,b1,w2,b2,x):                                   
    return w2@torch.tanh(w1@x+b1)+b2

# -----------计算损失函数: 使用均方差--------------
def loss(y,py):
    return ((y-py)**2).mean()

# ------训练数据----------------
x = torch.linspace(-5,5,20).reshape(1,20)                      # 在[-5,5]之间生成20个数作为x
y = torch.sin(x)                                               # 模型的输出值y

#-----------训练模型------------------------
in_num  = x.shape[0]                                            # 输入个数
out_num = y.shape[0]                                            # 输出个数
hn  = 4                                                         # 隐节点个数
w1  = torch.randn([hn,in_num],requires_grad=True)               # 初始化输入层到隐层的权重w1
b1  = torch.randn([hn,1],requires_grad=True)                    # 初始化隐层的阈值b1
w2  = torch.randn([out_num,hn],requires_grad=True)              # 初始化隐层到输出层的权重w2
b2  = torch.randn([out_num,1],requires_grad=True)               # 初始化输出层的阈值b2

lr = 0.01                                                       # 学习率
for i in range(5000):                                           # 训练5000步
    py = forward(w1,b1,w2,b2,x)                                 # 计算网络的输出
    L = loss(y,py)                                              # 计算损失函数
    print('第',str(i),'轮:',L)                                 # 打印当前损失函数值
    L.backward()                                                # 用损失函数更新模型参数的梯度
    w1.data=w1.data-w1.grad*lr                                  # 更新模型系数w1
    b1.data=b1.data-b1.grad*lr                                  # 更新模型系数b1
    w2.data=w2.data-w2.grad*lr                                  # 更新模型系数w2
    b2.data=b2.data-b2.grad*lr                                  # 更新模型系数b2
    w1.grad.zero_()                                             # 清空w1梯度,以便下次backward
    b1.grad.zero_()                                             # 清空b1梯度,以便下次backward
    w2.grad.zero_()                                             # 清空w2梯度,以便下次backward
    b2.grad.zero_()                                             # 清空b2梯度,以便下次backward
px = torch.linspace(-5,5,100).reshape(1,100)                    # 测试数据,用于绘制网络的拟合曲线    
py = forward(w1,b1,w2,b2,px).detach().numpy()                   # 网络的预测值
plt.scatter(x, y)                                               # 绘制样本
plt.plot(px[0,:],py[0,:])                                       # 绘制拟合曲线  
print('w1:',w1)
print('b1:',b1)
print('w2:',w2)
print('b2:',b2)

运行结果如下:

.....                                            
第 4996 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4997 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4998 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4999 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
w1: tensor([[ 0.1742],[-0.8133],[-0.6450],[-0.4054]],requires_grad=True)
b1: tensor([[ 0.8125],[0.0593],[-1.8776],[1.1220]],requires_grad=True)
w2: tensor([[-0.7753,-2.0142,1.1161,1.9635]],requires_grad=True)
b2: tensor([[0.1094]], requires_grad=True)   

运行结果
可以看到,模型根据训练数据,已经较好地拟合出sin函数曲线


相关链接:

《老饼讲解-机器学习》:老饼讲解-机器学习教程-通俗易懂
《老饼讲解-神经网络》:老饼讲解-matlab神经网络-通俗易懂
《老饼讲解-神经网络》:老饼讲解-深度学习-通俗易懂

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1848937.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Github 2024-06-22Rust开源项目日报 Top10

根据Github Trendings的统计,今日(2024-06-22统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Rust项目10Dart项目1Move项目1TypeScript项目1RustDesk: 用Rust编写的开源远程桌面软件 创建周期:1218 天开发语言:Rust, Dart协议类型:GNU …

python编程大数据分析 ,anaconda ,删除包 提示没有meta信息,无法删除tensorflow2.10,无法降级到tensorflow2.5.3

python编程大数据分析 &#xff0c;anaconda &#xff0c;删除包 提示没有meta信息&#xff0c;无法删除tensorflow2.10,无法降级到tensorflow2.5.3 pip unstall tensorflow 提示 Requirement already satisfied: tensorflow in k:\programdata\anaconda3\lib\site-p ackage…

【源码】人力资源管理系统hrm功能剖析及源码

eHR人力资源管理系统&#xff1a;功能强大的人力资源管理工具 随着企业规模的不断扩大和业务需求的多样化&#xff0c;传统的人力资源管理模式已无法满足现代企业的需求。eHR人力资源管理系统作为一种先进的管理工具&#xff0c;能够为企业提供高效、准确、实时的人力资源管理。…

为什么要学习PMP

学习PMP&#xff08;项目管理专业人士认证&#xff09;能够在职场竞争力、薪资待遇、项目管理技能等方面带来显著的提升。以下是学习PMP的具体分析&#xff1a; 1、职场竞争力 升职加薪&#xff1a;学习PMP能够提升个人在项目中的管理能力和解决问题的能力&#xff0c;从而在…

STM32学习和实践笔记(37):DMA实验

1.DMA简介 DMA&#xff0c;全称是Direct Memory Access&#xff0c;中文意思为直接存储器访问。DMA可用于实现外设与存储器之间或者存储器与存储器之间数据传输的高效性。 之所以高效&#xff0c;是因为DMA传输数据移动过程无需CPU直接操作&#xff0c;这样节省的 CPU 资源就可…

QT基础 - 文本文件读写

目录 零. 前言 一.读取文件 二. 写入文件 三. 和二进制读写的区别 零. 前言 在 Qt 中&#xff0c;对文本文件进行读写操作是常见的任务之一。这对于保存和加载配置信息、处理数据文件等非常有用。 Qt 提供了多种方式来读写文本文件&#xff0c;使得文件操作变得相对简单和…

攻防世界-intoU

下载附件发现是wav文件&#xff0c;扔Audacity里面 将采样率&#xff08;右击选择&#xff09;改为900&#xff0c;之后再查看频谱图 再将进度条拉到最后 得到flag&#xff1a; RCTF{bmp_file_in_wav}

最新版ChatGPT对话系统源码 Chat Nio系统源码

最新版ChatGPT对话系统源码 Chat Nio系统源码 支持 Vision 模型, 同时支持 直接上传图片 和 输入图片直链或 Base64 图片 功能 (如 GPT-4 Vision Preview, Gemini Pro Vision 等模型) 支持 DALL-E 模型绘图 支持 Midjourney / Niji 模型的 Imagine / Upscale / Variant / Re…

二,SpringFramework

二、SpringFramework实战指南 目录 一、技术体系结构 1.1 总体技术体系1.2 框架概念和理解 二、SpringFramework介绍 2.1 Spring 和 SpringFramework概念2.2 SpringFramework主要功能模块2.3 SpringFramework 主要优势 三、Spring IoC容器和核心概念 3.1 组件和组件管理概念3…

【深度学习驱动流体力学】OpenFOAM框架剖析

目录 1. applications 目录solvers:存放各种求解器。mesh:网格生成相关工具。2. src 目录3. tutorials 目录其他主要目录和文件参考OpenFOAM 源码文件目录的框架如下,OpenFOAM 是一个开源的计算流体力学 (CFD) 软件包,其源码文件结构设计精巧,分为多个主要目录,每个目录都…

jeecg-boot项目的部署-windows系统

一、基础环境的准备&#xff1a; 1、后台基础环境&#xff1a;JDK、redis、数据库&#xff1a;sqlserver 2、前端基础环境&#xff1a;nginx redis和nginx的安装都很方便&#xff0c;直接去对应的官网&#xff0c;下载zip压缩包&#xff0c;然后解压&#xff0c;执行.exe文件…

制作WIFI二维码,实现一键扫描连接WIFI

在现代社会&#xff0c;Wi-Fi已成为我们日常生活中不可或缺的一部分。无论是在家庭、办公室还是公共场所&#xff0c;我们都希望能够快速方便地连接到Wi-Fi网络。下面小编就来和大家分享通过制作WIFI二维码&#xff0c;来实现一键扫描就可以连接WIFI的方法。连接WIFI不用在告诉…

计算机网络 VLAN间路由单臂路由

一、理论知识 VLAN是一种将物理网络划分成多个逻辑网络的方法。不同的VLAN属于不同的网段&#xff0c;因此互相通信需要通过路由器进行路由。通常情况下&#xff0c;在同一VLAN内的设备可以直接通信&#xff0c;而不同VLAN之间的设备则需要通过路由器转发数据。本实验利用单臂…

洛谷——P2824 排序

题目来源&#xff1a;[HEOI2016/TJOI2016] 排序 - 洛谷https://www.luogu.com.cn/problem/P2824 问题思路 本文介绍一种二分答案的做法&#xff0c;时间复杂度为&#xff1a;(nm)*log(n)*log(n).本题存在nlog(n)的做法&#xff0c;然而其做法没有二分答案的做法通俗易懂. 默认读…

水系统阻力计算

所谓水泵的选取计算其实就是估算&#xff08;很多计算公式本身就是估算的&#xff09;&#xff0c;估算分的细致些考虑的内容全面些就是精确的计算。 特别补充&#xff1a;当设计流量在设备的额定流量附近时&#xff0c;上面所提到的阻力可以套用&#xff0c;更多的是往往都大…

【前端技术】标签页通讯localStorage、BroadcastChannel、SharedWorker的技术详解

&#x1f604; 19年之后由于某些原因断更了三年&#xff0c;23年重新扬帆起航&#xff0c;推出更多优质博文&#xff0c;希望大家多多支持&#xff5e; &#x1f337; 古之立大事者&#xff0c;不惟有超世之才&#xff0c;亦必有坚忍不拔之志 &#x1f390; 个人CSND主页——Mi…

Apple - Text Attribute Programming Topics

本文翻译整理自&#xff1a;Text Attribute Programming Topics&#xff08;更新日期&#xff1a;2004-02-16 https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/TextAttributes/TextAttributes.html#//apple_ref/doc/uid/10000088i 文章目录 一、文…

http发展史(http0.9、http1.0、http1.1、http/2、http/3)详解

文章目录 HTTP/0.9HTTP/1.0HTTP/1.1队头阻塞&#xff08;Head-of-Line Blocking&#xff09;1. TCP 层的队头阻塞2. HTTP/1.1 的队头阻塞 HTTP/2HTTP/3 HTTP/0.9 发布时间&#xff1a;1991年 特点&#xff1a; 只支持 GET 方法没有 HTTP 头部响应中只有 HTML 内容&#xff0…

C语言入门系列:可迁移的数据类型

文章目录 1&#xff0c;精确宽度类型(exact-width integer type)2&#xff0c;最小宽度类型&#xff08;minimum width type&#xff09;3&#xff0c;最快的最小宽度类型&#xff08;fast minimum width type&#xff09;4&#xff0c;可以保存指针的整数类型。5&#xff0c; …

云原生微服务开发日趋成熟:有效拥抱左移以改善交付

在软件工程和应用程序开发方面&#xff0c;云原生已经成为许多团队的常用术语。当人们调查云原生的世界时&#xff0c;他们经常会得出这样的观点&#xff1a;云原生的整个过程都是针对大型企业应用程序的。几年前&#xff0c;情况可能确实如此&#xff0c;但随着 Kubernetes 等…