循环神经网络 - 参数学习之随时间反向传播算法

news2025/4/16 22:56:14

本文中,我们以同步的序列到序列模式为例来介绍循环神经网络的参数学习。

循环神经网络中存在一个递归调用的函数 𝑓(⋅),因此其计算参数梯度的方式和前馈神经网络不太相同。在循环神经网络中主要有两种计算梯度的方式:随时间反向传播(BPTT)算法和实时循环学习(RTRL)算法。

本文我们来学习随时间反向传播算法。

BPTT 算法将循环神经网络看作一个展开的多层前馈网络,其中“每一层”对 应循环网络中的“每个时刻”。这样,循环神经网络就可以按照前馈网络中的反向传播算法计算参数梯度。在“展开”的前馈网络中,所有层的参数是共享的,因此参数的真实梯度是所有“展开层”的参数梯度之和。

一、数学推导:

以随机梯度下降为例,给定一个训练样本 (𝒙, 𝒚),其中 𝒙1∶𝑇 =  (𝒙1, ⋯ , 𝒙𝑇 )为长度是𝑇的输入序列,𝑦1∶𝑇 =(𝑦1,⋯,𝑦𝑇)是长度为𝑇的标签序列。即在每个时 刻 𝑡,都有一个监督信息 𝑦𝑡 ,我们定义时刻 𝑡 的损失函数为:

其中 𝑔(𝒉𝑡) 为第 𝑡 时刻的输出,L 为可微分的损失函数,比如交叉熵。那么整个序列的损失函数为:

整个序列的损失函数 L 关于参数 𝑼 的梯度为

即每个时刻损失 L𝑡 对参数 𝑼 的偏导数之和。

基于参数 𝑼 和隐藏层在每个时刻 𝑘(1 ≤ 𝑘 ≤ 𝑡) 的净输入有关,通过数学推导(推导过程比较复杂,这里略过,大家着重掌握公式),可以得出:

计算偏导数

得到整个序列的损失函数 L 关于参数 𝑼 的梯度:

同理可得,L 关于权重 𝑾 和偏置 𝒃 的梯度为:

其中,类似前馈神经网络中的误差项为:

由上可以看出误差项,时刻k的𝛿𝑡,𝑘可以由时刻k+1的𝛿𝑡,𝑘+1得出,即所谓的反向传播。

下图给出了误差项随时间进行反向传播算法的示例:

二、进一步理解随时间反向传播算法BPTT

BPTT 的具体实现核心在于将 RNN 在时间维度上“展开”,从而把整个循环网络视作一个深层的前馈网络,然后利用反向传播算法计算每个时间步的梯度。以下是关键步骤:

(一)前向传播(Forward Pass)

在前向传播阶段,模型从初始隐藏状态开始,按时间顺序依次处理输入序列的每个时间步。在每个时间步 t 中,RNN 会计算出当前隐藏状态 ht:

同时,根据隐藏状态产生输出:

这些中间状态和输出都被存储下来,供之后的反向传播使用。

(二)损失计算

对于整个序列的输出,我们会计算一个总体损失 LL,它通常是所有时间步损失 LtL_t 的求和或平均:

例如,在一个回归或分类任务中,可能使用均方误差交叉熵作为每个时间步的损失。

(三)时间展开(Unrolling)

为了使反向传播适用于循环结构,我们把 RNN 展开成一个由 T 个层组成的前馈网络,每一层对应一个时间步。虽然这些层共享同一组参数,但在展开的过程中,各个时间步之间的依赖关系(主要是隐藏状态 ht)得以显现。

这样做的目的是为了使我们能够用标准的反向传播算法计算梯度,从而更新整个序列中共享的参数。下面通过一个简单的例子说明这一过程。

假设情景

假设我们有一个 RNN 模型,用来处理一个长度为 3 的输入序列 [x1,x2,x3](例如数值 1、2、3),初始隐藏状态 h0​ 设为零。模型的前向计算公式为:

将 RNN 展开成前馈网络

原始的 RNN 是通过循环实现的,即使用同一组参数不断将隐藏状态从前一步传递到下一步。为了直观地理解反向传播的过程,我们将其在时间轴上展开,即把每个时间步看作网络中的一层,这些层之间按照时间顺序相连。

对于我们的序列,有如下展开:

  1. 时间步 1

  2. 时间步 2

  3. 时间步 3

这整个过程就像一个前馈网络,共有 3 层(不包括初始状态),每层的输出 h_t​ 都依赖于前一层的输出 h_{t-1}​ 和当前输入 x_t 。注意,虽然在展开过程中每一层对应一个不同的时间步,但所有层共享同一组权重和偏置。

为什么这样展开?

这种展开方式将时序依赖“展开”到层级结构中,使得整个序列可以看成一个深层网络。这样有两个好处:

  1. 便于反向传播计算
    我们可以像对普通前馈神经网络那样,基于链式法则逐层计算梯度,并且由于参数共享,每层计算的梯度会累积在同一组权重上。

  2. 捕捉长距离依赖
    通过展开,我们能直观地理解误差如何从最后一层传回到第一层,反映长距离依赖问题,以及梯度消失或爆炸的问题。

总结

  • 展开过程:将 RNN 从时间步 1 到 T 展开,每个时间步视为一层前馈网络,所有层使用同一组参数。

  • 前向传播:依次计算每层隐藏状态和输出。

  • 反向传播:从最后一层开始反向传播,逐层累积梯度,更新共享参数。

这种展开不仅使得梯度计算过程清晰,而且方便我们理解如何利用 BPTT 解决时间依赖问题,确保模型能够捕捉序列中长期和短期的信息。

(四)反向传播(Backward Pass Through Time,BPTT)

从展开后的最后一个时间步 T 开始,依次向前计算梯度:

  • 局部梯度计算:在每个时间步,根据当前输出与目标之间的误差,首先计算当前时间步的输出层梯度,然后通过当前隐藏状态对损失的贡献,计算激活函数(如 tanh⁡)的导数。

  • 梯度传递与累积:由于隐藏状态 hth_t 不仅直接影响当前输出,还间接影响后续所有时间步的输出,因而需要将来自未来时间步传回的梯度(往往称为 “dh_next”)与当前时间步的梯度相加,形成一个总的梯度

  • 参数梯度更新:利用链式法则,通过隐藏状态梯度计算出对输入到隐藏权重 Wxh 和隐藏到隐藏权重 Whh 以及偏置的梯度。由于这些参数在每个时间步都是共享的,每一步计算出的梯度都会被累加起来。

  • 时间传递:在完成当前时间步梯度计算后,再将梯度通过 传递回前一个时间步,继续重复这一过程直到第一个时间步。

下面以一个简单的 RNN 模型展开一个长度为 3 的序列的反向传播过程,来详细说明 BPTT 中的四个关键步骤:局部梯度计算、梯度传递与累积、参数梯度更新、和时间传递。假设模型的前向传播计算如下(激活函数采用 tanh):

  • 隐藏状态更新

  • 输出计算

假设我们的损失函数 L 是所有时间步损失的加和:

其中每个时间步的损失 Lt 是模型输出 yt 和目标 (y_t)^{target} 之间的误差(例如均方误差)。

下面分步详细说明 BPTT 的反向传播过程。

1. 局部梯度计算

在反向传播时,我们需要先计算每个时间步在输出端的局部梯度,然后再传回隐藏层。具体来说:

  • 对于时间步 t,我们先计算输出层的梯度:

  • 接着,通过输出层将梯度传递给隐藏状态:

  • 由于隐藏状态经过 tanh⁡ 激活,,其局部梯度部分需乘上激活函数的导数:

    因此,得到当前时间步的局部梯度:

    其中“⊙”表示元素级相乘。

这部分称为“局部梯度计算”,即对当前时刻输出误差先求到隐藏层(通过 Why​),再结合激活函数求出对 zt 的梯度。

2. 梯度传递与累积

由于 RNN 中隐藏状态间存在依赖,当前时刻 ht 不仅受当前时间步损失 Lt 影响,还间接受到后续时间步的反馈。因此,反向传播时需要将未来时间步传回来的梯度累积在当前时刻。设我们计算总的梯度 dht​ 对隐藏状态的偏导,其计算方式为:

即当前时刻的总梯度等于当前局部梯度加上由下一时间步传回来的梯度经过隐藏层的权重传递后的结果。

3. 参数梯度更新

有了每个时间步的梯度 δt​ 和从后续传来的梯度累积 ,我们可以对各层参数求导。具体来说:

  • 对于 输入到隐藏层权重 Wxh​

    其中 ⊗ 表示外积,此处对每个时间步将 δt 与相应的输入 xt 外积,然后累加。

  • 对于 隐藏到隐藏层权重 Whh​

    同样每个时间步累加当前梯度与前一隐藏状态的外积。

  • 对于 隐藏层偏置 bh​

  • 对于 隐藏到输出层权重 Why​,输出层的梯度已经在局部步骤中计算:

    .
  • 对于 输出偏置 by​

这些梯度在反向传播过程中在每个时间步内计算完毕后,通过累加得到整个序列上的梯度,接着就可以用常规优化方法更新参数。

4. 时间传递(从未来到过去)

在反向传播过程中,必须将未来时间步的梯度传递到当前时间步,这就是“时间传递”。具体步骤为:

这种梯度传递过程在整个序列反向迭代中重复执行,从时间步 T 逐层传回到时间步 1。

综合一个详细例子

假设我们有一个时间序列长度为 3 的 RNN,且以时间步 t=3 开始反向传播。简化起见,以下给出各步描述:

  1. 时间步 3

  2. 时间步 2

  3. 时间步 1

总结

  • 局部梯度计算:在每个时间步,根据输出误差乘以输出层权重和激活函数导数,得到对当前隐藏单元输入的梯度(δt​)。

  • 梯度传递与累积:从后向前逐步将未来时刻的梯度通过隐藏层(乘以 和激活导数)传递给前一时间步,累加成当前时刻的总梯度

  • 参数梯度更新:利用每个时间步局部梯度与输入(或前一时刻隐藏状态)的外积,累积得到 Wxh​、Whh​ 和 bh​ 的梯度;输出层参数的梯度也由对应输出误差累积。

  • 时间传递:通过计算隐藏状态之间的依赖(即 ​​),将梯度从后续传递给当前,直至序列首端。

这种详细步骤体现了 BPTT 如何让 RNN 捕捉序列中长距离依赖,以及如何利用链式求导从序列的末端逐步将梯度传回并更新共享参数。

(五)参数更新

在累积了整个序列上各时间步的梯度后,使用如梯度下降、Adam 等优化算法对共享参数进行更新,从而使整体损失下降,模型逐步学会捕捉时序依赖关系。

总体来说,BPTT 的实现流程可总结为:

  • 先在时间上前向传播:依次计算每个时间步的隐藏状态和输出,并存储中间结果。

  • 计算整个序列的总损失:对每一时间步的输出和目标计算损失。

  • 从后向前反向传播:将误差信息沿时间展开的网络逐层反向传递,每一步既考虑当前的局部误差,也考虑来自未来时间步的反馈,累积梯度。

  • 更新共享参数:利用累积的梯度,通过优化算法更新各个权重和偏置。

这一过程确保了即使序列较长,模型也能捕捉到早期输入对后续输出的影响,从而在学习长距离依赖关系方面发挥关键作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2330987.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

球类(继承和多态)

父类Ball,设置为抽象类,调用get和set方法创建对象,将子类重写的功能函数抽象化。 // 抽象球类 abstract class Ball {private String name;private double radius; // 半径private double weight; // 重量private double price; // 价格// 构…

DFS和BFS的模版

dfs dfs金典例题理解就是走迷宫 P1605 迷宫 - 洛谷 dfs本质上在套一个模版&#xff1a; ///dfs #include<bits/stdc.h> using namespace std; int a[10][10]{0}; int m,n,t,ans0; int ex,ey; int v[10][10]{0}; int dx[4]{-1,0,1,0}; int dy[4]{0,1,0,-1}; void dfs(in…

Ansible Playbook 进阶探秘:Handlers、变量、循环及条件判断全解析

192.168.60.100ansible.com192.168.60.110 client-1.com 192.168.60.120client-2.com192.168.60.130client-1.com 一、Handlers 介绍&#xff1a;在发生改变时执行的操作(类似puppet通知机制) 示例&#xff1a; 当apache的配置文件发生改变时&#xff0c;apache服务才会重启…

MATLAB的24脉波整流器Simulink仿真与故障诊断

本博客来源于CSDN机器鱼&#xff0c;未同意任何人转载。 更多内容&#xff0c;欢迎点击本专栏目录&#xff0c;查看更多内容。 目录 0 引言 1 故障数据采集 2 故障特征提取 3 故障诊断分类 4 结语 本博客内容是在MATLAB2023下完成。 0 引言 对于电力电子电路的故障诊断…

linux第三次作业

1、将你的虚拟机的网卡模式设置为nat模式&#xff0c;给虚拟机网卡配置三个主机位分别为100、200、168的ip地址 2、测试你的虚拟机是否能够ping通网关和dns&#xff0c;如果不能请修改网关和dns的地址 3、将如下内容写入/etc/hosts文件中&#xff08;如果有多个ip地址则写多行&…

国标GB28181视频平台EasyCVR顺应智慧农业自动化趋势,打造大棚实时视频监控防线

一、方案背景 近年来&#xff0c;温室大棚种植技术凭借其显著的优势&#xff0c;在提升农作物产量和质量、丰富农产品供应方面发挥了重要的作用&#xff0c;极大改善了人们的生活水平&#xff0c;得到了广泛的推广和应用。大棚内的温度、湿度、光照度和二氧化碳浓度等环境因素…

HOOPS Visualize:跨平台、高性能的三维图形渲染技术解析

在当今数字化时代&#xff0c;三维可视化技术已成为众多行业的核心竞争力。HOOPS Visualize作为一款功能强大的三维图形渲染引擎&#xff0c;凭借其卓越的渲染能力、跨平台支持、丰富的交互功能、高度定制化以及快速部署等特性&#xff0c;为开发人员提供了构建高质量、高性能3…

蓝桥杯 C/C++ 组历届真题合集速刷(一)

一、1.单词分析 - 蓝桥云课 &#xff08;模拟、枚举&#xff09;算法代码&#xff1a; #include <bits/stdc.h> using namespace std;int main() {string s;cin>>s;unordered_map<char,int> mp;for(auto ch:s){mp[ch];}char result_charz;int max_count0;fo…

多类型医疗自助终端智能化升级路径(代码版.上)

大型医疗自助终端的智能化升级是医疗信息化发展的重要方向,其思维链一体化路径需要围绕技术架构、数据流协同、算法优化和用户体验展开: 一、技术架构层:分布式边缘计算与云端协同 以下针对技术架构层的分布式边缘计算与云端协同模块,提供具体编程实现方案: 一、边缘节点…

区间 DP 详解

文章目录 区间 DP分割型合并型环形合并 区间 DP 区间 DP&#xff0c;就是在对一段区间进行了若干次操作后的最小代价&#xff0c;一般是合并和拆分类型。 分割型 分割型&#xff0c;指把一个区间内的几项分开拆成一份一份的&#xff0c;再全部合起来就是当前答案&#xff0c…

QAM 信号的距离以及能量归一化

QAM星座图平均功率能量_星座图功率计算-CSDN博客 正交幅度调制(QAM) - Vinson88 - 博客园 不同阶QAM调制星座图中&#xff0c;符号能量的归一化计算原理_qpsk的星座图归一化-CSDN博客 https://zhuanlan.zhihu.com/p/690157236

Reactive编程框架与工具

文章目录 6.2 后端 Reactive 框架6.2.1 Spring WebFlux核心架构核心组件实际应用高级特性性能优化适用场景与限制 6.2.2 Akka&#xff08;Actor模型&#xff09;Actor模型基础基本用法高级特性响应式特性实现性能优化实际应用场景优势与挑战 6.2.3 Vert.x&#xff08;事件驱动&…

Python爬虫第7节-requests库的高级用法

目录 前言 一、文件上传 二、Cookies 三、会话维持 四、SSL证书验证 五、代理设置 六、超时设置 七、身份认证 八、Prepared Request 前言 上一节&#xff0c;我们认识了requests库的基本用法&#xff0c;像发起GET、POST请求&#xff0c;以及了解Response对象是什么。…

Maven的安装配置-项目管理工具

各位看官&#xff0c;大家早安午安晚安呀~~~ 如果您觉得这篇文章对您有帮助的话 欢迎您一键三连&#xff0c;小编尽全力做到更好 欢迎您分享给更多人哦 今天我们来学习&#xff1a;Maven的安装配置-项目管理工具 目录 1.什么是Maven&#xff1f;Maven用来干什么的&#xff1f…

智能 SQL 优化工具 PawSQL 月度更新 | 2025年3月

&#x1f4cc; 更新速览 本月更新包含 21项功能增强 和 9项问题修复&#xff0c;重点提升SQL解析精度与优化建议覆盖率。 一、SQL解析能力扩展 ✨ 新增SQL语法解析支持 SELECT...INTO TABLE 语法解析&#xff08;3/26&#xff09; ALTER INDEX RENAME/VISIBLE 语句解析&#…

Ubuntu虚拟机编译安装部分OpenCV模块方法实现——保姆级教程

Ubuntu虚拟机的安装过程可以查看另一篇文章&#xff1a;VMware安装Ubuntu虚拟机实现COpenCV代码在虚拟机下运行教程-CSDN博客 目前我们已经下载好了OpenCV&#xff0c;这里以OpenCV4.5.2为例。 在内存要求尽可能小的情况下&#xff0c;可以尝试只编译安装代码中使用到的OpenC…

spring mvc @ResponseBody 注解转换为 JSON 的原理与实现详解

ResponseBody 注解转换为 JSON 的原理与实现详解 1. 核心作用 ResponseBody 是 Spring MVC 的一个注解&#xff0c;用于将方法返回的对象直接序列化为 HTTP 响应体&#xff08;如 JSON 或 XML&#xff09;&#xff0c;而不是通过视图解析器渲染为视图&#xff08;如 HTML&…

skynet.rawcall使用详解及应用场景

目录 核心特性函数原型使用场景场景 1&#xff1a;高性能二进制传输&#xff08;如文件转发&#xff09;场景 2&#xff1a;自定义序列化协议&#xff08;如 Protocol Buffers&#xff09;场景 3&#xff1a;跨服务共享内存&#xff08;避免拷贝&#xff09; 配套接收方实现与 …

使用SpringSecurity下,发生重定向异常

使用SpringSecurity下&#xff0c;发生空转异常 环境信息&#xff1a; Spring Boot 3.4.4 &#xff0c; jdk 17 &#xff0c; springSecurity 6.4.4 问题背景&#xff1a; 没有自定义controller &#xff0c;改写了login 页面&#xff0c;并且进行了成功后的跳转处理&#xf…

Elasticsearch | ES索引模板、索引和索引别名的创建与管理

关注&#xff1a;CodingTechWork 引言 在使用 Elasticsearch (ES) 和 Kibana 构建数据存储和分析系统时&#xff0c;索引模板、索引和索引别名的管理是关键步骤。本文将详细介绍如何通过 RESTful API 和 Kibana Dev Tools 创建索引模板、索引以及索引别名&#xff0c;并提供具…