神经网络反向传播算法

news2024/11/20 10:40:58

今天我们来看一下神经网络中的反向传播算法,之前介绍了梯度下降与正向传播~       神经网络的反向传播

专栏:💎实战PyTorch💎

反向传播算法(Back Propagation,简称BP)是一种用于训练神经网络的算法。 

反向传播算法是神经网络中非常重要的一个概念,它由Rumelhart、Hinton和Williams于1986年提出。这种算法基于梯度下降法来优化误差函数,利用了神经网络的层次结构来有效地计算梯度,从而更新网络中的权重和偏置。

基本工作流程:

  1. 通过正向传播得到误差,所谓正向传播指的是数据从输入到输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而对网络参数进行适当的调整,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新。对于复杂的复合函数,我们将其拆分为一系列的加减乘除或指数,对数,三角函数等初等函数,通过链式法则完成复合函数的求导。

我们通过一个例子来简单理解下 BP 算法进行网络参数更新的过程🧧:

如图我们在最下边输入两个维度的值进入神经网络:0.05、0.1 ,经过两个隐藏层(每层两个神经元),每个神经元有两个值,左边为输入值,右边是经过激活函数后的输出值;经过这个神经网络后的输出值为:m1、m2,实际值为0.01、0.99 🌠

设置的初始权重w1,w2,...w8分别为0.15、0.20、0.25、0.30、0.30、0.35、0.55、0.60

我们通过计算得到损失函数Error = 1/2 ((m1- target1)2 + (m2 - target2)2) = 0.2988

w5和w7均可以通过求三次导来求梯度,而w1,w3则不能直接通过L降序求导,我们需要求从L到m1,m1到o1,o1到k1,k1到h1,h1到w1:

由于w1是输出两个方向分别到o1和o2,所以是两个方向的梯度求和~

我们也发现所以激活函数都是要可微的~

其他的网络参数更新过程和上面的求导过程是一样的,这里就不过多赘述,我们直接看一下代码。

反向传播代码 

我们先来回顾一些Python中类的一些小细节:

🌈在Python中,使用super()函数可以调用父类的方法。这在子类中重写父类方法时非常有用,因为它允许你调用父类的实现,而不是完全覆盖它

class Parent:
    def __init__(self):
        print("Parent init")

class Child(Parent):
    def __init__(self):
        super().__init__()
        print("Child init")

c = Child()

# 输出
Parent init
Child init

🌈当我们创建一个Child类的实例时,它会首先调用Parent类的__init__方法(通过super().__init__()),然后执行Child类的__init__方法,与类的__init__方法(构造方法)对应的类关闭时自动调用的方法是__del__方法。对象不再被使用时,Python解释器会自动调用这个方法。通常在这个方法中进行一些清理工作,比如释放资源、关闭文件等。

反向传播实现

import torch
import torch.nn as nn
import torch.optim as optim


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(2, 2)
        self.linear2 = nn.Linear(2, 2)

        # 网络参数初始化w1/w2/w3/w4
        self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])
        # w5/w6/w7/w8
        self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])
        # 截距b
        self.linear1.bias.data = torch.tensor([0.35, 0.35])
        self.linear2.bias.data = torch.tensor([0.60, 0.60])
    # 定义前向传播的行径
    def forward(self, x):

        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        x = torch.sigmoid(x)

        return x


if __name__ == '__main__':

    inputs = torch.tensor([[0.05, 0.10]])
    target = torch.tensor([[0.01, 0.99]])

    # 获得网络输出值
    net = Net()
    output = net(inputs)
    # print(output)  # tensor([[0.7514, 0.7729]], grad_fn=<SigmoidBackward>)

    # 计算误差
    loss = torch.sum((output - target) ** 2) / 2
    # print(loss)  # tensor(0.2984, grad_fn=<DivBackward0>)

    # 优化方法
    optimizer = optim.SGD(net.parameters(), lr=0.5)

    # 梯度清零
    optimizer.zero_grad()

    # 反向传播
    loss.backward()

    # 打印 w5、w7、w1 的梯度值
    print(net.linear1.weight.grad.data)
    # tensor([[0.0004, 0.0009],
    #         [0.0005, 0.0010]])

    print(net.linear2.weight.grad.data)
    # tensor([[ 0.0822,  0.0827],
    #         [-0.0226, -0.0227]])

    # 打印网络参数
    optimizer.step()
    print(net.state_dict())
    # OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996], [0.2498, 0.2995]])),
    #              ('linear1.bias', tensor([0.3456, 0.3450])),
    #              ('linear2.weight', tensor([[0.3589, 0.4087], [0.5113, 0.5614]])),
    #              ('linear2.bias', tensor([0.5308, 0.6190]))])
  • optimizer.step() 相当于是将w和b所有参数更新一步的过程

🌈关于nn.Linear的使用

import torch
import torch.nn.functional as F
import torch.nn as nn

# 均匀分布随机初始化

linear = nn.Linear(5, 3)
# 从0-1均匀分布产生参数
nn.init.uniform_(linear.weight)
print(linear.weight.data)

nn.Linear是PyTorch中用于创建线性层的类,也被称为全连接层。它的主要作用是将输入数据与权重矩阵相乘并加上偏置,然后通常会通过一个非线性激活函数进行转换。 

  1. 在函数内部,创建一个线性层,输入维度为5,输出维度为3;
  2. 使用nn.init.uniform_()函数对线性层的权重进行均匀分布随机初始化;
  3. 打印线性层的权重数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1636569.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一键PDF水印添加工具

一键PDF水印添加工具 引言优点1. 精准定位与灵活布局2. 自由旋转与透明度调控3. 精细化页码选择4. 全方位自定义水印内容5. 无缝整合工作流程 功能详解结语工具示意图【工具链接】 引言 PDF作为最常用的文档格式之一&#xff0c;其安全性和版权保护显得尤为重要。今天&#xff…

qcheckbox互斥 也就是单选 纯代码实现 没有ui界面转到槽

1.init&#xff08;&#xff09;函数把所有的qcheckbox找到&#xff0c;然后通过信号与槽&#xff0c;做到点击哪个qcheckbox&#xff0c;哪个qcheckbox就发出信号 2.checkchange&#xff08;&#xff09;槽函数&#xff0c;通过42行拿到是哪个qcheckbox发出的信号&#xff0c…

怎么用微信小程序实现远程控制台球室

怎么用微信小程序实现远程控制台球室呢&#xff1f; 本文描述了使用微信小程序调用HTTP接口&#xff0c;实现控制台球室&#xff0c;控制球台上方的照明灯&#xff0c;单台设备可控制多张球台的照明灯。 可选用产品&#xff1a;可根据实际场景需求&#xff0c;选择对应的规格 …

PVDF-SiO₂复合纳米纤维膜

PVDF-SiO₂复合纳米纤维膜是一种结合了聚偏氟乙烯&#xff08;PVDF&#xff09;和二氧化硅&#xff08;SiO₂&#xff09;纳米粒子的新型复合材料。这种材料通常通过静电纺丝技术或其他纤维制备技术制备而成&#xff0c;具有许多良好的性能和广泛的应用前景。 PVDF是一种热塑性…

中兴UME网管LTE共享参数配置-PLMN添加

本文为中兴设备UME网管电联中频共享参数配置&#xff0c;PLMN添加参数配置部分&#xff0c;因UME与U&#xff13;&#xff11;网管添加PLMN配置区别较大&#xff0c;UME网管需同时配置运营商EN&#xff0d;DC策略&#xff0c;相关配置流程及参数配置如下文。 PLMN eNodeB CU …

《Python编程从入门到实践》day19

#昨日知识点回顾 使用unittest模块测试单元和类 #今日知识点学习 第12章 武装飞船 12.1 规划项目 游戏《外星人入侵》 12.2 安装pygame 终端管理器执行 pip install pygame 12.3 开始游戏项目 12.3.1 创建Pygame窗口及响应用户输入 import sysimport pygameclass…

一个类实现Mybatis的SQL热更新

引言 平时用SpringBootMybatis开发项目&#xff0c;如果项目比较大启动时间很长的话&#xff0c;每次修改Mybatis在Xml中的SQL就需要重启一次。假设项目重启一次需要5分钟&#xff0c;那修改10次SQL就过去了一个小时&#xff0c;成本有点太高了。关键是每次修改完代码之后再重…

【webrtc】MessageHandler 2: 基于线程的消息处理:以PeerConnectionClient为例

PeerConnectionClient 前一篇 nullaudiopoller 并么有场景线程,而是就是在当前线程直接执行的, PeerConnectionClient 作为一个独立的客户端,默认的是主线程。 PeerConnectionClient 同时维护客户端的信令状态,并且通过OnMessage实现MessageHandler 消息处理。 目前只处理一…

CCF-CSP真题题解:201403-1 相反数

201403-1 相反数 #include <iostream> #include <cstring> #include <algorithm> using namespace std;const int MAXN 510;int n, a[MAXN]; int cnt[MAXN];int main() {scanf("%d", &n);for (int i 0; i < n; i) { scanf("%d"…

【分治算法】【Python实现】最接近点对

文章目录 [toc]问题描述一维最接近点对算法Python实现 二维最接近点对算法分治算法时间复杂性Python实现 个人主页&#xff1a;丷从心 系列专栏&#xff1a;分治算法 学习指南&#xff1a;Python学习指南 问题描述 给定平面上 n n n个点&#xff0c;找其中的一对点&#xff…

Python 深度学习(二)

原文&#xff1a;zh.annas-archive.org/md5/98cfb0b9095f1cf64732abfaa40d7b3a 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 第五章&#xff1a;图像识别 视觉可以说是人类最重要的感官之一。我们依赖视觉来识别食物&#xff0c;逃离危险&#xff0c;认出朋友和家人…

【C++题解】1044. 找出最经济型的包装箱型号

问题&#xff1a;1044. 找出最经济型的包装箱型号 类型&#xff1a;多分支结构 题目描述&#xff1a; 已知有 A&#xff0c;B&#xff0c;C&#xff0c;D&#xff0c;E 五种包装箱&#xff0c;为了不浪费材料&#xff0c;小于 10 公斤的用 A 型&#xff0c;大于等于 10 公斤小…

浅论汽车研发项目数字化管理之道

随着汽车行业竞争不断加剧&#xff0c;汽车厂商能否快速、高质地推出贴合市场需求的新车型已经成为车企竞争的重要手段&#xff0c;而汽车研发具备流程复杂、专业领域多、协作难度大、质量要求高等特点&#xff0c;企业如果缺少科学健全的项目管理体系&#xff0c;将会在汽车研…

应用监控(Prometheus + Grafana)

可用于应用监控的系统有很多&#xff0c;有的需要埋点(切面)、有的需要配置Agent(字节码增强)。现在使用另外一个监控系统 —— Grafana。 Grafana 监控面板 这套监控主要用到了 SpringBoot Actuator Prometheus Grafana 三个模块组合的起来使用的监控。非常轻量好扩展使用。…

光伏管理系统:降本增效解决方案。

现在是光伏发展的重要节点&#xff0c;如何在众多同行中脱颖而出并且有效的达到降低成本、提高效率也是很多企业都在考虑的问题&#xff0c;鹧鸪云的团队研发出了光伏管理系统&#xff0c;通过更高效、更智能、更全面的管理方式来帮助企业实现降本增效的转型&#xff0c;小编带…

记录AE学习查漏补缺(持续补充中。。。)

记录AE学习查漏补缺 常用win下截图WinShifts导入AI/PS工程文件将图层上移一个位置或者下移一个位置展示/关闭图层标线/标度放大面板适应屏幕大小 CtrlAltF 关键帧熟记关键参数移动锚点位置加选一个关键参数快速回到上下一帧隐藏/显示图层关键帧拉长缩短关键帧按着鼠标左键不松手…

【面试经典 150 | 回溯】单词搜索

文章目录 写在前面Tag题目来源解题思路方法一&#xff1a;回溯 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容进行回顾…

Golang | Leetcode Golang题解之第59题螺旋矩阵II

题目&#xff1a; 题解&#xff1a; func generateMatrix(n int) [][]int {matrix : make([][]int, n)for i : range matrix {matrix[i] make([]int, n)}num : 1left, right, top, bottom : 0, n-1, 0, n-1for left < right && top < bottom {for column : lef…

Flask表单详解

Flask表单详解 概述跨站请求伪造保护表单类把表单渲染成HTML在视图函数中处理表单重定向和用户会话Flash消息 概述 尽管 Flask 的请求对象提供的信息足够用于处理 Web 表单&#xff0c;但有些任务很单调&#xff0c;而且要重复操作。比如&#xff0c;生成表单的 HTML 代码和验…

【stomp 实战】spring websocket用户消息发送源码分析

这一节&#xff0c;我们学习用户消息是如何发送的。 消息的分类 spring websocket将消息分为两种&#xff0c;一种是给指定的用户发送&#xff08;用户消息&#xff09;&#xff0c;一种是广播消息&#xff0c;即给所有用户发送消息。那怎么区分这两种消息呢?那就是用前缀了…