【深度学习 | 反向传播】释放反向传播的力量: 让训练神经网络变得简单

news2024/10/7 18:30:40

在这里插入图片描述

🤵‍♂️ 个人主页: @AI_magician
📡主页地址: 作者简介:CSDN内容合伙人,全栈领域优质创作者。
👨‍💻景愿:旨在于能和更多的热爱计算机的伙伴一起成长!!🐱‍🏍
🙋‍♂️声明:本人目前大学就读于大二,研究兴趣方向人工智能&硬件(虽然硬件还没开始玩,但一直很感兴趣!希望大佬带带)

在这里插入图片描述

该文章收录专栏
[✨— 《深入解析机器学习:从原理到应用的全面指南》 —✨]

反向传播算法

反向传播算法是一种用于训练神经网络的常用优化算法。它通过计算损失函数对每个参数的梯度,然后根据这些梯度更新参数值,以使得神经网络能够逐步调整和改进其预测结果。

下面是一个详细解释反向传播算法的步骤:

  1. 前向传播:从输入开始,通过神经网络进行前向传播。每个节点都会将输入加权求和,并应用非线性激活函数(如ReLU、Sigmoid等),生成输出。

  2. 定义损失函数:选择合适的损失函数来衡量模型预测与实际标签之间的差异。例如,在分类问题中可以使用交叉熵损失或均方误差损失。

  3. 反向传播:从输出层开始,计算每个节点对于最终预测结果的贡献程度,并将该信息沿着网络进行反向传播(在最后一层输出开始,以计算损失函数)。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bzEcrRdT-1691847459841)(classical algorithm.assets/image-20230812141415318.png)]

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G1Af45cB-1691847459842)(classical algorithm.assets/image-20230812142623880.png)]

    这里我们以三个全连接神经元为例。 整体导数通过链式法则链接,公式如下:
    ∂ C 0 ∂ w ( L ) = ∂ z ( L ) ∂ w ( L ) ∂ a ( L ) ∂ z ( L ) ∂ C 0 ∂ a ( L ) \frac{\partial C_{0}}{\partial w^{(L)}}=\frac{\partial z^{(L)}}{\partial w^{(L)}} \frac{\partial a^{(L)}}{\partial z^{(L)}} \frac{\partial C_{0}}{\partial a^{(L)}} w(L)C0=w(L)z(L)z(L)a(L)a(L)C0
    这是损失函数与最后一个神经元 W W W参数的偏导数(偏置 b i a s bias bias同样样),其中我们可以看到“一同激活的神经元联系在一起”,上一个神经元的激活值就是下一个神经元的导数
    C 0 = ( a ( L ) − y ) 2 z ( L ) = w ( L ) a ( L − 1 ) + b ( L ) a ( L ) = σ ( z ( L ) ) ∂ C 0 ∂ a ( L ) = 2 ( a ( L ) − y ) ∂ a ( L ) ∂ z ( L ) = σ ′ ( z ( L ) ) ∂ z ( L ) ∂ w ( L ) = a ( L − 1 ) ∂ C 0 ∂ w ( L ) = 2 ( a ( L ) − y ) ∗ σ ′ ( z ( L ) ) ∗ a ( L − 1 ) \begin{array}{rlrl} C_{0} & =\left(a^{(L)}-y\right)^{2} \\ z^{(L)} & =w^{(L)} a^{(L-1)}+b^{(L)} \\ a^{(L)} & =\sigma\left(z^{(L)}\right) \\ \frac{\partial C_0}{\partial a^{(L)}} & =2\left(a^{(L)}-y\right) & \\ \frac{\partial a^{(L)}}{\partial z^{(L)}} & =\sigma^{\prime}\left(z^{(L)}\right) \\ \frac{\partial z^{(L)}}{\partial w^{(L)}} & =a^{(L-1)} & \\ \frac{\partial C_{0}}{\partial w^{(L)}} & = 2\left(a^{(L)}-y\right) * \sigma^{\prime}\left(z^{(L)}\right) * a^{(L-1)} \end{array} C0z(L)a(L)a(L)C0z(L)a(L)w(L)z(L)w(L)C0=(a(L)y)2=w(L)a(L1)+b(L)=σ(z(L))=2(a(L)y)=σ(z(L))=a(L1)=2(a(L)y)σ(z(L))a(L1)
    再反向一个神经元,公式如下:
    ∂ a ( L − 1 ) ∂ z ( L − 1 ) = σ ′ ( z ( L ) ) ∂ z ( L − 1 ) ∂ w ( L − 1 ) = a ( L − 2 ) ∂ C 0 ∂ w ( L ) = σ ′ ( z ( L ) ) ∗ a ( L − 2 ) \begin{array}{rlrl} \frac{\partial a^{(L-1)}}{\partial z^{(L-1)}} & =\sigma^{\prime}\left(z^{(L)}\right) \\ \frac{\partial z^{(L-1)}}{\partial w^{(L-1)}} & =a^{(L-2)} & \\ \frac{\partial C_{0}}{\partial w^{(L)}} & = \sigma^{\prime}\left(z^{(L)}\right) * a^{(L-2)} \end{array} z(L1)a(L1)w(L1)z(L1)w(L)C0=σ(z(L))=a(L2)=σ(z(L))a(L2)
    此时该神经元的梯度就是上一个神经元的激活值与该神经元输入与激活输出的局部梯度相乘,一直反向传播到最开始的神经元就可以得到最早期的神经元输出。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b87ffXCM-1691847459843)(classical algorithm.assets/image-20230812151645267.png)]

    这是三个单个神经元的过程,我们把他推广到多个神经元全连接:(其实只不过多了很多下标,整体流程是一致的,参数是矩阵形式,损失函数同时由多个神经元共同影响累加,整体以层为单位累加求和)

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JXVf4tKZ-1691847459843)(classical algorithm.assets/image-20230812160056339.png)]

    我们精炼成两个关键步骤:

    • 计算局部梯度:针对每个节点,计算其相对于加权输入和输出之间关系(即激活函数)的偏导数(参数)
    • 链式规则:利用链式规则(也称为复合函数求导法则),将局部梯度(激活函数梯度)乘以上游节点对该节点的贡献(加权输入梯度),以计算上游节点的梯度。(参数 w w w与激活输出的梯度)
  4. 计算参数梯度:根据反向传播过程中计算得到的梯度信息,对每个参数进行偏导数计算。这可以通过将网络中各层的局部梯度与输入值(或前一层输出)相乘来实现。

  5. 更新参数:使用优化器(如随机梯度下降)根据参数的负梯度方向和学习率大小来更新模型中的权重和偏置项。

通过迭代执行以上步骤,不断调整神经网络的参数,使其能够更好地拟合训练数据,并在测试数据上表现出良好泛化能力。到这里,你就弄懂神经网络重要的部分,反向传播,以下图片有两种数学公式形式表示损失函数的导数,一个三个导函数的累积,一个是MSE的求导

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OYy5oXSh-1691847459844)(classical algorithm.assets/image-20230812161006209.png)]

如果你希望进一步了解反向传播算法及其相关细节,推荐以下资源:

  1. 视频教程: Backpropagation in Neural Networks (https://www.youtube.com/watch?v=Ilg3gGewQ5U) 3Blue1Brown !!
  2. 博客文章: A Gentle Introduction to Backpropagation (LSTM) (https://machinelearningmastery.com/gentle-introduction-backpropagation/)
  3. 课程笔记: CS231n Convolutional Neural Networks for Visual Recognition (http://cs231n.github.io/optimization-2/)

我们可以思考以下,如果在LSTM中等特殊改进神经单元,反向传播又是如何运行的呢?

答案是一样的: 我们的输出是 细胞状态的正切激活 * 输入数据和隐藏状态拼接的激活函数, 由此根据每一个时间步链式求导每一个权重矩阵,在每一个矩阵中再次通过累加求和导数,以此类推得到梯度,通过偏导求和得到整体矩阵,参数更新)

下一章我们将会讲解梯度消失和爆炸,通过了解反向传播,我们可以更加清楚其原理
在这里插入图片描述

						  🤞到这里,如果还有什么疑问🤞
					🎩欢迎私信博主问题哦,博主会尽自己能力为你解答疑惑的!🎩
					 	 🥳如果对你有帮助,你的赞是对博主最大的支持!!🥳

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/870659.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

快速上手PyCharm指南

PyCharm简介 PyCharm是一种Python IDE(Integrated Development Environment,集成开发环境),带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具,比如调试、语法高亮、项目管理、代码跳转、智能提示、自动…

为什么String要设计成不可变的

文章目录 一、前言二、缓存hashcode缓存 三、性能四、安全性五、线程安全 一、前言 为什么要将String设计为不可变的呢?这个问题一直困扰着许多人,甚至有人直接向Java的创始人James Gosling提问过。在一次采访中,当被问及何时应该使用不可变…

Axure RP9小白安装教程

第一步: 打开:Axure中文学习网 第二步: 鼠标移动软件下载,点击Axure RP 9下载既可 第三步: 注意:Axure RP 9 MAC正式版为苹果版本,Axure RP 9 WIN正式版为Windows版本 中文汉化包&#xff…

春秋云镜 CVE-2022-0410

春秋云镜 CVE-2022-0410 WordPress plugin The WP Visitor Statistics SQLI 靶标介绍 WordPress plugin The WP Visitor Statistics (Real Time Traffic) 5.6 之前存在SQL注入漏洞,该漏洞源于 refUrlDetails AJAX 不会清理和转义 id 参数。 登陆账户:u…

windows环境下打印机无法打印的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

js 使用 AES对字节数组进行加密

AES 分组密码算法和所有常见操作模式(CBC、CFB、CTR、ECB 和 OFB。 js实现 aes 对字符串进行加密解密,网上有很多实现demo,但是对字节数组进行加密,找了很久都没找到合适的代码。我这次开发使用的场景是微信小程序直接解析ble协议…

力求超越ChatGPT,谷歌加入人工智能新项目

受到新贵OpenAI的威胁,谷歌承诺快速跟踪新的人工智能项目。 ChatGPT吓坏了谷歌。周五,纽约时报据报告的谷歌创始人拉里佩奇和谢尔盖布林与公司高管就OpenAI召开了几次紧急会议新聊天机器人谷歌认为这可能威胁到其价值1490亿美元的搜索业务。 由OpenAI创…

String为什么设计成不可变的?

为什么要把 String 设计成不可变的呢?有什么好处呢? 这个问题,困扰过很多人,甚至有人直接问过 Java 的创始人 James Gosling。 在一次采访中 James Gosling 被问到什么时候应该使用不可变变量,他给出的回答是&#xff…

服务器安全维护注意事项有哪些?

服务器的安全关系着公司整个网络以及所有数据的安全,我们该如何做好服务器后续的安全维护呢?河南亿恩科技股份有限公司,专注服务器托管23年,不仅是国内专业的互联网基础应用服务提供商之一,还是国家工信部认定的综合电信服务运营…

C语言 野指针

目录 一、野指针 (一)概念 (二)野指针的分类 (三)指针未初始化 (四) 指针越界访问 (五)指针指向的空间释放 二、避免野指针 (一&#xff0…

MATLAB图论合集(一)基本操作基础

本帖总结一些经典的图论问题,通过MATLAB如何计算答案。近期在复习考研,以此来巩固一下相关知识——虽然考研肯定不能用MATLAB代码哈哈,不过在实际应用中解决问题还是很不错的,比C易上手得多~ 图论中的图(Graph&#xf…

【C/C++】用return返回一个函数

2023年8月13日&#xff0c;周日早上 我的第一篇使用了动态图的博客 #include<iostream> #include<windows.h>int loop(){int i0;while(1){Sleep(1000);std::cout<<i<<std::endl;}return 1; }int main(){std::cout<<"程序开始"<<…

【Bert101】最先进的 NLP 模型解释【01/4】

0 什么是伯特&#xff1f; BERT是来自【Bidirectional Encoder Representations from Transformers】变压器的双向编码器表示的缩写&#xff0c;是用于自然语言处理的机器学习&#xff08;ML&#xff09;模型。它由Google AI Language的研究人员于2018年开发&#xff0c;可作为…

【Bert101】变压器模型背后的复杂数学【02/4】

一、说明 众所周知&#xff0c;变压器架构是自然语言处理&#xff08;NLP&#xff09;领域的突破。它克服了 seq-to-seq 模型&#xff08;如 RNN 等&#xff09;无法捕获文本中的长期依赖性的局限性。变压器架构被证明是革命性架构&#xff08;如 BERT、GPT 和 T5 及其变体&…

Java经典OJ题 回文

OJ题 回文 1.题目2.判断范围是否合理2.1 普通代码2.2 优化代码 3.判断回文的关系代码4.总代码 1.题目 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后&#xff0c;短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 字母和数字都属于字母数字字…

Multipass虚拟机设置局域网固定IP同时实现快速openshell的链接

本文只介绍在windows下实现的过程&#xff0c;Ubuntu采用22.04 安装multipass后&#xff0c;在卓面右下角Open shell 就可以链接默认实例Primary&#xff0c;当然如果你有多个虚拟机&#xff0c;可以针对不同内容单独建立终端的链接&#xff0c;而本文仅仅用Primary来说明。 …

孤儿进程与僵尸进程

进程退出 关于进程退出有两个函数 exit和 _exit&#xff1a;其主要差别是在于是否直接退出。 其流程主要区别如下&#xff1a; 孤儿进程&#xff08;不存在危害&#xff09; 父进程运行结束&#xff0c;但子进程还在运行&#xff08;未运行结束&#xff09;&#xff0c;这…

VS中.cu文件属性中项目类型没有cuda

问题 VS中.cu文件属性中项目类型没有cuda 解决办法 右键项目“自定义” ![请添加图片描述](https://img-blog.csdnimg.cn/9717093332604b5982e67b15108c9ec8.png 再回到cu文件右键属性就会出现cuda选项了 请添加图片描述

SQL 基础查询

msyql 不区分大小写 DDL 数据定义语言 查询 show databases create database db01 创建数据库 create database if not exists db01 创建数据库 删除数据库 drop database if exists db01 使用数据库 use 数据库名 CREATE TABLE tb_user(id int PRIMARY KEY COMMENT i…