深度学习--------------------长短期记忆网络(LSTM)

news2024/12/30 2:12:43

目录

  • 长短期记忆网络
    • 候选记忆单元
    • 记忆单元
    • 隐状态
  • 长短期记忆网络代码从零实现
    • 初始化模型参数
    • 初始化
    • 实际模型
    • 训练
  • 简洁实现

长短期记忆网络

忘记门:将值朝0减少
输入门:决定要不要忽略掉输入数据
输出门:决定要不要使用隐状态。



在这里插入图片描述

在这里插入图片描述




候选记忆单元

在这里插入图片描述




记忆单元

记忆单元会把上一个时刻的记忆单元作为状态放进来,所以LSTM和RNN跟GRU不一样的地方是它的状态里面有两个独立的。
如果: F t F_t Ft等于0的话,就是希望不要记住 C t − 1 C_{t-1} Ct1
如果: I t I_t It是1的话,就是希望尽量的去用它,如果 I t I_t It等于0的话,就是把现在的记忆单元丢掉。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述




隐状态

在这里插入图片描述

在这里插入图片描述


在这里插入图片描述




长短期记忆网络代码从零实现

import torch
from torch import nn
from d2l import torch as d2l

# 设置批量大小为32,时间步数为35
batch_size, num_steps = 32, 35
# 使用d2l库中的load_data_time_machine函数加载时间机器数据集,
# 并设置批量大小为32,时间步数为35,将加载的数据集赋值给train_iter和vocab变量
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)



初始化模型参数

def get_lstm_params(vocab_size, num_hiddens, device):
    # 将词汇表大小赋值给num_inputs和num_outputs
    num_inputs = num_outputs = vocab_size
    
    # 定义一个辅助函数normal,用于生成具有特定形状的正态分布随机数,并将其初始化为较小的值
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    # 定义一个辅助函数three,用于生成三个参数:输入到隐藏状态的权重矩阵、隐藏状态到隐藏状态的权重矩阵和隐藏状态的偏置项
    def three():
        return (normal(
            (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))
    
    # 调用three函数获取输入到隐藏状态的权重矩阵W_xi、隐藏状态到隐藏状态的权重矩阵W_hi和隐藏状态的偏置项b_i
    W_xi, W_hi, b_i = three()
    # 调用three函数获取输入到隐藏状态的权重矩阵W_xf、隐藏状态到隐藏状态的权重矩阵W_hf和隐藏状态的偏置项b_f
    W_xf, W_hf, b_f = three()
    # 调用three函数获取输入到隐藏状态的权重矩阵W_xo、隐藏状态到隐藏状态的权重矩阵W_ho和隐藏状态的偏置项b_o
    W_xo, W_ho, b_o = three()
    # 调用three函数获取输入到隐藏状态的权重矩阵W_xc、隐藏状态到隐藏状态的权重矩阵W_hc和隐藏状态的偏置项b_c
    W_xc, W_hc, b_c = three()
    # 生成隐藏状态到输出的权重矩阵W_hq
    W_hq = normal((num_hiddens, num_outputs))
    # 生成输出的偏置项b_q
    b_q  = torch.zeros(num_outputs, device=device)
    # 将所有参数组合成列表params
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]    
    
    # 变量所有参数
    for param in params:
        # 将所有参数的requires_grad属性设置为True,表示需要计算梯度
        param.requires_grad_(True)
        
    # 返回所有参数
    return params



初始化

def init_lstm_state(batch_size, num_hiddens, device):
    # 返回一个元组,包含两个张量:一个全零张量表示初始的隐藏状态(即:H要有个初始化),和一个全零张量表示初始的记忆细胞状态(即:C要有个初始化)。
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))



实际模型

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    # 解包状态元组state,分别赋值给隐藏状态H和记忆细胞状态C
    (H, C) = state
    # 创建一个空列表用于存储每个时间步的输出
    outputs = []
    # 对于输入序列中的每个时间步
    for X in inputs:
        # 输入门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输入门
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        # 遗忘门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算遗忘门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        # 输出门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输出门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        # 新的记忆细胞候选值的计算:使用输入、隐藏状态和偏置项,通过线性变换和tanh函数计算新的记忆细胞候选值
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        # 更新记忆细胞状态:将旧的记忆细胞状态与遗忘门和输入门的乘积相加,再与新的记忆细胞候选值的乘积相加,得到新的记忆细胞状态
        C = F * C + I * C_tilda
        # 更新隐藏状态:将输出门和经过tanh函数处理的记忆细胞状态的乘积作为新的隐藏状态
        H = O * torch.tanh(C)
        # 输出的计算:使用新的隐藏状态和偏置项,通过线性变换得到输出
        Y = (H @ W_hq) + b_q
        # 将当前时间步的输出添加到列表中
        outputs.append(Y)
    # 将所有时间步的输出在维度0上拼接起来,作为最终的输出结果;
    # 返回最终的输出结果和更新后的隐藏状态和记忆细胞状态的元组
    return torch.cat(outputs, dim=0), (H, C)



训练

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 使用d2l库中的RNNModelScratch类创建一个基于LSTM的模型对象,
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

在这里插入图片描述




简洁实现

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
# 使用nn.LSTM创建一个LSTM层,输入特征数量为num_inputs,隐藏单元数量为num_hiddens
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
# 使用d2l库中的RNNModel类创建一个基于LSTM的模型对象,传入LSTM层和词汇表大小
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这4款专业的Windows录屏工具,帮你解决多样的录屏的问题。

像Xbox 录制,步骤记录器等工具都是Windows系统里面自带的录屏工具,如果时想要更多功能的录屏工具,可以下载一些专业录屏软件,我可以给大家推荐几款,实用稳定,专业高效的录屏软件。 1、福昕多效录屏 直达&a…

【Java基础】Java面试基础知识QA(上)

Java面试基础知识Q&A(上) 面向对象编程( OOP) Java 是一个支持并发、基于类和面向对象的计算机编程语言。面向对象软件开发的优点: 代码开发模块化,更易维护和修改。代码复用。增强代码的可靠性和灵活性…

springboot系列--web相关知识探索二

映射 指的是与请求处理方法关联的URL路径,通过在Spring MVC的控制器类(使用RestController注解修饰的类)上使用注解(如 RequestMapping、GetMapping)来指定请求映射路径,可以将不同的HTTP请求映射到相应的处…

【PRISMA卫星有关简介】

PRISMA卫星是一颗小型超光谱成像卫星,以下是对其的详细介绍: 一、基本信息 英文全称:Prototype Research Instruments and Space Mission technology Advancement Main,或简化为PRISMA。发射时间:PRISMA卫星于2019年…

今日指数项目项目集成RabbitMQ与CaffienCatch

今日指数项目项目集成RabbitMQ与CaffienCatch 一. 为什么要集成RabbitMQ 首先CaffeineCatch 是作为一个本地缓存工具 使用CaffeineCatch 能够大大较少I/O开销 股票项目 主要分为两大工程 --> job工程(负责数据采集) , backend(负责业务处理) 由于股票的实时性也就是说 ,…

【Redis】Redis中的 AOF(Append Only File)持久化机制

目录 1、AOF日志 2、AOF 的执行顺序与潜在风险 3、如何优化 AOF?(写入策略) 4、AOF重写机制(防止日志文件无限增长) 1、AOF日志 想象一下,Redis 每次执行写操作的时候,都把这些操作以追加的…

SpringBoot项目 | 瑞吉外卖 | 短信发送验证码功能改为免费的邮箱发送验证码功能 | 代码实现

0.前情提要 之前的po已经说了单独的邮箱验证码发送功能怎么实现: https://blog.csdn.net/qq_61551948/article/details/142641495 这篇说下如何把该功能整合到瑞吉项目里面,也就是把原先项目里的短信发送验证码的功能改掉,改为邮箱发送验证…

World of Warcraft [CLASSIC][80][Grandel] /console cameraDistanceMaxZoomFactor 2

学习起来!!! 调整游戏界面镜头距离,默认值为:2 /console cameraDistanceMaxZoomFactor 2 大于4,效果不明显了,鼠标滚轮向后滚,拉起来镜头 World of Warcraft [CLASSIC][80][Grandel…

Another redis desktop manager使用说明

Another redis desktop manager使用说明 概述界面介绍图示说明连接界面设置界面查看操作日志主界面信息进入redis-cli控制台更多 概述 Another Redis Desktop Manager是一个开源的跨平台 Redis 客户端,提供了简洁易用的图形用户界面(GUI)&am…

第5篇:勒索病毒自救指南----应急响应篇

经常会有一些小伙伴问:中了勒索病毒,该怎么办,可以解密吗? 第一次遇到勒索病毒是在早几年的时候,客户因网站访问异常,进而远程协助进行排查。登录服务器,在站点目录下发现所有的脚本文件及附件…

【JaveEE】——多线程中使用顺序表,队列,哈希表

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:多线程环境使用ArrayList 引入: 1:顺序表使用同步机制 2&…

Linux服务器配置anaconda3,下载torch

如图,vscode连接远程服务器后,如下所示: 下载 Anaconda 下载及安装 进入下载官网,点击linux, 下载方式有两种, 直接下载安装包,下载完上传服务器,并安装,安装执行b…

【算法系列-链表】移除链表元素

【算法系列-链表】移除链表元素 欢迎来到【算法系列】第二弹 🏆 链表,接下来我们将围绕链表这类型的算法题进行解析与练习!一起加油吧!!( •̀ ω •́ )✧✨ 文章目录 【算法系列-链表】移除链表元素1. 算法分析&am…

Spring Data(学习笔记)

JPQL语句???(Query括号中的就是JPQL语句) 怎么又会涉及到连表查询呢? 用注解来实现表间关系。 分页是什么?为什么什么都有分页呢 ? 继承,与重写方法的问题 Deque是什么 ?…

线程池:线程池的实现 | 日志

🌈个人主页: 南桥几晴秋 🌈C专栏: 南桥谈C 🌈C语言专栏: C语言学习系列 🌈Linux学习专栏: 南桥谈Linux 🌈数据结构学习专栏: 数据结构杂谈 🌈数据…

C++容器之vector模拟实现(代码纯享版!!!)

目录 前言 一、头文件 .h文件 总结 前言 本文是模拟实现vector部分功能的代码&#xff0c;可以直接拿去使用 一、头文件 .h文件 #include<assert.h> #include<iostream> using namespace std; namespace zz {template<class T>class vector{public:typedef…

C++ set,multiset与map,multimap的基本使用

1. 序列式容器和关联式容器 string、vector、list、deque、array、forward_list等STL容器统称为序列式容器&#xff0c;因为逻辑结构为线性序列的数据结构&#xff0c;两个位置存储的值之间一般没有紧密的关联关系&#xff0c;比如交换一下&#xff0c;他依旧是序列式容器。顺…

STM32器件支持包安装,STLINK/JLINK驱动安装

一、支持包安装 1、离线安装 先下载支持包之后&#xff0c;再进行安装。如下图要安装STM32F1系列&#xff0c;双击 出现如下&#xff0c;会自动锁定安装路径&#xff0c;然后点击下一步&#xff0c;直接安装。 2、在线安装 首先需要电脑联网。如下。先点击第一个红框绿色按钮…

常见的VPS或者独立服务器的控制面板推荐

随着越来越多的企业和个人转向VPS和独立服务器以获得更高的性能和灵活性&#xff0c;选择合适的控制面板变得尤为重要。一个好的控制面板可以大大简化服务器管理&#xff0c;提高工作效率。本篇文章将介绍2024年最值得推荐的VPS控制面板&#xff0c;帮助您做出明智的选择。 1.…

STL容器适配器

欢迎来到本期节目- - - STL容器适配器 适配器模式&#xff1a; 在C中&#xff0c;适配器是一种设计模式&#xff0c;有时也称包装样式&#xff1b; 通过将类自己的接口包裹在一个已存在的类中&#xff0c;使得因接口不兼容而不能在一起工作的类能在一起工作&#xff1b; 也就…