第R4周:LSTM-火灾温度预测

news2025/1/11 21:39:44
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、代码流程
      • 1、导入包,设置GPU
      • 2、导入数据
      • 3、数据集可视化
      • 4、数据集预处理
      • 5、设置X,y
      • 6、划分数据集
      • 7、构建模型
      • 8、定义训练函数
      • 9、定义测试函数
      • 10、正式训练
      • 11、模型评估- LOSS图
      • 12、调用模型进行训练

电脑环境:
语言环境:Python 3.8.0

一、代码流程

1、导入包,设置GPU

import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import pandas as pd

2、导入数据

data = pd.read_csv('woodpine2.csv')
data
	Time	Tem1	CO 1	Soot 1
0	0.000	25.0	0.000000	0.000000
1	0.228	25.0	0.000000	0.000000
2	0.456	25.0	0.000000	0.000000
3	0.685	25.0	0.000000	0.000000
4	0.913	25.0	0.000000	0.000000
...	...	...	...	...
5943	366.000	295.0	0.000077	0.000496
5944	366.000	294.0	0.000077	0.000494
5945	367.000	292.0	0.000077	0.000491
5946	367.000	291.0	0.000076	0.000489
5947	367.000	290.0	0.000076	0.000487
5948 rows × 4 columns

3、数据集可视化

from os import confstr_names
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['figure.dpi'] = 500
plt.rcParams['savefig.dpi'] = 500

fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(14, 3))
sns.lineplot(data=data['Tem1'], ax=ax[0])
sns.lineplot(data=data['CO 1'], ax=ax[1])
sns.lineplot(data=data['Soot 1'], ax=ax[2])
plt.show()

在这里插入图片描述

dataFrame = data.iloc[:, 1:]
dataFrame

4、数据集预处理

from sklearn.preprocessing import MinMaxScaler

dataFrame = data.iloc[:, 1:].copy()

scaler = MinMaxScaler(feature_range=(0, 1))

for i in ['CO 1', 'Soot 1', 'Tem1']:
    dataFrame[i] = scaler.fit_transform(dataFrame[i].values.reshape(-1, 1))  

dataFrame.shape

(5948, 3)

5、设置X,y

width_X = 8
width_Y = 1

X = []
y = []

in_start = 0

for _, _ in data.iterrows():
    in_end = in_start + width_X
    out_end = in_end + width_Y
    if out_end < len(dataFrame):
        X_ = np.array(dataFrame.iloc[in_start:in_end, :])
        y_ = np.array(dataFrame.iloc[in_end:out_end, :])
        X.append(X_)
        y.append(y_)
    in_start += 1 
X = np.array(X)
y = np.array(y)

X.shape, y.shape

((5939, 8, 3), (5939, 1, 1))

检查数据集中是否有空值

print(np.any(np.isnan(X)))
print(np.any(np.isnan(y)))

6、划分数据集

X_train = torch.tensor(np.array(X[:5000]), dtype=torch.float32)
y_train = torch.tensor(np.array(y[:5000]), dtype=torch. float32)

X_test = torch.tensor(np.array(X[5000:]), dtype=torch.float32)
y_test = torch.tensor(np.array(y[5000:]), dtype=torch. float32)

X_train.shape, y_train.shape

(torch.Size([5000, 8, 3]), torch.Size([5000, 1, 3]))

from torch.utils.data import TensorDataset, DataLoader
train_dl = DataLoader(TensorDataset(X_train, y_train),
                        batch_size=64,
                        shuffle=False)
test_dl = DataLoader(TensorDataset(X_test, y_test),
                        batch_size=64,
                        shuffle=False)

7、构建模型

class model_lstm(nn.Module):
    def __init__(self):
        super(model_lstm, self).__init__()

        self.lstm0 = nn.LSTM(input_size=3, hidden_size=320,
                             num_layers=1, batch_first=True)
        
        self.lstm1 = nn.LSTM(input_size=320, hidden_size=320,
                             num_layers=1, batch_first=True)
        self.fc0 = nn.Linear(320, 1)

    def forward(self, x):
        out, hidden1 = self.lstm0(x)
        out, _ = self. lstm1(out, hidden1)
        out = self.fc0(out)
        return out[:, -1:, :]
        #取2个预测值,否则经过1stm会得到8*2个预
model = model_lstm()
model

8、定义训练函数

import copy
def train(train_dl, model, loss_fn, opt, lr_scheduler=None) :
    size = len(train_dl.dataset)
    num_batches = len(train_dl)
    train_loss = 0 #初始化训练损失和正确率
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        #计算预测误差
        pred = model(x) #网络输出
        loss = loss_fn(pred, y) #计算网络输出和真实值之间的差距
        # 反向传播
        opt.zero_grad()#grad属性归零
        loss.backward()# 反向传播
        opt.step()# 每一步自动更新

        #记录Loss
        train_loss += loss. item()
    if lr_scheduler is not None:
        lr_scheduler.step()
        print ("learning rate = {:.5f}". format(opt.param_groups[0]['lr']), end='  ')
    train_loss /= num_batches
    return train_loss

9、定义测试函数

def test (dataloader, model, loss_fn) :
    size = len(dataloader.dataset) #测试集的大小
    num_batches = len(dataloader)# 批次数目
    
    test_loss = 0
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            # 计算loss
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            test_loss += loss.item()
    test_loss /= num_batches
    return test_loss

10、正式训练

# 设置GPU训练
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

#训练模型
model = model_lstm()
model = model.to(device)
loss_fn = nn.MSELoss() #创建损失函数
learn_rate = 1e-1 #学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate, weight_decay=1e-4)
epochs = 50
train_loss = []
test_loss = []
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, epochs, last_epoch=-1)

for epoch in range(epochs):
    model.train()
    epoch_train_loss = train(train_dl, model, loss_fn, opt, lr_scheduler)

    model.eval()
    epoch_test_loss = test(test_dl, model, loss_fn)

    train_loss.append(epoch_train_loss)
    test_loss.append(epoch_test_loss)

    template = ('Epoch: {:2d}, Train loss: {:.5f}, Test loss: {:.5f}')
    print(template.format(epoch+1, epoch_train_loss,epoch_test_loss))
print("="*20, 'Done', "="*70)

11、模型评估- LOSS图

import matplotlib.pyplot as plt
plt. figure(figsize=(5, 3), dpi=120)

plt.plot(train_loss, label='LSTM Training Loss')
plt.plot(test_loss, label='LSTM Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

12、调用模型进行训练

predicted_y_lstm = sc.inverse_transform(model(X_test).detach().numpy().reshape(-1,1))
y_test_1 = sc.inverse_transform(y_test.reshape(-1,1))
y_test_one = [i[0] for i in y_test_1]
predicted_y_lstm_one = [i[0] for i in predicted_y_lstm]
                                        
plt.figure(figsize=(5, 3) , dpi=120)
# 画出真实数据和预测数据的对比曲线
plt.plot(y_test_one[:2000], color='red', label='real_temp')
plt.plot(predicted_y_lstm_one[:2000], color='blue', label='prediction')
plt.title('Title')
plt.xlabel('X')
plt.ylabel('y')
plt.legend ()
plt.show( )

在这里插入图片描述

from sklearn import metrics
'''
RMSE:均方根误差--->对均方误差开方
R2:决定系数,可以简单理解为反映模型拟合优度的重要的统计量
'''
RMSE_lstm = metrics.mean_squared_error(predicted_y_lstm_one, y_test_1)**0.5
R2_lstm = metrics.r2_score(predicted_y_lstm_one, y_test_1)
print('均方根误差:%.5f' % RMSE_lstm)
print('R2: %.5f' % R2_lstm)

均方根误差:7.01314
R2: 0.82595

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2275129.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python自写包模块的标准化方法】

目标: 自写一个包,提供关于字符串和文件的模块 要求对异常可以检测 str_tools.py: def str_reverse(s):""":param s: 传入的字符串:return: 反转后的字符串"""# i -1# j 0# s2 ""# while i > (-len(s)):# s2 s[i]# …

Win10本地部署大语言模型ChatGLM2-6B

鸣谢《ChatGLM2-6B&#xff5c;开源本地化语言模型》作者PhiltreX 作者显卡为英伟达4060 安装程序 打开CMD命令行&#xff0c;在D盘新建目录openai.wiki if not exist D:\openai.wiki mkdir D:\openai.wiki 强制切换工作路径为D盘的openai.wiki文件夹。 cd /d D:\openai.wik…

【简博士统计学习方法】第1章:1. 统计学习的定义与分类

自用笔记 1. 统计学习的定义与分类 1.1 统计学习的概念 统计学习&#xff08;Statistical Machine Learning&#xff09;是关于计算机基于数据构建概率统计模型并运用模型对数据进行预测与分析的一门学科。 以计算机和网络为平台&#xff1b;以数据为研究对象&#xff1b;以…

Unity 人体切片三维可视化,可任意裁切切割。查看不同断层的图像。

Unity 人体切片三维可视化&#xff0c;真彩色&#xff0c;可任意裁切切割。查看不同断层的图像。 点击查看效果: 视频效果

汽车基础软件AutoSAR自学攻略(四)-AutoSAR CP分层架构(3) (万字长文-配21张彩图)

汽车基础软件AutoSAR自学攻略(四)-AutoSAR CP分层架构(3) (万字长文-配21张彩图) 前面的两篇博文简述了AutoSAR CP分层架构的概念&#xff0c;下面我们来具体到每一层的具体内容进行讲解&#xff0c;每一层的每一个功能块力求用一个总览图&#xff0c;外加一个例子的图给大家进…

科研绘图系列:R语言绘制Y轴截断分组柱状图(y-axis break bar plot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍特点意义加载R包数据下载导入数据数据预处理画图输出总结系统信息介绍 Y轴截断分组柱状图是一种特殊的柱状图,其特点是Y轴的刻度被截断,即在某个范围内省略了部分刻度。这种图表…

慧集通(DataLinkX)iPaaS集成平台-数据流程之流程透明化调试功能简介

在线运行流程 查看运行状态 流程第一次执行状态显示 流程第二次执行状态显示&#xff08;由于订单已同步到七星ERP中&#xff0c;由于还是这些订单所以第二次同步时就报错了&#xff09; 点击查看节点组件的详细入参与出参信息 U8C销售订单读取组件执行时详情 入参-查询条件…

PostgreSQL技术内幕22:vacuum full 和 vacuum

文章目录 0.简介1.概念及使用方式2.工作原理2.1 主要功能2.2 清理流程2.3 防止事务id环绕说明 3.使用建议 0.简介 在之前介绍MVCC文章中介绍过常见的MVCC实现的两种方式&#xff0c;一种是将旧数据放到回滚段&#xff0c;一种是直接生成一条新数据&#xff08;对于删除是不删除…

kubernetes第七天

1.影响pod调度的因素 nodeName 节点名 resources 资源限制 hostNetwork 宿主机网络 污点 污点容忍 Pod亲和性 Pod反亲和性 节点亲和性 2.污点 通常是作用于worker节点上&#xff0c;其可以影响pod的调度 语法&#xff1a;key[value]:effect effect:[ɪˈfek…

【CSS】HTML页面定位CSS - position 属性 relative 、absolute、fixed 、sticky

目录 relative 相对定位 absolute 绝对定位 fixed 固定定位 sticky 粘性定位 position&#xff1a;relative 、absolute、fixed 、sticky &#xff08;四选一&#xff09; top&#xff1a;距离上面的像素 bottom&#xff1a;距离底部的像素 left&#xff1a;距离左边的像素…

Redis数据库——Redis快的原因

本文详细介绍redis为什么这么快的原因&#xff0c;这里是本系列文章的总结篇&#xff08;后面会补充一些内容&#xff0c;或者在原文上进行更新迭代&#xff09;&#xff0c;将从各方面出发解释为什么redis快&#xff0c;受欢迎的原因。 文章目录 内存内存数据库预分配内存 数据…

排序:插入、选择、交换、归并排序

排序 &#xff1a;所谓排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性 &#xff1a;假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录&#xff0c;若经过排序&#xff0c;…

RocketMQ 和 Kafka 有什么区别?

目录 RocketMQ 是什么? RocketMQ 和 Kafka 的区别 在架构上做减法 简化协调节点 简化分区 Kafka 的底层存储 RocketMQ 的底层存储 简化备份模型 在功能上做加法 消息过滤 支持事务 加入延时队列 加入死信队列 消息回溯 总结 来源:面试官:RocketMQ 和 Kafka 有…

赛车微型配件订销管理系统(源码+lw+部署文档+讲解),源码可白嫖!

摘要 赛车微型配件行业通常具有产品多样性、需求不确定性、市场竞争激烈等特点。配件供应商需要根据市场需求及时调整产品结构和库存&#xff0c;同时要把握好供应链管理和销售渠道。传统的赛车微型配件订销管理往往依赖于人工经验和简单的数据分析&#xff0c;效率低下且容易…

公众号如何通过openid获取unionid

通过接口 https://api.weixin.qq.com/cgi-bin/user/info?access_tokenxxxxxxx&langzh_CN 返回的数据如下&#xff1a; 前提是必须绑定 微信开放平台 token如何获取呢 代码如下&#xff1a; String tokenUrl "https://api.weixin.qq.com/cgi-bin/token"; …

半导体数据分析: 玩转WM-811K Wafermap 数据集(二) AI 机器学习

一、数据集回顾 前面我们已经基本了解了WM-811K Wafermap 数据集&#xff0c;并通过几段代码&#xff0c;熟悉了这个数据集的数据结构&#xff0c;这里为了方便各位连续理解&#xff0c;让我们再回顾一下&#xff1a; WM-811K Wafermap 数据集是一个在半导体制造领域广泛使用…

协同过滤算法私人诊所系统|Java|SpringBoot|VUE|

【技术栈】 1⃣️&#xff1a;架构: B/S、MVC 2⃣️&#xff1a;系统环境&#xff1a;Windowsh/Mac 3⃣️&#xff1a;开发环境&#xff1a;IDEA、JDK1.8、Maven、Mysql5.7 4⃣️&#xff1a;技术栈&#xff1a;Java、Mysql、SpringBoot、Mybatis-Plus、VUE、jquery,html 5⃣️…

Python基于YOLOv8和OpenCV实现车道线和车辆检测

使用YOLOv8&#xff08;You Only Look Once&#xff09;和OpenCV实现车道线和车辆检测&#xff0c;目标是创建一个可以检测道路上的车道并识别车辆的系统&#xff0c;并估计它们与摄像头的距离。该项目结合了计算机视觉技术和深度学习物体检测。 1、系统主要功能 车道检测&am…

nexus搭建maven私服

说到maven私服每个公司都有&#xff0c;比如我上一篇文章介绍的自定义日志starter&#xff0c;就可以上传到maven私服供大家使用&#xff0c;每次更新只需deploy一下就行&#xff0c;以下就是本人搭建私服的步骤 使用docker安装nexus #拉取镜像 docker pull sonatype/nexus3:…

MiniMind - 从0训练语言模型

文章目录 一、关于 MiniMind &#x1f4cc;项目包含 二、&#x1f4cc; Environment三、&#x1f4cc; Quick Start Test四、&#x1f4cc; Quick Start Train0、克隆项目代码1、环境安装2、如果你需要自己训练3、测试模型推理效果 五、&#x1f4cc; Data sources1、分词器&am…