【PyTorch】基于LSTM网络的气温预测模型实现

news2024/11/24 12:44:27

假设CSV文件名为temperature_data.csv,其前五行和标题如下:
在这里插入图片描述

这里,我们只使用Temperature列进行单步预测。以下是整合的代码示例:

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

# 加载数据
data = pd.read_csv('temperature_data.csv')

# 选择Temperature列
temperatures = data['Temperature'].values.reshape(-1, 1)

# 数据归一化
scaler = MinMaxScaler()
temperatures = scaler.fit_transform(temperatures)

# 创建数据集函数
def create_dataset(data, time_step=1):
    dataX, dataY = [], []
    for i in range(len(data)-time_step-1):
        a = data[i:(i+time_step), 0]
        dataX.append(a)
        dataY.append(data[i + time_step, 0])
    return np.array(dataX), np.array(dataY)

# 定义时间步长
time_step = 5
X, y = create_dataset(temperatures, time_step)

# 重塑输入数据为[samples, time_step, features]
X = X.reshape(X.shape[0], X.shape[1], 1)

# 转换为PyTorch张量
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

# 定义LSTM模型
class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=1, batch_first=True)
        self.fc = nn.Linear(50, 1)

    def forward(self, x):
        output, (hn, cn) = self.lstm(x)
        y_pred = self.fc(output[:, -1, :])
        return y_pred

# 初始化模型、损失函数和优化器
model = LSTMModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 100
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y.unsqueeze(2))
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')

# 测试模型
model.eval()
with torch.no_grad():
    predicted = model(X)
    predicted = predicted.numpy().squeeze()

# 反归一化
predicted = scaler.inverse_transform(predicted.reshape(-1, 1))

# 绘制实际值和预测值
plt.figure(figsize=(10, 6))
plt.plot(y.numpy().squeeze(), label='Actual')
plt.plot(predicted, label='Predicted')
plt.title('Temperature Prediction')
plt.xlabel('Time Step')
plt.ylabel('Temperature')
plt.legend()
plt.show()

请注意,这段代码假设你已经有了一个名为temperature_data.csv的CSV文件,并且该文件位于你的工作目录中。此外,代码中的time_step变量可以根据需要调整,以改变模型预测的时间范围。这个例子中的模型是一个简单的LSTM网络,它可以根据前time_step个时间步的气温数据来预测下一个时间步的气温。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【深度学习】yolov8-seg分割训练,拼接图的分割复原

文章目录 项目背景造数据训练 项目背景 在日常开发中,经常会遇到一些图片是由多个图片拼接来的,如下图就是三个图片横向拼接来的。是否可以利用yolov8-seg模型来识别出这张图片的三张子图区域呢,这是文本要做的事情。 造数据 假设拼接方式有…

Qt+OpenCascade开发笔记(一):occ的windows开发环境搭建(一):OpenCascade介绍、下载和安装过程

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/140604141 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

OpenStack Yoga版安装笔记(八)glance练习补充2

1、openstack image list数据流回顾 OpenStack Yoga版安装笔记(七)通过Wireshark抓包、Mermaid绘图,解析了执行openstack image list的数据流,图示如下: 数据流1-4:user admin认证,并获得admin…

ros2--中间件--rmw

rmw robot middleware 什么是中间件 一套位于操作系统之上,引用程序之下的软件。 在ros2中理解就是:中间件就是介于某两个或者多个节点中间的组件 中间件的作用 就是提供多个节点中间通信用的。 教程 ROS2中间件DDS架构 ros2从入门到精通

使用puma部署ruby on rails的记录

之前写过一篇《记录一下我的Ruby On Rails的systemd服务脚本》的记录,现在补上一个比较政治正确的Ruby On Rails的生产环境部署记录。使用Puma部署项目。 创建文件 /usr/lib/systemd/system/puma.service [Unit] DescriptionPuma HTTP Server DocumentationRuby O…

在Linux、Windows和macOS上释放IP地址并重新获取新IP地址的方法

文章目录 LinuxWindowsmacOS 在Linux、Windows和macOS上释放IP地址并重新获取新IP地址的方法各有不同。以下是针对每种操作系统的详细步骤: Linux 使用DHCP客户端:大多数Linux发行版都使用DHCP(动态主机配置协议)来自动获取IP地址…

RT-Thread全球嵌入式电子设计大赛入选名单发布!

目录 概述 ​1 瑞萨 RA8D1 Vision Board 2 英飞凌 Psoc6-EvaluationKit-062S2 WIFI模块 3 恩智浦 FRDM-MCXN947 4 STM32 星火一号 STM32F407 5 先楫 HPM5300EVK (RISC-V) 6 自带开发板 概述 RT-Thread全球嵌入式电子设计大赛入选名单发布啦,如下名单的小…

数学建模学习(3)——模拟退火算法

一、模拟退火算法解TSP问题 import random import numpy as np from math import e, exp import matplotlib.pyplot as plt# 31个城市的坐标 city_loc [(1304, 2312), (3639, 1315), (4177, 2244), (3712, 1399), (3488, 1535),(3326, 1556), (3238, 1229), (4196, 1004), (4…

FPGA开发在verilog中关于阻塞和非阻塞赋值的区别

一、概念 阻塞赋值:阻塞赋值的赋值号用“”表示,对应的是串行执行。 对应的电路结构往往与触发沿没有关系,只与输入电平的变化有关系。阻塞赋值的操作可以认为是只有一个步骤的操作,即计算赋值号右边的语句并更新赋值号左边的语句…

如何将mp4格式的视频压缩更小 mp4格式视频怎么压缩最小 工具软件分享

在数字化时代,视频内容成为信息传播的重要载体。然而,高清晰度的视频往往意味着较大的文件体积,这给存储和分享带来了一定的困扰。MP4格式作为目前最流行的视频格式之一,其压缩方法尤为重要。下面,我将为大家详细介绍如…

力扣高频SQL 50题(基础版)第六题

文章目录 1378. 使用唯一标识码替换员工ID题目说明思路分析实现过程结果截图总结 1378. 使用唯一标识码替换员工ID 题目说明 Employees 表: ---------------------- | Column Name | Type | ---------------------- | id | int | | name | varchar | ------…

自监督学习在言语障碍及老年语音识别中的应用

近几十年来针对正常言语的自动语音识别(ASR)技术取得了快速进展,但准确识别言语障碍(dysarthric)和老年言语仍然是一项极具挑战性的任务。言语障碍是一种由多种运动控制疾病引起的常见言语障碍类型,包括脑瘫…

Elasticsearch基础(六):使用Kibana Lens进行数据可视化

文章目录 使用Kibana Lens进行数据可视化 一、进入Kibana Lens 二、基础可视化 1、指标可视化 2、垂直堆积条形图 3、表格 三、高级可视化 1、多图层和索引 2、子桶 3、树状图 使用Kibana Lens进行数据可视化 一、进入Kibana Lens 在Kibana主页,单击页面…

中文分词库 jieba 详细使用方法与案例演示

1 前言 jieba 是一个非常流行的中文分词库,具有高效、准确分词的效果。 它支持3种分词模式: 精确模式全模式搜索引擎模式 jieba0.42.1测试环境:python3.10.9 2 三种模式 2.1 精确模式 适应场景:文本分析。 功能&#xff1…

OpenAI从GPT-4V到GPT-4O,再到GPT-4OMini简介

OpenAI从GPT-4V到GPT-4O,再到GPT-4OMini简介 一、引言 在人工智能领域,OpenAI的GPT系列模型一直是自然语言处理的标杆。随着技术的不断进步,OpenAI推出了多个版本的GPT模型,包括视觉增强的GPT-4V(GPT-4 with Vision&…

【接口自动化_07课_Pytest+Excel+Allure完整框架集成_下】

目标:优化框架场景 1. 生成对应的接口关联【重点】 2. 优化URL基础路径封装【理解】 3. 利用PySQL操作数据库应用【理解】--- 怎么用python连接数据库、mysql 4. 通过数据库进行数据库断言【重点】 5. 通过数据库进行关联操作【重点】 一、接口关联&#xff1a…

深入浅出mediasoup—协议交互

本文主要分析 mediasoup 一对一 WebRTC 通信 demo 的协议交互,从协议层面了解 mediasoup 的设计与实现,这是深入阅读 mediasoup 源码的重要基础。 1. 时序图 下图是 mediasoup 客户端-服务器协议交互的总体架构,服务器是一个 Node.js 进程加…

Django学习第一天(如何创建和运行app)

前置知识: URL组成部分详解: 一个url由以下几部分组成: scheme://host:port/path/?query-stringxxx#anchor scheme:代表的是访问的协议,一般为http或者ftp等 host:主机名,域名,…

高翔【自动驾驶与机器人中的SLAM技术】学习笔记(三)基变换与坐标变换;微分方程;李群和李代数;雅可比矩阵

一、基变换与坐标变换 字小,事不小。 因为第一反应:坐标咋变,坐标轴就咋变呀。事实却与我们想象的相反。这俩互为逆矩阵。 第一次读没有读明白,后面到事上才明白。 起因是多传感器标定:多传感器,就代表了多个坐标系,多个基底。激光雷达和imu标定。这个标定程序,网上,…

秒杀优化: 记录一次bug排查

现象 做一人一单的时候,为了提升性能,需要将原来的业务改造成Lua脚本加Stream流的方式实现异步秒杀。 代码改造完成,使用Jmeter进行并发测试,发现redis中的数据和预期相同,库存减1,该用户也成功添加了进去…