用Python实现时间序列模型实战——Day 23: LSTM 与 RNN 模型的深入学习

news2024/12/23 8:15:42
一、学习内容
1. 深入理解 LSTM 和 RNN 模型的工作原理

LSTM 和 RNN 模型都擅长处理时间序列数据,但它们在处理长序列时遇到了一些问题,比如 梯度消失梯度爆炸。LSTM 通过 门控机制 改进了传统 RNN 的缺陷,但在处理非常长的序列时仍可能遇到效率和性能问题。

2. 常见问题及解决方法
  • 梯度消失:随着序列长度增加,反向传播时梯度逐渐变小,模型难以学习远端依赖关系。
  • 长序列建模:LSTM 可以捕捉较长序列的依赖关系,但如果序列过长,LSTM 也会遇到性能瓶颈。
3. 高级技巧优化 LSTM 和 RNN 模型
  • 双向 LSTM (Bidirectional LSTM)

    双向 LSTM 是一种改进的模型,它不仅考虑过去的状态,还同时考虑未来的状态。通过双向遍历序列,双向 LSTM 更好地捕捉全局信息。

  • 堆叠 LSTM (Stacked LSTM)

    堆叠 LSTM 是指将多层 LSTM 堆叠在一起,增强模型的表达能力,适合处理复杂的时间序列问题。

  • 注意力机制 (Attention Mechanism)

    注意力机制通过赋予输入序列中不同位置的权重,使得模型能够更加关注关键时间步长,特别适合处理长序列任务。

二、实战案例

我们将使用双向 LSTM 和堆叠 LSTM 来对时间序列数据进行建模。数据仍然使用航空乘客数据集。Python 代码如下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Bidirectional
from tensorflow.keras.layers import Attention, TimeDistributed, RepeatVector

# 1. 数据加载与预处理
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv'
data = pd.read_csv(url, header=0, parse_dates=['Month'], index_col='Month')

# 数据归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data[['Passengers']])

# 生成输入序列
def create_sequences(data, seq_length):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

seq_length = 10
X, y = create_sequences(scaled_data, seq_length)

# 将数据集分为训练集和测试集
train_size = int(len(X) * 0.8)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 2. 双向 LSTM 模型
bidirectional_model = Sequential()
bidirectional_model.add(Bidirectional(LSTM(50, activation='relu'), input_shape=(X_train.shape[1], X_train.shape[2])))
bidirectional_model.add(Dense(1))
bidirectional_model.compile(optimizer='adam', loss='mse')

# 训练双向 LSTM 模型
bidirectional_model.fit(X_train, y_train, epochs=100, batch_size=32, verbose=0)

# 进行预测
bi_preds = bidirectional_model.predict(X_test)
bi_preds_rescaled = scaler.inverse_transform(bi_preds)

# 3. 堆叠 LSTM 模型
stacked_lstm_model = Sequential()
stacked_lstm_model.add(LSTM(50, activation='relu', return_sequences=True, input_shape=(X_train.shape[1], X_train.shape[2])))
stacked_lstm_model.add(LSTM(50, activation='relu'))
stacked_lstm_model.add(Dense(1))
stacked_lstm_model.compile(optimizer='adam', loss='mse')

# 训练堆叠 LSTM 模型
stacked_lstm_model.fit(X_train, y_train, epochs=100, batch_size=32, verbose=0)

# 进行预测
stacked_preds = stacked_lstm_model.predict(X_test)
stacked_preds_rescaled = scaler.inverse_transform(stacked_preds)

# 4. 结果可视化
plt.figure(figsize=(12, 6))
plt.plot(data.index[-len(y_test):], scaler.inverse_transform(y_test.reshape(-1, 1)), label='Actual Passengers')
plt.plot(data.index[-len(y_test):], bi_preds_rescaled, label='Bidirectional LSTM Predictions')
plt.plot(data.index[-len(y_test):], stacked_preds_rescaled, label='Stacked LSTM Predictions')
plt.title('Bidirectional LSTM vs Stacked LSTM Predictions')
plt.xlabel('Date')
plt.ylabel('Number of Passengers')
plt.legend()
plt.grid(True)
plt.show()
三、代码解释
3.1 数据预处理
  • 使用航空乘客数据集,生成了时间步长为10的输入序列。
  • 数据集被分为训练集和测试集,80%用于训练,20%用于测试。
3.2 双向 LSTM 模型
  • 双向 LSTM 模型通过同时从过去和未来两个方向来学习时间序列数据,从而获得更好的预测效果。
3.3 堆叠 LSTM 模型
  • 堆叠 LSTM 模型使用了两层 LSTM,第一层设置 return_sequences=True,以保证输出的序列可以传递到下一层 LSTM。
3.4 预测与可视化
  • 预测结果使用 inverse_transform 还原到原始数据范围,并与真实值进行对比。通过可视化图表,可以直观地比较双向 LSTM 和堆叠 LSTM 的预测效果。
四、结果输出

五、结果分析
5.1 双向 LSTM 预测结果
  • 双向 LSTM 模型通过从序列的两个方向进行学习,可以更好地捕捉到全局的模式,因此在一些复杂的时间序列任务中可能具有优势。
5.2 堆叠 LSTM 预测结果
  • 堆叠 LSTM 模型通过多层 LSTM 的堆叠,增强了模型的表达能力,可以处理更加复杂的时间依赖关系。
六、总结

通过本次案例,我们深入了解了 LSTM 和 RNN 模型的高级优化技巧。双向 LSTM 模型通过从过去和未来两个方向同时进行学习,增强了模型的全局感知能力,而堆叠 LSTM 模型则通过多层堆叠提升了模型的复杂性和表达能力。实际预测效果根据数据和任务的不同可能有所变化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2141592.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java浅,深拷贝;内,外部类的学习了解

目录 浅拷贝 深拷贝 内部类 匿名内部类 实例内部类 静态内部类 外部类 浅拷贝 简单理解:定义了A,A里面有age和num,拷贝成为B,B里面有age和num package demo1浅克隆和深克隆;//interfaces 是定义了一个接口//implements是使…

火语言RPA流程组件介绍--浏览器页面操作

🚩【组件功能】:浏览器页面前进,后退,刷新及停止等操作 配置预览 配置说明 丨操作类型 后退/前进/刷新 丨超时时间 支持T或# 输入仅支持整型 页面操作超时时间 丨执行后后等待时间(ms) 支持T或# 当前组件执行完成后继续等待…

Spring框架常见漏洞

文章目录 SpEL注入攻击Spring H2 Database Console未授权访问Spring Security OAuth2远程命令执行漏洞(CVE-2016-4977)Spring WebFlow远程代码执行漏洞(CVE-2017-4971)Spring Data Rest远程命令执行漏洞(CVE-2017-8046)Spring Messaging远程命令执行漏洞(CVE-2018-1270)Spring …

Python酷库之旅-第三方库Pandas(119)

目录 一、用法精讲 526、pandas.DataFrame.head方法 526-1、语法 526-2、参数 526-3、功能 526-4、返回值 526-5、说明 526-6、用法 526-6-1、数据准备 526-6-2、代码示例 526-6-3、结果输出 527、pandas.DataFrame.idxmax方法 527-1、语法 527-2、参数 527-3、…

C语言刷题日记(附详解)(5)

一、选填部分 第一题: 下面代码在64位系统下的输出为( ) void print_array(int arr[]) {int n sizeof(arr) / sizeof(arr[0]);for (int i 0; i < n; i)printf("%d", arr[i]); } int main() {int arr[] { 1,2,3,4,5 };print_array(arr);return 0; } A . 1…

vi | vim基本使用

vim三模式&#xff1a;① 输入模式 ②命令模式 ③末行模式&#xff08;编辑模式&#xff09; vim四模式&#xff1a;① 输入模式 ②命令模式 ③末行模式&#xff08;编辑模式&#xff09; ④V模式 一、命令模式进入输入模式方法&#xff1a; 二、命令模式基…

Hybrid接口的基础配置

Hybrid模式是交换机端口的一种配置模式&#xff0c;它允许端口同时携带多个VLAN&#xff08;虚拟局域网&#xff09;的流量。Hybrid端口可以指定哪些VLAN的数据帧被打上标签&#xff08;tagged&#xff09;和哪些VLAN的数据帧在发送时去除标签&#xff08;untagged&#xff09;…

828华为云征文|部署知识库问答系统 MaxKB

828华为云征文&#xff5c;部署知识库问答系统 MaxKB 一、Flexus云服务器X实例介绍1.1 云服务器介绍1.2 核心竞争力1.3 计费模式 二、Flexus云服务器X实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置 三、部署 MaxKB3.1 MaxKB 介绍3.2 Docker 环境搭建3.3 MaxKB 部署3.4 Max…

Leetcode—322. 零钱兑换【中等】(memset(dp,0x3f, sizeof(dp))

2024每日刷题&#xff08;159&#xff09; Leetcode—322. 零钱兑换 算法思想 dp实现代码 class Solution { public:int coinChange(vector<int>& coins, int amount) {int m coins.size();int n amount;int dp[m 1][n 1];memset(dp, 0x3f, sizeof(dp));dp[0][…

基于springboot+vue+uniapp的驾校报名小程序

开发语言&#xff1a;Java框架&#xff1a;springbootuniappJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#…

使用随机森林模型在digits数据集上执行分类任务

程序功能 使用随机森林模型对digits数据集进行手写数字分类任务。具体步骤如下&#xff1a; 加载数据&#xff1a;从digits数据集中获取手写数字图片的特征和对应的标签。 划分数据&#xff1a;将数据集分为训练集和测试集&#xff0c;测试集占30%。 训练模型&#xff1a;使用…

鸿蒙开发笔记_电商严选02_登录页面跳转到我的页面、并传值

鸿蒙开发笔记整理,方便以后查阅! 由于上班较忙,只能抽空闲暇时间,快速整理更新中。。。 登录页面跳转到我的页面、并传值 效果图 我的设置页面 /*** 我的设置页面*/ import CommonConstants from ./CommonConstants import ItemData from ./ItemData import DataModel fr…

某个图形商标驳回,不建议做驳回复审!

近日一四川的网友联系到普推知产商标老杨&#xff0c;咨询看驳回的商标可以做驳回复审不&#xff0c;是个纯图形商标&#xff0c;这个一看是一标多类&#xff0c;就是在一个商标名称是申请两个类别&#xff0c;42类部分通过&#xff0c;35类全部驳回。 35类和42类引用的近似商标…

07_Python数据类型_集合

Python的基础数据类型 数值类型&#xff1a;整数、浮点数、复数、布尔字符串容器类型&#xff1a;列表、元祖、字典、集合 集合 集合&#xff08;set&#xff09;是Python中一个非常强大的数据类型&#xff0c;它存储的是一组无序且不重复的元素&#xff0c;集合中的元素必须…

SpringBoot 消息队列RabbitMQ死信交换机

介绍 生产者发送消息时指定一个时间&#xff0c;消费者不会立刻收到消息&#xff0c;而是在指定时间之后才收到消息。 死信交换机 当一个队列中的消息满足下列情况之一时&#xff0c;就会成为死信(dead letter) 消费者使用basic.reject或 basic.nack声明消费失败&#xff0…

LidarView之定制版本

介绍 LidarView软件定制开发需要关注几点&#xff1a;1.应用程序名称&#xff1b;2.程序logo&#xff1b;3.Application版本号&#xff1b;4.安装包版本号 应用程序名称 在项目的顶层cmake里边可以指定程序名称 project(LidarView)需要指定跟Superbuild一样的编译类型 set…

英语学习之fruit

目录 不熟悉熟悉 不熟悉 breadfruit 面包果 date 椰枣 raspberry 覆盆子 blackberry 黑莓 blackcurrant 黑加仑&#xff0c;黑醋栗 plum 李子 熟悉 apple 苹果&#x1f34e; coconut 椰子&#x1f965; banana 香蕉&#x1f34c; tomato 西红柿 pear 梨子 watermelon 西瓜…

30款免费好用的工具,打工人必备!

免费工具软件&#xff0c;办公人必备&#xff0c;提升工作效率 启动盘制作&#xff1a;Ventoype工具&#xff1a;微PEwindows/office jh工具&#xff1a;HEU KMS Activator桌面资料转移工具&#xff1a;个人资料专业工具右键菜单管理&#xff1a;ContextMenuManager驱动安装&a…

【面试八股总结】GMP模型

GMP概念 G&#xff08;Goroutine&#xff09;&#xff1a;代表Go协程&#xff0c;是参与调度与执行的最小单位。 存储Goroutine执行栈信息、状态、以及任务函数等。G的数量无限制&#xff0c;理论上只受内存的影响。Goroutines 是并发执行的基本单位&#xff0c;相比于传统的线…

虽难必学系列:Netty

Netty 是一个基于 Java 的高性能、异步事件驱动的网络应用框架&#xff0c;广泛用于构建各类网络应用&#xff0c;尤其是在高并发、低延迟场景下表现出色。作为一个开源项目&#xff0c;Netty 提供了丰富的功能&#xff0c;使得开发者可以轻松构建协议服务器和客户端应用程序。…