神经网络基础--什么是正向传播??什么是方向传播??

news2024/11/28 4:39:35

前言

  • 本专栏更新神经网络的一些基础知识;
  • 这个是本人初学神经网络做的笔记,仅仅堆正向传播、方向传播就行了了一个讲解,更加系统的讲解,本人后面会更新《李沐动手学习深度学习》,会更有详细讲解;
  • 案例代码基于pytorch;
  • 欢迎收藏 + 关注, 本人将会持续更新。

文章目录

  • 正向传播与反向传播
    • 梯度下降法
      • 简介
      • 不同梯度下降法区别
    • 前向传播
    • 反向传播算法
      • 简介
      • 案例介绍原理
      • 代码展示

正向传播与反向传播

梯度下降法

简介

表达式

w i j n e w = w i j o l d − η ∂ E ∂ w i j w_{ij}^{new}= w_{ij}^{old} - \eta \frac{\partial E}{\partial w_{ij}} wijnew=wijoldηwijE

其中 n \text{n} n 是学习率,控制梯度收敛的快慢。

深度学习几个基础概念

  1. Eopch:使用数据对模型进行完整训练,一次
  2. Batch:使用训练集中小部分样本对模型权重进行方向传播更新
  3. Iteration:使用一个 Batch 数据对模型进行一次参数更新

不同梯度下降法区别

梯度下降方式Training Set SizeBatch SizeNumber of Batches
BGD(批量梯度下降)NN1
SGD(随机梯度下降)N1N
Mini-Batch(小批量梯度下降)NBN / B + 1
  • 批量梯度下降法是最原始的形式,它是指在每一次迭代时使用所有样本来进行梯度的更新;
  • 随机梯度下降法不同于批量梯度下降,随机梯度下降是在每次迭代时使用一个样本来对参数进行更新;
  • 小批量梯度下降相当于是前两个的总和。
  • 具体优缺点:后面更新**李沐老师《动手学习深度学习》**会有更详细解释。

前向传播

前向传播就是输入x,在神经网络中一直向前计算,一直到输出层为止,图像如下:

在这里插入图片描述

在网络的训练过程中经过前向传播后得到的最终结果跟训练样本的真实值总是存在一定误差,这个误差便是损失函数。想要减小这个误差,就用损失函数Loss,从后往前,依次求各个参数的偏导,这就是反向传播.

反向传播算法

简介

BP 算法也叫做误差反向传播算法,它用于求解模型的参数梯度,从而使用梯度下降法来更新网络参数。它的基本工作流程如下:

  1. 通过正向传播得到误差,正向传播指的是数据从输入–> 隐藏层–>输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而调整神经网络参数,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新

案例介绍原理

  • 网络结构:
    • 输入层:两个神经元
    • 隐藏层:一共两层,每一层两个神经元
    • 输出层:输出两个值
  • 输入:
    • i1:0.05,i2:0.10
  • 目标值:
    • 0.01,0.99
  • 初始化权重:
    • w1: 0.15
    • w2: 0.20
    • w3: 0.25
    • w4:0.30
    • w5:0.40
    • w6:0.45
    • w7:0.50
    • w8:0.55

在这里插入图片描述

  1. 由下向上看,最下层绿色的两个圆代表两个输入值
  2. 右侧的8个数字,最下面4个表示 w1、w2、w3、w4 的参数初始值,最上面的4个数字表示 w5、w6、w7、w8 的参数初始值
  3. b1 值为 0.35,b2 值为 0.60
  4. 预测结果分别为: 0.7514、0.7729

接下来就是梯度更新

方向传播中梯度跟新流程

w5、w7为例:

在这里插入图片描述

计算出偏导后,运用梯度下降法进行更新,这里学习率为0.5:

在这里插入图片描述

接下来跟新w1:

在这里插入图片描述

梯度更新:

在这里插入图片描述

代码展示

该代码就是将上面的神经网络代码实现,牢记对神经网络理解很有用

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # 定义神经网络结构
        self.linear1 = nn.Linear(2, 2)
        self.linear2 = nn.Linear(2, 2)
        
        # 神经网络参数初始化,设置权重和偏置
        self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])  # 权重
        self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])
        self.linear1.bias.data = torch.tensor([0.35, 0.35])   # 偏置
        self.linear2.bias.data = torch.tensor([0.60, 0.60])
        
    def forward(self, x):
        
        x = self.linear1(x)     # 经过第一次线性变换
        x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换
        x = self.linear2(x)     # 经过第二层线性变换
        x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换
        
        return x
        
if __name__ == '__main__':
    
    # 输入变量和目标变量
    inputs = torch.tensor([0.05, 0.10])
    target = torch.tensor([0.01, 0.99])
    
    # 创建神经网络
    net = Net()
    # 训练
    outputs = net(inputs)
    
    # 计算误差,梯度下降法,*********  公式实现  ***********
    loss = torch.sum((target - outputs) ** 2) / 2
    
    # 计算梯度下降
    optimizer = optim.SGD(net.parameters(), lr=0.5)
    
    # 清除梯度
    optimizer.zero_grad()
    
    # 反向传播
    loss.backward()
    
    # 打印 w5、w7、w1 的梯度值
    print(net.linear1.weight.grad.data)
    print(net.linear2.weight.grad.data)
    
    # 更新权重,梯度下降更新
    optimizer.step()
    
    # 打印网络参数
    print(net.state_dict())

输出:

tensor([[0.0004, 0.0009],
        [0.0005, 0.0010]])
tensor([[ 0.0822,  0.0827],
        [-0.0226, -0.0227]])
OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996],
        [0.2498, 0.2995]])), ('linear1.bias', tensor([0.3456, 0.3450])), ('linear2.weight', tensor([[0.3589, 0.4087],
        [0.5113, 0.5614]])), ('linear2.bias', tensor([0.5308, 0.6190]))])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2235676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【大模型系列】Grounded-VideoLLM(2024.10)

Paper:https://arxiv.org/pdf/2410.03290Github:https://github.com/WHB139426/Grounded-Video-LLMHuggingface:https://huggingface.co/WHB139426/Grounded-Video-LLMAuthor:Haibo Wang et al. 加州大学,复旦 动机&a…

IDEA2024下安装kubernetes插件并配置进行使用

【1】安装插件 其实2024.2.3下默认已经安装了kubernetes插件,如果你发现自己IDEA中没有,在市场里面检索并下载即可。 【2】kubernetes配置 ① 前置工作 首先你要准备一个config文件和一个kubectl.exe 。 config文件类似如下: apiVersi…

onnx-web + yolov8n 在视频流里做推理

顺着我上一篇文章 使用onnxruntime-web 运行yolov8-nano推理 继续说,有朋友在问能不能接入 视频流动,实时去识别物品。 首先使用 getUserMedia 获取摄像头视频流 getUserMedia API 可以访问设备的摄像头和麦克风。你可以使用这个 API 获取视频流&#…

Python练习11

Python日常练习 题目: 编写一个石头剪刀布游戏,该程序要求完成如下功能: (1) 显示游戏规则,提醒用户输入一个1-3的整数或者直接回车。 用户输入回车时游戏结束。 用户输入不合法(包括输入的…

航展畅想:从F35机载软件研发来看汽车车载软件研发

两款经典战机的机载软件 F-22和F-35战斗机的研制分别始于1980年代和1990年代末,F-22项目在1981年启动,主要由洛克希德马丁(Lockheed Martin)和波音公司(Boeing)合作开发,以满足美军“先进战术战…

实践出真知:MVEL表达式empty的坑

目录标题 背景为什么呢?验证下empty的含义case1case2case3 结论具体解释: 背景 //是否白名单 if(goodInfo.?isWhite ! empty){showList.add(["label": "是否白名单","value":["text":(goodInfo.?isWhite tr…

RPC核心实现原理

目录 一、基本原理 二、详细步骤 三、额外考虑因素 RPC(Remote Procedure Call,远程过程调用)是一种计算机通信协议,也是一种用于实现分布式系统中不同节点之间进行通信和调用的技术。其实现原理主要可以分为以下几个步骤&…

Kaggle生物信息学挑战:酶稳定性预测大赛

背景介绍 酶的稳定性是影响其实际应用的关键因素之一。通过定点突变可以改善酶的稳定性,但实验筛选稳定性突变体的成本较高。预测突变对酶稳定性的影响,加速筛选稳定性更高的酶突变体。 概念解释 X 残基:假设 它用 红色表示 , Y 残基:假设…

【开发工具——依赖管理工具——Maven】

1. Maven介绍 Apache Maven 的本质是一个软件项目管理和理解工具。基于项目对象模型 (Project Object Model,POM) 的概念,Maven 可以从一条中心信息管理项目的构建、报告和文档。 对于开发者来说,Maven 的主要作用主要有 3 个: …

vue3+vite搭建脚手架项目本地运行electron桌面应用

1.搭建脚手架项目 搭建Vue3ViteTs脚手架-CSDN博客 2.创建完项目后,安装所需依赖包 npm i vite-plugin-electron electron26.1.0 3.根目录下创建electron/main.ts electron/main.ts /** electron/main.ts */import { app, BrowserWindow } from "electron&qu…

鸿蒙ArkTS中的获取网络数据

一、通过web组件加载网页 在C/S应用程序中,都有网络组件用于加载网页,鸿蒙ArkTS中也有类似的组件。   web组件,用于加载指定的网页,里面有很多的方法可以调用,虽然现在用得比较少,了解还是必须的。   演…

无人车之路径规划篇

无人车的路径规划是指在一定的环境模型基础上,给定无人车起始点和目标点后,按照性能指标规划出一条无碰撞、能安全到达目标点的有效路径。 一、路径规划的重要性 路径规划对于无人车的安全、高效运行至关重要。它不仅能够提高交通效率,减少交…

C语言心型代码解析

方法一 心型极坐标方程 爱心代码你真的理解吗 笛卡尔的心型公式&#xff1a; for (y 1.5; y > -1.5; y - 0.1) for (x -1.5; x < 1.5; x 0.05) 代码里面用了二个for循环&#xff0c;第一个代表y轴&#xff0c;第二个代表x轴 二个增加的单位不同&#xff0c;能使得…

11月7日(内网横向移动(二))

利用系统服务 SCShell SCShell是一款利用系统服务的无文件横向移动工具。与传统的创建远程服务的方法不同&#xff0c;SCShell利用提供的用户凭据&#xff0c;通过ChangeServiceConfigA API修改远程主机上的服务配置&#xff0c;将服务的二进制路径名修改为指定的程序或攻击载…

【YOLOv11[基础]】目标检测OD | 导出ONNX模型 | ONN模型推理以及检测结果可视化 | python

本文将导出YOLO11.pt模型对应的ONNX模型,并且使用ONNX模型推理以及结果的可视化。话不多说,先看看效果图吧!!! 目录 一 导出ONNX模型 二 推理及检测结果可视化 1 代码 2 效果图

力扣—不同路径(路径问题的动态规划)

文章目录 题目解析算法原理代码实现题目练习 题目解析 算法原理 状态表示 对于这种「路径类」的问题&#xff0c;我们的状态表示⼀般有两种形式&#xff1a; i. 从[i, j] 位置出发。 ii. 从起始位置出发&#xff0c;到[i, j] 位置。 这⾥选择第⼆种定义状态表⽰的⽅式&#xf…

传统RAG流程;密集检索器,稀疏检索器:中文的M3E

目录 传统RAG流程 相似性搜索中:神经网络的密集检索器,稀疏检索器 密集检索器 BGE系列模型 text-embedding-ada-002模型 M3E模型 稀疏检索器 示例一:基于TF-IDF的稀疏检索器 示例二:基于BM25的稀疏检索器 稀疏检索器的特点与优势 传统RAG流程 相似性搜索中:神经…

Javascript 获取设备信息 工具

JS获取设备信息(操作系统信息、地理位置、UUID、横竖屏状态、设备类型、网络状态、浏览器信息、生成浏览器指纹、日期、生肖、周几等) Get Device Info Online GitHub - skillnull/DeviceJs: JS获取设备信息(操作系统信息、地理位置、UUID、横竖屏状态、设备类型、网络状态、浏…

【数据仓库】

1、概述 数据仓库&#xff0c;英文名称为Data Warehouse&#xff0c;可简写为DW或DWH。 数据仓库是企业中用于集中存储和管理来自多个源的经过处理和组织的数据的系统。它为复杂的查询和分析提供了一个优化的环境&#xff0c;使得用户能够执行高级数据分析&#xff0c;以支持…

成都栩熙酷网络科技有限公司抖音小店探索

在数字经济的浪潮中&#xff0c;电商行业正以前所未有的速度蓬勃发展&#xff0c;而短视频平台的崛起更是为这一领域注入了新的活力。成都栩熙酷网络科技有限公司&#xff08;以下简称“栩熙酷”&#xff09;&#xff0c;作为这股浪潮中的佼佼者&#xff0c;凭借其敏锐的市场洞…