《动手学深度学习》-57长短期记忆网络LSTM

news2024/11/22 14:00:06

沐神版《动手学深度学习》学习笔记,记录学习过程,详细的内容请大家购买书籍查阅。

b站视频链接
开源教程链接

长短期记忆网络(LSTM)

长期以来,隐变量模型存在长期信息保存和短期输入缺失的问题。解决这一问题的最早方法之一是长短期记忆网络。

长短期记忆网络的设计灵感来自于计算机的逻辑门。

在这里插入图片描述

门控记忆元:
遗忘门:将值向0减少
输入门:决定是不是忽略掉输入数据
输出门:决定是不是使用隐状态

在这里插入图片描述

由3个带有 s i g m o i d sigmoid sigmoid激活函数的全连接层处理:

在这里插入图片描述

候选记忆单元,与上面描述的3个门类似,但使用tanh函数作为激活函数,函数值范围为(-1,1):

在这里插入图片描述

记忆单元:输入门 I t I_t It控制采用多少来自 C ~ t \tilde{C}_t C~t的新数据,而遗忘门 F i F_i Fi控制保留多少过去的记忆元 C t − 1 C_{t-1} Ct1的内容。

如果遗忘门始终为1且输入门始终为0,则过去的记忆元 C t − 1 C_{t-1} Ct1将随时间被保存并传递到当前时间步。引入这种设计是为了缓减梯度消失问题,并更好地捕获序列中的长距离依赖关系。

在这里插入图片描述

最后定义隐状态 H t H_t Ht的计算,这也是输出门发挥作用的地方。在长短期记忆网络中,它仅仅是记忆元的tanh的门控版本。这就确保了 H t H_t Ht的值始终在区间(-1,1)内。

只要输出门接近1,我们就能有效地将所有记忆信息传递给预测部分,而对于输出门接近0,我们只保留记忆元内的所有信息,而不需要更新隐状态。

有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。

在这里插入图片描述

总结

长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。然后由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(如门控循环单元)的成本是相当高的。Transformer是其高级替代模型。

LSTM可以缓解梯度爆炸和梯度消失。

只有隐状态会传递到输出层(Y),而记忆元完全属于内部信息。

在这里插入图片描述

动手学

长短期记忆网络-LSTM

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    
    return params
# H 和 C 的初始化
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):

    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    
    outputs = []
    for X in inputs:

        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i) # 输入门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f) # 遗忘门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o) # 输出门
        
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c) # 候选记忆单元
        
        C = F * C + I * C_tilda # 记忆元
        
        H = O * torch.tanh(C) # 隐状态

        Y = (H @ W_hq) + b_q # 输出
        outputs.append(Y)
        
    return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

简洁实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/950682.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

bazel入门学习笔记

简介 Bazel Google开源的,是一款与 Make、Maven 和 Gradle 类似的开源构建和测试工具。 它使用人类可读的高级构建语言。Bazel 支持多种语言的项目,可为多个平台构建输出。Bazel支持任意大小的构建目标,并支持跨多个代码库和大量用户的大型代…

webpack loader和plugins的区别

在Webpack中,Loader和Plugin是两个不同的概念,用于不同的目的。 Loader是用于处理非JavaScript模块的文件的转换工具。它们将文件作为输入,并将其转换为Webpack可以处理的模块。例如,当您在Webpack配置中使用Babel Loader时&…

深入浅出AXI协议(3)——握手过程

一、前言 在之前的文章中我们快速地浏览了一下AXI4协议中的接口信号,对此我们建议先有一个简单的认知,接下来在使用到的时候我们还会对各种信号进行一个详细的讲解,在这篇文章中我们将讲述AXI协议的握手协议。 二、握手协议概述 在前面的文章…

文具寄到墨西哥可以走墨西哥专线吗?

文具寄到墨西哥可以选择墨西哥专线进行运输。墨西哥专线是一种专门为墨西哥进口货物提供的物流服务,其优势在于能够提供快速、高效和可靠的运输服务,以及专业的清关和包税服务。 1.墨西哥专线可以提供快速的运输服务。 一般而言,墨西哥专线的…

Mysql中九种索引失效场景分析

表数据: 索引情况: 其中a是主键,对应主键索引,bcd三个字段组成联合索引,e字段为一个索引 情况一:不符合最左匹配原则 去掉b1的条件后就不符合最左匹配原则了,导致索引失效 情况二&#xff…

从LeakCanary看内存快照解析

在从LeakCanary看内存快照生成一节中,我们已经了解了hprof的生成,并且将生成的hprof文件通过Android Studio进行解析,确实发现了内存泄漏对象MainActivity,但是在实际开发中,要求开发者自己去手动pull hprof文件进行解…

应急物资管理系统|智物资DW-S300提升应急响应能力

项目背景 智慧应急物资管理系统(智装备DW-S300)是一套成熟系统,依托互3D技术、云计算、大数据、RFID技术、数据库技术、AI、视频分析技术对RFID智能仓库进行统一管理、分析的信息化、智能化、规范化的系统。 本项目采用东识智慧应急物资管理…

ubuntu系统安装tensorRT-8.6.1版本(2023-8月最新版)

目录 前言pip安装可能出现的报错: tar.gz安装 前言 看了无数教程和b站视频,啊啊啊啊啊啊啊啊啊啊啊tensorRT要我狗命啊。我要写全网tensorRT最全的博客!!! 总体来说成功安装方式有两种,pip安装和tar.gz安装(其实官网安装方式居多…

d3dx9_29.dll丢失如何修复?dll修复工具下载方法

大家好!今天,我将为大家介绍一个与我们日常生活息息相关的话题——电脑d3dx9_29.dll丢失的6种修复方法。作为一名计算机专业的学生,我深知这个文件对我们电脑运行的重要性。在接下来的时间里,我将带领大家了解d3dx9_29.dll的作用、…

C#关于WebService中File.Exists()处理远程路径的异常记录

目录 前言方案一打开网站对应的程序池的高级设置按下图步骤设置凭据重启网站若方案一未能解决,请继续尝试方案二👇 方案二从控制面板进入到 凭据管理器为windows凭据添加凭据点击**Windows凭据**,并点击**添加Windows凭据**键入远程路径的地址…

Java之API详解之Biginteger类的详解

6 BigInteger类 6.1 引入 平时在存储整数的时候,Java中默认是int类型,int类型有取值范围:-2147483648 ~ 2147483647。如果数字过大,我们可以使用long类型,但是如果long类型也表示不下怎么办呢? 就需要用…

DC-DC 升压电路、 升压模块原理

一、什么是 DC-DC 转换器? DC-DC 转换器是一种电力电子电路,可有效地将直流电从一个电压转换为另一个电压。 DC-DC 转换器在现代电子产品中扮演着不可或缺的角色。这是因为与线性稳压器相比,它们具有多项优势。尤其是线性稳压器会散发大量热量…

全民健康生活方式行动日,天猫健康联合三诺生物推出“15天持续测糖计划”

糖尿病是全球高发慢性病中患病人数增长最快的疾病,是导致心血管疾病、失明、肾衰竭以及截肢等重大疾病的主要病因之一。目前中国有近1.4亿成人糖尿病患者,科学的血糖监测和健康管理对于糖尿病患者来说至关重要。 在9月1日全民健康生活方式行动日前夕&am…

Shell编程之运算符

目录 算数运算符 关系运算符 文件运算符 逻辑运算符 算数运算符 注意: 原生bash不支持简单的数学运算,但是可以通过其他命令来实现, 例如 expr 常用算数运算符 加-减*乘/除%取余 示例如下: A2 B3 expr $[$A$B] expr $[$A-$…

C语言的数据类型简介

一、基本类型 (1)六种基本类型 **字符串常量和字符常量的不同 1)‘a’为字符常量,”a”为字符串常量 2)每个字符串的结尾,编译器会自动添加一个结束标志位‘\0’ “a”包含两个字符’a’和’\0’ &#x…

FPGA优质开源项目 – UDP万兆光纤以太网通信

本文开源一个FPGA项目:UDP万兆光通信。该项目实现了万兆光纤以太网数据回环传输功能。Vivado工程代码结构和之前开源的《UDP RGMII千兆以太网》类似,只不过万兆以太网是调用了Xilinx的10G Ethernet Subsystem IP核实现。 下面围绕该IP核的使用、用户接口…

Qt6.5安装教程——国内源

为什么离线包没了? Qt6开始非商业授权下,不再提供离线安装方式的exe,但源码安装费时费力,所以推荐安装方式已经为在线组件安装方式,包括vs2022、Qt在线安装工具已经成为开发工具新的安装趋势。 Qt是不是要放弃开源&…

服务器日志出现大量NTLM(NT LAN Manager)攻击

日志名称:Security 来源: Microsoft-Windows-Security-Auditing 日期: 2023/8/30 20:57:40 事件 ID:4625 任务类别:登录 级别: 信息 关键字: 审核失败 用户: 暂缺 计算机: WIN-QBJ3ORTR0CF 描述: 帐户登录失败。 主题: 安全 ID:NULL SID 帐户名:- 帐户域:- …

【软考】系统集成项目管理工程师(二)信息系统服务管理【2分】

一、信息技术服务标准 1、组成要素 组成要素描述人员提供IT服务所需的人员及其知识、经验和技能要求【正确选人】流程提供IT服务时,合理利用必要的资源,将输入转化为输出的一组相互关联和结构化的活动【正确做事】技术交付满足质量要求的IT服务应使用的…

搭建个人备忘录中心服务memos、轻量级笔记服务

目录 一、源码 二、官网 三、搭建 四、使用 一、源码 GitHub - usememos/memos: A privacy-first, lightweight note-taking service. Easily capture and share your great thoughts. 二、官网 memos - Easily capture and share your great thoughts 三、搭建 docke…