深度学习——长短期记忆网络LSTM(笔记)

news2024/11/27 4:34:25

 

长短期记忆网络LSTM:

①隐变量模型存在长期信息保存和短期输入缺失问题,解决方法是LSTM

②发明于90年代

③使用效果和GRU差别不大,但是实现起来复杂

1.长短期记忆网络

①忘记门Ft:将值朝0减少

②输入门It:是否忽略输入数据

③输出门Ot:是否使用隐状态

2.门

 

 类似于GRU,当前时间步的输入和前一个时间步的隐状态作为数据送入LSTM中。由三个具有sigmoid激活函数的全连接层处理,计算输入门,遗忘门,输出门的值(三个门的值在0~1)

3.候选记忆单元

 

 使用tanh函数作为激活函数,函数值在-1~1之间

4.记忆单元

 

 

①在LSTM中,通过输入门和遗忘门控制输入和遗忘:输入门lt控制采用多少来自Ct的新数据,遗忘门控制保留过去多少记忆元C(t-1)的内容

②遗忘门是1输入门是0,过去的记忆元C(t-1)传递当前时间步

③上一时刻的记忆单元作为状态输入到模型

③LSTM由两个状态:H和C

5.隐状态

 

 

①在LSTM中,tanh的作用是将Ct的值限制在-1~1之间

②Ot为1的时候有效地将所有记忆信息传递给预测部分。Ot为0丢弃当前的 Xt 和过去所有的信息,只保留记忆元内的所有信息,而不需要更新隐状态

【总结】

 

1   LSTM 和 GRU 所想要实现的效果是差不多的,但是结构更加复杂

C :一个数值可能比较大的辅助记忆单元

C 中包含两项:  当前的 Xt 和过去的状态(在 GRU 中只能二选一,这里可以实现两个都选)

2  长短期记忆网络包含三种类型的门:输入门、遗忘门和输出门

3  长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息

4  长短期记忆网络可以缓解梯度消失和梯度爆炸

5  长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型

【代码】

# 加载《时光机器》数据集
import torch
from torch import nn
from d2l import torch as d2l

# 加载《时光机器》数据集
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
# 初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01  # 标准差为的高斯分布中提取权重

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
# 定义模型
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device)
            )


def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)
# 训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()
# 简洁实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/167007.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最容易理解的并查集详解

并查集 并查集,在一些有N个元素的集合应用问题中,我们通常是在开始时让每个元素构成一个单元素的集合,然后按一定顺序将属于同一组的元素所在的集合合并,其间要查找一个元素在哪个集合中。 比如下面这幅图,总共有 10 …

MySQL之存储过程

MySQL存储过程1、基本介绍1.1、介绍存储过程:1.2、特点1.3、基本语法1.3.1、delimiter1.3.1、创建存储过程1.3.2、调用存储过程1.3.3、查看存储过程1.3.4、删除存储过程2、变量2.1、系统变量2.1.1、查询(会话、全局、模糊、精确)2.1.2、设置系统变量2.2、用户定义变…

IB学生必须具备的三大特质

以往的专栏亦提及过,修读IB课程要面对几大挑战。而要应对这些挑战,IB学生须具备以下三大条件: 时间管理能力 IBDP 首先,要对时间分配掌握得很好。两年的IB预科课程非常紧凑,不但每科都有其内部评核(Interna…

VMware17虚拟机安装Ubuntu最新版本(Ubuntu22.04LTS)详细步骤

目录 一、概述 二、下载Ubuntu 22.04.1 LTS 三、在VMware虚拟机下安装Ubuntu22.04 四、配置网络 一、概述 Ubuntu是基于Linux内核开发的,免费下载,使用和分享的开源系统。如果需要在Linux下开发程序,这是一个很好的选择。本文介绍了Ubuntu最…

【问题解决】Tomcat启动服务时提示Filter初始化或销毁出现java.lang.AbstractMethodError错误

问题背景 最近在开发项目接口,基于SpringBoot 2.6.8,最终部署到外置Tomcat 8.5.85 下,开发过程中写了一个CookieFilter,实现javax.servlet.Filter接口,代码编译期正常。部署到外置Tomcat 8.5.85 下,在控制…

【Java寒假打卡】Java基础-类加载器

【Java寒假打卡】Java基础-类加载器概述类加载时机类加载的过程-加载类加载的过程-链接类加载的过程-初始化类加载器的分类类加载器-双亲委派模型类加载器-常用方法概述 负责将字节码文件加载到内存中 类加载时机 创建类的实例对象调用类的类方法访问类或者接口的类变量&am…

SymPy符号运算库与latex数学公式

SymPy符号运算库与latex数学公式sympylatexsympy SymPy是一个用于以符号运算为主的符号数学的Python库。它的目标是成为一个全功能的计算机代数系统(CAS),同时保持代码尽可能的简单,以便易于理解和易于扩展。SymPy完全是用Python编写的。 官网地址:http…

【linux kernel】Linux设备驱动模型 | bus

文章目录一、导读二、与总线相关的数据结构(2-1)struct bus_type(2-2)struct subsys_private三、总线的初始化四、总线的操作接口(4-1)总线的注册(4-2)总线的注销(4-3&am…

Linux的基本使用在Linux上部署程序

linux概述 Linux严格意义来说只是一个"操作系统内核",一个完整的操作系统 操作系统内核 配套的应用程序 由于 Linux 是一个完全开源免费的内核,因此有些公司/开源组织又基于 Linux 内核,提供了不同的配套程序,这就构…

GAN“家族”又添新成员——EditGAN,不但能自己修图,还修得比你我都好

导语:从风格迁移到特征解耦、语言概念解耦,研究人员正通过数学和语言逐步改善GAN的功能。作者 | 莓酊编辑 | 青暮首先想让大家猜一猜,这四张图中你觉得哪张是P过的?小编先留个悬念不公布答案,请继续往下看。生成对抗网…

【蓝桥杯】历届真题 时间显示(省赛)Java

【问题描述】 小蓝要和朋友合作开发一个时间显示的网站。在服务器上,朋友已经获取了当前的时间,用一个整数表示,值为从1970年1月1日O0:00:00到当前时刻经过的毫秒数。 现在,小蓝要在客户端显示出这个时间。小蓝不用显示出年月日&a…

Allegro如何灌铜操作指导

Allegro如何灌铜操作指导 在做PCB设计平面层的铜皮时候,会需要用到灌铜的操作,如下图 灌铜可以让铜皮自动沿着Antietch画指定网络的铜皮 具体操作如下 点击Add Line命令选择Anti Etch的层面,比如Anti Etch画在L2层,线宽设置为40mil

TCP通信的三次握手和四次挥手详解

TCP通信的三次握手和四次挥手详解 计算机网络参考模型: 应用层:例如Modbus、Http、FTP 传输层:TCP、UDP 网络层:IP 数据链路层:MAC 物理层:RS485、RS232、以太网 TCP的包头: TCP包头为至少20字节 TCP包头解释  源端口号、目的端口号,用于建立连接时,确认源端口(本机…

2.Spring 等框架简单入门了解

1.Spring 1.什么是spring? 一个轻量级Java开发框架,目的是为了解决企业级应用开发 的业务逻辑层和其他各层的耦合问题. 两个核心特性,也就是依赖注入(dependency injection,DI)和面向切面编程(aspect- oriented programming,AOP) 2.IOC(控制…

一文带你秒懂十大排序

目录 一、排序的概述 二、插入排序 1、直接插入排序 2、希尔排序 二、选择排序 1、直接选择排序 2、堆排序 三、交换排序 1、冒泡排序 2、快速排序 四、归并排序 五、计数排序 六、基数排序 七、桶排序 八、排序总结 一、排序的概述 排序就是将一组…

pod私有库

私有库制作步骤 1、在gitlab上创建一个空项目,并用source tree导到本地,便于后面代码更新上传 2、cd 到项目下 执行pod lib create 【组件名】如:pod lib create TDAlertView 输入命令后会显示下载模板,会有几秒钟等待 Cloni…

一文搞懂 python 中的 classmethod、staticmethod和普通的实例方法的使用场景

什么是类方法(classmethod)/静态方法(staticmethod)和普通成员方法? 首先看这样一个例子: class A(object):def m1(self, n):# 属于实例对象,self 指代实例对象,print("self:…

Allegro如何更改钻孔孔符以及大小操作指导

Allegro如何更改钻孔孔符以及大小操作指导 PCB设计完成时,需要放出整板的钻孔表来,有的钻孔孔符以及大小并不是需要的,Allegro支持更改钻孔符以及大小,如下图 需要更改孔符以及大小, 具体操作如下 选择Manufacture选择NC

aws parallelcluster 理解 parallelcluster 集群的配置和使用

参考资料 Setup AWS ParallelCluster 3.0 with AWS Cloud9 200 HPC For Public Sector Customers 200 HPC pcluster workshop 200 Running CFD on AWS ParallelCluster at scale 400 Tutorial on how to run CFD on AWS ParallelCluster 400 Running CFD on AWS ParallelC…

CSS 伪元素也可以被用于反爬案例?来学习一下。26

先说一下什么是 CSS 中的伪元素,CSS 伪元素的概念是指在 CSS 中使用的一些特殊的元素,它们不存在于 HTML 文档中,而是由浏览器生成的元素,用于提供额外的样式控制。这些伪元素在 HTML 代码中不存在,但可以在 CSS 中通过…