【深度学习数学基础之线性代数】研究使用链式法则进行反向传播的求导算法

news2024/11/16 18:41:04

链式法则

简单的说链式法则就是原本y对x求偏导,但是由于过程较为复杂,我们需要将函数进行拆分,通过链式进行分别求导,这样会使整个计算更为简单。

假设f = k ( a + b c ) f = k(a + bc)f=k(a+bc)

通俗来说,链式法则表明,知道z相对于y的瞬时变化率和y相对于x的瞬时变化率,就可以计算z相对于x的瞬时变化率作为这两个变化率的乘积。其实就是求复合函数导数的过程。
在这里插入图片描述
用链式法则(将这些梯度表达式链接起来相乘。)分别对变量a、b、c进行求导:
在这里插入图片描述

前向传播

前向传播(forward propagation或forward pass) 指的是:按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

对于中间变量:
在这里插入图片描述
W为参数权重,b为函数偏置,函数结果经过激活函数C(常见的激活函数有Sigmoid、tanh、ReLU)
在这里插入图片描述
假设损失函数为l,真实值为h,我们可以计算单个数据样本的损失项,
在这里插入图片描述
在不考虑优化函数,单个神经元从输入到输出结束,后面需要对误差进行反向传播,更新权值,重新计算输出。

反向传播

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法。 简言之,该方法根据微积分中的链式规则,按相反的顺序从输出层到输入层遍历网络。 该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。

(1)梯度下降
在说反向传播算法前,先简单了解一些梯度下降,对于损失函数(这里假设损失是MSE,即均方误差损失)
在这里插入图片描述
在这里插入图片描述
除此以外还有一些再次基础上优化的其他梯度下降方法: 小批量样本梯度下降(Mini Batch GD)、随机梯度下降(Stochastic GD)等。
(2)反向传播
反向传播计算损失函数相对于单个输入-输出示例的网络权重的梯度,为了说明这个过程,使用了具有2个输入和1个输出的2层神经网络,如下图所示:
在这里插入图片描述
不考虑优化算法,单个神经结构如下图所示,第一个单元将权重系数和输入信号的乘积相加。第二单元为神经元激活函数(反向传播需要在网络设计时激活函数可微的),如下图所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
根据链式法则求出所有的更新后的权重W WW梯度,偏值使用同样的方法。通过反向传播,计算损失函数与当前神经元权重的最陡下降方向。然后,可以沿最陡下降方向修改权重,并以有效的方式降低损失。
(3)反向传播代码

def optimize(w, b, X, Y, num_iterations, learning_rate):
    costs = []

    for i in range(num_iterations):

        # 梯度更新计算函数
        grads, cost = propagate(w, b, X, Y)

        # 取出两个部分参数的梯度
        dw = grads['dw']
        db = grads['db']

        # 按照梯度下降公式去计算
        w = w - learning_rate * dw
        b = b - learning_rate * db

        if i % 100 == 0:
            costs.append(cost)
        if i % 100 == 0:
            print("损失结果 %i: %f" % (i, cost))
            print(b)
    params = {"w": w,
              "b": b}
              
    grads = {"dw": dw,
             "db": db}
    return params, grads, costs

def propagate(w, b, X, Y):
    m = X.shape[1]

    # 前向传播
    A = basic_sigmoid(np.dot(w.T, X) + b)
    cost = -1 / m * np.sum(Y * np.log(A) + (1 - Y) * np.log(1 - A))

    # 反向传播
    dz = A - Y
    dw = 1 / m * np.dot(X, dz.T)
    db = 1 / m * np.sum(dz)

    grads = {"dw": dw,
             "db": db}

    return grads, cost

参考

https://blog.csdn.net/Peyzhang/article/details/125479563

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/173498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

宝贝代码部署笔记

记录前后端分离项目部署到云服务器 文章目录1. 启动数据库2. 创建数据库3. 阿里云开放后端项目端口4. 运行SQL文件5. 打包前端文件6. 服务端创建文件夹7. 打包后端jar包8. 安装配置Nginx服务器9. 启动Tomcat10. 项目文件上传部署1. 启动数据库 使用命令cd /opt/mysql/support-…

Generative Adversarial Network (GANs) 对抗神经网络 基础 第一部分

Generative Adversarial Network (GANs) 对抗神经网络 基础 第一部分 定义 Definition Discriminative model: Classifier 判别器Generative model: (random set of value , class) as input -> Create new features X 生成器 对抗神经网络模型主要就是通过判…

android的system域解耦

google很早在为此做准备,要求所有设备能够刷GSI(通用系统镜像),并跑过XTS测试。动态分区解耦方案如上图。一、分区描述单一系统映像 (SSI)。包含system和system_ext图像的新概念图像。当这些分区对于一组目标设备是通用的时&#…

二叉树(一)

先简单了解一下树的概念,从而进一步了解二叉树,最后进行代码测试。树概念及结构(了解)在认识而二叉树之前我们首先了解一下树的概念。树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。…

图扑喜获第十一届中国创新创业大赛全国赛优秀奖!

在近期结束的第十一届中国创新创业大赛全国赛(新一代信息技术)比赛中,图扑软件喜获成长组优秀奖。这是继“创客中国”创新创业大赛优胜奖荣誉后,再一次对图扑软件在新一代信息技术领域专业的认可!大赛围绕新一代信息技…

电机行业EDI案例分析

项目背景 J公司需要与国内某知名电机品牌Z公司建立EDI对接,J公司选择通过知行EDI系统与Z公司建立AS2连接,通过AS2接收Z公司发送过来的ORDERS(采购订单)和ORDCHG(采购订单变更),并根据发接收到的…

Linux常见命令 15 - 权限管理命令 chmod

1. chmod 语法 chmod为修改文件/文件夹权限,有以下两种操作,其中-R表示递归修改 chmod {ugoa} {-} {rwx} [文件或目录] -Rchmod [mode421] [文件或目录] -R 2. chmod {ugoa} {-} {rwx} [文件或目录] -R u:文件或目录的所有者,g…

C++设计模式(5)——观察者模式

观察者模式 亦称: 事件订阅者、监听者、Event-Subscriber、Listener、Observer 意图 观察者模式是一种行为设计模式, 允许你定义一种订阅机制, 可在对象事件发生时通知多个 “观察” 该对象的其他对象。 问题 假如你有两种类型的对象&a…

概论第6章_正态总体的抽样分布_卡方分布_F分布_t分布

一 卡方分布 定义 设X1,X2,...,XnX_1, X_2,..., X_nX1​,X2​,...,Xn​ 独立同分布于标准正态分布N(0, 1), 则χ2X12...Xn2\chi^2X_1^2 ... X_n^2χ2X12​...Xn2​的分布称为 自由度为 n 的χ2\chi^2χ2分布, 记为χ2\chi^2χ2 ~ χ2(n)\chi^2(n)χ2(n) χ2\chi…

Python爬虫序章---爬取csdn作者排行榜

上篇文章介绍了requests库获取数据的基本方法,本篇文章利用自动化测试工具selenium进行数据抓取,也会对代码部分进行详细解释,以便小伙伴们能够更加理解和上手。 一.selenium技术介绍 Selenium是最广泛使用的开源 Web UI(用户界面…

windows11远程连接Ubuntu桌面

如何通过Windows 11远程连接Ubuntu桌面 在日常开发过程中,很多时候是这样一种情形:一台装了Ubuntu系统的计算机作为远程服务器,开发人员则使用带Windows系统的计算机去连服务器进行开发。 连接服务器的方式有很多种,最简单的就是…

图扑软件荣获第十一届中国创新创业大赛全国赛优秀奖!

在近期结束的第十一届中国创新创业大赛全国赛(新一代信息技术)比赛中,图扑软件喜获成长组优秀奖。这是继“创客中国”创新创业大赛优胜奖荣誉后,再一次对图扑软件在新一代信息技术领域专业的认可!大赛围绕新一代信息技…

DW动手学数据分析Task4:数据可视化

目录1 了解matplotlib2 可视化图案3 matplotlib用法4 了解Seaborn1 了解matplotlib Matplotlib: 是 Python 的绘图库, 它可与 NumPy 一起使用,提供了一种有效的 MatLab 开源替代方案。 2 可视化图案 基本可视化团及场景使用 柱状图 场景&am…

如何实现机械臂的正解计算?

1. 机械臂运动学介绍 机械臂运动学 机器人运动学就是根据末端执行器与所选参考坐标系之间的几何关系,确定末端执行器的空间位置和姿态与各关节变量之间的数学关系。包括正运动学(Forward Kinematics)和逆运动学(Inverse Kinematic…

在线支付系列【3】支付安全之对称和非对称加密

有道无术,术尚可求,有术无道,止于术。 文章目录前言信息安全加密机制核心概念对称加密非对称加密JCE对称加解密1. 创建密钥2. 加密3. 解密非对称加解密1. 创建密钥2. 公钥加密3. 私钥解密前言 支付和金钱挂钩,支付安全显得尤为重…

域名被封的解决方案

如果您的域名被封,可能是域名下网站存在非法信息或敏感内容,导致被GFW屏蔽。 封禁原因及解决方案如下: 1. 域名解析的IP纳入黑名单 这种情况只需更换IP即可恢复正常,但换IP也只能解除一时的燃眉之急,一旦又被GFW发现很…

MySQL进阶——视图(view)

1. 视图 1.1 视图介绍 视图(View)是一种虚拟存在的表。视图中的数据并不在数据库中实际存在,行和列数据来自定义视图的查询中使用的表,并且是在使用视图时动态生成的。 通俗的讲,视图只保存了查询的SQL逻辑&#xf…

MySQL详细教程,2023年硬核学习路线

文章目录前言1. 数据库的相关概念1.1 数据1.2 数据库1.3 数据库管理系统1.4 数据库系统1.5 SQL2. MySQL数据库2.1 MySQL安装2.2 MySQL配置2.2.1 添加环境变量2.2.2 新建配置文件2.2.3 初始化MySQL2.2.4 注册MySQL服务2.2.5 启动MySQL服务2.3 MySQL登录和退出2.4 MySQL卸载2.5 M…

【Python】如何为Matplotlib图像添加标签?

一、添加文本标签 plt.text() 用于在绘图过程中,在图像上指定坐标的位置添加文本。需要用到的是plt.text()方法。 其主要的参数有三个: plt.text(x, y, s)其中x、y表示传入点的x和y轴坐标。s表示字符串。 需要注意的是,这里的坐标&#x…

基于Springboot+Mybatis+mysql+vue电影院在线售票系统

基于SpringbootMybatismysqlvue电影院在线售票系统一、系统介绍二、所用技术三、功能展示1.主页(普通用户)2.影院管理员相关功能(影院管理员)3.系统管理权限(管理员)四、获取源码一、系统介绍 电影院网上售票系统拥有三种角色,用户、工作人员…