82.长短期记忆网络(LSTM)以及代码实现

news2024/10/5 18:27:04

1. 长短期记忆网络

  • 忘记门:将值朝0减少
  • 输入门:决定不是忽略掉输入数据
  • 输出门:决定是不是使用隐状态

2. 门

在这里插入图片描述

3. 候选记忆单元

在这里插入图片描述

4. 记忆单元

在这里插入图片描述

5. 隐状态

在这里插入图片描述

6. 总结

在这里插入图片描述

7. 从零实现的代码

我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

7.1 初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差 0.01 的高斯分布初始化权重,并将偏置项设为 0 。

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

7.2 定义模型

初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

# C和H都要初始化
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 𝐂𝑡 不直接参与输出计算

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

7.3 训练和预测

让我们通过实例化rnn_scratch中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在gru中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果:

在这里插入图片描述

8. 简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果:

在这里插入图片描述

实际情况下,LSTM和GRU用哪个都可以,性能差不多。

长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。 多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。 然而,由于序列的长距离依赖性,训练长短期记忆网络 和其他序列模型(例如门控循环单元)的成本是相当高的。 在后面的内容中,我们将讲述更高级的替代模型,如Transformer

9. Q&A

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170028.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于python手撕实现BP 神经网络实现手写数字识别(不调库,附完整版本代码)

本项目使用python实现全连接网络和梯度优化 方向传播并且实现了 手写数字识别项目: 神经网络 model 先介绍个三层的神经网络,如下图所示输入层(input layer)有三个 units( 为补上的 bias,通常设为 1)

安卓影像飞升时刻:vivo X90 Pro+打通HDR任督二脉

在手机产业中,大多数人会有一种刻板印象:一项新技术/功能,苹果发布会上展示意味着已经成熟,具有很高的产品完成度,好用且有效;而安卓厂商在发布会上展示出的一些炫酷技术,往往还需要时间观望&am…

多目标建模算法PLE

1. 概述 在现如今的推荐系统或者搜索中,都存在多个目标,多目标的算法在现如今的系统中已然成为了标配。在多目标的建模过程中,如果不同的学习任务之间较为相关时,多个任务之间可以共享一部分的信息,这样最终能够提升整…

Vue7-el和data的两种写法

1.el的两种写法 1创建Vue实例的时候通过el指定属性 2. 创建Vue实例之后,通过vm.$mount(#demo)进行挂载 console.log(v):此处的v是Vue的实例对象 在往下看__proto__属性,这里是Vue构造类的方法,其中的方法vue实例都可以使用,比如$…

Spring cache整合Redis详解 动态设置失效时间

文章目录1.spring cache简介2.spring cache集成redis3.spring cache与redisTemple统一格式4.SpEL标签5.Cacheable注解实现6.CachePut注解实现7.CacheEvict注解实现8.Caching注解实现9.自定义key生成器KeyGenerator10.自定义前缀CacheKeyPrefix11.多个CacheManager实现不同失效时…

【微信小程序】收藏功能的实现(条件渲染、交互反馈)

🏆今日学习目标:第十九期——收藏功能的实现(条件渲染、交互反馈) 😃创作者:颜颜yan_ ✨个人主页:颜颜yan_的个人主页 ⏰预计时间:35分钟 🎉专栏系列:我的第一个微信小程序 文章目录…

django框架

目录简介MVC与MTV模型MVCMTV创建项目目录生命周期静态文件配置(无用)启动django[启动](https://www.cnblogs.com/xiaoyuanqujing/articles/11902303.html)路由分组无名分组有名分组路由分发反向解析反向解析结合分组名称空间re_path与path自定义转换器视…

vue3项目怎么写路由 + 浅析vue-router4源码

在SPA项目里,路由router基本是前端侧处理的,那么vue3项目中一般会怎么去写router呢,本文就来讲讲vue-router4的一些常用写法,以及和Composition API的结合使用,同时简单讲讲实现原理,让你轻松理解前端route…

【04】FreeRTOS的任务挂起与恢复

目录 1.任务的挂起与恢复的API函数 1.1任务挂起函数介绍 1.2任务恢复函数介绍(任务中恢复) 1.3任务恢复函数介绍(中断中恢复) 2.任务挂起与恢复实验 3.任务挂起和恢复API函数“内部实现”解析 3.1vTaskSuspend() 3.2&#…

Prometheus基础

一、何为Prometheus Prometheus受启发于Google的Brogmon监控系统(相似的Kubernetes是从Google的Brog系统演变而来),从2012年开始由前Google工程师在Soundcloud以开源软件的形式进行研发,并且于2015年早期对外发布早期版本。2016年…

【基础】Netty 的基础概念及使用

Netty基本概念理解阻塞与非阻塞同步与异步BIO 与 NIOReactor 模型Netty 基本概念Netty 的执行流程Netty 的模块组件Netty 工作原理Netty 的基本使用Netty ServerNetty Client参考文章基本概念理解 阻塞与非阻塞 阻塞与非阻塞是进程访问数据时的处理方式,根据数据是…

系分 - 案例分析 - 系统维护与设计模式

个人总结,仅供参考,欢迎加好友一起讨论 文章目录系分 - 案例分析 - 系统维护与设计模式典型例题 1题目描述参考答案典型例题 2题目描述参考答案系分 - 案例分析 - 系统维护与设计模式 典型例题 1 题目描述 某企业两年前自主研发的消防集中控制软件系统…

05-requests添加Cookies与正则表达式

第5讲 requests添加Cookies与正则表达式 整体课程知识点查看 :https://blog.csdn.net/j1451284189/article/details/128713764 本讲总结 request代理使用 request SSL request添加Cookies 数据解析方法简介 数据解析:正则表达式讲解 一、requests 代理 …

【23种设计模式】学习汇总(未完结+思维导图)

获取思维导图翻至底部底部,基本概览博客内容(暂未完全完善,期待你的持续关注) 写作不易,如果您觉得写的不错,欢迎给博主来一波点赞、收藏~让博主更有动力吧! 一.相关内容 在软件工程中&#xf…

关系型数据库RDBMS | 字节青训营笔记

一、经典案例 1、红包雨案例 每年春节,抖音都会有红包雨获得 2、事务 事务(Transaction): 是由一组SQL语句组成的一个程序执行单元(Unit),它需要满足ACID特性 BEGIN; UPDATE account table SET balance balance - 小目标 WHERE name “抖音; UPDATE…

指数加权平均、动量梯度下降法

目录1.指数加权平均(exponentially weighted averages)这里有一年的温度数据。如果想计算温度的趋势,也就是局部平均值(local average),或者说移动平均值(moving average),怎么做?:当天的温度,:…

交换机的基本原理(特别是动态ARP、静态ARP、代理ARP)

第六章:交换机的基本配置 二层交换设备工作在OSI模型的第二层,即数据链路层,它对数据包的转发是建立在MAC(Media Access Control )地址基础之上的。二层交换设备不同的接口发送和接收数据独立,各接口属于不…

esxi宿主机进入维护模式虚拟机不会自动释放【不会自动迁移出去】解决方法、查看辨别宿主机本地空间和存储池、esxi进入存储内部清理空间

文章目录说明虚拟机不自动释放处理过程报错说明宿主机进入维护模式说明手动迁移报错说明直接启动虚拟机报错说明解决方法报错原因分析解决方法查看辨别宿主机本地空间esxi进入存储内部清理空间进入存储池内存储内部空间清理及原则存储空间说明说明 我当前的esxi主机版本为5.5 …

7亿人养活的眼镜行业,容不下一家县城小店

文|螳螂观察 作者| 青月 如果要盘点那些被暴利眷顾的行业,眼镜零售肯定榜上有名。 从上市企业的财报数据来看,国内眼镜零售行业的首家上市公司——博士眼镜,2021年前三季度的平均毛利率超过60%;国内镜片第一股明月眼镜在2021年…

【C进阶】文件操作

⭐博客主页:️CS semi主页 ⭐欢迎关注:点赞收藏留言 ⭐系列专栏:C语言进阶 ⭐代码仓库:C Advanced 家人们更新不易,你们的点赞和关注对我而言十分重要,友友们麻烦多多点赞+关注,你们…