4.6.长短期记忆网络(LSTM)

news2024/11/23 15:06:34

长短期记忆网络(LSTM)

​ 长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。

1.门控记忆元

​ 为了控制记忆元,我们需要许多门。 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。

1.1.输入门、忘记门和输出门

在这里插入图片描述

  • 忘记门:将值朝0减少
  • 输入门:决定是否忽略掉输入数据
  • 输出门:决定是否使用隐状态

计算方法如下:
I t = σ ( X t W x i + H t − 1 W h i + b i ) F t = σ ( X t W x f + H t − 1 W h f + b f ) O t = σ ( X t W x o + H t − 1 W h o + b o ) I_t = \sigma(X_tW_{xi}+H_{t-1}W_{hi}+b_i)\\ F_t = \sigma(X_tW_{xf}+H_{t-1}W_{hf}+b_f)\\ O_t = \sigma(X_tW_{xo}+H_{t-1}W_{ho}+b_o) It=σ(XtWxi+Ht1Whi+bi)Ft=σ(XtWxf+Ht1Whf+bf)Ot=σ(XtWxo+Ht1Who+bo)

1.2.候选记忆元

在这里插入图片描述

​ 与上述三道门计算方法类似,但使用tang函数作为激活函数:
C ^ t = t a n h ( X t W x c + H t − 1 W h c + b c ) \hat{C}_t = tanh(X_t W_{xc}+H_{t-1}W_{hc}+b_c) C^t=tanh(XtWxc+Ht1Whc+bc)

​ 决定当前输入的重要性

1.3. 记忆元

在这里插入图片描述

C t = F t ⊙ C t − 1 + I t ⊙ C ^ t C_t = F_t\odot C_{t-1}+I_t\odot \hat{C}_t Ct=FtCt1+ItC^t

​ 遗忘门控制保留多少过去的记忆元 C t − 1 C_{t-1} Ct1的内容,按元素乘法

1.4 隐状态

在这里插入图片描述

H t = O t ⊙ t a n h ( C t ) H_t = O_t\odot tanh(C_t) Ht=Ottanh(Ct)
​ 输出门发挥作用,决定是否将当前信息传递出去

2.代码实现

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params


def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))


def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

'''简洁实现'''
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1987316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

萱仔求职系列——1.1 机器学习基础知识复习

由于我最近拿到offer还是想再找找更好的机会,目前有很多的面试,面试的时候很多面试官会问一些机器学习的基础知识,由于我上一段实习的时候主要是机器学习和部分深度学习的内容,为了避免在面试的时候想不起来自己学习的内容&#x…

MPU6050的STM32数据读取

目录 1. 概述2. STM32G030对MPU6050的读取3. STM32F1xx对MPU6050的读取 1. 概述 项目中,往往需要根据不同的环境使用不同的芯片处理某些数据,当使用不同的芯片对六轴陀螺仪芯片MPU6050进行数据处理中,硬件的连接、I/O口的设置往往需要根据相…

【HarmonyOS NEXT星河版开发学习】小型测试案例05-得物列表项

个人主页→VON 收录专栏→鸿蒙开发小型案例总结​​​​​ 基础语法部分会发布于github 和 gitee上面(暂未发布) 前言 鸿蒙操作系统通过其先进的分布式架构和开发工具,以及灵活的界面布局和样式控制,为开发者提供了丰富的开发资源…

设计模式- 数据源架构模式

活动记录(Active Record) 一个对象,它包装数据库表或视图中的某一行,封装数据库访问,并在这些数据上增加了领域逻辑 对象中既有数据又有行为。这些数据大多是持久数据、并且需要保存到数据库。 运行机制 活动记录的…

Iris for mac 好用的录屏软件

Iris 是一款高性能屏幕录像机,可录制到 h.264。Iris 在可用时利用板载 GPU 加速。它可以选择包括来自摄像头和最多两个麦克风的视频。 兼容性 所有功能在macOS 11.0-14上完全支持,包括macOS Sonoma。 简单编码 直接录制为h.264、h.265、ProRes或Motion…

WPF学习(10)-Label标签+TextBlock文字块+TextBox文本框+RichTextBox富文本框

Label标签 Label控件继承于ContentControl控件,它是一个文本标签,如果您想修改它的标签内容,请设置Content属性。我们曾提过ContentControl的Content属性是object类型,意味着Label的Content也是可以设置为任意的引用类型的。 案…

游戏ID统一管理器DEMO

一般游戏的角色ID、名字,工会ID、名字,等最好统一创建,方便合服处理,可以以此基础,动态配置生成ID 这个也可以用openresty 作个,可能更专业点, 1:go1.20 最后一版支持win7的 mongod…

微信小程序乡村医疗系统,源码、部署+讲解

目录 摘 要 Abstract 1 绪论 1.1 研究背景及意义 1.2 研究现状 1.3 研究内容 2 相关技术介绍 2.1 Java 语言 2.2 MySQL 数据库 2.3 Spring Boot 框架 2.4 B/S 结构 2.5 微信小程序 3 系统分析 3.1 可行性分析 3.1.1 经济可行性 3.1.2 技术可行性…

4.MySQL数据类型

目录 数据类型 ​编辑数值类型 tinyint类型 bit类型 float类型 decimal类型 字符串类型 char类型 varchar varchar和char的区别 日期和时间类型 数据类型 数值类型 说明一下:MySQL本身是不支持bool类型的,当把一个数据设置成bool类型时&#x…

【ThreadLocal总结】

文章目录 为什么使用ThreadLocalThreadLocal核心ThreadLocal内部结构ThreadLocal内存泄漏解决内存泄漏 为什么使用ThreadLocal 在并发编程中,多个线程同时访问和修改共享变量是一个常见的场景。这种情况下,可能会出现线程安全问题,即多个线程…

AWS生成式AI项目的全生命周期管理

随着人工智能技术的迅速发展,生成式 AI 已成为当今最具创新性和影响力的领域之一。生成式 AI 能够创建新的内容,如文本、图像、音频等,具有广泛的应用前景,如自然语言处理、计算机视觉、创意设计等。然而,构建一个成功…

【Python】pandas:计算,统计,比较

pandas是Python的扩展库(第三方库),为Python编程语言提供 高性能、易于使用的数据结构和数据分析工具。 pandas官方文档:User Guide — pandas 2.2.2 documentation 帮助:可使用help(...)查看函数说明文档&#xff0…

文本编辑器小型架构

C字体库开发之字体列表设计七-CSDN博客 创作不易,小小的支持一下吧!

odoo from样式更新

.xodoo_form {.o_form_sheet {padding-bottom: 0 !important;border-style: solid !important;border-color: white;}.o_inner_group {/* 线框的样式 *//*--line-box-border: 1px solid #666;*//*box-shadow: 0 1px 0 #e6e6e6;*/margin: 0;}.grid {display: grid;gap: 0;}.row …

【数据结构】排序 —— 归并排序(mergeSort)、计数排序、基数排序

Hi~!这里是奋斗的明志,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 🌱🌱个人主页:奋斗的明志 🌱🌱所属专栏:数据结构、LeetCode专栏 📚本系…

【数据结构】哈希应用-STL-位图

目录 1、位图的概念 2、位图的设计与实现 2.1 set 2.2 reset 2.3 test 3、C库中的位图 4、位图的优缺点 5、位图相关题目 1、位图的概念 面试题:给40亿个不重复的无符号整数,没排过序。给一个无符号整数,如何快速判断一个数是否在这4…

【Material-UI】按钮组件中的实验性API:Loading按钮详解

文章目录 一、LoadingButton 组件概述1. 组件介绍2. 基本用法 二、LoadingButton 组件的高级用法1. 自定义加载指示器2. 图标与加载位置 三、已知问题与解决方法1. Chrome 翻译工具与 LoadingButton 的兼容性问题 四、实用性与未来展望1. 应用场景2. 未来展望 五、总结 Materia…

共享内存的原理及初识线程

char *str"hello world"; *str-H; 运行时报错,RWX只有R权限。 外设和内存交互以4KB为单位。 虚拟地址32位的划分为10 10 12 前10位对应页表的页目录。 在10位即为页表,页表中存放指定页框的起始物理地址虚拟地址的低12位作为页内偏移。 共…

RedLock算法分析

Redis分布式锁-RedLock算法 手写分布式锁的缺点 Redlock算法设计理念 Redis也提供了Redlock算法,用来实现基于多个实例的分布式锁。 锁变量由多个实例维护,即使有实例发生了故障,锁变量仍然是存在的,客户端还是可以完成锁操作。…

第一篇Linux介绍

目录 1、操作系统 2、Windows和Linux操作系统的区别 3、 Linux 的发行版本 4、 linux 分支 5、 Linux 的含义 6、Linux 特点 1、操作系统 常见操作系统有:Windows、MacOS、Unix/Linux。 类 UNIX Windows:其是微软公司研发的收费操作系统&#xff…