【Week-R2】使用LSTM实现火灾预测(tf版本)

news2024/11/18 19:51:57

【Week-R2】使用LSTM实现火灾预测(tf版本)

  • 一、 前期准备
    • 1.1 设置GPU
    • 1.2 导入数据
    • 1.3 数据可视化
  • 二、数据预处理(构建数据集)
    • 2.1 设置x、y
    • 2.2 归一化
    • 2.3 划分数据集
  • 三、模型创建、编译、训练、得到训练结果
    • 3.1 构建模型
    • 3.2 编译模型
    • 3.3 训练模型
    • 3.4 模型评估
      • 3.4.1 Loss与Accuracy图
      • 3.4.2 调用模型进行预测
      • 3.4.3 查看误差
  • 四、其他
    • 4.1 模块报错:seaborn模块导入错误
    • 4.2 图片实时显示比例不对
    • 4.3 什么是LSTM

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

在这里插入图片描述

一、 前期准备

语言环境:Python3.7.8
编译器选择:VSCode
深度学习环境:TensorFlow
数据集:本地数据集

1.1 设置GPU

本文使用CPU环境

'''
LSTM-实现火灾预测
'''

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0,true)
    tf.config.set_visible_devices([gpu0],"GPU")
    print("GPU: ",gpus)
else:
    print("CPU:")

输出:
在这里插入图片描述

1.2 导入数据

下载数据集文件woodpine2.csv到本地,使用绝对路径进行访问:

# 2.1 导入数据
import  pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.read_csv("D:\\jupyter notebook\\DL-100-days\\RNN\\woodpine2.csv")
print("df:", df)

输出:
在这里插入图片描述

1.3 数据可视化

# 3.数据可视化 
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['figure.dpi'] = 500
 
fig,ax = plt.subplots(1,3,constrained_layout = True , figsize = (14,3))
sns.lineplot(data=df["Tem1"],ax=ax[0])
sns.lineplot(data=df["CO 1"],ax=ax[1])
sns.lineplot(data=df["Soot 1"],ax=ax[2])
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
#plt.show()

输出:
在这里插入图片描述

二、数据预处理(构建数据集)

# 二、数据预处理(构建数据集)
dataFrame = df.iloc[:,1:]
print("dataFrame:", dataFrame)

输出:
在这里插入图片描述

2.1 设置x、y

# 1,设置x、y
# 需要实现:使用1-8时刻段预测9时刻段,则通过下述代码做好长度的确定:
width_X = 8
width_y = 1

X = []
y = []
 
in_start = 0
 
for _,_ in df.iterrows():
    in_end = in_start + width_X
    out_end = in_end + width_y
 
    if out_end < len(dataFrame):
        X_ = np.array(dataFrame.iloc[in_start:in_end,])
        X_ = X_.reshape((len(X_)*3))
        y_ = np.array(dataFrame.iloc[in_end:out_end,0])
 
        X.append(X_)
        y.append(y_)
 
    in_start += 1
 
X = np.array(X)
y = np.array(y)
 
print(X.shape,y.shape)

输出:
在这里插入图片描述

2.2 归一化

# 2,归一化
from sklearn.preprocessing import MinMaxScaler
 
sc = MinMaxScaler(feature_range=(0,1))
X_scaled = sc.fit_transform(X)
print(X_scaled.shape)

X_scaled = X_scaled.reshape(len(X_scaled),width_X,3)
print(X_scaled.shape)

输出:
在这里插入图片描述

2.3 划分数据集

取5000之前的数据作为训练集,5000之后的数据作为验证集:

# 3,划分数据集
# 取5000之前的数据作为训练集,5000之后的数据作为验证集:
X_train = np.array(X_scaled[:5000]).astype('float64')
y_train = np.array(y[:5000]).astype('float64')
 
X_test = np.array(X_scaled[5000:]).astype('float64')
y_test = np.array(y[5000:]).astype('float64')
 
print(X_train.shape)

输出:
在这里插入图片描述

三、模型创建、编译、训练、得到训练结果

3.1 构建模型

通过下面代码,构建一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为(X_train.shape[1], 3)的输入,其中X_train.shape[1]是时间步数,3 是每个时间步的特征数。

# 三、构建模型
import keras
from keras.models import Sequential
from keras.layers import Dense,LSTM
 
model_lstm = Sequential()
model_lstm.add(LSTM(units=64,activation='relu',return_sequences=True,input_shape=(X_train.shape[1],3)))
model_lstm.add(LSTM(units=64,activation='relu'))
model_lstm.add(Dense(width_y))
# 通过上述代码,构建了一个包含两个LSTM层和一个全连接层的LSTM模型。这个模型将接受形状为 (X_train.shape[1], 3) 的输入,其中 X_train.shape[1] 是时间步数,3 是每个时间步的特征数。

输出:
在这里插入图片描述

3.2 编译模型

# 四、 编译模型 
model_lstm.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.Adam(1e-3))

3.3 训练模型

history = model_lstm.fit(X_train,y_train,
                    epochs = 40,
                    batch_size = 64,
                    validation_data=(X_test,y_test),
                    validation_freq= 1)

训练输出:
在这里插入图片描述

3.4 模型评估

3.4.1 Loss与Accuracy图

# 六、 模型评估
# 1.Loss与Accuracy图
import matplotlib.pyplot as plt
 
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
 
plt.figure(figsize=(5, 3),dpi=120)
 
plt.plot(history.history['loss'],label = 'LSTM Training Loss')
plt.plot(history.history['val_loss'],label = 'LSTM Validation Loss')
 
plt.title('Training and Validation Accuracy')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

输出:
在这里插入图片描述

3.4.2 调用模型进行预测

# 2.调用模型进行预测
predicted_y_lstm = model_lstm.predict(X_test)
 
y_tset_one = [i[0] for i in y_test]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]
 
plt.figure(figsize=(5,3),dpi=120)
plt.plot(y_tset_one[:1000],color = 'red', label = '真实值')
plt.plot(predicted_y_lstm_one[:1000],color = 'blue', label = '预测值')
 
plt.title('Title')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

3.4.3 查看误差


# 3. 查看误差
from  sklearn import metrics
 
RMSE_lstm = metrics.mean_squared_error(predicted_y_lstm,y_test)**0.5
R2_lstm = metrics.r2_score(predicted_y_lstm,y_test)
 
print('均方根误差:%.5f' % RMSE_lstm)
print('R2:%.5f' % R2_lstm)

输出:
在这里插入图片描述

四、其他

4.1 模块报错:seaborn模块导入错误

解决方法如下:
在这里插入图片描述
在这里插入图片描述

4.2 图片实时显示比例不对

改为保存到本地查看:【在plt.show()之前保存
line 34:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\3.数据可视化.png")
plt.show()

line 125:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\1.Loss与Accuracy图.png")
plt.show()

line 143:

plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\2.调用模型进行预测图.png")
plt.show()

输出:
在这里插入图片描述

4.3 什么是LSTM

LSTM是一种特殊的RNN,能到学习到长期的依赖关系,可以理解为升级版的RNN。
在这里插入图片描述
传统的RNN在处理长序列时存在着“梯度爆炸(/梯度消失)”和“短时记忆”的问题,向RNN中加入遗忘门、输入门及输出门使得困扰RNN的问题得到了一定的解决;
在这里插入图片描述
关于LSTM的实现流程:(1、单输出时间步)单输入单输出、多输入单输出、多输入多输出(2、多输出时间步)单输入单输出、多输入单输出、多输入多输出;
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1799404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虚拟机Ubuntu 22.04上搭建GitLab操作步骤

GitLab是仓库管理系统&#xff0c;使用Git作为代码管理工具。GitLab提供了多个版本&#xff0c;包括社区版(Community Edition)和企业版(Enterprise Edition)。实际应用场景中要求CPU最小4核、内存最小8GB&#xff0c;非虚拟环境。 以下是在虚拟机中安装社区版步骤&#xff1a;…

C++青少年简明教程:C++函数

C青少年简明教程&#xff1a;C函数 C函数是一段可重复使用的代码&#xff0c;用于执行特定的任务&#xff0c;可以提高代码的可读性和可维护性。函数可以接受参数&#xff08;输入&#xff09;并返回一个值&#xff08;输出&#xff09;&#xff0c;也可以没有参数和返回值。 …

应用层——HTTP协议(自己实现一个http协议)——客户端(浏览器)的请求做反序列化和请求分析,然后创建http向响应结构

应用层&#xff1a;之前我们写的创建套接字&#xff0c;发送数据&#xff0c;序列化反序列化这些都是在写应用层 我们程序员写的一个个解决我们实际问题, 满足我们日常需求的网络程序, 都是在应用层 之前的网络计算机是我们自定义的协议&#xff1a;传输的数据最终是什么样的结…

Redis缓存(笔记二:Redis常用五大数据类型)

目录 1、Redis中String字符串 1.1 常用命令解释&#xff1a; 1.2 原子性 1.3 具有原子性的常用命令 1.4 String数据结构 1、Redis中String字符串 概念 String 是 Redis 最基本的类型&#xff0c;可以理解成与 Memcached 一模一样的类型&#xff0c;一个 key对应一个 value…

Go微服务: 基于使用场景理解分布式之二阶段提交

概述 二阶段提交&#xff08;Two-Phase Commit&#xff0c;2PC&#xff09;是一种分布式事务协议&#xff0c;用于在分布式系统中确保多个参与者的操作具有原子性即所有参与者要么全部提交事务&#xff0c;要么全部回滚事务&#xff0c;以维持数据的一致性它分为两个阶段进行&…

php反序列化中的pop链

目录 一、什么是POP 二、成员属性赋值对象 例题&#xff1a; 方法一 方法二 三、魔术方法的触发规则 例题&#xff1a; 四、POC的编写 例题1&#xff1a; 例题2 [NISACTF 2022]babyserialize 今日总结&#xff1a; 一、什么是POP 在反序列化中&#xff0c;我们…

DexCap——斯坦福李飞飞团队泡茶机器人:更好数据收集系统的原理解析、源码剖析

前言 2023年7月&#xff0c;我司组建大模型项目开发团队&#xff0c;从最开始的论文审稿&#xff0c;演变成目前的两大赋能方向 大模型应用方面&#xff0c;以微调和RAG为代表 除了论文审稿微调之外&#xff0c;目前我司内部正在逐一开发论文翻译、论文对话、论文idea提炼、论…

RDMA (1)

RDMA是什么 Remote Direct Memory Access(RDMA)是用来给有高速需求的应用释放网络消耗的。 RDMA在网络的两个应用之间进行低延迟,高吞吐的内存对内存的直接数据通信。 InfiniBand需要部署独立的协议。 RoCE(RDMA over Converged Ethernet),也是由InfiniBand Trade Associat…

不要硬来!班组管理有“巧思”

班组管理&#xff0c;听起来似乎是一个充满“硬气”的词汇&#xff0c;让人联想到严肃、刻板的制度和规矩。然而&#xff0c;在实际操作中&#xff0c;我们却需要运用一些“巧思”&#xff0c;以柔克刚&#xff0c;让班组管理既有力度又不失温度。 在班组管理中&#xff0c;我们…

Istio_1.17.8安装

项目背景 按照istio官网的命令一路安装下来&#xff0c;安装好的istio版本为目前的最新版本&#xff0c;1.22.0。而我的k8s集群的版本并不支持istio_1.22的版本&#xff0c;导致ingress-gate网关安装不上&#xff0c;再仔细查看istio的发布文档&#xff0c;如果用istio_1.22版本…

Fatfs

STM32进阶笔记——FATFS文件系统&#xff08;上&#xff09;_stm32 fatfs-CSDN博客 STM32进阶笔记——FATFS文件系统&#xff08;下&#xff09;_stm32 文件系统怎样获取文件大小-CSDN博客 STM32——FATFS文件基础知识_stm32 fatfs-CSDN博客 021 - STM32学习笔记 - Fatfs文件…

Go select 语句使用场景

1. select介绍 select 是 Go 语言中的一种控制结构&#xff0c;用于在多个通信操作中选择一个可执行的操作。它可以协调多个 channel 的读写操作&#xff0c;使得我们能够在多个 channel 中进行非阻塞的数据传输、同步和控制。 基本语法&#xff1a; select {case communica…

纷享销客集成平台(iPaaS)的应用与实践

案例一 企业系统集成的产品级解决方案 概况 随着国家出台一系列鼓励LED照明产业发展与创新的规划和政策&#xff0c;以及国际市场全球演唱会、音乐会的活跃以及线上零售、商业地产等行业回暖&#xff0c;LED显示行业发展形势积极向好。深圳市艾比森光电股份有限公司&#xff…

第一周:计算机网络概述(上)

一、计算机网络基本概念 1、计算机网络通信技术计算机技术 计算机网络就是一种特殊的通信网络&#xff0c;其特殊之处就在于它的信源和信宿就是计算机。 2、什么是计算机网络 在计算机网络中&#xff0c;我们把这些计算机统称为“主机”&#xff08;上图中所有相连的电脑和服…

【动手学深度学习】softmax回归的简洁实现详情

目录 &#x1f30a;1. 研究目的 &#x1f30a;2. 研究准备 &#x1f30a;3. 研究内容 &#x1f30d;3.1 softmax回归的简洁实现 &#x1f30d;3.2 基础练习 &#x1f30a;4. 研究体会 &#x1f30a;1. 研究目的 理解softmax回归的原理和基本实现方式&#xff1b;学习如何…

开发人员必备的常用工具合集-lombok

Project Lombok 是一个 java 库&#xff0c;它会自动插入您的编辑器和构建工具&#xff0c;为您的 Java 增添趣味。 再也不用编写另一个 getter 或 equals 方法了&#xff0c;只需一个注释&#xff0c;您的类就拥有了一个功能齐全的构建器&#xff0c;自动化了您的日志记录变量…

从零开始手把手Vue3+TypeScript+ElementPlus管理后台项目实战五(引入vue-router,并给注册功能加上美丽的外衣el-form)

安装vue-router pnpm install vue-router创建router src下新增router目录&#xff0c;ruoter目录中新增index.ts import { createRouter, createWebHashHistory } from "vue-router"; const routes [{path: "/",name: "Home",component: () …

SQL语句练习每日5题(四)

题目1——查找GPA最高值 想要知道复旦大学学生gpa最高值是多少&#xff0c;请你取出相应数据 题解&#xff1a; 1、使用MAX select MAX(gpa) FROM user_profile WHERE university 复旦大学 2、使用降序排序组合limit select gpa FROM user_profile WHERE university 复…

当C++的static遇上了继承

比如我们想要统计下当前类被实例化了多少次&#xff0c;我们通常会这么写 class A { public:A() { Count_; }~A() { Count_--; }int GetCount() { return Count_; }private:static int Count_; };class B { public:B() { Count_; }~B() { Count_--; }int GetCount() { return …

LeetCode1143最长公共子序列

题目描述 给定两个字符串 text1 和 text2&#xff0c;返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff0c;返回 0 。一个字符串的 子序列 是指这样一个新的字符串&#xff1a;它是由原字符串在不改变字符的相对顺序的情况下删除某些字符&#xff08…