长短期记忆(LSTM)与RNN的比较:突破性的序列训练技术

news2025/1/11 8:15:28

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

Why

LSTM提出的动机是为了解决「长期依赖问题」

长期依赖(Long Term Dependencies)

在深度学习领域中(尤其是RNN),“长期依赖“问题是普遍存在的。长期依赖产生的原因是当神经网络的节点经过许多阶段的计算后,之前比较长的时间片的特征已经被覆盖,例如下面例子

eg1: The cat, which already ate a bunch of food, was full.
      |   |     |      |     |  |   |   |   |     |   |
     t0  t1    t2      t3    t4 t5  t6  t7  t8    t9 t10
eg2: The cats, which already ate a bunch of food, were full.
      |   |      |      |     |  |   |   |   |     |    |
     t0  t1     t2     t3    t4 t5  t6  t7  t8    t9   t10

我们想预测'full'之前系动词的单复数情况,显然full是取决于第二个单词’cat‘的单复数情况,而非其前面的单词food。根据RNN的结构,随着数据时间片的增加,RNN丧失了学习连接如此远的信息的能力。

alt

LSTM vs. RNN

alt

相比RNN只有一个传递状态 ,LSTM有两个传输状态,一个 (cell state),和一个 (hidden state)。(Tips:RNN中的 对于LSTM中的

其中对于传递下去的 改变得很慢,通常输出的 是上一个状态传过来的 加上一些数值。

则在不同节点下往往会有很大的区别。

Model 详解

状态计算

首先使用LSTM的当前输入 和上一个状态传递下来的 拼接训练得到四个状态。

alt

其中, 是由拼接向量乘以权重矩阵之后,再通过一个 激活函数转换成0到1之间的数值,来作为一种门控状态。而 则是将结果通过一个 激活函数将转换成-1到1之间的值(这里使用 是因为这里是将其做为输入数据,而不是门控信号)。

计算过程

alt

⊙ 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 ⊕ 则代表进行矩阵加法。

LSTM内部主要有三个阶段:

  1. 「忘记阶段」。这个阶段主要是对上一个节点传进来的输入进行 「选择性」忘记。简单来说就是会 “忘记不重要的,记住重要的”。

具体来说是通过计算得到的 (f表示forget)来作为忘记门控,来控制上一个状态的 哪些需要留哪些需要忘。

  1. 「选择记忆阶段」。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 表示。而选择的门控信号则是由 (i代表information)来进行控制。

将上面两步得到的结果相加,即可得到传输给下一个状态的 。也就是上图中的第一个公式。

  1. 「输出阶段」。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 来进行控制的。并且还对上一阶段得到的 进行了放缩(通过一个tanh激活函数进行变化)。

与普通RNN类似,输出 往往最终也是通过 变化得到。

Code

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 3235
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
  • 初始化模型参数

定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
  • 定义模型
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)
  • 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 5001
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

# perplexity 1.3, 17736.0 tokens/sec on cuda:0
# time traveller for so it will leong go it we melenot ir cove i s
# traveller care be can so i ngrecpely as along the time dime
alt

总结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

Ref

  1. https://zhuanlan.zhihu.com/p/32085405
  2. https://zhuanlan.zhihu.com/p/42717426
  3. https://zh.d2l.ai/chapter_recurrent-modern/lstm.html

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1220861.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows使用lcx端口转发登陆远程主机

1.编译lcx源码: GitHub - UndefinedIdentifier/LCX: 自修改免杀lcx端口转发工具 2.在win7上安装vs2010并编译生成lcx.exe 3.在要被控制主机上运行: lcx -slave 192.168.31.248 51 192.168.31.211 3389 192.168.31.248为远程主控制主机,51为远程主机端口 192.168.31.211为被…

ZYNQ7000---FLASH读写

提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Flash是什么?二、Flash的分类1、内部结构(接口)区分:2、外部接口区分:SPIQPSI Flash: QSPI 控制…

Leetcode刷题详解——猜数字大小 II

1. 题目链接:375. 猜数字大小 II 2. 题目描述: 我们正在玩一个猜数游戏,游戏规则如下: 我从 1 到 n 之间选择一个数字。你来猜我选了哪个数字。如果你猜到正确的数字,就会 赢得游戏 。如果你猜错了,那么我…

分发糖果(贪心算法)

题目描述 n 个孩子站成一排。给你一个整数数组 ratings 表示每个孩子的评分。 你需要按照以下要求,给这些孩子分发糖果: 每个孩子至少分配到 1 个糖果。相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩子分发糖果,计算并返回…

操作系统:输入输出管理(二)磁盘调度算法

一战成硕 5.3 磁盘固态硬盘5.3.1 磁盘5.3.2 磁盘的管理5.3.3 磁盘调度算法 5.3 磁盘固态硬盘 5.3.1 磁盘 磁盘是表面涂有磁性物质的物理盘片,通过一个称为磁头的导体线圈从磁盘存取数据。在读写操作中,磁头固定,磁盘在下面高速旋转。磁盘盘…

Linux安装DMETL5与卸载

Linux安装DMETL5与卸载 环境介绍1 DM8数据库配置1.1 DM8数据库安装1.2 初始化达梦数据库1.3 创建DMETL使用的数据库用户 2 配置DMETL52.1 解压DMETL5安装包2.2 安装调度器2.3 安装执行器2.4 安装管理器2.5 启动dmetl5 调度器2.6 启动dmetl5 执行器2.7 启动dmetl5 管理器2.8 查看…

LangGPT作者教你编写高质量提示词

CoT和ToT能够提升表现,但是会使得模型的使用变复杂。在对话的场景下容易消耗人的耐心;实际应用的场景下,比较消耗人的token。 还有一点需要说明的是,我们在写自己的prompt的时候,不应该盲目地追求和堆砌提示词技巧&am…

Unity 预制体放在场景中可见,通过代码复制出来不可见的处理

首先我制作了一个预制体,在场景中是可见的,如下图 无论是Scene视图,还是Game视图都正常。 我把预制体放到Resources里面,然后我通过如下代码复制到同个父物体下。 GameObject obj1 Instantiate(Resources.Load("Butcon&quo…

Django+vue前后端分离实战--vue后台管理系统--vue环境安装项目创建

Djangovue前后端分离实战--vue后台管理系统 安装nodejsvue clivue-cli创建项目 安装nodejsvue cli 1、下载nodejs并安装 https://nodejs.org/dist/v20.9.0/node-v20.9.0-x64.msi 2、修改npm 默认仓库地址,要修改成taobao的镜像npm 仓库地址 cmd下命令&#xff1a…

sql注入 [极客大挑战 2019]LoveSQL 1

打开题目 几次尝试,发现输1 1",页面都会回显NO,Wrong username password!!! 只有输入1,页面报错,说明是单引号的字符型注入 那我们万能密码试试能不能登录 1 or 11 # 成功登录 得到账号…

【Redis】springboot整合redis(模拟短信注册)

要保证redis的服务器处于打开状态 上一篇: 基于session的模拟短信注册 https://blog.csdn.net/m0_67930426/article/details/134420531 整个流程是,前端点击获取验证码这个按钮,后端拿到这个请求,通过RandomUtil 工具类的方法生…

微服务实战系列之Token

前言 什么是“Token”? 它是服务端生成的一串字符串,以作客户端进行请求的一个令牌,当第一次登录后,服务器生成一个Token便返回给客户端;以后客户端只携带此Token请求数据即可。 简言之,Token其实就是用户身…

VBA_MF系列技术资料1-222

MF系列VBA技术资料 为了让广大学员在VBA编程中有切实可行的思路及有效的提高自己的编程技巧,我参考大量的资料,并结合自己的经验总结了这份MF系列VBA技术综合资料,而且开放源码(MF04除外),其中MF01-04属于定…

Rxswift(1)

基础用法 数据绑定核心Observerable 可监听序列 数据绑定 平常的写法 let image: UIImage UIImage(named: ...) imageView.image image绑定的写法 //可监听序列 let image: Observable<UIImage> ... //imageView.rx.image 观察者 image.bind(to: imageView.rx.image…

python爬取穷游网景点评论

爬取穷游网的景点评论数据&#xff0c;使用selenium爬取edge浏览器的网页文本数据。 同程的评论数据还是比较好爬取&#xff0c;不像大众点评需要你登录验证杂七杂八的&#xff0c;只需要找准你想要爬取的网页链接就能拿到想要的文本数据。 这里就不得不提一下爬取过程中遇到的…

echarts双轴右边的轴刻度不显示

图表单轴的时候&#xff0c;yAxis 和 series 是一个对象&#xff0c;但是当双轴显示的时候&#xff0c;yAxis 和 series 就都是一个数组里面包含两个对象&#xff0c;如果是多轴&#xff0c;就是多个对象 看下代码&#xff0c;关键代码 yAxisIndex: 1, 多轴的时候需要指定ind…

【2023春李宏毅机器学习】快速了解机器学习基本原理

文章目录 机器学习约等于机器自动找一个函数 机器学习分类 regression&#xff1a;输出为连续值classification&#xff1a;输出为一个类别structured learning&#xff1a;又叫生成式学习generative learning 生成有结构的物件&#xff08;如&#xff1a;影像、句子&#xf…

matplotlib 绘制双纵坐标轴图像

效果图&#xff1a; 代码&#xff1a; 由于使用了两组y axis&#xff0c;如果直接使用ax.legend绘制图例&#xff0c;会得到两个图例。而下面的代码将两个图例合并显示。 import matplotlib.pyplot as plt import numpy as npdata np.random.randint(low0,high5,size(3,4)) …

【Unity】XML文件的解析和生成

目录 使用XPath路径语法解析 使用xml语法解析 XML文件的生成 XML文件是一种常用的数据交换格式&#xff0c;它以文本形式存储数据&#xff0c;并使用标签来描述数据。解析和生成XML文件是软件开发中常见的任务。 解析XML文件是指从XML文件中读取数据的过程。在.NET中&#…

【入门篇】1.2 Redis 客户端之 Jedis 详解和示例

文章目录 1. 简介2. Jedis的依赖下载Jedis导入Jedis jar包配置Redis服务器的地址和端口 3. Jedis 的基本操作连接 Redis 服务器设置和获取字符串类型的键值对判断键是否存在删除键设置键的过期时间 4. Jedis 的数据类型操作字符串类型列表类型集合类型哈希类型有序集合类型 5. …