【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

news2024/11/29 12:48:01

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1556152.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【学习笔记】java项目—苍穹外卖day04

文章目录 1. 新增套餐1.1 需求分析和设计1.2 代码实现1.2.1 DishController1.2.2 DishService1.2.3 DishServiceImpl1.2.4 DishMapper1.2.5 DishMapper.xml1.2.6 SetmealController1.2.7 SetmealService1.2.8 SetmealServiceImpl1.2.9 SetmealMapper1.2.10 SetmealMapper.xml1.…

HarborCDN技术分析

一、介绍 简要介绍 ​​Harbor​​ 是由VMware公司开源的企业级的Docker Registry管理项目,它包括权限管理(RBAC)、LDAP、日志审核、管理界面、自我注册、镜像复制和中文支持等功能。Harbor 的所有组件都在 Dcoker 中部署,所以 Harbor 可使用 Docker C…

NC269391 炸鸡块哥哥的粉丝题

题目描述 智乃作为炸鸡块哥哥的粉丝,做了一场炸鸡块哥哥的比赛后得出一个结论,那就是炸鸡块哥哥的话,最多只能信半句。 现在给你一个长度为N的字符串S,请输出前 个字符,表示只能相信半句话。 例如当炸鸡块哥哥说&…

既有理论深度又有技术细节——深度学习计算机视觉

推荐序 我曾经试图找到一本既有理论深度、知识广度,又有技术细节、数学原理的关于深度学习的书籍,供自己学习,也推荐给我的学生学习。虽浏览文献无数,但一直没有心仪的目标。两周前,刘升容女士将她的译作《深度学习计…

python安装删除以及pip的使用

目录 你无法想象新手到底会在什么地方出问题——十二个小时的血泪之言! 问题引入 python modify setup 隐藏文件夹 环境变量的配置 彻底删除python 其他零碎发现 管理员终端 删不掉的windous应用商店apps 发现问题 总结 你无法想象新手到底会在什么地方…

(学习日记)2024.03.27:UCOSIII第二十四节:任务状态

写在前面: 由于时间的不足与学习的碎片化,写博客变得有些奢侈。 但是对于记录学习(忘了以后能快速复习)的渴望一天天变得强烈。 既然如此 不如以天为单位,以时间为顺序,仅仅将博客当做一个知识学习的目录&a…

Java毕业设计 基于SSM新闻管理系统

Java毕业设计 基于SSM新闻管理系统 SSM jsp 新闻管理系统 功能介绍 用户:首页 图片轮播 查询 登录 注册 新闻正文 评论 广告 社会新闻 天下新闻 娱乐新闻 个人中心 个人收藏 管理员:登录 用户管理 新闻管理 新闻类型管理 角色:用户 管理员…

笔记本电脑上部署LLaMA-2中文模型

尝试在macbook上部署LLaMA-2的中文模型的详细过程。 (1)环境准备 MacBook Pro(M2 Max/32G); VMware Fusion Player 版本 13.5.1 (23298085); Ubuntu 22.04.2 LTS; 给linux虚拟机分配8*core CPU 16G RAM。 我这里用的是16bit的量化模型,…

python实战之进阶篇(一)

定义类 1. 构造方法 2. 实例方法 3. 类方法 类似于Java中的静态方法, 使用方式: 类名.类方法 4. 私有变量 5. 私有方法 6. 使用属性set和get

stm32再实现感应开关盖垃圾桶

一、项目需求 检测靠近时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 发生震动时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 按下按键时,垃圾桶自动开盖并伴随滴一声,2秒后关盖 硬件清单 SG90 舵机,…

MySQL生产环境常见故障及解决方案汇总

MySQL生产环境常见故障及解决方案汇总 1. MySQL主从同步异常故障1.1. 情景说明1.2. 排查过程1.3. 数据同步2. MySQL慢查询故障1. MySQL主从同步异常故障 1.1. 情景说明 MySQL主库网卡需要更换IP地址,并将原IP地址配置为MySQL集群的VIP地址,上层应用程序其实不需要更改连接My…

牛客练习赛123 A~C

A.炸鸡块哥哥的粉丝题 输出字符串的前 ⌈ n 2 ⌉ \lceil \frac{n}{2} \rceil ⌈2n​⌉ 个字符 void solve() {int n;string s;cin >> n >> s;cout << s.substr(0, (n 1) / 2); }B.智乃想考一道鸽巢原理 当小球总个数为奇数时&#xff0c;贪心的留下 1 个…

(C++笔试题)选择题+编程题

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 选择题 第一道 下面对析构函数的正确描述是&#xff08;&#xff09; A. 系统不能提供默认的析构函数B. 析构函数必须由用户定义C. 析构函数没有参数D. 析构函数可以设置默认参数 解析&#xff1a; 正确描述析构函数的…

【独立开发前线】Vol.27 为什么独立开发者需要一个网站?

现在很多内容创造者都把主要平台放在了第三方平台上&#xff0c;包括像知乎、B站、头条等等&#xff0c;但即使在2024年&#xff0c;我依然建议你做一个完全属于你的网站。 为什么呢&#xff1f; 你有没有在微信或知乎看到过这种拦截页面&#xff1f; 你花了好大的精力写了一…

关于github提交失败的问题

问题描述 Username for https://github.com: LAL-Better Password for https://LAL-Bettergithub.com: remote: Support for password authentication was removed on August 13, 2021. remote: Please see https://docs.github.com/get-started/getting-started-with-git/abo…

线程的通信

1.需求(为什么需要线程通信) 当我们需要多个线程完成同一任务时&#xff0c;并且希望他们有规律的执行&#xff0c;那么多线程之间需要一些通信机制&#xff0c;并且可以协调他们的工作&#xff0c;以此实现多个线程共同操作共享数据. 例 : A做包子&#xff0c;B吃包子&#…

SAP Fiori开发中的JavaScript基础知识9 - 代码注释,严格模式,JSON

1 背景 本文将介绍JavaScript编程中的三个小知识点&#xff1a;也即代码注释&#xff0c;严格模式&#xff0c;JSON文件。 2 代码注释 JavaScript的代码注释方式如下&#xff1a; // Single line comment/* Multi line comment */3 严格模式 JavaScript的"strict mod…

vue3封装Element表格自适应

表格高度自适应 分页跟随表格之后 1. 满屏时出现滚动条 2. 不满屏时不显示滚动条 坑 表格设置maxHeight后不出现滚动条 解决方案 表格外层元素设置max-height el-table–fit 设置高度100% .table-box {max-height: calc(100% - 120px); } .el-table--fit {height: 100%; }示例代…

rust使用Command库调用cmd命令或者shell命令,并支持多个参数和指定文件夹目录

想要在不同的平台上运行flutter doctor命令&#xff0c;就需要知道对应的平台是windows还是linux&#xff0c;如果是windows就需要调用cmd命令&#xff0c;如果是linux平台&#xff0c;就需要调用sh命令&#xff0c;所以可以通过cfg!实现不同平台的判断&#xff0c;然后调用不同…

[flask]http请求//获取请求头信息+客户端信息

在网站中查询请求头信息&#xff0c;可以通过以下操作进行 右键然后选择检查 进入改页面后选择文档&#xff0c;刷新一下页面就好了 获取所有的请求头信息 print(request.headers, type(request.headers)) 在flask模块中&#xff0c;使用上面的输出函数就可以查看到有关于请求…