57 长短期记忆网络(LSTM)_by《李沐:动手学深度学习v2》pytorch版

news2024/9/29 6:34:17

系列文章目录


文章目录


长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。
解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM)它有许多与门控循环单元(GRU)一样的属性。有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,却比门控循环单元早诞生了近20年。

门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。
长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。
有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。
为了控制记忆元,我们需要许多门。
其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。
另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。
我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图所示。它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。
因此,这三个门的值都在 ( 0 , 1 ) (0, 1) (0,1)的范围内。

在这里插入图片描述label:lstm_0

我们来细化一下长短期记忆网络的数学表达。
假设有 h h h个隐藏单元,批量大小为 n n n,输入数为 d d d
因此,输入为 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d,前一时间步的隐状态为 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h
相应地,时间步 t t t的门被定义如下:
输入门是 I t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h} ItRn×h
遗忘门是 F t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h} FtRn×h
输出门是 O t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h} OtRn×h
它们的计算方法如下:

I t = σ ( X t W x i + H t − 1 W h i + b i ) , F t = σ ( X t W x f + H t − 1 W h f + b f ) , O t = σ ( X t W x o + H t − 1 W h o + b o ) , \begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned} ItFtOt=σ(XtWxi+Ht1Whi+bi),=σ(XtWxf+Ht1Whf+bf),=σ(XtWxo+Ht1Who+bo),

其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,WxoRd×h W h i , W h f , W h o ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h} Whi,Whf,WhoRh×h是权重参数, b i , b f , b o ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h} bi,bf,boR1×h是偏置参数。

候选记忆元 (相当于RNN中计算 H t H_t Ht)

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h} C~tRn×h
它的计算与上面描述的三个门的计算类似,但是使用 tanh ⁡ \tanh tanh函数作为激活函数,函数的值范围为 ( − 1 , 1 ) (-1, 1) (1,1)
下面导出在时间步 t t t处的方程:

C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c), C~t=tanh(XtWxc+Ht1Whc+bc),

其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} WxcRd×h W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} WhcRh×h是权重参数, b c ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h} bcR1×h是偏置参数。

候选记忆元的如下图 :numref:lstm_1所示。

在这里插入图片描述label:lstm_1

记忆元

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。
类似地,在长短期记忆网络中,也有两个门用于这样的目的:
输入门 I t \mathbf{I}_t It控制采用多少来自 C ~ t \tilde{\mathbf{C}}_t C~t的新数据,而遗忘门 F t \mathbf{F}_t Ft控制保留多少过去的记忆元 C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} Ct1Rn×h的内容。
使用按元素乘法,得出:

C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t. Ct=FtCt1+ItC~t.

如果遗忘门始终为 1 1 1且输入门始终为 0 0 0,则过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct1将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

这样我们就得到了计算记忆元的流程图,如 :numref:lstm_2

在这里插入图片描述label:lstm_2

隐状态

最后,我们需要定义如何计算隐状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} HtRn×h,这就是输出门发挥作用的地方。
在长短期记忆网络中,它仅仅是记忆元的 tanh ⁡ \tanh tanh的门控版本。
这就确保了 H t \mathbf{H}_t Ht的值始终在区间 ( − 1 , 1 ) (-1, 1) (1,1)内:

H t = O t ⊙ tanh ⁡ ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Ottanh(Ct).

只要输出门接近 1 1 1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 0 0 0,我们只保留记忆元内的所有信息,而不需要更新隐状态(相当于重置隐状态)。

下图 :numref:lstm_3提供了数据流的图形化演示。

在这里插入图片描述label:lstm_3
在这里插入图片描述

从零开始实现

现在,我们从零开始实现长短期记忆网络。我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

接下来,我们需要定义和初始化模型参数。
如前所述,超参数num_hiddens定义隐藏单元的数量。
我们按照标准差 0.01 0.01 0.01的高斯分布初始化权重,并将偏置项设为 0 0 0

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

[实际模型]的定义与我们前面讨论的一样:
提供三个门和一个额外的记忆元。
请注意,只有隐状态才会传递到输出层,而记忆元 C t \mathbf{C}_t Ct不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y) #Y的shape是(批量大小,词表长度)只有这里输出了批量大小的预测,之后才能用来计算损失
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化RNN从零实现中引入的RNNModelScratch类来训练一个长短期记忆网络,就如我们在GRU中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 14.5, 27965.3 tokens/sec on cuda:0
time traveller te at at at at at at at at at at at at at at at a
traveller te at at at at at at at at at at at at at at at a

<Figure size 350x250 with 1 Axes>

在这里插入图片描述

简洁实现

使用高级API,我们可以直接实例化LSTM模型。
高级API封装了前文介绍的所有配置细节。
这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 11.2, 233619.5 tokens/sec on cuda:0
time traveller the the the the the the the the the the the the t
traveller the the the the the the the the the the the the t

<Figure size 350x250 with 1 Axes>

在这里插入图片描述
长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。
多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。
然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(例如门控循环单元)的成本是相当高的。
在后面的内容中,我们将讲述更高级的替代模型,如Transformer。

小结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

练习

  1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
  2. 如何更改模型以生成适当的单词,而不是字符序列?
  3. 在给定隐藏层维度的情况下,比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。
  4. 既然候选记忆元通过使用 tanh ⁡ \tanh tanh函数来确保值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间,那么为什么隐状态需要再次使用 tanh ⁡ \tanh tanh函数来确保输出值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间呢?
  5. 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2176075.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0基础学习CSS(六)字体

CSS 字体 CSS字体属性定义字体&#xff0c;加粗&#xff0c;大小&#xff0c;文字样式。 serif和sans-serif字体之间的区别 在计算机屏幕上&#xff0c;sans-serif字体被认为是比serif字体容易阅读 CSS字型 在CSS中&#xff0c;有两种类型的字体系列名称&#xff1a; 通用字体…

Java | Leetcode Java题解之第443题压缩字符串

题目&#xff1a; 题解&#xff1a; class Solution {public int compress(char[] chars) {int n chars.length;int write 0, left 0;for (int read 0; read < n; read) {if (read n - 1 || chars[read] ! chars[read 1]) {chars[write] chars[read];int num read …

解读文本嵌入:语义表达的练习

【引子】近来在探索并优化AIPC的软件架构&#xff0c;AI产品经理关于语义搜索的讨论给了自己较多的触动&#xff0c;于是重新梳理嵌入与语义的关系&#xff0c;遂成此文。 文本转换成机器可理解格式的最早版本之一是 ASCII码&#xff0c;这种方法有助于渲染和传输文本&#xff…

win10系统K8S安装教程

准备工作 电脑硬件&#xff1a;支持虚拟化的CPU&#xff0c;内存最好在32G以上&#xff0c;16G也可以操作系统&#xff1a;window10 专业版 1 开启虚拟化 1.1 BIOS 由于主板和CPU的品牌不太一样&#xff0c;这里的操作仅供参考&#xff0c;以Intel的平台为例&#xff1a; …

【刷点笔试面试题试试水】有符号变量与无符号变量的值的转换

大家好,这里是国中之林! ❥前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。有兴趣的可以点点进去看看← 问题: 解答: 注意无符号类型与有符合类型参与计算会做类型提升,有符合的变为无符号…

加法器以及标志位

加法器的结构&#xff1a; OF&#xff08;溢出标志位&#xff09;&#xff0c;SF&#xff08;符号标志位&#xff09;&#xff0c;ZF&#xff08;0标志位&#xff09;&#xff0c;ZF&#xff08;进位/借位标志位&#xff09; 有符号数看标志位&#xff1a;OF&#xff0c;SF 无符…

ubuntu 不用每次输入sudo的四种方式

在Ubuntu系统中&#xff0c;如果不希望每次执行需要管理员权限的命令时都输入sudo&#xff0c;有几种方法可以实现这一目标。以下是一些详细的方法&#xff1a; 第一种方式: 切换root用户 (如果你有足够的权限) # 修改root密码命令(没有设置的用户需要设置一下) consolaadmin…

面试中顺序表常考的十大题目解析

在数据结构与算法的面试中&#xff0c;顺序表是一个常见的考点。它作为一种基础的数据结构&#xff0c;涵盖了多种操作和概念&#xff0c;以下将详细介绍面试中关于顺序表常考的十大题目。 &#x1f49d;&#x1f49d;&#x1f49d;如果你对顺序表的概念与理解还存在疑惑&#…

【Threejs进阶教程-着色器篇】8. Shadertoy如何使用到Threejs-基础版

【Threejs进阶教程-着色器篇】8. Shadertoy如何使用到Threejs - 基础版 前七篇地址,建议按顺序学习致谢带我入门的[X01动力装甲]大佬本文适用范围怎么样在Shadertoy中画出正圆形shadertoy中的坐标系比例转换理解Shadertoy的fragCoord理解Shadertoy中的iResolution 转移Shaderto…

SigmaStudio淡入淡出增益控件(Single SW slew vol(adjustable))延时分析

斜率范围1~23&#xff0c;参考12khz正弦波&#xff08;-17.99db,调减15.2db&#xff09;作为分析依据 一、淡入时间与斜率关系 斜率1-----淡入延时时间大概0.08毫秒 斜率2—淡入延时时间大概0.2毫秒 斜率3–淡入延时时间按大概0.5毫秒 斜率4–淡入延时时间大概1毫秒 斜率5–淡…

C++学习笔记之结构体

C学习笔记之结构体 https://www.runoob.com/cplusplus/cpp-struct.html 结构体是C中一种由用户自定义的数据类型&#xff0c;允许存储不同类型的数据项 1、定义结构体 使用struct语句定义结构体 结构体与C中的类看起来结构相似&#xff0c;同样是可以在其中定义成员变量和成员…

picgo + typora + gitee图床

Picgo打造个人图床&#xff0c;稳定又安全 解决Typora笔记上传到CSDN图片无法显示的问题 typora中

完全二叉树的节点个数 C++ 简单问题

完全二叉树 的定义如下&#xff1a;在完全二叉树中&#xff0c;除了最底层节点可能没填满外&#xff0c;其余每层节点数都达到最大值&#xff0c;并且最下面一层的节点都集中在该层最左边的若干位置。若最底层为第 h 层&#xff0c;则该层包含 1~ 2h 个节点。 示例 1&#xff…

蓝桥杯—STM32G431RBT6(RTC时钟获取时间和日期)

一、RTC是什么&#xff0c;有什么用&#xff1f; 在 STM32 中&#xff0c;RTC&#xff08;Real-Time Clock&#xff0c;实时时钟&#xff09;主要有以下作用&#xff1a; 时间保持&#xff1a;即使在系统断电情况下&#xff0c;也能持续记录时间。&#xff08;需要纽扣电池供电…

解决银河麒麟V10密码过期无法登录的问题

解决银河麒麟V10密码过期无法登录的问题 1、问题描述2、 解决方法步骤一&#xff1a;更改密码步骤二&#xff1a;调整密码策略&#xff08;可选&#xff09; 3、总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在使用银河麒麟桌面操作系…

Java:选择排序

目录 直接选择排序 堆排序 基本思想&#xff1a; 每一次从待排序的数据元素中选出最小(或最大)的一个元素&#xff0c;存放在序列的起始位置&#xff0c;直到全部待排序的数据元素排完。 直接选择排序 思路1&#xff1a; 在元素集合array[i]--array[n-1]中选择关键码最大(小…

【论文阅读】视觉里程计攻击

Adversary is on the Road: Attacks on Visual SLAM using Unnoticeable Adversarial Patch 一、视觉SLAM的不安全因素 根据论文的分析&#xff0c;视觉SLAM由于完全依赖于特征&#xff0c;缺少验证机制导致算法不安全。前端在受到干扰的情况下&#xff0c;会导致误匹配增加&…

算法工程师重生之第十八天(修剪二叉搜索树 将有序数组转换为二叉搜索树 把二叉搜索树转换为累加树 总结篇 )

参考文献 代码随想录 一、修剪二叉搜索树 给你二叉搜索树的根节点 root &#xff0c;同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树&#xff0c;使得所有节点的值在[low, high]中。修剪树 不应该 改变保留在树中的元素的相对结构 (即&#xff0c;如果没有被移除…

【Sentinel-2简介】

Sentinel-2简介 Sentinel-2是欧洲空间局&#xff08;European Space Agency, ESA&#xff09;全球环境和安全监视&#xff08;即哥白尼计划&#xff09;系列卫星的重要组成部分&#xff0c;由Sentinel-2A和Sentinel-2B两颗卫星组成。以下是关于Sentinel-2的详细介绍&#xff1…

信息安全工程师(27)环境安全分析与防护

前言 环境安全分析与防护是一个综合性的议题&#xff0c;涉及多个方面&#xff0c;包括环境安全的概念、分析方法、存在的安全隐患以及相应的防护措施。 一、环境安全的概念 环境安全是指人类赖以生存发展的环境&#xff0c;处于一种不受污染和破坏的安全状态&#xff0c;或者说…