长短期记忆网络LSTM

news2025/1/21 5:54:46

目录

  • 一、LSTM提出的背景:
    • 1.RNN存在的问题:
    • 2.LSTM的思想:
      • 2.1回顾GRU的提出:
      • 2.2LSTM在GRU上的改进:
  • 二、遗忘门、输入门、输出门:
  • 三、LSTM网络架构:
    • 1.候选记忆单元C~t:
    • 2.遗忘门、输入门、输出门如何发挥作用:
      • 2.1记忆单元Ct:
      • 2.2隐藏状态Ht:
    • 3.LSTM:
  • 四、训练过程举例******:
  • 五、预测过程举例******:
  • 六、底层源码:
  • 七、Pytorch版代码:

一、LSTM提出的背景:

1.RNN存在的问题:

循环神经网络讲解文章

由于RNN的隐藏状态ht用于记录每个句子之前的所有序列信息,而对于长序列问题来说ht会记录太多序列信息导致序列时序特征区分度很差(最前面的序列特征因为进行了太多轮迭代往往不太好从ht中提取),并且RNN默认当前时间步的token单词和该句子的隐藏状态ht中所有序列信息都有同等的相关度,因此一些比较靠前但与当前时间步输入的token相关性高的序列特征在ht中可能就不太被重视,而一些比较靠后但与当前时间步输入的token相关性低的序列特征在ht中被过于关注。

2.LSTM的思想:

2.1回顾GRU的提出:

门控循环单元GRU讲解文章

GRU的提出就是为了解决RNN默认序列内所有token之间的相关性相等问题。
GRU的思想是对于每个时间步的输入token,使用门的控制将隐藏状态ht中与当前token相关性高的序列信息拿来参与计算,而ht中与当前token相关性低的序列信息作为噪音不参与计算。

  • 对于需要关注的序列信息,使用更新门来提高关注度
  • 对于需要遗忘的序列信息,使用遗忘门来降低关注度

2.2LSTM在GRU上的改进:

LSTM可以理解成GRU的变体,保留了重置门(遗忘门)用来对过去所有时间步的序列信息进行选择、更新门(输入门)用来对当前一个时间步的序列信息进行选择。在此基础上增加了记忆单元Ot用来保存长序列的序列特征,而Ht仅需要保存短序列的序列特征即可,解决了Ht不能很好的保存长序列信息的缺点。除此之外还增加了输出门的概念来控制Ot中分配多少个时间步的序列信息给Ht。

二、遗忘门、输入门、输出门:

在这里插入图片描述
遗忘门、输入门、输出门可以分别看做一个全连接层的隐藏层,这样的话上图就等价于三个并排的隐藏层,其中:

  • 每个隐藏层都接收之前时间步的隐藏状态Ht-1和当前时间步的输入token。
  • 遗忘门、输入门、输出门有各自的可学习权重参数和偏置值,公式含义类似传统RNN。
  • Ft、It、Ot 都是根据过去的隐藏状态 Ht-1 和当前输入 Xt 计算得到的 [0,1] 之间的量(激活函数)。

三、LSTM网络架构:

1.候选记忆单元C~t:

在这里插入图片描述
候选记忆单元的计算公式类似于RNN计算Ht的公式,用来记录当前时间步token的序列信息和前t-1个时间步的序列信息

2.遗忘门、输入门、输出门如何发挥作用:

2.1记忆单元Ct:

LSTM 3D视频讲解链接
首先LSTM在Ht的基础上加入了Ct,其中Ht仅需记录短期序列信息,Ct负责记录长期序列信息。
并且LSTM主要是对Ct更新,而GRU和RNN主要是对Ht更新。
在这里插入图片描述

在这里插入图片描述
(1)因为Ft是一个[0,1] 之间的量,所以Ft×Ct-1是对之前的长期序列信息Ot-1进行一次选择:Ft 在某个位置的值越趋近于0,则表示Ot-1这个位置的序列信息越倾向于被丢弃,反之保留。

综上,遗忘门的作用是对过去的长序列信息Ot-1进行选择,Ot-1中哪些序列信息当前的Ct是有用的,应该被保存下来,而哪些序列信息是不重要的,应该被遗忘。

(2)因为It是一个[0,1] 之间的量,如果It全为0,则当前记忆单元Ot为上一个时间步的记忆单元Ot-1;如果It全为1,则当前记忆单元Ot为上一个时间步的记忆单元Ot-1和候选记忆单元O~t(候选记忆单元记录当前时间步token的序列信息和前t-1个时间步的序列信息)的和(感觉这里Ot-1和O~t中记录的过去序列信息重复了,设计好像有冗余问题,没有GRU那么完美)。

综上,输入门的作用是决定当前一个时间步的序列信息是否保留,如果It全为1,则说明当前时间步token的序列信息是有用的(候选记忆单元记录当前时间步token的序列信息和前t-1个时间步的序列信息),保留下来加入到记忆单元Ot中;如果It全为0,则说明当前时间步token的序列信息是没有用的,丢弃当前token的序列信息,直接使用上一个时间步的记忆单元Ot-1作为当前的记忆单元Ot(记录迄今t为止长序列信息)。(Ot-1仅包含之前的长序列信息,不包含当前一个时间步的序列信息)

下图形象的展示了遗忘门和输入门的作用:
LSTM 3D视频讲解链接
在这里插入图片描述

注意GRU中的遗忘门和输入门是对Ht的修改,而LSTM中的遗忘门和输入门是对Ct的修改。

注意记忆单元Ct输出范围是[-2,2]

2.2隐藏状态Ht:

在这里插入图片描述
因为Ct是一个[-2,2] 之间的量,为了保证Ht的输出范围,所以需要取tan将Ct变为[-1,1]的范围内

因为Ot是一个[0,1]之间的量,所以输出门的主要作用是控制当前隐藏状态Ht(记录短期序列信息)的输出,即决定从记忆单元Ot(记录长期序列信息)​中传递多久的序列信息给Ht。

3.LSTM:

LSTM网络架构如下,可以看做是四个隐藏层并排的架构。

LSTM不仅循环隐藏状态Ht,还循环记忆单元Ct,其中Ct和Ht分别保存长、短期序列信息,这也是长短期记忆网络的由来。
在这里插入图片描述

四、训练过程举例******:

以下文预测问题为例,一次epoch训练过程如下。
1.对整个文本进行数据预处理,获得数据字典,这里假设字典中有vocab_size条字典序,这样就转换成了一个vocab_size分类的序列问题。
2.将每个单词token值使用独热编码转换成1×vocab_size的一维向量,作为特征,表示各分类上的概率。
3.每轮epoch输入格式为batch_num×batch_size×num_steps×vocab_size,其中batch_num表示该轮压迫训练多少个batch,batch_size表示每个batch中有多少个句子序列,每个句子有num_steps个单词token,即该batch要训练多少个时间步,即循环time_step次传统神经网络,每个单词为一个一维向量,表示在字典序上的概率。每次训练一个batch,每个时间步t使用该batch中所有batch_size个序列的第t个token集合Xt进行训练(num_steps=t的token),batch尺寸为batch_size×num_steps×vocab_size,Xt尺寸为batch_size×vocab_size
4.隐藏层参数Whh维度为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;Whx维度为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bh维度为1×num_hiddens
5.四个并行隐藏层各自的参数Whi、Whf、Who、Whc维度计算为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;四个并行隐藏层各自的参数Wxi、Wxf、Wxo、Wxc维度计算为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bi、bf、bo、bc维度计算为1×num_hiddens。这里由于四个隐藏层输出维度相同,所以隐藏内的神经元数目都是相同的=num_hiddens。
6.对于第一个batch,训练过程如下:
6.1.初始化0时刻短序列信息h0,尺寸为(batch_size,神经元个数num_hiddens)
6.2.初始化0时刻长序列信息C0,尺寸为(batch_size,神经元个数num_hiddens)
6.3.t1时间步num_steps=1,取该batch所有序列样本的第一个token组成x0,尺寸batch_size×vocab_size,每个vocab一维向量并行放入神经网络学习,首先x0中每个token和ho同时进入遗忘门隐藏层、输入门隐藏层、输出门隐藏层和候选记忆单元隐藏层,输入门隐藏层输出I1=sigmoid(Whi×h0+Wxi×x0+bi)、遗忘门隐藏层输出F1=sigmoid(Whf×h0+Wxf×x0+bf)、输出门隐藏层输出O1=sigmoid(Who×h0+Wxo×x0+bo)、候选记忆单元隐藏层C~1=sigmoid(Whc×h0+Wxc×x0+bc),四个隐藏层分别用来筛选Ct和Ht序列信息,输出维度均为batch_size×num_hiddens。
6.4.F1、I1、C~1和记忆单元C0联合计算,使用遗忘门对过去的序列信息进行筛选、使用输入门对当前的序列信息进行筛选,计算出当前时间步的记忆单元C1。
6.5.O1、当前时间步记忆单元C1联合计算,使用输出门对长序列信息C1进行筛选,计算出当前时间步的隐藏状态h1,隐藏层输出维度batch_size×num_hiddens,h1作为t1时间步的输出层输入、t2时间步的隐藏层输入序列信息(隐藏状态)。
6.6.此时两个操作并行执行:t1时间步的输出层计算、t2时间步的隐藏层计算。
6.6.1首先h1和C1作为t1时间步的输出层输入,输出层有vocab_size个神经元,会执行多分类预测,可学习参数为Woh(num_hiddens×vocab_size)和bo(1×vocab_size),每个token输出维度1×vocab_size,输出层输出维度batch_size×vocab_size,表示各个token在各个分类上的预测。
6.6.2其次,t2时间步num_steps=2,取batch中num_steps=2的token集合为x1,维度为batch_size×vocab_size,并行将每个token一维向量放入神经网络学习,隐藏层输出h2=…,记忆单元C2=…,隐藏层输出维度batch_size×num_hiddens,h2和C2作为t2时间步的输出层输入、t3时间步的隐藏层输入序列信息。
6.7.如此反复每个时间步取一个数据点token集合进行训练,并更新隐藏层输出ht和Ct作为下一个时间步的输入,直到完成所有num_steps个时间步的训练任务,整个batch就训练完成了。
6.8.对于每个时间步上的预测batch_size×vocab_size,num_steps个时间步上总的预测为(num_steps×batch_size,vocab_size),这是该batch的训练总输出。
6.9.使用损失函数计算batch中各个句子中每个token的概率损失,并取均值。
6.10.反向传播算法计算各个参数关于损失函数的梯度。
6.11.梯度裁剪修改梯度。
6.12.梯度下降算法修改参数值。
7.该batch训练完成。进行下一个batch训练,初始化隐藏状态h0、C0…。

五、预测过程举例******:

背景定义同训练过程,模型的预测过程如下。
1.输入prefix长度的前缀,来预测接下来num_preds个token。
2.首先还是将prefix转换成字典序并进行独热编码,尺寸为1×prefix×vocab_size,其中prefix=num_steps。
3.加载模型,初始化时序信息h0\、C0。
4.batch_size为1,在每个时间步上对句子长度每个token一维向量依次作为模型一个时间步的输入,输入维度1×vocab_size,总共计算prefix个时间步,循环计算prefix个时间步后的时序信息hp、Cp,hp和、Cp尺寸为1×num_hiddens(batch_size=1)。
5.将prefix最后一个token和hp、Cp作为模型输入,来预测num_preds个token的第一个token,输出预测结果pred1和时序信息hp1、Cp1,然后将pred1和hp1、Cp1作为输入预测pred2和hp2、Cp2(即使用预测值来预测下一个预测值),直到预测num_preds个预测值。(等价于batch=1,num_steps=num_preds的训练过程)
6.将预测值使用字典转为字符串输出。

六、底层源码:

代码中num_hiddens表示隐藏层神经元个数,由于遗忘门、输入门、输出门的输出维度相同,所以三个隐藏层的神经元个数也是一样的=num_hiddens。

并且除了初始化隐藏状态Ht外,还需要初始化记忆单元Ct。

import torch
from torch import nn
from d2l import torch as d2l

# 数据预处理,获取datalodaer和字典
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)


# 初始化可学习参数
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    def three():
        return (normal(
            (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()
    W_xf, W_hf, b_f = three()
    W_xo, W_ho, b_o = three()
    W_xc, W_hc, b_c = three()
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    params = [
        W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
        W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

# 初始化隐藏状态Ht和记忆单元Ct
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

# 定义LSTM模型
def lstm(inputs, state, params):
    [
        W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
        W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)


#训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

七、Pytorch版代码:

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1978552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

顶刊TPAMI 2024!无需全标注,仅用少量涂鸦标注即可获得确定和一致的语义分割预测结果...

本文介绍了山东大学,北京大学和纽约州立大学石溪分校合作开展的一项工作。该工作面向图像涂鸦弱标注语义分割任务,重点关注采用涂鸦弱标注时语义分割网络的不确定性和不一致性问题。 作者提出最小化熵损失函数和网络嵌入的随机游走过程来分别改善分割网络…

Altera之FPGA器件系列简介

目录 一、前言 二、命名规则 2.1 MAX V系列 2.2 Cyclone 系列 2.3 Arria 系列 2.4 Stratix 系列 2.5 Agilex 系列 三、器件划分 3.1 工艺制程 3.2 使用领域 四、参考 一、前言 Altera是作为FPGA领域的头部企业,是一家老牌的技术公司,成立于19…

【一图学技术】7.削峰与限流防刷技术解决方案及限流算法图解

削峰与限流防刷技术 一、削峰技术 ✈解决问题:解决流量大的问题,限制单机流量 🚀核心技术: 秒杀令牌:颁发给用户令牌,给予操作特权 秒杀大闸:限制令牌数量 队列泄洪:队列增加缓…

4_损失函数和优化器

教学视频:损失函数与反向传播_哔哩哔哩_bilibili 损失函数(Loss Function) 损失函数是衡量模型预测输出与实际目标之间差距的函数。在监督学习任务中,我们通常希望模型的预测尽可能接近真实的目标值。损失函数就是用来量化模型预…

神经网络基础--激活函数

🕹️学习目标 🕹️什么是神经网络 1.神经网络概念 2.人工神经网络 🕹️网络非线性的因素 🕹️常见的激活函数 1.sigmoid激活函数 2.tanh激活函数 3.ReLU激活函数 4.softmax激活函数 🕹️总结 &#x1f57…

计算机基础(Windows 10+Office 2016)教程 —— 第5章 文档编辑软件Word 2016(上)

第5章 文档编辑软件Word 2016 5.1 Word 2016入门5.1.1 Word 2016 简介5.1.2 Word 2016 的启动5.1.3 Word 2016 的窗口组成5.1.4 Word 2016 的视图方式5.1.5 Word 2016 的文档操作5.1.6 Word 2016 的退出 5.2 Word 2016的文本编辑5.2.1 输入文本5.2.3 插入与删除文本5.2.4 复制与…

二进制与进制转换与原码、反码、补码详解--内含许多超详细图片讲解!!!

前言 今天给大家分享一下C语言操作符的详解,但在此之前先铺垫一下二进制和进制转换与原码、反码、补码的知识点,都非常详细,也希望这篇文章能对大家有所帮助,大家多多支持呀! 操作符的内容我放在我的下一篇文章啦&am…

基于人工智能的口试模拟、LLM将彻底改变 STEM 教育

概述 STEM教育是一种整合科学(Science)、技术(Technology)、工程(Engineering)和数学(Mathematics)的教育方法。这种教育模式旨在通过跨学科的方式培养学生的创新能力、问题解决能力…

MySQL 高级 - 第十四章 | 事务基础知识

目录 第十四章 事务基础知识14.1 数据库事务概述14.1.1 存储引擎支持情况14.1.2 基本概念14.1.3 事务的 ACID 特性14.1.4 事务的状态 14.2 如何使用事务14.2.1 显示事务14.2.2 隐式事务14.2.3 隐式提交数据的情况14.2.4 使用举例14.2.4.1 提交与回滚14.2.4.2 测试不支持事务的 …

Yarn:一个快速、可靠且安全的JavaScript包管理工具

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,还请三连支持一波哇ヾ(@^∇^@)ノ) 目录 一、Yarn简介 二、Yarn的安装 1. 使用npm安装Yarn 2. 在macOS上…

11.redis的客户端-Jedis

1.Jedis 以redis命令作为方法名称,学习成本低,简单使用。但是jedis实例是不安全的,多线程环境下需要基于连接池来使用。 2.Lettuce lettuce是基于Netty实现的,支持同步,异步和响应式编程方式,并且是线程…

EmEditor 打开文档后光标如何默认定位到文档最后一行?

1、录制宏 (1)、点击工具栏上的红色录制宏按钮,开始录制宏。如图: (2)、按住快捷键Ctrl End快捷键,使光标跳转到文档末尾 (3)、完成录制后,再次点击录制按钮…

Hive SQL ——窗口函数源码阅读

前言 使用Starrocks引擎中的窗口函数 row_number() over( )对10亿的数据集进行去重操作,BE内存溢出问题频发(忘记当时指定的BE内存上限是多少了.....),此时才意识到,开窗操作,如果使用 不当,反而…

stm32工程配置

目录 STM32F103 start:启动文件、内核寄存器文件、外设寄存器文件、时钟配置文件 library:标准库函数(内核及外设驱动) user:用户文件、库函数配置文件、中断程序文件 添加宏定义 STM32F407 start目录 启动文件…

实战:使用Certbot签发免费ssl泛域名证书(主域名及其它子域名共用同一套证书)-2024.8.4(成功测试)

1、使用Certbot签发免费ssl泛域名证书 | One实战:使用Certbot签发免费ssl泛域名证书(主域名及其它子域名共用同一套证书)-2024.8.4(成功测试)https://wiki.onedayxyy.cn/docs/docs/Certbot-install/

Transformer相关介绍

1 Transformer 介绍 Transformer的本质上是一个Encoder-Decoder的结构。 1.1 编码器 在Transformer模型中,编码器(Encoder) 的主要作用是将输入序列(例如文本、语音等)转换为隐藏表示(或者称为特征表示…

24军dui文职联勤保障部报名照规格要求

24军dui文职联勤保障部报名照规格要求 #军队文职 #文职 #文职备考 #联勤保障部队 #文职考试 #文职上岸 #2024军队文职

python-查找元素3(赛氪OJ)

[题目描述] 有n个不同的数&#xff0c;从小到大排成一列。现在告诉你其中的一个数x&#xff0c;x不一定是原先数列中的数。你需要输出最后一个<x的数在此数组中的下标。输入&#xff1a; 输入共两行第一行为两个整数n、x。第二行为n个整数&#xff0c;代表a[i]。输出&#x…

练习2.30

2.29题目没有理解,暂时没有做出来,先把2.30做了 上代码 (defn square [x](* x x)) ;第一版,直接定义 (defn square-tree[tree](cond (not (seq? tree)) (square tree)(empty? tree) nil:else (cons (square-tree (first tree)) (square-tree (rest tree)))) ) ;使用map …

LeetCode刷题笔记 | 283 | 移动零 | 双指针 |Java | 详细注释

&#x1f64b;大家好&#xff01;我是毛毛张! &#x1f308;个人首页&#xff1a; 神马都会亿点点的毛毛张 原地移除元素2 LeetCode链接&#xff1a;283. 移动零 1.题目描述 给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元…