【动手学深度学习】--长短期记忆网络LSTM

news2025/1/11 17:56:02

文章目录

  • 长短期记忆网络LSTM
    • 1.门控记忆元
      • 1.1输入门、忘记门、输出门
      • 1.2候选记忆元
      • 1.3记忆元
      • 1.4隐状态
    • 2.从零实现
      • 2.1加载数据集
      • 2.2初始化模型参数
      • 2.3定义模型
      • 2.4 训练与预测
    • 3.简洁实现

长短期记忆网络LSTM

学习视频:长短期记忆网络(LSTM)【动手学深度学习v2】

官方笔记:长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这一问题的最早方法之一是长短期存储器(LSTM),它有许多与GRU一样的属性,有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

1.门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。 为了控制记忆元,我们需要许多门。 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

1.1输入门、忘记门、输出门

  • 忘记门:将值朝0减少
  • 输入门:决定不是忽略掉输入数据
  • 输出门:决定是不是使用隐状态

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中,如下图所示, 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在(0,1)的范围内。

image-20230909153058460

1.2候选记忆元

image-20230909153136156

1.3记忆元

image-20230909153206499

1.4隐状态

image-20230909153227773

2.从零实现

2.1加载数据集

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

2.2初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

2.3定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 C t C_t Ct不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

2.4 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153543972

3.简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

image-20230909153613692

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/993378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DP-modeler建模

1、打开软件,新建工程,导入模型,如下: 2、建立一个立体模型,结果如下图:

jmeter安装了插件,但是添加时无插件选项

想用阶梯加压,然后需要安装插件,按照网上教程,下载插件管理器,使用插件管理器安装好jpgc后(如图一,已勾选,说明已安装), 然后重启打开jmeter,添加线程组下一级…

python知识:有效使用property装饰器

一、说明 Python是唯一有习语(idioms)的语言。这增强了它的可读性,也许还有它的美感。装饰师遵循Python的禅宗,又名“Pythonic”方式。装饰器从 Python 2.2 开始可用。PEP318增强了它们。下面是一个以初学者为中心的教程&#xff…

Jdk1.7之ConcurrentHashMap源码总结

文章目录 一、常见属性1. 初始化容量2. 加载因子3. 并发级别 二、重要方法1. 构造方法2. ConcurrentHashMap#put方法2.1 ConcurrentHashMap#put#ensureSegment2.2 ConcurrentHashMap#Segment#put2.2.1 Segment#put#scanAndLockForPut2.2.2 Segment#put#rehash 3. ConcurrentHas…

linux内核如何根据文件名索引到文件内容

https://zhuanlan.zhihu.com/p/78724124 根据文件名索引到文件内容 表面上,用户通过文件名,打开文件。实际上,系统内部这个过程分成三步:首先,系统找到这个文件名对应的inode号码;其次,通过in…

迅为RK3568开发板驱动指南第六篇-平台总线

文档教程更新至第六篇 第1篇 驱动基础篇 第2篇 字符设备基础 第3篇 并发与竞争 第4篇 高级字符设备进阶 第5篇 中断 第6篇 平台总线 未完待续,持续更新中... 视频教程更新至十一期 第一期_驱动基础 第二期_字符设备基础 第三期_并发与竞争 第四期_高级字…

解释模块化开发及其优势,并介绍常用的模块化规范。

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 模块化开发⭐ 模块化开发的优势⭐ 常用的模块化规范⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅!这个专栏是…

公众号hanniman往期精选

目录 一、AI产品分析(10篇) 二、AI产品经理(10篇) 三、AI技术(10篇) 四、AI行业及个人成长(10篇) 一、AI产品分析 1、【重点】深度 | 关于AIGC商业化的13个非共识认知(80…

华为OD机试 - 滑动窗口最大和

滑动窗口的经典题型&#xff0c;重复题目 #include <stdio.h> #include <string.h> #include <stdlib.h>#define MAX(a,b) ((a) > (b) ? (a) : (b)) int main() {int n;scanf("%d", &n);int *list malloc(sizeof(int) * n);for (int i …

[学习笔记]DeepWalk图神经网络论文精读

参考资料&#xff1a;DeepWalk【图神经网络论文精读】 word2vec 相关论文&#xff1a; Efficient Estimation of Word Representations in Vector Space Distributed Representations of Words and Phrases and their Compositionality 随机游走Ramdom Walk简述 通过随机游…

LLMs之Falcon 180B:Falcon 180B的简介、安装、使用方法之详细攻略

LLMs之Falcon 180B&#xff1a;Falcon 180B的简介、安装、使用方法之详细攻略 导读&#xff1a;Falcon-180B是一个由TII发布的模型&#xff0c;它是Falcon系列的升级版本&#xff0c;是一个参数规模庞大、性能优越的开放语言模型&#xff0c;适用于各种自然语言处理任务&#x…

Net跨平台UI框架Avalonia入门-样式详解

设计器的使用 设计器预览 在window和usercontrol中&#xff0c;在代码中修改了控件&#xff0c;代码正确情况下&#xff0c;设计器中就可以实时看到变化&#xff0c;但是在样式&#xff08;Styles&#xff09;文件中&#xff0c;无法直接看到&#xff0c;需要使用设计器预览D…

uni-app 前面项目(vue)部署到本地win系统Nginx上

若依移动端的项目&#xff1a;整合了uview开源ui框架&#xff0c; 配置后端请求接口基本路径地址&#xff1a; 打包复现到nginx下&#xff1a; 安装个稳定版本的&#xff1a;nginx-1.24.0 部署配置&#xff1a; 增加了网站&#xff1a;8083端口的&#xff0c; 网站目录在ngi…

Reactor

1.epoll底层工作原理 creat: 红黑树 就绪队列 回调机制 control: 用户告诉内核做什么事情&#xff0c;就是操作红黑树 wait: 操作就绪队列 2.LT ET模式 3.Reactor 4.前摄式

懒汉式逆向APK

作者&#xff1a;果然翁 通过各方神仙文档,以及多天调试,整理了这篇极简反编译apk的文档(没几个字,吧).轻轻松松对一个apk(没壳的)进行逆向分析以及调试.其实主要就是4个命令. 准备 下载apktool &#xff1a;https://ibotpeaches.github.io/Apktool/install/下载Android SDK …

被问到: http 协议和 https 协议的区别怎么办?别慌,这篇文章给你答案

前言 作为软件测试师&#xff0c;大家都知道一些常用的网络协议是我们必须要了解和掌握的&#xff0c;比如 HTTP 协议&#xff0c;HTTPS 协议就是两个使用非常广泛的协议&#xff0c;所以也是面试官问的面试的时候问的比较多的两个协议&#xff1b;因为这两个协议有相似和关联的…

CSP 202209-1 如此编码

答题 题目就是字多 #include<iostream>using namespace std;int main() {int n,m;cin>>n>>m;int a[n],c[n1];c[0]1;for(int i0;i<n;i){cin>>a[i];c[i1]c[i]*a[i];}for(int i0;i<n;i){cout<<(m%c[i1]-m%c[i])/c[i]<< ;} }

基于 Flink CDC 高效构建入湖通道

摘要&#xff1a;本文整理自阿里云 Flink 数据通道负责人、Flink CDC 开源社区负责人&#xff0c; Apache Flink PMC Member & Committer 徐榜江&#xff08;雪尽&#xff09;&#xff0c;在 Streaming Lakehouse Meetup 的分享。内容主要分为四个部分&#xff1a; 1. Flin…

2.4 选择结构语句

选择结构语句根据是否满足某个条件确定执行哪些操作&#xff0c;分为if条件语句和switch条件语句。 1. 单分支if语句 &#xff08;1&#xff09;if语句是指如果满足某种条件&#xff0c;就进行某种处理&#xff0c;格式如下。 if(条件) {// 执行语句 } 根据上述格式中&…

有限状态机的概念

一、有限状态机的概念 有限状态机简称状态机&#xff0c;是表示有限个状态&#xff0c;以及在状态之间的转移和动作等行为的数学模型。状态机的要素有状态和状态转移两个。 在Unity中&#xff0c;动画状态机最重要的属性就是节点和连线&#xff0c;其中每个节点都是一个动画片…