LSTM网络:一种强大的时序数据建模工具

news2025/1/20 1:11:52

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

LSTM

(封面图由ERNIE-ViLG AI 作画大模型生成)

LSTM网络:一种强大的时序数据建模工具

在日常生活和工作中,我们经常会遇到各种类型的时序数据,如股票价格、天气数据、心电图、语音识别、自然语言处理等。这些数据具有时间依赖性,不同时间点的数据之间存在关联性。而LSTM网络是一种非常适合处理时序数据的神经网络,已经被广泛应用于各种任务中。本文将介绍LSTM网络的原理、优势和劣势,并结合代码和案例进行实践演示。

1. LSTM网络原理

LSTM(Long Short-Term Memory)网络是一种循环神经网络(Recurrent Neural Network, RNN)的变种。相比于传统的RNN,LSTM网络有着更强的长时记忆和远距离依赖处理能力,能够有效地避免梯度消失和梯度爆炸问题。

LSTM网络包括三个门控单元,分别是输入门(input gate)、遗忘门(forget gate)和输出门(output gate)。这些门控单元可以选择性地控制信息的流动,以达到记忆和遗忘的目的。除此之外,LSTM网络还有一个记忆单元(memory cell),用来存储长期的信息。

LSTM网络的计算过程可以分为以下几个步骤:

  • 输入门的计算:输入门决定了当前输入的信息在多大程度上被传递到记忆单元中。输入门的输出值为 i t i_t it,计算公式如下:
    i t = σ ( W i x t + U i h t − 1 + b i ) i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i) it=σ(Wixt+Uiht1+bi)
    其中, x t x_t xt 表示当前时刻的输入, h t − 1 h_{t-1} ht1 表示上一个时刻的隐藏状态, W i W_i Wi U i U_i Ui b i b_i bi 是可学习的参数, σ \sigma σ 是sigmoid函数。

  • 遗忘门的计算:遗忘门决定了哪些历史信息需要被遗忘。遗忘门的输出值为 f t f_t ft,计算公式如下:
    f t = σ ( W f x t + U f h t − 1 + b f ) f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f) ft=σ(Wfxt+Ufht1+bf)
    其中, W f W_f Wf U f U_f Uf b f b_f bf 是可学习的参数。

  • 记忆单元的更新:根据输入门的输出值和遗忘门的输出值,可以计算出当前时刻的记忆单元 C t C_t Ct,计算公式如下:
    tanh ⁡ ( W c x t + U c h t − 1 + b c ) \tanh(W_c x_t + U_c h_{t-1} + b_c) tanh(Wcxt+Ucht1+bc)
    其中, ⊙ \odot 表示逐元素乘积, W c W_c Wc U c U_c Uc b c b_c bc 是可学习的参数, tanh ⁡ \tanh tanh 是双曲正切函数。

  • 输出门的计算:输出门决定了当前时刻的输出值。输出门的输出值为 o t o_t ot,计算公式如下:
    o t = σ ( W o x t + U o h t − 1 + b o ) o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o) ot=σ(Woxt+Uoht1+bo)
    其中, W o W_o Wo U o U_o Uo b o b_o bo 是可学习的参数。

  • 隐藏状态的计算:根据当前时刻的记忆单元和输出门的输出值,可以计算出当前时刻的隐藏状态 h t h_t ht,计算公式如下:
    h t = o t ⊙ tanh ⁡ ( C t ) h_t = o_t \odot \tanh(C_t) ht=ottanh(Ct)
    LSTM网络通过这些门控单元的选择性连接,实现了对时序数据的长期依赖性建模。同时,由于LSTM网络中的梯度可以通过记忆单元从一层传递到另一层,可以有效地避免梯度消失和梯度爆炸问题,提高了训练效率和模型的准确性。

2. LSTM网络的优势和劣势

  • 优势:
    (1)长期依赖性建模能力强:LSTM网络具有很好的长期依赖性建模能力,能够很好地处理时序数据中的长期依赖关系。
    (2)避免梯度消失和梯度爆炸问题:LSTM网络中的梯度可以通过记忆单元从一层传递到另一层,可以有效地避免梯度消失和梯度爆炸问题。
    (3)可适应不同长度的时序数据:LSTM网络中的记忆单元可以自适应地存储不同长度的时序数据,不需要事先指定固定长度。

  • 劣势:
    (1)计算复杂度高:LSTM网络中有多个门控单元和记忆单元,计算复杂度较高,需要更多的计算资源。
    (2)需要大量的数据训练
    :LSTM网络具有很多可调参数,需要大量的数据进行训练,否则容易出现过拟合现象。

3. 案例演示

为了更好地理解LSTM网络的应用,本文选取了一个经典的时序数据建模问题:股票价格预测。我们将使用Keras深度学习框架,使用LSTM网络对股票价格进行预测。

(1)数据预处理

首先,我们需要对股票价格数据进行预处理。我们选择了纽约证券交易所上市的Apple公司(AAPL)的历史股票价格数据,该数据包含了从1980年到2021年的日交易数据。我们将使用前70%的数据作为训练集,后30%的数据作为测试集。

在预处理数据之前,我们需要导入相关的库:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

接下来,我们加载股票价格数据,并按照训练集和测试集的比例进行拆分:

df = pd.read_csv('AAPL.csv')
df.head()

我们可以看到,数据包含日期、开盘价、最高价、最低价、收盘价、成交量和股票调整后的收盘价。我们只需要使用调整后的收盘价作为特征进行建模。

# 只使用调整后的收盘价作为特征
data = df.filter(['Adj Close']).values

# 拆分训练集和测试集
training_data_len = int(len(data) * 0.7)
train_data = data[0:training_data_len]
test_data = data[training_data_len:]

接下来,我们需要对数据进行归一化处理,使得所有特征都在0到1之间。

# 归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_train_data = scaler.fit_transform(train_data)
scaled_test_data = scaler.transform(test_data)

(2)创建LSTM模型

接下来,我们需要创建LSTM模型。在Keras中,我们可以使用LSTM层来创建LSTM模型。首先,我们需要指定LSTM层中的参数,包括LSTM单元的数量、输入序列的长度和输出序列的长度。

from keras.models import Sequential
from keras.layers import LSTM, Dense

# 指定LSTM模型参数
lstm_units = 50
input_seq_len = 60
output_seq_len = 30

接下来,我们创建LSTM模型。模型包含一个LSTM层和一个全连接层。在LSTM层中,我们使用50个LSTM单元,输入序列的长度为60,输出序列的长度为30。在全连接层中,我们使用一个神经元作为输出层。

model = Sequential()
model.add(LSTM(units=lstm_units, input_shape=(input_seq_len, 1)))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.summary()

(3)训练模型

接下来,我们需要训练模型。在训练模型之前,我们需要将训练数据划分成输入序列和输出序列。

def create_sequences(data, input_seq_len, output_seq_len):
    x = []
    y = []
    for i in range(len(data)-input_seq_len-output_seq_len+1):
        x.append(data[i:i+input_seq_len])
        y.append(data[i+input_seq_len:i+input_seq_len+output_seq_len])
   return np.array(x), np.array(y)
train_x, train_y = create_sequences(scaled_train_data, input_seq_len, output_seq_len)
test_x, test_y = create_sequences(scaled_test_data, input_seq_len, output_seq_len)

接下来,我们可以使用train_x和train_y训练模型:

history = model.fit(train_x, train_y, epochs=50, batch_size=32, validation_split=0.1, verbose=1)

我们可以使用Matplotlib绘制训练和验证损失的曲线:

```python
# 绘制训练和验证损失的曲线
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

(4)模型预测

训练完成后,我们可以使用模型对测试集中的股票价格进行预测。由于我们使用了30个股票价格作为输出序列的长度,因此每次预测时,我们需要使用前60个价格作为输入序列。

def predict_future(model, data, input_seq_len, output_seq_len):
    predicted_data = []
    for i in range(len(data)-input_seq_len-output_seq_len+1):
        input_data = data[i:i+input_seq_len]
        predicted_seq = []
        for j in range(output_seq_len):
            predicted_price = model.predict(input_data.reshape((1, input_seq_len, 1)))[0][0]
            predicted_seq.append(predicted_price)
            input_data = np.append(input_data[1:], [[predicted_price]], axis=0)
        predicted_data.append(predicted_seq)
    return np.array(predicted_data)

predicted_data = predict_future(model, scaled_test_data, input_seq_len, output_seq_len)
predicted_data = scaler.inverse_transform(predicted_data.reshape((-1, output_seq_len)))
test_data = scaler.inverse_transform(test_y.reshape((-1, output_seq_len)))

接下来,我们可以绘制预测结果和实际结果的图表:

# 绘制预测结果和实际结果的图表
plt.figure(figsize=(10, 6))
plt.plot(test_data, label='Actual')
plt.plot(predicted_data.flatten(), label='Predicted')
plt.legend()
plt.show()

完整的代码如下所示:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense

# 加载股票价格数据
df = pd.read_csv('AAPL.csv')

# 只使用调整后的收盘价作为特征
data = df.filter(['Adj Close']).values

# 拆分训练集和测试集
training_data_len = int(len(data) * 0.7)
train_data = data[0:training_data_len]
test_data = data[training_data_len:]

# 归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_train_data = scaler.fit_transform(train_data)
scaled_test_data = scaler.transform(test_data)

# 指定LSTM模型参数
lstm_units = 50
input_seq_len = 60
output_seq_len = 30

# 创建LSTM模型
model = Sequential()
model.add(LSTM(units=lstm_units, input_shape=(input_seq_len, 1)))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.summary()
# 训练模型
history = model.fit(train_x, train_y, epochs=50, batch_size=32, validation_split=0.1, verbose=1)

# 绘制训练和验证损失的曲线
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

# 使用模型预测股票价格
predicted_data = predict_future(model, scaled_test_data, input_seq_len, output_seq_len)
predicted_data = scaler.inverse_transform(predicted_data.reshape((-1, output_seq_len)))
test_data = scaler.inverse_transform(test_y.reshape((-1, output_seq_len)))
# 绘制预测结果和实际结果的图表
plt.figure(figsize=(10, 6))
plt.plot(test_data, label='Actual')
plt.plot(predicted_data.flatten(), label='Predicted')
plt.legend()
plt.show()

4. 公式推导

LSTM模型中的关键部分是门控单元,它能够控制信息的流动,从而实现长期依赖关系的捕捉。门控单元由三个部分组成:遗忘门、输入门和输出门。

遗忘门用于控制前一时刻的记忆细胞中的信息是否需要被遗忘,其公式为:

f t = σ ( W f [ h t − 1 , x t ] + b f ) f_t=\sigma(W_f[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf)

其中, h t − 1 h_{t-1} ht1为前一时刻的隐藏状态, x t x_t xt为当前时刻的输入, W f W_f Wf b f b_f bf为遗忘门的权重和偏置, σ \sigma σ为sigmoid函数。

输入门用于控制当前时刻输入信息的权重,其公式为:

i t = σ ( W i [ h t − 1 , x t ] + b i ) i_t=\sigma(W_i[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi)

其中, W i W_i Wi b i b_i bi为输入门的权重和偏置。

记忆细胞的更新通过下面的公式实现:

C t = f t ⊙ C t − 1 + i t ⊙ tanh ⁡ ( W c [ h t − 1 , x t ] + b c ) C_t=f_t\odot C_{t-1}+i_t\odot \tanh(W_c[h_{t-1},x_t]+b_c) Ct=ftCt1+ittanh(Wc[ht1,xt]+bc)

其中, ⊙ \odot 表示元素乘积, W c W_c Wc b c b_c bc为更新记忆细胞的权重和偏置, tanh ⁡ \tanh tanh表示双曲正切函数。

输出门用于控制输出信息的权重,其公式为:

o t = σ ( W o [ h t − 1 , x t ] + b o ) o_t=\sigma(W_o[h_{t-1},x_t]+b_o) ot=σ(Wo[ht1,xt]+bo)

h t h_t ht为当前时刻的隐藏状态,其计算公式为:

h t = o t ⊙ tanh ⁡ ( C t ) h_t=o_t\odot \tanh(C_t) ht=ottanh(Ct)

最终的预测结果通过连接输出层实现。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/397934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

32位Ubuntu系统安装visual studio code

Step.01 下载vscode安装包 vscode自1.36版本后停止支持32位linux系统&#xff0c;所以要使用<1.36版本。1.33版本下载地址&#xff1a; Visual Studio Code March 2019See what is new in the Visual Studio Code March 2019 Release (1.33)https://code.visualstudio.com…

nvm的使用

nvm工具 nvm是什么nvm下载与安装nvm的基本使用 1、nvm介绍 1.1、基于node的开发 在介绍nvm之前&#xff0c;先介绍下前端开发中关于node的使用。目前前端不管是基于vue或者react框架的开发&#xff0c;都是基于node环境下&#xff0c;进行包的管理与开发的。而不同项目组&a…

work-notes(23):结合typora、git、gitee实现云存储笔记完成的操作过程

时间&#xff1a;2023-03-07 文章目录摘要一、下载 typora二、安装 Git三、创建连接远程仓库四、使用 Git 上传到远程仓库五、到gitee上查看总结摘要 由于很想找一个好用&#xff0c;又有云存储的笔记软件。之前用过 有道笔记&#xff08;还行&#xff0c;量大了难找&#xff…

「MySQL进阶」为什么MySQL用B+树做索引而不用二叉查找树、平衡二叉树、B树

「MySQL进阶」为什么MySQL用B树做索引而不用二叉查找树、平衡二叉树、B树 文章目录「MySQL进阶」为什么MySQL用B树做索引而不用二叉查找树、平衡二叉树、B树一、概述二、二叉查找树三、平衡二叉树四、B树五、B树六、聚集索引和非聚集索引七、利用聚集索引和非聚集索引查找数据利…

剑指 Offer 67 把字符串转换成整数

摘要 面试题67. 把字符串转换成整数 一、字符串解析 根据题意&#xff0c;有以下四种字符需要考虑&#xff1a; 首部空格&#xff1a; 删除之即可&#xff1b;符号位&#xff1a;三种情况&#xff0c;即 , − , 无符号"&#xff1b;新建一个变量保存符号位&#xff0…

螯合剂p-SCN-Bn-TCMC,282097-63-6,双功能配体化合物应用于光学成像应用

p-SCN-Bn-TCMC 反应特点&#xff1a;p-SCN-Bn-TCMC属于双功能配体是螯合剂&#xff0c;也具有共价连接到生物靶向载体&#xff08;如抗体、肽和蛋白质&#xff09;的反应位点。应用于核医学、MRI和光学成像应用。西安凯新生物科技有限公司供应的杂环化合物及其衍生物可制作为具…

消息队列理解

为什么使用消息队列 使⽤消息队列主要是为了&#xff1a; 减少响应所需时间和削峰。降低系统耦合性&#xff08;解耦/提升系统可扩展性&#xff09;。 当我们不使⽤消息队列的时候&#xff0c;所有的⽤户的请求会直接落到服务器&#xff0c;然后通过数据库或者 缓存响应。假…

GPU是什么

近期ChatGPT十分火爆&#xff0c;随之而来的是M国开始禁售高端GPU显卡。M国想通过禁售GPU显卡的方式阻挡中国在AI领域的发展。 GPU是什么&#xff1f;GPU&#xff08;英语&#xff1a;Graphics Processing Unit&#xff0c;缩写&#xff1a;GPU&#xff09;是显卡的“大脑”&am…

给比特币“雕花” 增值还是累赘?

比特币网络也能发NFT了&#xff0c;大玩家快速入场。3月6日&#xff0c;Yuga Labs开启了TwelveFold拍卖会&#xff0c;该项目是Yuga Labs在比特币区块链网络上发行的首个NFT合集&#xff0c;内含300个艺术品。 在没有智能合约的比特币网络造NFT&#xff0c;没那么友好。但Web3…

Jmeter+Ant+Jenkins自动化搭建之报告优化

平台简介一个完整的接口自动化测试平台需要支持接口的自动执行&#xff0c;自动生成测试报告&#xff0c;以及持续集成。Jmeter支持接口的测试&#xff0c;Ant支持自动构建&#xff0c;而Jenkins支持持续集成&#xff0c;所以三者组合在一起可以构成一个功能完善的接口自动化测…

概率论与数理统计相关知识

本博客为《概率论与数理统计&#xff0d;&#xff0d;茆诗松&#xff08;第二版&#xff09;》阅读笔记&#xff0c;目的是查漏补缺前置知识数学符号连乘符号&#xff1a;&#xff1b;总和符号&#xff1a;&#xff1b;“任意”符号&#xff1a;∀&#xff1b;“存在”符号&…

IDEA项目中配置Maven镜像源(下载源)

目录前言一、IDEA中Maven的位置二、修改Maven的配置文件2.1 配置文件2.2 修改镜像源三、在IDEA中使配置文件生效四、配置文件和本地仓库迁移前言 在使用IDEA搭建项目的过程中&#xff0c;我们发现框架的jar包下载非常缓慢&#xff0c;这是因为国内访问Maven仓库速度较低&#…

构建GRE隧道打通不同云商的云主机内网

文章目录1. 环境介绍2 GRE隧道搭建2.1 华为云 GRE 隧道安装2.2 阿里云 GRE 隧道安装3. 设置安全组4. 验证GRE隧道4.1 在华为云上 ping 阿里云云主机内网IP4.2 在阿里云上 ping 华为云云主机内网IP5. 总结1. 环境介绍 华为云上有三台云主机&#xff0c;内网 CIDR 是 192.168.0.0…

TensoRT8.4_cuda11.6 sampleOnnxMNIST运行生成

1、版本信息 win10电脑环境&#xff1a; TensorRT:8.4.1.5CUDA: 11.6VS: 2019 环境安装成功后&#xff0c;使用sampleOnnxMNIST测试 2、VS2019环境配置 用vs打开sampleOnnxMNIST项目&#xff0c;位置在 D:\TensorRT-8.4.1.5\samples\sampleOnnxMNIST &#xff08;1&#xf…

创建SpringBoot工程详细步骤

new新建一个项目选择Spring Initializr, 然后配置一下地址, 可以如下图使用阿里云的,(因为国外的Spring官网可能不稳定) 下面这三个地址(选一个)能用的用上就行 https://start.spring.io(默认) https://start.springboot.io https://start.aliyun.com 然后 然后点击Finish…

HarmonyOS/OpenHarmony应用开发-dataUriUtils的使用

模块导入接口详情 dataUriUtils.getId getId(uri: string): number 获取附加到给定uri的路径组件末尾的ID。 参数&#xff1a; 名称 类型 必填 描述 uri string 是 指示要从中获取ID的uri对象。 dataUriUtils.attachId attachId(uri: string, id: number): string …

上班三年,薪资还赶不上应届程序员的一半奖金?

工资的鸿沟&#xff0c;始于社会分工的出现和细化。打工人行走职场&#xff0c;你是否也经历过&#xff1a;卷也卷不赢&#xff0c;躺也躺不平的45人生&#xff01;不同打工人分工提升了社会生产的效率&#xff0c;也加速了社会财富的积累&#xff0c;更提高了人们的收入水平。…

Zookeeper特性和节点数据类型详解

什么是ZK&#xff1f; zk,分布式应用协调框架&#xff0c;Apache Hadoop的一个子项目&#xff0c;解决分布式应用中遇到的数据管理问题。 可以理解为存储少量数据基于内存的数据库。两大核心&#xff1a;文件系统存储结构 和 监听通知机制。 文件系统存储结构 文件目录以 / …

Pytorch深度学习与入门实战

Pytorch深度学习入门与实战Pytorch简介Pytorch特点PyTorch安装环境要求PyTorch兼容的Python版本搭建开发环境下载Miniconda![下载miniconda](https://img-blog.csdnimg.cn/adace1a2f7ae476aa883b53203477c92.pnPytorch官网地址GPU版本安装检查显卡驱动依赖库安装机器学习基础与…

【备战面试】TCP的三次握手与四次挥手

本篇总结的是计算机网络知识相关的面试题&#xff0c;后续也会更新其他相关内容 文章目录1、TCP头部结构2、三次握手3、四次挥手4、为什么TCP连接的时候是三次&#xff1f;两次是否可以&#xff1f;5、为什么TCP连接的时候是三次&#xff0c;关闭的时候却是四次&#xff1f;6、…