深度学习Week13-火灾温度预测(LSTM)

news2024/9/21 12:35:15
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:第R2周:LSTM-火灾温度预测(训练营内部可读)
  • 🍖 作者:K同学啊

 任务说明:数据集中提供了火灾温度(Tem1)、一氧化碳浓度(CO 1)、烟雾浓度(Soot 1)随着时间变化数据,我们需要根据这些数据对未来某一时刻的火灾温度做出预测(本次任务仅供学习)

🍺要求:
1了解LSTM是什么,并使用其构建一个完整的程序
2R2达到0.83

🍻拔高:
1使用第1~8个时刻的数据预测第9~10个时刻的温度数据

一句话介绍LSTM,它是RNN的进阶版,如果说RNN的最大限度是理解一句话,那么LSTM的最大限度则是理解一段话,详细介绍如下:
LSTM,全称为长短期记忆网络(Long Short Term Memory networks),是一种特殊的RNN,能够学习到长期依赖关系。LSTM由Hochreiter & Schmidhuber (1997)提出,许多研究者进行了一系列的工作对其改进并使之发扬光大。LSTM在许多问题上效果非常好,现在被广泛使用。
所有的循环神经网络都有着重复的神经网络模块形成链的形式。在普通的RNN中,重复模块结构非常简单,其结构如下:
 

 LSTM避免了长期依赖的问题。可以记住长期信息!LSTM内部有较为复杂的结构。能通过门控状态来选择调整传输的信息,记住需要长时间记忆的信息,忘记不重要的信息,其结构如下:

一.前期准备工作

1.导入数据

数据地址:🔗百度网盘

import tensorflow as tf
import pandas     as pd
import numpy      as np

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
print(gpus)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

顺便先导入数据

df_1 = pd.read_csv("woodpine2.csv")
df_1.head()

 2.数据可视化

import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['savefig.dpi'] = 500 #图片像素
plt.rcParams['figure.dpi']  = 500 #分辨率

fig, ax =plt.subplots(1,3,constrained_layout=True, figsize=(14, 3))

sns.lineplot(data=df_1["Tem1"], ax=ax[0])
sns.lineplot(data=df_1["CO 1"], ax=ax[1])
sns.lineplot(data=df_1["Soot 1"], ax=ax[2])
plt.show()

 二、构建数据集

dataFrame=df_1.iloc[:,1:]
print(dataFrame)

 1.设置X,y

width_X=8
width_y=2

 取前8个时间段的Tem1、CO 1、Soot 1为X,而第9,10个时间段的Tem1为y。

X = []
y = []

in_start = 0

for _, _ in df_1.iterrows():
    in_end = in_start + width_X
    out_end = in_end + width_y

    if out_end < len(dataFrame):
        X_ = np.array(dataFrame.iloc[in_start:in_end, ])

        X_ = X_.reshape((len(X_) * 3))
        y_ = np.array(dataFrame.iloc[in_end:out_end, 0])

        X.append(X_)
        y.append(y_)

    in_start += 1

X = np.array(X)
y = np.array(y)

print(X.shape, y.shape)

 ((5938, 24), (5938, 2))

 2.归一化

from sklearn.preprocessing import MinMaxScaler

#将数据归一化,范围是0到1
sc       = MinMaxScaler(feature_range=(0, 1))
X_scaled = sc.fit_transform(X)
X_scaled.shape

(5939, 24)

X_scaled=X_scaled.reshape(len(X_scaled),width_X,3)
X_scaled.shape

 (5938, 8, 3)

 3.划分数据集

取5000之前的数据为训练集,5000之后的为验证集

X_train=X_scaled[:5000]
y_train=y[:5000]

X_test=X_scaled[5000:,]
y_test=y[5000:,]

X_train.shape

(5000, 8, 3)

 三.构建模型

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,Bidirectional
from tensorflow.keras        import Input

model_lstm = Sequential()
model_lstm.add(LSTM(units=64, activation='relu', return_sequences=True,
               input_shape=(X_train.shape[1], 3)))
model_lstm.add(LSTM(units=64, activation='relu'))

model_lstm.add(Dense(width_y))

WARNING:tensorflow:Layer lstm_8 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU
WARNING:tensorflow:Layer lstm_9 will not use cuDNN kernel since it doesn't meet the cuDNN kernel criteria. It will use generic GPU kernel as fallback when running on GPU

四.模型训练

1.模型编译

#只观察loss数值,不观察准确率,所以删去metrics选项
model_lstm.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
                  loss='mean_squared_error')

from tensorflow.keras.callbacks import ModelCheckpoint

ModelCheckPointer=ModelCheckpoint('best_model.h5',
                                  monitor='val_loss',
                                  save_best_only=True,
                                  save_weights_only=True,               
                )
print(X_train.shape,y_train.shape)

(5000, 8, 3) (5000, 2)

history_lstm=model_lstm.fit(X_train,y_train,
                            batch_size=64,
                            epochs=50,
                            validation_data=(X_test,y_test),
                            validation_freq=1,
                            callbacks=[ModelCheckPointer])

 然后就是训练了。和以前训练差不多,不在赘述。

五.评估

1.loss图

# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

plt.figure(figsize=(5, 3),dpi=120)

plt.plot(history_lstm.history['loss']    , label='LSTM Training Loss')
plt.plot(history_lstm.history['val_loss'], label='LSTM Validation Loss')

plt.title('Training and Validation Loss')
plt.legend()
plt.show()

2.调用方模型进行预测

model_lstm.load_weights('best_model.h5')

predicted_y_lstm=model_lstm.predict(X_test)

y_test_one=[i[0] for i in y_test]
predicted_y_lstm_one=[i[0] for i in predicted_y_lstm]

y_test_two=[i[1] for i in y_test]
predicted_y_lstm_two=[i[1] for i in predicted_y_lstm]


fig, ax =plt.subplots(1,2,constrained_layout=True, figsize=(14, 3))
#画出第9个时间段真实数据与预测数据的对比图
ax[0].plot(y_test_one[:1000],color='red',label='真实值')
ax[0].plot(predicted_y_lstm_one[:1000],color='blue',label='预测值')

#画出第10个时间段真实数据与预测数据的对比图
ax[1].plot(y_test_two[:1000],color='red',label='真实值')
ax[1].plot(predicted_y_lstm_two[:1000],color='blue',label='预测值')

ax[0].set(xlabel='X',ylabel='Y',title='第9个时间段')
ax[1].set(xlabel='X',ylabel='Y',title='第10个时间段')

from sklearn import metrics
"""
RMSE :均方根误差  ----->  对均方误差开方
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量
"""
RMSE_lstm  = metrics.mean_squared_error(predicted_y_lstm, y_test)**0.5
R2_lstm    = metrics.r2_score(predicted_y_lstm, y_test)

print('均方根误差: %.5f' % RMSE_lstm)
print('R2: %.5f' % R2_lstm)

 均方根误差: 6.58734
R2: 0.85471

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/134853.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c++知识点总结

文章目录1.引用2. 重载3. extern “C”4.构造函数5.析构函数6.类和对象7.面向对象模型8.继承9.多态10.函数模板11.类模板12.STL1.引用 不要返回局部变量的引用&#xff0c;调用函数执行后局部变量会销毁 2. 重载 是C多态的特性&#xff08;静态多态&#xff09;。同一个函数名代…

正点原子STM32(基于HAL库)

正点原子B站视频地址&#xff1a;https://www.bilibili.com/video/BV1bv4y1R7dp?p1&vd_sourcecc0e43b449de7e8663ca1f89dd5fea7d 目录单片机简介Cortex-M介绍初识STM32STM32命名规则STM32 选型STM32 设计数据手册最小系统IO 分配STM32启动过程分析启动模式启动文件分析启动…

基于彩色的图像分割

图像分割就是把图像分成若干个特定的、具有独特性质的区域并提出感兴趣目标的技术和过程。它是由图像处理到图像分析的关键步骤。现有的图像分割方法主要分以下几类&#xff1a;基于阈值的分割方法、基于区域的分割方法、基于边缘的分割方法以及基于特定理论的分割方法等。从数…

小程序03/ uni-app认识目录结构 、uni-app应用生命周期 和 生命周期钩子、uni-app页面生命周期 和 生命周期钩子

一.uni-app认识目录结构 二.uni-app应用生命周期 和 生命周期钩子 位置: uni-app 在App.vue中监听 在页面监听无效 说明: App.vue是uni-app的主组件 所有页面都是在App.vue 下进行切换的、是页面入口文件 但是App.vue 本身不是页面 这里不能编写视图元素 也就是没有<tmpl…

Git(一) - Git 概述

01_尚硅谷_Git_课程介绍_哔哩哔哩_bilibili Git 是一个免费的、开源的分布式版本控制系统&#xff0c;可以高效的处理从小型到大型的各种项目。 一、何为版本控制 版本控制是一种记录文件内容变化&#xff0c;以便将来查阅特定版本修订情况的系统。 版本控制其实最主要的是可以…

java 瑞吉外卖优化day2 Nginx

Nginx概述 下载与安装 可以到Nginx官方网站下载Nginx的安装包&#xff0c;地址为&#xff1a;https://nginx.org/en/download.html 安装过程&#xff1a; 1、安装依赖包 yum -y install gcc pcre-devel zlib-devel openssl openssl-devel 先yum install wget &#xff0c;…

随机数 - 时间种子的方案与实践

1.应用场景 主要弄清楚设置随机数种子的方法&#xff0c;可用于游戏开发当中的时间种子从而产生合理的随机数&#xff0c;避免出现bug。 2.学习/操作 1.文档阅读 07 | 带你快速上手 Lua-极客时间 2.整理输出 2.1 什么是种子 现在很多朋友下载东西时都会用bt种子文件&#xff0…

5分钟带你学习 linux 收发邮件步骤详解 at命令详解 crontab命令详解 附加at crontab命令练习

linux 收发邮件步骤详解 1.安装软件yum install mailx -yyum install sendmail -y 2.启动服务sendmailsystemctl start sendmail 3.更改配置vim /etc/mail.rc at命令详解 实例&#xff1a; crontab命令详解 实例&#xff1a; linux 收发邮件步骤详解 1.安装软件 yum…

拜占庭将军问题

前言 在分布式系统中交换信息, 部分成员可能出错导致发送了错误的信息, 在这种情况下如何达成共识. 这就是拜占庭将军问题所要解决的. 问题的简要描述如下: 3个军队协同作战(为了简单易懂, 以3个军队描述)每支军队的作战策略有两种"进攻"和"撤退"每个军…

SparkLaunch提交Spark任务到Yarn集群

SparkLaunch提交任务1.提交Spark任务的方式2.SparkLaunch 官方接口3.任务提交流程及实战1.提交Spark任务的方式 通过Spark-submit 提交任务通过Yarn REST Api提交Spark任务通过Spark Client Api 的方式提交任务通过SparkLaunch 自带API提交任务基于Livy的方式提交任务&#xf…

深拷贝浅拷贝的区别?如何实现一个深拷贝

一、数据类型存储 前面文章我们讲到&#xff0c;JavaScript中存在两大数据类型&#xff1a; 基本类型引用类型 基本类型数据保存在在栈内存中 引用类型数据保存在堆内存中&#xff0c;引用数据类型的变量是一个指向堆内存中实际对象的引用&#xff0c;存在栈中 二、浅拷贝…

【2】SCI易中期刊推荐——遥感图像领域(2区)

🚀🚀🚀NEW!!!SCI易中期刊推荐栏目来啦 ~ 📚🍀 SCI即《科学引文索引》(Science Citation Index, SCI),是1961年由美国科学信息研究所(Institute for Scientific Information, ISI)创办的文献检索工具,创始人是美国著名情报专家尤金加菲尔德(Eugene Garfield…

2022年最有开创性的10篇AI论文总结

2022年随着聊天GPT和Mid - journey和Dall-E等图像生成器的流行&#xff0c;我们看到了整个人工智能领域的重大进展。在人工智能和计算机科学的时代&#xff0c;这是令人振奋的一年。本文我们总结了在2022年发表的最具开创性的10篇论文&#xff0c;无论如何你都应该看看。 1、Al…

Apache Calcite初识

Calcite原理和代码讲解(一) https://blog.csdn.net/qq_35494772/article/details/118887267quickstart:Apache Calcite精简入门与学习指导 https://blog.51cto.com/xpleaf/2639844quickstart:多源数据的关联 csv和mem数据类型 https://cloud.tencent.com/developer/article/162…

【Javassist】快速入门系列14 使用Javassist导入包路径

系列文章目录 01 在方法体的开头或结尾插入代码 02 使用Javassist实现方法执行时间统计 03 使用Javassist实现方法异常处理 04 使用Javassist更改整个方法体 05 当有指定方法调用时替换方法调用的内容 06 当有构造方法调用时替换方法调用的内容 07 当检测到字段被访问时使用语…

CSS复习(一)

CSS复习1.前言2. CSS介绍2.1 CSS的引入方式2.2 选择器2.2 颜色的赋值方式3. 补充4.display4.1 盒子模型4.1.1 盒子模型之宽高盒子模型之外边距盒子模型之边框盒子模型之内边距4.2 文本问题1.前言 首先补充一下部分相关知识&#xff1a; 分区标签自身没有显示效果&#xff0c;…

【算法】kmp、Trie、并查集、堆

文章目录1.kmp2.Trie3.并查集4.堆1.kmp KMP 的精髓就是 next 数组&#xff1a;也就是用 next[j] k;简单理解就是&#xff1a;来保存子串某个位置匹配失败后&#xff0c;回退的位置。 给定一个字符串 S&#xff0c;以及一个模式串 P&#xff0c;所有字符串中只包含大小写英文字…

大文件上传如何做断点续传

大文件上传如何做断点续传 一、是什么 不管怎样简单的需求&#xff0c;在量级达到一定层次时&#xff0c;都会变得异常复杂 文件上传简单&#xff0c;文件变大就复杂 上传大文件时&#xff0c;以下几个变量会影响我们的用户体验 服务器处理数据的能力请求超时网络波动 上…

信息安全3——数字签名和认证

1 &#xff09;签名&#xff1a;手写签名是被签文件的物理组成部分&#xff0c;而数字签名不是被签消息的物理部分&#xff0c;因而需要将签名连接到被签消息上。 2 &#xff09;验证&#xff1a;手写签名是通过将它与其它真实的签名进行比较来验证而数字签名是利用已经公开的验…

年终总结(我心飞翔向)

2022 年度个人总结&#xff08;自由向&#xff09; 前奏 其实在2021年12月底考研前就回家了&#xff0c;回家做毕设。他们考研的那几天回了中北&#xff0c;参加了党支部会议&#xff0c;见证了一批同学的转预转正&#xff1b;收拾了一大波衣服&#xff0c;因为我已经提前想到…