【论文复现】LSTM长短记忆网络

news2024/11/16 8:29:33

LSTM

  • 前言
  • 网络架构
    • 总线
    • 遗忘门
    • 记忆门
    • 记忆细胞
    • 输出门
  • 模型定义
    • 单个LSTM神经元的定义
    • LSTM层内结构的定义
  • 模型训练
  • 模型评估
  • 代码细节
    • LSTM层单元的首尾的处理
    • 配置Tensorflow的GPU版本

前言

LSTM作为经典模型,可以用来做语言模型,实现类似于语言模型的功能,同时还经常用于做时间序列。由于LSTM的原版论文相关版权问题,这里以colah大佬的博客为基础进行讲解。之前写过一篇Tensorflow中的LSTM详解,但是原理部分跟代码部分的联系并不紧密,实践性较强但是如果想要进行更加深入的调试就会出现原理性上面的问题,因此特此作文解决这个问题,想要用LSTM这个有趣的模型做出更加好的机器学习效果😊。

网络架构

LSTM框架图
这张图展示了LSTM在整体结构,下面就开始分部分介绍中间这个东东。

总线

在这里插入图片描述
这条是总线,可以实现神经元结构的保存或者更改,如果就是像上图一样一条总线贯穿不做任何改变,那么就是不改变细胞状态。那么如果想要改变细胞状态怎么办?可以通过来实现,这里的门跟高中生物中学的神经兴奋阈值比较像,用数学来表示就是sigmoid函数或者其他的激活函数,当门的输入达到要求,门就会打开,允许当前门后面的信息“穿过”门改变主线上面传递的信息,如果把每一个神经元看成一个时间节点,那么从上一个时间节点传到下一个时间节点过程中的门的开启与关闭就实现了时间序列数据的信息传递。
在这里插入图片描述

遗忘门

在这里插入图片描述
首先是遗忘门,这个门的作用是决定从上一个神经元传输到当前神经元的数据丢弃的程度,如果经过sigmoid函数以后输出0表示全部丢弃,输出1表示全部保留,这个层的输入是旧的信息和当前的新信息。

σ \sigma σ:sigmoid函数
W f W_f Wf:权重向量
b f b_f bf:偏置项,决定丢弃上一个时间节点的程度,如果是正数,表示更容易遗忘,如果是负数,表示比较容易记忆
h t − 1 h_{t-1} ht1:上一个时刻的输入
x t x_t xt:当前层的输入

记忆门

在这里插入图片描述
接下来是记忆门,这个门决定要记住什么信息,同时决定按照什么程度记住上一个状态的信息。

i t i_t it:在时间步t时刻的输入门激活值,计算方法跟上面的遗忘门是一样的,只是目的不一样,这里是记忆
C ~ t \tilde{C}_{t} C~t:表示上一个时刻的信息和当前时刻的信息的集合,但是是规则化到[-1,1]这个范围内了的

记忆细胞

在这里插入图片描述
有了上面的要记忆的信息和要丢弃的信息,记忆细胞的功能就可以得到实现,用 f t f_t ft这个标量决定上一个状态要遗忘什么,用 i t i_t it这个标量决定上一个状态要记住什么以及当前状态的信息要记住什么。这样就形成了一个记忆闭环了。

输出门

在这里插入图片描述
最后,在有了记忆细胞以后不仅仅不要将当前细胞状态记住,还要将当前的信息向下一层继续传输,实现公式中的状态转移。

o t o_t ot:跟前面的门公式都一样,但是功能是决定输出的程度
h t h_t ht:将输出规范到[-1,1]的区间,这里有两个输出的原因是在构建LSTM网络的时候需要有纵向向上的那个 h t h_t ht,然而在当前层的LSTM的神经元之间还是首尾相接的😍。

模型定义

单个LSTM神经元的定义


# 定义单个LSTM单元
# 定义单个LSTM单元
class My_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(My_LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化门的权重和偏置,由于每一个神经元都有自己的偏置,所以在定义单元内部定义
        self.Wf = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bf = nn.Parameter(torch.Tensor(hidden_size))
        self.Wi = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bi = nn.Parameter(torch.Tensor(hidden_size))
        self.Wo = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bo = nn.Parameter(torch.Tensor(hidden_size))
        self.Wg = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size))
        self.bg = nn.Parameter(torch.Tensor(hidden_size))
        # 初始化输出层的权重和偏置
        self.W = nn.Parameter(torch.Tensor(hidden_size, output_size))
        self.b = nn.Parameter(torch.Tensor(output_size))
        
    # 用于计算每一种权重的函数
    def cal_weight(self, input, weight, bias):
        return F.linear(input, weight, bias)
    # x是输入的数据,数据的格式是(batch, seq_len, input_size),包含的是batch个序列,每个序列有seq_len个时间步,每个时间步有input_size个特征
    def forward(self, x):
        # 初始化隐藏层和细胞状态
        h = torch.zeros(1, 1, self.hidden_size).to(x.device)
        c = torch.zeros(1, 1, self.hidden_size).to(x.device)
        # 遍历每一个时间步
        for i in range(x.size(1)):
            input = x[:, i, :].view(1, 1, -1) # 取出每一个时间步的数据
            # 计算每一个门的权重
            f = torch.sigmoid(self.cal_weight(input, self.Wf, self.bf)) # 遗忘门
            i = torch.sigmoid(self.cal_weight(input, self.Wi, self.bi)) # 输入门
            o = torch.sigmoid(self.cal_weight(input, self.Wo, self.bo)) # 输出门
            C_ = torch.tanh(self.cal_weight(input, self.Wg, self.bg)) # 候选值
            # 更新细胞状态
            c = f * c + i * C_
            # 更新隐藏层
            h = o * torch.tanh(c) # 将输出标准化到-1到1之间
        output = self.cal_weight(h, self.W, self.b) # 计算输出
        return output

LSTM层内结构的定义

class My_LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(My_LSTMNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = My_LSTM(input_size, hidden_size)  # 使用自定义的LSTM单元
        self.fc = nn.Linear(hidden_size, output_size)  # 定义全连接层

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))  # LSTM层的前向传播
        out = self.fc(out[:, -1, :])  # 全连接层的前向传播
        return out


模型训练

history = model.fit(trainX, trainY, batch_size=64, epochs=50, 
                    validation_split=0.1, verbose=2)
print('compilation time:', time.time()-start)

模型评估

为了更加直观展示,这里用画图的方法进行结果展示。

fig3 = plt.figure(figsize=(20, 15))
plt.plot(np.arange(train_size+1, len(dataset)+1, 1), scaler.inverse_transform(dataset)[train_size:], label='dataset')
plt.plot(testPredictPlot, 'g', label='test')
plt.ylabel('price')
plt.xlabel('date')
plt.legend()
plt.show()

代码细节

LSTM层单元的首尾的处理

  • 首部:由于第一个节点不用接受来自上一个节点的输入,不需要有输入,当然也有一些是添加标识。

  • 尾部:由于已经进行到当前层的最后一个节点,因此输出只需要向下一层进行传递而不用向下一个节点传递,添加标识也是可以的。

配置Tensorflow的GPU版本

这一篇写的比较好,我自己的硬件环境如下图所示,需要的可以借鉴一下,当然也可以在我提供的代码链接直接用我给的environment.yml一键构建环境😃。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1698938.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构:树(3)【二叉树链式结构实现】【二叉树的前序,中序,后序遍历】【求二叉树全部结点个数】【求二叉树叶子结点个数】【求二叉树的深度】【单值二叉树】

一.二叉树链式结构的实现 二叉树的链式结构的实现相对于顺序结构的实现就没有那么多的讲究了。就是普通的链表,只不过多了一个指向的指针。 具体结构如下: typedef int BTDataType; typedef struct BinaryTreeNode {BTDataType data;struct BinaryTre…

云计算和大数据处理

文章目录 1.云计算基础知识1.1 基本概念1.2 云计算分类 2.大数据处理基础知识2.1 基础知识2.3 大数据处理技术 1.云计算基础知识 1.1 基本概念 云计算是一种提供资源的网络,使用者可以随时获取“云”上的资源,按需求量使用,并且可以看成是无…

一.架构设计

架构采用 ddd 架构,不同于传统简单的三层的架构,其分层的思想对于大家日后都是很有好处的,会给大家的思想层级,提高很多。 传统的项目 现有的架构 采取ddd架构,给大家在复杂基础上简化保留精髓,一步步进行…

以太坊钱包

以太坊钱包是你通往以太坊系统的门户。它拥有你的密钥,并且可以代表你创建和广播交易。选择一个以太坊钱包可能很困难,因为有很多不同功能和设计选择。有些更适合初学者,有些更适合专家。即使你现在选择一个你喜欢的,你可能会决定…

游戏找不到d3dcompiler_43.dll怎么办,教你5种可靠的修复方法

在电脑使用过程中,我们经常会遇到一些错误提示,其中之一就是“找不到d3dcompiler43.dll”。这个问题通常出现在游戏或者图形处理软件中,它会导致程序无法正常运行。为了解决这个问题,我经过多次尝试和总结,找到了以下五…

粤嵌—2024/5/23—不同路径 ||(✔)

代码实现&#xff1a; int uniquePathsWithObstacles(int **obstacleGrid, int obstacleGridSize, int *obstacleGridColSize) {int x obstacleGridSize, y obstacleGridColSize[0];int dp[x][y];memset(dp, 0, sizeof(int) * x * y);for (int j 0; j < y && obs…

OpenUI 可视化 AI:打造令人惊艳的前端设计!

https://openui.fly.dev/ai/new 可视化UI的新时代&#xff1a;通过人工智能生成前端代码 许久未更新, 前端时间在逛github&#xff0c;发现一个挺有的意思项目&#xff0c;通过口语化方式生成前端UI页面&#xff0c;能够直观的看到效果&#xff0c;下面来给大家演示下 在现代…

智能家居完结 -- 整体设计

系统框图 前情提要: 智能家居1 -- 实现语音模块-CSDN博客 智能家居2 -- 实现网络控制模块-CSDN博客 智能家居3 - 实现烟雾报警模块-CSDN博客 智能家居4 -- 添加接收消息的初步处理-CSDN博客 智能家居5 - 实现处理线程-CSDN博客 智能家居6 -- 配置 ini文件优化设备添加-CS…

复习java5.26

面向对象和面向过程 面向过程&#xff1a;把一个任务分成一个个的步骤&#xff0c;当要执行这个任务的时候&#xff0c;只需要依次调用就行了 面向对象&#xff1a;把构成任务的事件构成一个个的对象&#xff0c;分别设计这些对象&#xff08;属性和方法&#xff09;、然后把…

国内最受欢迎的7大API供应平台对比和介绍||电商API数据采集接口简要说明

本文将介绍7款API供应平台&#xff1a;聚合数据、百度APIStore、Apix、数说聚合、通联数据、HaoService、datasift 。排名不分先后&#xff01; 免费实用的API接口 第一部分 1、聚合数据&#xff08;API数据接口_开发者数据定制&#xff09; 2、百度API Store(API集市_APIStore…

数据安全不容小觑:.hmallox勒索病毒的防范与应对

一、引言 随着网络技术的飞速发展&#xff0c;网络安全问题日益凸显&#xff0c;其中勒索病毒作为一种极具破坏性的网络攻击手段&#xff0c;已在全球范围内造成了巨大的经济损失和社会影响。在众多勒索病毒中&#xff0c;.hmallox勒索病毒以其狡猾的传播方式和强大的加密能力…

27【Aseprite 作图】盆栽——拆解

1 橘子画法拆解 (1)浅色3 1 0;深色0 2 3 就可以构成一个橘子 (2)浅色 2 1;深色1 0 (小个橘子) (3)浅色 2 1 0;深色1 2 3 2 树根部分 (1)底部画一条横线 (2)上一行 左空2 右空1 【代表底部重心先在右】 (3)再上一行,左空1,右空1 (4)再上一行,左突出1,…

L01_JVM 高频知识图谱

这些知识点你都掌握了吗&#xff1f;大家可以对着问题看下自己掌握程度如何&#xff1f;对于没掌握的知识点&#xff0c;大家自行网上搜索&#xff0c;都会有对应答案&#xff0c;本文不做知识点详细说明&#xff0c;只做简要文字或图示引导。 类的生命周期 类加载器 JVM 的内存…

汽车IVI中控开发入门及进阶(二十):显示技术之LCDC

TFT LCD=Thin Film Transistor Liquid Crystal Display LCDC=LCD Controller 薄膜晶体管液晶显示器(TFT LCD)控制器在驱动现代显示技术的功能和性能方面起着关键作用。它们充当屏幕后面的大脑,仔细处理数字信号,并将其转化为精确的命令,决定每个像素的行为,决定它们的…

数据库系统概论(第5版)复习笔记

笔记的Github仓库地址 &#x1f446;这是笔记的gihub仓库&#xff0c;内容是PDF格式。 因为图片和代码块太多&#xff0c;放到CSDN太麻烦了&#xff08;比较懒&#x1f923;&#xff09; 如果感觉对各位有帮助的话欢迎点一个⭐\^o^/

Unity 生成物体的几种方式

系列文章目录 unity工具 文章目录 系列文章目录前言&#x1f449;一、直接new的方式创建生成1-1.代码如下1-2. 效果图 &#x1f449;二、使用Instantiate创建生成&#xff08;GameObject&#xff09;2-1.代码如下2-2.效果如下图 &#x1f449;三.系统CreatePrimitive创建生成3…

C++编程揭秘:虚表机制与ABI兼容性的实例剖析

前言&#xff1a; 假设你的应用程序引用的一个库某天更新了&#xff0c;虽然 API 和调用方式基本没变&#xff0c;但你需要重新编译你的应用程序才能使用这个库&#xff0c;那么一般说这个库是源码兼容&#xff08;Source compatible&#xff09;&#xff1b;反之&#xff0c;如…

BGP策略实验(路径属性和选路规则)

要求&#xff1a; 1、使用preval策略&#xff0c;确保R4通过R2到达192.168.10.0/24 2、使用AS Path策略&#xff0c;确保R4通过R3到达192.168.11.0/24 3、配置MED策略&#xff0c;确保R4通过R3到达192.168.12.0/24 4、使用Local Preference策略&#xff0c;确保R1通过R2到达19…

不同网段的通信过程

这里的AA和HH指的是mac地址&#xff0c;上面画的是路由器 底下的这个pc1&#xff0c;或者其他的连接在这里的pc&#xff0c;他们的默认网关就是路由器的这个192.168.1.1/24这个接口 来看看通信的过程 1、先判断&#xff08;和之前一样&#xff09; 2、去查默认网关&#xf…

【MySQL】库的基础操作

&#x1f30e;库的操作 文章目录&#xff1a; 库的操作 创建删除数据库 数据库编码集和校验集 数据库的增删查改       数据库查找       数据库修改 备份和恢复 查看数据库连接情况 总结 前言&#xff1a;   数据库操作是软件开发中不可或缺的一部分&#xff0…