用Python实现时间序列模型实战——Day 22: LSTM 与 RNN 模型

news2024/9/20 15:44:56
一、学习内容
1. 长短期记忆网络 (LSTM) 的原理

LSTM(长短期记忆网络) 是一种专门用于处理时间序列数据的神经网络,它克服了传统 RNN 在处理长序列时出现的梯度消失问题。LSTM 通过引入 记忆单元门控机制(输入门、遗忘门、输出门)来选择性地保留或遗忘信息,从而更好地捕捉长期依赖性。

LSTM的主要公式

  • 遗忘门:决定当前单元状态应该遗忘多少过去的信息。

    f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
  • 输入门:决定将多少新的信息写入到细胞状态。

    i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)
  • 候选记忆状态:将当前输入与过去的信息结合,用于更新细胞状态。

    \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)
  • 输出门:决定输出多少细胞状态中的信息作为当前时刻的隐藏状态。

    o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)
  • 细胞状态更新

    C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t
  • 隐藏状态更新

    h_t = o_t \cdot \tanh(C_t)
2. 循环神经网络 (RNN) 的应用

RNN(循环神经网络) 是处理序列数据的基本模型,它通过连接到自身的隐状态,捕捉序列中的时间依赖性。然而,RNN 存在梯度消失问题,导致其在处理长序列时效果较差。

3. LSTM 与 RNN 的超参数调优
  • 学习率:控制模型更新的步长,过大或过小都会影响模型的训练效果。
  • 隐藏层神经元数量:影响模型的表达能力,神经元越多模型越复杂。
  • 序列长度:输入到模型的时间步长,过短可能无法捕捉长依赖,过长则增加计算复杂度。
  • 优化器:常用的优化器有 Adam、RMSprop 等,选择合适的优化器能够加速训练过程。
二、实战案例

我们将使用 tensorflow 构建 LSTM 和 RNN 模型,对时间序列数据进行预测。代码如下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, SimpleRNN

# 1. 数据加载与预处理
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv'
data = pd.read_csv(url, header=0, parse_dates=['Month'], index_col='Month')

# 数据归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data[['Passengers']])

# 生成输入序列
def create_sequences(data, seq_length):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

seq_length = 10
X, y = create_sequences(scaled_data, seq_length)

# 将数据集分为训练集和测试集
train_size = int(len(X) * 0.8)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 2. LSTM 模型构建
lstm_model = Sequential()
lstm_model.add(LSTM(50, activation='relu', input_shape=(X_train.shape[1], X_train.shape[2])))
lstm_model.add(Dense(1))
lstm_model.compile(optimizer='adam', loss='mse')

# 训练 LSTM 模型
lstm_model.fit(X_train, y_train, epochs=100, batch_size=32, verbose=0)

# 进行预测
lstm_preds = lstm_model.predict(X_test)
lstm_preds_rescaled = scaler.inverse_transform(lstm_preds)

# 3. RNN 模型构建
rnn_model = Sequential()
rnn_model.add(SimpleRNN(50, activation='relu', input_shape=(X_train.shape[1], X_train.shape[2])))
rnn_model.add(Dense(1))
rnn_model.compile(optimizer='adam', loss='mse')

# 训练 RNN 模型
rnn_model.fit(X_train, y_train, epochs=100, batch_size=32, verbose=0)

# 进行预测
rnn_preds = rnn_model.predict(X_test)
rnn_preds_rescaled = scaler.inverse_transform(rnn_preds)

# 4. 结果可视化
plt.figure(figsize=(12, 6))
plt.plot(data.index[-len(y_test):], scaler.inverse_transform(y_test.reshape(-1, 1)), label='Actual Passengers')
plt.plot(data.index[-len(y_test):], lstm_preds_rescaled, label='LSTM Predictions')
plt.plot(data.index[-len(y_test):], rnn_preds_rescaled, label='RNN Predictions')
plt.title('LSTM vs RNN Predictions on Airline Passengers Data')
plt.xlabel('Date')
plt.ylabel('Number of Passengers')
plt.legend()
plt.grid(True)
plt.show()
三、代码解释
3.1 数据加载与预处理
  • 使用航空乘客数据集,并对数据进行归一化处理,方便 LSTM 和 RNN 模型训练。
  • 使用 create_sequences 函数生成输入序列和对应的目标值,时间步长为10。
3.2 LSTM 模型构建
  • LSTM 模型中使用50个隐藏层神经元,激活函数为 relu,并使用 Adam 优化器和均方误差损失函数。
  • 模型训练了100个 epochs。
3.3 RNN 模型构建
  • RNN 模型中使用简单循环单元(SimpleRNN),结构与 LSTM 类似。
3.4预测与结果可视化
  • 将模型的预测结果与真实的乘客数量进行对比,绘制图形。
四、结果输出

五、结果分析
  • 5.1 LSTM 预测

    • LSTM 模型能够很好地捕捉时间序列中的长期依赖性,预测曲线与实际值较为接近。
  • 5.2 RNN 预测

    • RNN 模型虽然能够捕捉短期依赖性,但在长时间序列数据上表现较差,预测结果可能会出现偏差。
六、总结

通过本次案例,我们学习了如何构建 LSTM 和 RNN 模型进行时间序列预测。LSTM 由于其记忆门机制,能够较好地捕捉长序列中的模式,而传统的 RNN 在处理长序列时容易出现梯度消失问题。通过调节超参数(如隐藏层神经元数量、学习率等),可以进一步优化模型的预测性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2140327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ruffle 继续在开源软件中支持 Adobe Flash Player

大多数人已经无需考虑对早已寿终正寝的 Adobe Flash 的支持,但对于那些仍有一些 Adobe Flash/SWF 格式的旧资产,或想重温一些基于 Flash 的旧游戏/娱乐项目的人来说,开源 Ruffle 项目仍是 2024 年及以后处理 Flash 的主要竞争者之一。 Ruffl…

【Hot100】LeetCode—4. 寻找两个正序数组的中位数

目录 1- 思路题目识别二分 2- 实现⭐4. 寻找两个正序数组的中位数——题解思路 3- ACM 实现 原题链接:4. 寻找两个正序数组的中位数 1- 思路 题目识别 识别1 :给定两个数组 nums1 和 nums2 ,找出数组的中位数 二分 思路 将寻找中位数 —…

Python数据分析案例59——基于图神经网络的反欺诈交易检测(GCN,GAT,GIN)

以前的数据分析案例的文章可以参考:数据分析案例 案例背景 以前二维的表格数据的机器学习模型都做烂了,[线性回归,惩罚回归,K近邻,决策树,随机森林,梯度提升,支持向量机,神经网络],还有现在常用的XGBoost,lightgbm,ca…

ffmpeg实现视频的合成与分割

视频合成与分割程序使用 作者开发了一款软件,可以实现对视频的合成和分割,界面如下: 播放时,可以选择多个视频源;在选中“保存视频”情况下,会将多个视频源合成一个视频。如果只取一个视频源中一段视频…

keil5进行stm32编程时常遇到的问题和ST-LINK在线仿真的连接问题

本文记录原因 最近一直在尝试usb的自定义键盘、无刷电机和pcb的一些东西,很久没使用stm32编写程序了。在浏览购物网站的时候发现很多便宜的小系统板。 使用小的系统板原因 1,在网上看到板子很便宜,以前很少看见,但现在网上对这…

大数据新视界 --大数据大厂之数据科学项目实战:从问题定义到结果呈现的完整流程

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

uniapp 知识总结

1. uniapp 知识总结 uni-app是一个使用 Vue.js 开发所有前端应用的框架,开发者编写一套代码,可发布到iOS、Android、Harmony、Web(响应式)以及各种小程序(微信/支付宝/百度/头条/飞书/QQ/快手/钉钉/淘宝)、…

【webpack4系列】设计可维护的webpack4.x+vue构建配置(终极篇)

文章目录 构建配置包设计通过多个配置文件管理不同环境的 webpack 配置抽离成一个 npm 包统一管理(省略)通过 webpack-merge 组合配置 功能模块设计目录结构设计构建配置插件安装webpack、webpack-cli关联HTML插件html-webpack-plugin解析ES6解析vue、JS…

笔记本安装Linux系统向日葵远程控制

1、制作启动U盘 Ubuntu: Create a bootable USB stick with Rufus on Windows 2、安装 1、重启笔记本,出现logo后,按 f2(注:联想拯救者。其他型号参考官方文档)。按左右方向键切换到 Boot。选择 Boot Mo…

【软件测试】--xswitch将请求代理到测试桩

背景 在做软件测试的过程中,经常会遇见需要后端返回特定的响应数据,这个时候就需要用到测试桩,进行mock测试。 测试工程师在本地模拟后端返回数据时,需要将前端请求数据代理到本地,本文介绍xswitch插件代理请求到flas…

Float类型的有效位数有几位

大家好,今天我们来聊一聊C语言中的Float类型。 正如标题所说,你知道Float类型的有效位数有几位吗? 或者你知道为什么Float类型可以表示数字16777218但是却无法表示16777217吗? 如果你不是很确定那我们就一起来看看吧&#xff0…

AcWing算法基础课-789数的范围-Java题解

大家好,我是何未来,本篇文章给大家讲解《AcWing算法基础课》789 题——数的范围。本文详细解析了一个基于二分查找的算法题,题目要求在有序数组中查找特定元素的首次和最后一次出现的位置。通过使用两个二分查找函数,程序能够高效…

数据结构(Day13)

一、学习内容 内存空间划分 1、一个进程启动后,计算机会给该进程分配4G的虚拟内存 2、其中0G-3G是用户空间【程序员写代码操作部分】【应用层】 3、3G-4G是内核空间【与底层驱动有关】 4、所有进程共享3G-4G的内核空间,每个进程独立拥有0G-3G的用户空间 …

【C++】深入理解作用域和命名空间:从基础到进阶详解

🦄个人主页:小米里的大麦-CSDN博客 🎏所属专栏:C_小米里的大麦的博客-CSDN博客 🎁代码托管:C: 探索C编程精髓,打造高效代码仓库 (gitee.com) ⚙️操作环境:Visual Studio 2022 目录 一、前言 二、域的概念 1. 类域 2. 命名空间…

Redis——常用数据类型string

目录 常用数据结构(类型)Redis单线程模型Reids为啥效率这么高?速度这么快?(参照于其他数据库) stringsetgetMSET 和 MGETSETNX,SETEX,PSETEXincr,incrby,decr…

sshj使用代理连接服务器

之前我是用jsch连接服务器的,但是没办法使用私钥连接,搜了一下似乎是不支持新版的SSH-rsa,并且jsch很久没更新了,java - "com.jcraft.jsch.JSchException: Auth fail" with working passwords - Stack Overflow 没办法…

mybatis的基本使用与配置

注释很详细,直接上代码 项目结构 源码 UserMapper package com.amoorzheyu.mapper;import com.amoorzheyu.pojo.User; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Select;import java.util.List;Mapper //在运行时生成代…

从数据仓库到数据中台再到数据飞轮:金融行业的数据技术进化史

前言​ 大家好,我是一名大数据开发工程师,在金融行业深耕多年,其实数据技术的演进不仅是技术层面的革新,更是业务模式与决策方式的深刻变革。从最开始的数据仓库兴起,到数据中台的普及,再到数据飞轮的出现…

MFEA/D-DRA--基于分解和动态资源分配的多目标多任务优化

MFEA/D-DRA–基于分解和动态资源分配的多目标多任务优化 title: A Multiobjective multifactorial optimization algorithm based on decomposition and dynamic resource allocation strategy author: Shuangshuang Yao, Zhiming Dong, Xianpeng Wang…

跨界融合,GIS如何赋能游戏商业——以《黑神话:悟空》为例

在数字化时代,地理信息系统(GIS)技术正以其独特的空间分析和可视化能力,为游戏产业带来革命性的变革。《黑神话:悟空》作为中国首款3A级别的动作角色扮演游戏,不仅在游戏设计和技术上取得了突破&#xff0c…