Pytorch--3.使用CNN和LSTM对数据进行预测

news2025/1/12 17:41:33

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求,所以我们在这篇文章里使用CNN和LSTM来对时间上有联系的数据来进行学习,同时来实现预测的功能。

1.数据集:使用的是kaggle上一个公开的气象数据集(CSV)

有需要的可以去kaggle下载,也可以在评论区留下mail,题主发送过去
在这里插入图片描述

2.导入我们所需要的库和完成前置工作

2.1导入相关的库

torch为人工智能的库,pandas用于数据读取,numpy为张量处理的库,matplotlib为画图库

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import torch.nn as nn
import torch.optim as optim
import random

2.2设置相关配置

我们设置随机种子(方便代码的复现)和警告的忽律(防止出现太多警告看不到代码运行的效果)

warnings.filterwarnings('ignore')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(99)
np.random.seed(99)
random.seed(99)
print ("随机种子")

2.3数据的读入

pd.read_csv里面的参数为相对位置,即代码和文件要在同一个文件夹下面。使用.head()函数来读一下数据的前几行,保证数据是存在的

train_data = pd.read_csv("LSTM-Multivariate_pollution.csv")
train_data.head()

请添加图片描述
我们来看一下各个值的前2048个数据分布情况(方便挑选数据进行代码测试)
代码里面的pollution可以换成dew,temp等值(也就是上图里面的值),用于观看分布情况。

train_use = train_data["pollution"].values
plt.plot([i for i in range(2048)], pollution[:2048])

pollution:
请添加图片描述
dew:
请添加图片描述
temp:
请添加图片描述
我们可以看到temp属性里面的数据整体呈现上升的趋势,所以我们使用属性为temp的值来进行学习和预测。
首先对数据进行归一化操作(因为值过大的话会导致神经网络损失不降低,同时神经网络难以达到收敛),我们使用minmax归一化后将其打印出来可以看到代码显示的效果

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
train_use = scaler.fit_transform(train_use.reshape(-1, 1))
print ((train_use))                                                                     
print ("归一化处理")

可以看到归一化后的结果如下图所示:
在这里插入图片描述
我们将数据进行处理,默认使用30天的数据对第31天的数据进行预测,同时将数据进行升维处理,使得输入的训练数据为3维度,分别为batchsize,每次所需要的数据(30个数据),和数据的输入维度(1维度)

def split_data(data, time_step = 30):
    dataX = []
    dataY = []
    for i in range(len(data) - time_step):
        dataX.append(data[i:i + time_step])
        dataY.append(data[i + time_step])
    dataX = np.array(dataX).reshape(len(dataX), time_step, -1)
    dataY = np.array(dataY)
    return dataX, dataY

进行数据处理后,获得了可以训练的数据和标签

datax,datay = split_data(train_use, 30)
print ((datay))

结果如下:
请添加图片描述

紧接着我们划分训练集和测试集,默认为80%的数据用于做训练集,20%的数据用于做测试集,shuffle表示是否要将数据进行打乱,以此来测试训练效果

def train_test_split(dataX,datay,shuffle = True,percentage = 0.8):
    if shuffle:
        random_num = [i for i in range(len(dataX))]
        np.random.shuffle(random_num)
        dataX = dataX[random_num]
        datay = datay[random_num]
    split_num = int(len(dataX)*percentage)
    train_X = dataX[:split_num]
    train_y = datay[:split_num]
    testX = dataX[split_num:]
    testy = datay[split_num:]
    return train_X, train_y, testX, testy

获取我们的训练数据和测试数据,同时把源数据保存到X_train和y_train里面,方便以后对网络的性能进行评比。

train_X, train_y, testx,testy = train_test_split(datax,datay,False,0.8)
print (type(testx))
print("datax的形状为{},dataY的形状为{}".format(train_X.shape, train_y.shape))
X_train = train_X
y_train = train_y

定义我们的自定义网络

class CNN_LSTM(nn.Module):
    def __init__(self, conv_input, input_size, hidden_size, num_layers, output_size):
        super(CNN_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.conv = nn.Conv1d(conv_input, conv_input, 1)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.conv(x)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _= self.lstm(x,(h0,c0))
        out = self.fc(out[:,-1,:])
        return out

设置我们网络训练所需要的参数

test_X1 = torch.Tensor(testx)
test_y1 = torch.Tensor(testy)

input_size = 1
conv_input = 30
hidden_size = 64
num_layers = 2

output_size = 1

model = CNN_LSTM(conv_input, input_size, hidden_size, num_layers,output_size)


num_epoch = 1000
batch_size = 4

optimizer = optim.Adam(model.parameters(), lr = 0.0001, betas=(0.5, 0.999))


criterion = nn.MSELoss()
#print ((torch.Tensor(train_X[:batch_size])))

开始运行代码:

train_losses = []
test_losses = []
for epoch in range(num_epoch):
    random_num = [i for i in range(len(train_X))]
    np.random.shuffle(random_num)

    train_X = train_X[random_num]
    train_y = train_y[random_num]

    train_x1 = torch.Tensor(train_X[:batch_size])
    train_y1 = torch.Tensor(train_y[:batch_size])
    model.train()

    optimizer.zero_grad()
    output = model(train_x1)

    train_loss = criterion(output, train_y1)

    train_loss.backward()
    optimizer.step()

    if epoch%50 == 0 :
        model.eval()
        with torch.no_grad():
            output = model(test_X1)
            test_loss = criterion(output, test_y1)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        print("epoch{},train_loss:{},test_loss:{}".format(epoch, train_loss, test_loss))

在这里插入图片描述

自己手写一个mse计算函数(直接调库也可以),什么是mse?(均方误差,均方误差越小说明模型拟合的越好)

def mse(pred_y, true_y):
    return np.mean((pred_y - true_y) **2)

然后我们对模型进行测试,观察mse的值

train_X1 = torch.Tensor(X_train)
train_pred = model(train_X1).detach().numpy()
test_pred = model(test_X1).detach().numpy()

pred_y = np.concatenate((train_pred, test_pred))
pred_y = scaler.inverse_transform(pred_y).T[0]

true_y = np.concatenate((y_train, testy))
#print (true_y)
true_y = scaler.inverse_transform(true_y).T[0]
#print (true_y)
print (f"mse(pred_y, true_y):{mse(pred_y, true_y)}")
##print (pred_y)

在这里插入图片描述

我们取前2048个值来看我们的预测的情况(因为数据有几万条,为了避免图形太过密集难以看出效果,所以我们只采用前2048个值来进行展示)

plt.title("CNN_LSTM")
x = [i for i in range(2048)]
plt.plot(x, pred_y[:2048], marker = "o", markersize =1, label="pred_y",color=(1, 0, 0))
plt.plot(x, true_y[:2048], marker = "x", markersize=1, label="true_y",color=(0, 0, 1))
plt.legend()
plt.show()

可以看出来,已经学习到了基本的上升趋势的
在这里插入图片描述
我们将两个图拆开来看,看到前8192个点的值,可以看到已经获得到了相对应的趋势。
请添加图片描述
在这里插入图片描述

码字不易,写代码不易,点个赞再走把

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1127896.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《数字图像处理-OpenCV/Python》连载(26)绘制椭圆和椭圆弧

《数字图像处理-OpenCV/Python》连载(26)绘制椭圆和椭圆弧 本书京东优惠购书链接:https://item.jd.com/14098452.html 本书CSDN独家连载专栏:https://blog.csdn.net/youcans/category_12418787.html 第 4 章 绘图与鼠标交互 本章…

在keil中debug分析单片机数据和函数调用过程(c51为例),使用寄存器组导致错误原因分析

寄存器参考 参考2 [寄存器组使用using参考](https://blog.csdn.net/weixin_46720928/article/details/110221835) keil中的using关键字参考 官方文档里关于using的说明可参阅2个地方,(1)keil软件菜单栏->Help->…

被邀请为期刊审稿时,如何做一个合格的审稿人?官方版本教程来喽

审稿是学术研究中非常重要的环节,它可以确保研究的科学性和严谨性。审稿人的任务是检查文章是否符合学术规范,是否具有创新性,是否具有科学价值,以及是否符合期刊的定位和风格。因此,审稿人需要具有扎实的学术背景和丰…

SHELL编程基础2

文章目录 if语句if单分支应用案例 if多分支案例 for循环while循环正则表达式基本正则Perl兼容的正则 if语句 if单分支 if单分支的语法组成: 方式一: if 条件测试;then 命令序列 fi方式二 if 条件测试then 命令序列 fi应用案例 [rootsom day01]# vim user_v2.…

WPF中的绑定知识详解(含案例源码分享)

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…

带你深入了解git

目录 1、Git介绍1.1git是什么工具1.2git起到的作用1.2.1 个人开发:1.2.2 多人开发: 2、Git安装与下载项目代码2.1 下载安装git2.2 从仓库下载代码 3、创建仓库及提交代码3.1 创建仓库3.2 将本地代码以及文件提交到远程仓库3.2.1git全局配置3.2.2 远程仓库…

领域驱动设计:基于DDD的微服务设计实例

文章目录 项目基本信息战略设计战术设计后续的工作 用一个项目来了解 DDD 的战略设计和战术设计,走一遍从领域建模到微服务设计的全过程,一起掌握 DDD 的主要设计流程和关键 点。 项目基本信息 项目的目标是实现在线请假和考勤管理。功能描述如下&…

【数据中台建设系列之二】数据中台-数据采集

​ 【数据中台建设系列之二】数据中台-数据采集 上篇文章介绍了数据中台的元数据管理,相信大家对元数据模块的设计和开发有了一定的了解,本编文章将介绍数据中台另一个重要的模块—数据采集。 一、什么是数据采集 数据采集简单来说就是从各种数据源中抓…

美颜滤镜SDK,企业技术解决方案

企业越来越注重提升用户体验,而美颜滤镜SDK正是满足这一需求的强大工具。美摄美颜滤镜SDK是一款专为企业级应用打造的高效、稳定的美颜滤镜解决方案,能够帮助您的企业在瞬息万变的市场中保持竞争力。 一、强大的美颜滤镜功能 美摄美颜滤镜SDK拥有丰富的…

【wvp】wvp设备上可以开启tcp被动模式

目录 开启了 tcp被动模式 开启UDP模式 地平线不支持这种tcp情况 开启了 tcp被动模式 我的理解是zlm就会是tcp被动收流模式 tcpdump -i any host 10.1.3.7 and tcp 而wvp->浏览器,是SRTP,其实还是基于zlm8000的udp端口出来的 开启UDP模式 tcpdump -i any host…

面试算法40:矩阵中的最大矩形

题目 请在一个由0、1组成的矩阵中找出最大的只包含1的矩形并输出它的面积。例如,在图6.6的矩阵中,最大的只包含1的矩阵如阴影部分所示,它的面积是6。 分析 直方图是由排列在同一基线上的相邻柱子组成的图形。由于题目要求矩形中只包含数字…

解析一个月销售额过千万的商业模式——七人拼团

在当今的商业环境中,营销策略的运用对于企业的成功至关重要。其中,拼团模式作为一种以社交为核心的营销方式,正逐渐受到越来越多企业的关注。本文将探讨七人拼团模式,分析其奖励机制和特点,为企业家提供新的营销思路。…

如何设计出优秀的虚拟展厅,设计虚拟展厅有哪些步骤

引言: 虚拟展厅已经成为了当今数字时代的重要组成部分,无论是展示产品、推广服务,还是展示艺术品和文化遗产,虚拟展厅为用户提供了一个全新的互动体验。如何设计虚拟展厅成了很多人关注的焦点。 一.虚拟展厅设计的基本原则 虚拟…

5G RedCap工业智能网关

5G RedCap工业智能网关是当前工业智能化发展领域的重要技术之一。随着物联网和工业互联网的迅速发展,企业对于实时数据传输和高速通信需求越来越迫切。在这种背景下,5G RedCap工业智能网关以其卓越的性能和功能,成为众多企业的首选。 5G RedC…

双11电视盒子什么牌子好?数码达人测评25款整理电视盒子排名

双11买电视盒子什么牌子好?为了推荐更客观,这段时间我进行了25款主流电视盒子的深度测评,从芯片、内存、网络、散热、系统、广告、流畅度等多方面进行对比,整理了电视盒子排名,双十一想买电视盒子不知道怎么选可以参考…

应用程序无法正常启动0xc000007b的解决策略,多种解决方法分享

当我们在使用特定的软件或游戏时,我们可能会遇到一个特别令人头疼的问题—那就是"应用程序无法正常启动0xc000007b"的错误。但是,为何会出现这类情况和如何解决呢?接下来的内容,将会详细地为你阐释。 一.0xc000007b错误…

【小程序】实现一个定制的音乐播放器

应用地址:https://spacexcode.com/player 介绍 这是为自己制作的一个在线 Web 版的音乐播放器。众所周知,现在市面上的所有的音乐平台都是会员制。而免费的资源却分散在网络上的各个角落,为此,我收集了自己 喜欢的音乐&#xff0…

代码签名证书到期了怎么续费?

我们都知道代码签名证书最长期限可以申请3年,但有的首次申请也会申请1年,这种情况下证书到期了就意味着要重新办理,同样的实名验证步骤还需要再走一遍,尤其目前无论是哪种类型的代码签名证书都会有物理硬件,即使交钱实…

将本地代码上传至码云具体步骤

前言:假如我们在本地创建了一个新项目,现在想将这个项目上传至码云 第一步:码云上创建仓库 第二步:点击创建完成仓库 到这就已经完成了码云仓库的创建!!! 第三步:打开cmd命令输入这…

微信小程序5

一、什么是后台交互? 在小程序中,与后台交互指的是小程序前端与后台服务器之间的数据通信和请求处理过程。通过与后台交互,小程序能够获取服务器端的数据、上传用户数据、发送请求等。 与后台交互可以通过以下方式实现: 发起网络请…