58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版

news2025/1/22 16:12:21

系列文章目录


文章目录

  • 系列文章目录
  • 深度循环神经网络
      • 1. 模型复杂性增加
      • 2. 训练数据不足
      • 3. 梯度消失和爆炸
      • 4. 正则化不足
      • 5. 特征冗余
        • 总结
    • 函数依赖关系
    • 简洁实现
    • 训练与预测
    • 小结
    • 练习


深度循环神经网络

🏷sec_deep_rnn

到目前为止,我们只讨论了具有一个单向隐藏层的循环神经网络。但是RNN中一个隐藏层不能做很宽,会很过拟合 。
在 RNN(递归神经网络)中,隐藏层的宽度(即隐藏单元的数量)对模型的表现有重要影响。宽度过大的隐藏层可能导致过拟合,原因如下:

1. 模型复杂性增加

  • 隐藏层的宽度增加意味着模型的参数数量显著增加。更多的参数可以捕捉到训练数据中的复杂模式,但这也使得模型容易记住训练数据的噪声,而不是学习到通用的特征。

2. 训练数据不足

  • 如果训练数据量相对较小,宽的隐藏层可能会导致模型在训练数据上表现很好,但在未见数据(测试数据)上的泛化能力差。过拟合通常发生在模型对训练数据的依赖过强,无法适应新的数据分布。

3. 梯度消失和爆炸

  • RNN 本身就容易出现梯度消失或爆炸的问题,尤其在处理长序列时。宽的隐藏层可能加剧这一问题,使得训练过程不稳定,从而影响模型的学习效果。

4. 正则化不足

  • 如果模型过于复杂而没有适当的正则化(如 dropout、L2 正则化等),则会更容易过拟合。宽隐藏层在缺乏正则化的情况下,可能会导致对训练数据的过度拟合。

5. 特征冗余

  • 宽的隐藏层可能导致特征冗余,许多隐藏单元可能学习到相似的特征,导致模型效率低下,并增加过拟合的风险。
总结

为了避免过拟合,通常需要在模型设计中平衡隐藏层的宽度与训练数据的数量和质量。可以考虑使用正则化方法、减少隐藏单元数量、增加训练数据量或使用更复杂的模型架构(如 LSTM 或 GRU)来提高模型的泛化能力。

其中,隐变量和观测值与具体的函数形式的交互方式是相当随意的。
只要交互类型建模具有足够的灵活性,这就不是一个大问题。
然而,对一个单层来说,这可能具有相当的挑战性。
之前在线性模型中,我们通过添加更多的层来解决这个问题。
而在循环神经网络中,我们首先需要确定如何添加更多的层,以及在哪里添加额外的非线性,因此这个问题有点棘手。
事实上,我们可以将多层循环神经网络堆叠在一起,通过对几个简单层的组合,产生了一个灵活的机制。
特别是,数据可能与不同层的堆叠有关。例如,我们可能希望保持有关金融市场状况(熊市或牛市)的宏观数据可用,而微观数据只记录较短期的时间动态。下图描述了一个具有 L L L个隐藏层的深度循环神经网络,每个隐状态都连续地传递到当前层的下一个时间步和下一层的当前时间步。

在这里插入图片描述

🏷fig_deep_rnn

函数依赖关系

我们可以将深度架构中的函数依赖关系形式化,这个架构是由 :numref:fig_deep_rnn中描述了 L L L个隐藏层构成。
后续的讨论主要集中在经典的循环神经网络模型上,但是这些讨论也适应于其他序列模型。

假设在时间步 t t t有一个小批量的输入数据 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本数: n n n,每个样本中的输入数: d d d)。
同时,将 l t h l^\mathrm{th} lth隐藏层( l = 1 , … , L l=1,\ldots,L l=1,,L)的隐状态设为 H t ( l ) ∈ R n × h \mathbf{H}_t^{(l)} \in \mathbb{R}^{n \times h} Ht(l)Rn×h(隐藏单元数: h h h),输出层变量设为 O t ∈ R n × q \mathbf{O}_t \in \mathbb{R}^{n \times q} OtRn×q(输出数: q q q)。
设置 H t ( 0 ) = X t \mathbf{H}_t^{(0)} = \mathbf{X}_t Ht(0)=Xt,第 l l l个隐藏层的隐状态使用激活函数 ϕ l \phi_l ϕl,则:

H t ( l ) = ϕ l ( H t ( l − 1 ) W x h ( l ) + H t − 1 ( l ) W h h ( l ) + b h ( l ) ) , \mathbf{H}_t^{(l)} = \phi_l(\mathbf{H}_t^{(l-1)} \mathbf{W}_{xh}^{(l)} + \mathbf{H}_{t-1}^{(l)} \mathbf{W}_{hh}^{(l)} + \mathbf{b}_h^{(l)}), Ht(l)=ϕl(Ht(l1)Wxh(l)+Ht1(l)Whh(l)+bh(l)),
:eqlabel:eq_deep_rnn_H

其中,权重 W x h ( l ) ∈ R h × h \mathbf{W}_{xh}^{(l)} \in \mathbb{R}^{h \times h} Wxh(l)Rh×h W h h ( l ) ∈ R h × h \mathbf{W}_{hh}^{(l)} \in \mathbb{R}^{h \times h} Whh(l)Rh×h和偏置 b h ( l ) ∈ R 1 × h \mathbf{b}_h^{(l)} \in \mathbb{R}^{1 \times h} bh(l)R1×h都是第 l l l个隐藏层的模型参数。

最后,输出层的计算仅基于第 l l l个隐藏层最终的隐状态:

O t = H t ( L ) W h q + b q , \mathbf{O}_t = \mathbf{H}_t^{(L)} \mathbf{W}_{hq} + \mathbf{b}_q, Ot=Ht(L)Whq+bq,

其中,权重 W h q ∈ R h × q \mathbf{W}_{hq} \in \mathbb{R}^{h \times q} WhqRh×q和偏置 b q ∈ R 1 × q \mathbf{b}_q \in \mathbb{R}^{1 \times q} bqR1×q都是输出层的模型参数。

与多层感知机一样,隐藏层数目 L L L和隐藏单元数目 h h h都是超参数。也就是说,它们可以由我们调整的。
另外,用门控循环单元或长短期记忆网络的隐状态来代替 :eqref:eq_deep_rnn_H中的隐状态进行计算,可以很容易地得到深度门控循环神经网络或深度长短期记忆神经网络。

简洁实现

实现多层循环神经网络所需的许多逻辑细节在高级API中都是现成的。
简单起见,我们仅示范使用此类内置函数的实现方式。
以长短期记忆网络模型为例,该代码与之前在LSTM中使用的代码非常相似,实际上唯一的区别是我们指定了层的数量,而不是使用单一层这个默认值。
像往常一样,我们从加载数据集开始。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

像选择超参数这类架构决策也跟 :numref:sec_lstm中的决策非常相似。
因为我们有不同的词元,所以输入和输出都选择相同数量,即vocab_size
隐藏单元的数量仍然是 256 256 256
唯一的区别是,我们现在(通过num_layers的值来设定隐藏层数)。

vocab_size, num_hiddens, num_layers = len(vocab), 256, 2
num_inputs = vocab_size
device = d2l.try_gpu()
lstm_layer = nn.LSTM(num_inputs, num_hiddens, num_layers)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)

训练与预测

由于使用了长短期记忆网络模型来实例化两个层,因此训练速度被大大降低了。

num_epochs, lr = 500, 2
d2l.train_ch8(model, train_iter, vocab, lr*1.0, num_epochs, device)
perplexity 1.0, 224250.2 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

<Figure size 350x250 with 1 Axes>

小结

  • 在深度循环神经网络中,隐状态的信息被传递到当前层的下一时间步和下一层的当前时间步。
  • 有许多不同风格的深度循环神经网络,
    如长短期记忆网络、门控循环单元、或经典循环神经网络。
    这些模型在深度学习框架的高级API中都有涵盖。
  • 总体而言,深度循环神经网络需要大量的调参(如学习率和修剪)
    来确保合适的收敛,模型的初始化也需要谨慎。

练习

  1. 基于我们在 :numref:sec_rnn_scratch中讨论的单层实现,
    尝试从零开始实现两层循环神经网络。
  2. 在本节训练模型中,比较使用门控循环单元替换长短期记忆网络后模型的精确度和训练速度。
  3. 如果增加训练数据,能够将困惑度降到多低?
  4. 在为文本建模时,是否可以将不同作者的源数据合并?有何优劣呢?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2175237.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于大数据的亚健康人群数据分析及可视化系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

在vsCode中将某个字符替换成换行符并且换行展示

将下面一行字符串中的“,”替换成换行符并且换行展示。 CTRLF键检索&#xff0c;查找和替换内容根据如下图输入即可。 点击全部替换后得到如下结果&#xff1a;

图文深入理解Oracle Network配置管理(二)

本篇图文深入介绍Oracle Network配置管理。 Oracle网络配置的目的 为了方便对Oracle 数据库进行管理&#xff0c;一般以下情况应该对Oracle进行网络配置。 • 在客户端对服务器端数据库进行管理&#xff08;网络客户端管理&#xff09; • 在一台服务器上管理多个数据库&…

性能测试学习1:性能测试的理论与目的,与功能测试的区别

一.什么是性能&#xff1f; 1&#xff09;性能&#xff1a;就是软件质量属性中的“效率”特性 2&#xff09;效率特性: ①时间特性&#xff1a;表示系统处理用户请求的响应时间【通俗来说&#xff0c;就是使用系统是否流畅】 ②资源特性&#xff1a;表示系统运行过程中&…

青动CRM-仓储云V1.1.2

多平台(微信公众号(高级授权)、微信小程序(高级授权)、H5网页(高级授权)、Android-App(高级授权)、iOS-App(高级授权))仓库管理系统&#xff0c;拥有强大的表单设计、多角色员工权限、出入库管理、仓库管理、送货管理、自定义审批流、绩效管理、客户管理、合同管理等功能。提供…

抖音支付回调验签 go 版本

序言 最近在做抖音小程序支付&#xff0c;由于抖音开放平台的文档写的较为简陋&#xff0c;让人踩了不少坑&#xff0c;在这里整理一下做小程序支付的整个过程&#xff0c;以通用交易系统为例子。 准备条件 1&#xff09;申请小程序&#xff0c;开通支付功能 这里需要明确你小…

Linux--基本指令

目录 1.ls指令 ​2.pwd指令 3.cd指令 4.tree指令 5.touch/mkdir/rmdir(rm)指令 6.cp/mv/cat/tac指令 7.head/tail/管道 8.date 9.find/which/grep 10.其它小知识 1.ls指令 补充知识&#xff1a; 2.pwd指令 补充知识&#xff1a; 3.cd指令 补充知识&#xff1a; 4.tree指令 …

Java语法-类和对象之抽象类和接口

1.抽象类 1.1 抽象类的概念 一个类中没有足够的信息来描述一个具体的对象,这样的类就是抽象类 比如: 从图中我们可以看出,只有继承了的类,我们产生的实例,调用的draw方法都是他们本身重写的draw方法,不会调用父类Shape的draw()方法,因此我们可以不管父类里面的draw()方法里面的…

MySQL 之多表设计详解

在实际应用场景中&#xff0c;我们经常需要处理包含多种数据实体及其之间复杂关系的业务逻辑&#xff0c;例如电商平台的用户、商品、订单&#xff0c;社交网络的用户、帖子、评论等等。如果将所有数据都堆砌在一张表中&#xff0c;不仅会造成数据冗余、难以维护&#xff0c;还…

MySQL 8.0.34 从C盘迁移到D盘

因为开始C盘够用&#xff0c;没注意mysql安装位置&#xff0c;如今C盘爆满&#xff0c;只能把mysql转移到D盘&#xff0c;以腾出更多的空间让我折腾。 一、关闭mysql服务 二、找到C盘MySQL安装文件和Data文件 1.找到C盘mysql bin文件目录安装文件路径&#xff1a; C:\Progra…

行为设计模式 -模板方法模式- JAVA

模板方法模式 一 .简介二. 案例2.1 抽象类&#xff08;Abstract Class&#xff09;2.2 具体子类&#xff08;Concrete Class&#xff09;2.3 测试 三. 结论3.1 优缺点3.2 适用场景3.3 要点 前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff0c;关注我&#xff0c;接…

linux从入门到精通--从基础学起,逐步提升,探索linux奥秘(六)

linux从入门到精通–从基础学起&#xff0c;逐步提升&#xff0c;探索linux奥秘&#xff08;六&#xff09; 一、linux高级指令&#xff08;1&#xff09; 1、hostname指令 1&#xff09;作用&#xff1a;操作服务器的主机名&#xff08;读取、设置&#xff09; 2&#xff0…

huggingface的transformers与datatsets的安装与使用

目录 1.安装 2.分词 2.1tokenizer.encode&#xff08;&#xff09; 2.2tokenizer.encode_plus &#xff08;&#xff09; 2.3tokenizer.batch_encode_plus&#xff08;&#xff09; 3.添加新词或特殊字符 3.1tokenizer.add_tokens&#xff08;&#xff09; 3.2 token…

Python自动收发邮件的详细步骤与使用方法?

Python自动收发邮件教程&#xff1f;Python怎么实现收发邮件&#xff1f; Python作为一种强大的编程语言&#xff0c;提供了丰富的库和工具&#xff0c;使得自动收发邮件变得简单而高效。AokSend将详细介绍如何使用Python自动收发邮件&#xff0c;帮助读者掌握这一实用技能。 …

【ASE】第四课_护盾效果(有碰撞效果)

今天我们一起来学习ASE插件&#xff0c;希望各位点个关注&#xff0c;一起跟随我的步伐 今天我们来学习护盾的效果。 思路&#xff1a; 1.添加纹理贴图和法线贴图&#xff08;这里省略&#xff09; 2.添加护盾边缘顶点扰动效果&#xff0c;也可以理解成变形效果 3.添加碰撞…

关于frp Web界面-----frp Server Dashboard 和 frp Client Admin UI

Web 界面 官方文档&#xff1a;https://gofrp.org/zh-cn/docs/features/common/ui/ 目前 frpc 和 frps 分别内置了相应的 Web 界面方便用户使用。 客户端 Admin UI 服务端 Dashboard 服务端 Dashboard 服务端 Dashboard 使用户可以通过浏览器查看 frp 的状态以及代理统计信…

59 双向循环神经网络_by《李沐:动手学深度学习v2》pytorch版

系列文章目录 文章目录 系列文章目录双向RNN推理 总结以下为理论部分双向循环神经网络隐马尔可夫模型中的动态规划双向模型定义模型的计算代价及其应用 (**双向循环神经网络的错误应用**)小结练习 双向RNN 这里理解这个图的时候&#xff0c;不要把正向和逆向认为有上下的关系&a…

计算机毕业设计 基于Python的音乐平台的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

IDE 使用技巧与插件推荐(含例说明)

在使用集成开发环境&#xff08;IDE&#xff09;进行编程时&#xff0c;掌握一些技巧和使用高效的插件可以显著提高开发效率。以下是一些通用的IDE使用技巧和插件推荐&#xff0c;适用于多种流行的IDE&#xff0c;如IntelliJ IDEA、Visual Studio Code、PyCharm等。每个技巧和插…

泳池异常检测系统源码分享

泳池异常检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…