长短期记忆网络(LSTM)

news2025/1/13 16:50:18
  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

  • 由于序列的长距离依赖性,训练长短期记忆网络 和其他序列模型(例如门控循环单元)的成本是相当高的

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元(门控循环单元(GRU)_流萤数点的博客-CSDN博客)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

目录

1.门控记忆元

1.1输入门、忘记门和输出门

1.2候选记忆元 

1.3记忆元 

1.4隐状态 

​2.从零开始实现

2.1初始化模型参数

2.2定义模型

2.3训练和预测

3.简洁实现


1.门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。 为了控制记忆元,我们需要许多门。 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

1.1输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在(0,1)的范围内。

1.2候选记忆元 

1.3记忆元 

1.4隐状态 

2.从零开始实现

 现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l

npx.set_np()

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.1初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return np.random.normal(scale=0.01, size=shape, ctx=device)

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                np.zeros(num_hiddens, ctx=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = np.zeros(num_outputs, ctx=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

2.2定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (np.zeros((batch_size, num_hiddens), ctx=device),
            np.zeros((batch_size, num_hiddens), ctx=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元Ct不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = npx.sigmoid(np.dot(X, W_xi) + np.dot(H, W_hi) + b_i)
        F = npx.sigmoid(np.dot(X, W_xf) + np.dot(H, W_hf) + b_f)
        O = npx.sigmoid(np.dot(X, W_xo) + np.dot(H, W_ho) + b_o)
        C_tilda = np.tanh(np.dot(X, W_xc) + np.dot(H, W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * np.tanh(C)
        Y = np.dot(H, W_hq) + b_q
        outputs.append(Y)
    return np.concatenate(outputs, axis=0), (H, C)

2.3训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

3.简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

lstm_layer = rnn.LSTM(num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。 多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。 然而,由于序列的长距离依赖性,训练长短期记忆网络 和其他序列模型(例如门控循环单元)的成本是相当高的。 在后面的内容中,我们将讲述更高级的替代模型,如Transformer。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/124873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

27移除元素--双指针(快慢指针)

27移除元素–双指针(快慢指针) 移除元素这道题看起来很简单,但其蕴含的快慢指针的思想十分重要。 双for循环(暴力法)-- O(n2n^2n2) 使用第1个for循环 i 遍历数组所有元素 使用第2个for循环从 i 开始进行数组元素的前移…

骨传导耳机伤耳朵吗、骨传导耳机适合适用的人群有哪些?

事实上,骨传导耳机是对耳朵最健康的一种耳机了,下面就来详细说说这种耳机。 骨传导耳机是以人的骨骼为介质,不经过外耳道和耳膜,将声音传递给听觉器官的耳机。他对人的耳朵损害相比起传统的耳机损害更小,因为听力受损…

JavaSE笔记——Lambda表达式

文章目录前言一、第一个Lambda表达式二、如何辨别Lambda表达式三、引用值,而不是变量四、函数接口五、类型推断总结前言 Java 8 的最大变化是引入了 Lambda 表达式——一种紧凑的、传递行为的方式。 一、第一个Lambda表达式 Swing 是一个与平台无关的 Java 类库&a…

redhat7.6+grid 11.2.0.4部署遇到各种问题

一、add cluster node时,卡住 两个节点时间不同步,设置时间同步即可 二、部署Redhat7.6oracle11g部署中的bug Oracle 11.2.0.4 部署rac过程中,需要运行root.sh脚本报错。提示: ohasd集群无法启动。该补丁修改ohasd无法启动的问题…

红外成像系统测试

通常人们把红外辐射称为红外光、红外线。实际上其波段是指其波长约在0.75μm到1000μm的电磁波。人们将其划分为近、中、远红外三部分。近红外指波长为0.75-3.0μm;中红外指波长为3.0-20μm;远红外则指波长为20-1000μm。由于大气对红外辐射的吸收,只留下三个重要的“窗口”…

一把巴枪,和被改变的菜鸟驿站站长们

成立9年的菜鸟物流一直在答题。如果说之前这张答卷更多的标签是面向物流前端的配送和分拣等,那么如今,它的更多答案已经不单纯是前端的流通和连接,更有最末端基于科技对人的温度和赋能。 作者|丰兰 出品|产业家 数字化,正在…

少儿Python每日一题(6):角谷猜想

原题解答 本次的题目如下所示(原题出处:NOC): 角谷猜想:以一个正整数n为例,如果n为偶数,就将它变为n/2;如果除后变成奇数,则将它乘3加1(即3n1)。…

latex常用语法速查

本文针对overleaf在线使用latex的情况编写。 文章目录文档结构要点导入图片使用表格添加引用参考资料文档结构 文档类型设置 \documentclass[12pt,article]{book} % []中设置文档格式,文档字体大小默认为10pt,article指定文档用纸类型,其他…

【金猿人物展】龙盈智达首席数据科学家王彦博:量子科技为AI大数据创新发展注入新动能...

‍王彦博本文由龙盈智达首席数据科学家王彦博撰写并投递参与“数据猿年度金猿策划活动——2022大数据产业趋势人物榜单及奖项”评选。‍数据智能产业创新服务媒体——聚焦数智 改变商业回顾2022年大数据行业发展,令人感触最深的是数字经济时代对“数据安全”和“数…

基于自主可控的新型基础测绘与实景三维中国建设

实景三维中国作为强赋能、稳基底、重应用的新型基础设施,是打造数字中国、数字经济、数字政府的核心资源,其关键技术的掌握已经成为撬动社会生产,促进行业良性内循环,引发国家数字资源合理分配的重中之重。 ▲实景三维工程技术研究…

小程序入门01

目录 1.什么是小程序 2.小程序可以干什么? 3.相关资料 4.入门 4.1 申请账号 4.2 安装第一个小程序 4.3 了解程序 1.什么是小程序 2017年度百度百科十大热词之一 微信小程序(wei xin xiao cheng xu),简称小程序,英文…

FreeSWITCH在视频会议中的实践经验

点击上方“LiveVideoStack”关注我们▲扫描图中二维码或点击阅读原文▲了解音视频技术大会更多信息// 编者按:视频会议已成为日常办公不可或缺的一部分,为远程交流的人们提供了许多便利。本次RTSCon 2022会议,由RTS社区和LiveVideoStack音视…

Win10的两个实用技巧系列之设置鼠标指针、红警玩不了怎么办?

win10系统怎么设置鼠标指针在打字时隐藏? win10隐藏鼠标指针的方法 win10系统怎么设置鼠标指针在打字时隐藏?win10系统输入文字的时候,想要隐藏鼠标指针,该怎么操作呢?下面我们就来看看win10隐藏鼠标指针的方法 win10如何隐藏鼠…

Android中的属性动画

在属性动画出来之前,Android系统提供的动画只有帧动画和View动画。View动画大家可能知道,它提供了AlphaAnimation(透明度),RotateAnimation(负责旋转),TranslateAnimation(负责移动),ScaleAnimation(负责缩放)这4种动画…

2022年广西建筑安全员考试真题题库及答案

百分百题库提供建筑安全员考试试题、安全员证考试真题、安全员证考试题库等,提供在线做题刷题,在线模拟考试,助你考试轻松过关。 100.《中华人民共和国建筑法》规定,建设单位申请领取施工许可证,应当具备下列条件有() A.已经办理该建筑工程用地批准手续…

手绘图说电子元器件-晶体管

晶体二极管与单结晶体管 晶体二极管是电子电路中最重要的半导体器件,包括一般二极管和特殊二极管两大类。 晶体二极管 晶体二极管简称二极管,是一种常用的具有一个PN结的半导体器件。 晶体二极管的极性 晶体二极管两引脚有正、负极之分 晶体二极管的参数 晶体二极管的…

HOW POWERFUL ARE GRAPH NEURAL NETWORKS? 论文/GIN学习笔记

对GNN的评估 GNN 通用表达式 聚合: av(k)AGGREAGTE(k)({hu(k−1):u∈N(v)})a_v^{(k)}AGGREAGTE^{(k)}(\{ h_u^{(k-1)} : u \in \mathcal{N}(v) \}) av(k)​AGGREAGTE(k)({hu(k−1)​:u∈N(v)}) 更新: hv(k)COMBINE(k)(hv(k−1),av(k))h_v^{(k)} COMB…

【JavaSE成神之路】流程控制语句

哈喽,我是兔哥呀,今天就让我们继续这个JavaSE成神之路! 这一节啊,咱们要学习的内容是流程控制语句。 先来看概念 Java的流程控制语句是指用来控制程序执行流程的语句,它们可以改变程序的执行顺序,使程序更…

javaee之SpringMVC1

三层架构与MVC设计模式介绍 一张图介绍 之前在写软件设计的三层架构的时候,有一张图直接拿过来 springMVC的一些简单介绍 入门案例 一、入门案例之需求分析 二、搭建开发环境 1.利用骨架创建一个maven项目 因为这个项目要部署到服务器,所以采用如下骨…

代码质量与安全 | 如何将清洁代码标准扩展到整个企业,促进业务上的成功?

清洁代码能够让软件开发工作变得更简单、更有趣。因为如果代码不够清洁,开发人员将花费很多时间在解决编码问题上,使他们无法将精力投入开发新代码、解决其他更有趣的问题上。 那么,该如何将清洁代码标准扩展到整个企业呢?阅读本…