【项目实践】基于LSTM的一维数据扩展与预测

news2025/1/20 16:29:33

基于LSTM的一维数据拟合扩展

一、引(fei)言(hua)

我在做Sri Lanka生态系统服务价值计算时,中间遇到了一点小问题。从世界粮农组织(FAO)上获得Sri Lanka主要农作物产量和价格数据时,其中的主要作物Sorghum仅有2001-2006年的数据,而Millet只有2001-2005,2020-2021这样的间断数据。虽然说可以直接剔除这种过分缺失的数据,但这无疑会对生态因子的计算造成重大影响。所以我想要不要整个函数把他拟合一下,刚好Maize和Rice有2001-2021的完备数据,于是,这个文档就这样诞生了。


二、数据

数据来自FAO,考虑到可能有同学想要跟着尝试一下,这里给出用到的数据。

作物产量

作物价格

2.1 数据探查

我们读取数据,并进行简单的统计量查看。如果要进一步深入研究数据分布及可视化,可以看看我的这篇文章

import pandas as pd

path=r"YourPath"

yield_=pd.read_csv(path+r"\yield.csv")
pp_=pd.read_csv(path+r"\Producer Prices.csv")
yield_.head()

在这里插入图片描述

需要用到的属性只有Item,Year,Unit,Value

所以我们做这样的处理:

yield_=yield_[["Item","Year","Unit","Value"]]

可以看到有些数据是从1961年开始的,太旧了就不用了,我们从2001年开始。

yield_=yield_[yield_["Year"]>2000]

同样,我们来看看pp_的情况:

pp_.head()

在这里插入图片描述

pp_=pp_[["Item","Year","Value","Element"]]
pp_=pp_[pp_["Year"]>2000]

实际上,在这个数据里,产量已经没有问题了。我们只需要做一个简单的处理:

yield_.groupby("Item").mean()["Value"]/10 #转为千克

在这里插入图片描述

便可拿到每种作物近二十年的平均产量。

好了现在大问题出现在价值上,我们从下往上看就知道了:

pp_.tail(10)

在这里插入图片描述

高粱只有2006年的,那有没有办法利用现成的数据将其扩展呢?

实际上,这类拟合问题有很多种解决方案,但是本问题涉及到时间,之前时间段的因子,以及可能的周期性,都会增加拟合的复杂性。所以,在这里我们采用LSTM来填充数据。


三、模型构建

在本小节,我们将比较传统一维CNN与RNN在结果上的异同。

一般做一维RNN时,可以指定一个时间窗口,比如用2006,2007,2008年的数据,推理2009年的数据,用2007,2008,2009年推理2010年。

我们现在要用之前处理好的pp_c数据中的玉米产量,来预测高粱产量。所以第一步就是将其转化为torch接受的格式。

别忘记导入模块:

import torch
import torch.nn as nn
from torch.nn import functional as F
x=pp_c[pp_c['Item']=="Maize (corn)"]['Value']
x=torch.FloatTensor(x)

之前写数据迭代器的时候,除了可以继承自torch.utils.data.DataLoader,也可以是任意的可迭代对象。这里我们可以简单的设置一个类:

# 设置迭代器
class MyDataSet(object):
    def __init__(self,seq,ws=6):
        # ws是滑动窗口大小
        self.ori=[i for i in seq[:ws]]
        self.label=[i for i in seq[ws:]]
        self.reset()
        self.ws=ws

    def set(self,dpi):
        # 添加数据
        self.x.append(dpi)
        
    def reset(self):
        # 初始化
        self.x=self.ori[:]
        
    def get(self,idx):
        return self.x[idx:idx+self.ws],self.label[idx]
    
    def __len__(self):
        return len(self.x)

哦这边提一下,有两种方式,一种是用原始数据做预测,一种是用预测数据做预测,可能有点抽象,下面举个例子。

假设 A = [ a 1 , a 2 , a 3 , a 4 , a 5 , a 6 ] A=[a1,a2,a3,a4,a5,a6] A=[a1,a2,a3,a4,a5,a6],时间窗口大小为3。

用原始数据做预测,那么输入值为: a 1 , a 2 , a 3 a1,a2,a3 a1,a2,a3,得到的结果将与 a 4 a4 a4做比较。下一轮输入为 a 2 , a 3 , a 4 a2,a3,a4 a2,a3,a4,得到的结果将与 a 5 a5 a5做比较。

而用预测的数据做预测,第一轮输入值为 a 1 , a 2 , a 3 a1,a2,a3 a1,a2,a3,得到的结果是 b 4 b4 b4,在与 a 4 a4 a4做比较后,下一轮的输入为 a 2 , a 3 , b 4 a2,a3,b4 a2,a3,b4,会出现如下情况:

输入数据为 b 4 , b 5 , b 6 b4,b5,b6 b4,b5,b6

我们现在举的例子是用预测的数据做预测。当然,最后也会给出一个用原始数据做预测的版本,那个版本相对简单。

ws=6 # 全局时间窗口
train_data=MyDataSet(x,ws)

网络的架构如下:

   
class Net3(nn.Module):
    def __init__(self,in_features=54,n_hidden1=128,n_hidden2=256,n_hidden3=512,out_features=7):
        super(Net3, self).__init__()
        self.flatten=nn.Flatten()
        self.hidden1=nn.Sequential(
            nn.Linear(in_features,n_hidden1,False),
           
            nn.ReLU()
        )
        self.hidden2=nn.Sequential(
            nn.Linear(n_hidden1,n_hidden2),

            nn.ReLU()
        )
        self.hidden3=nn.Sequential(
            nn.Linear(n_hidden2,n_hidden3),

            nn.ReLU()
        )
        self.out=nn.Sequential(nn.Linear(n_hidden3,out_features))

    def forward(self,x):
        x=self.flatten(x)
        x=self.hidden2(self.hidden1(x))
        x=self.hidden3(x)
        return self.out(x)



class CNN(nn.Module):
    def __init__(self, output_dim=1,ws=6):
        super(CNN, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv1d(ws, 64, 1)
        self.lr = nn.LeakyReLU(inplace=True)
        self.conv2 = nn.Conv1d(64, 128, 1)

        self.bn1, self.bn2 = nn.BatchNorm1d(64), nn.BatchNorm1d(128)
        self.bn3, self.bn4 = nn.BatchNorm1d(1024), nn.BatchNorm1d(128)
        self.flatten = nn.Flatten()
        self.lstm1 = nn.LSTM(128, 1024)
        self.lstm2 = nn.LSTM(1024, 256)
        self.lstm3=nn.LSTM(256,512)
        self.fc = nn.Linear(512, 512)
        self.fc4=nn.Linear(512,256)
        self.fc1 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, output_dim)

    @staticmethod
    def reS(x):
        return x.reshape(-1, x.shape[-1], x.shape[-2])

    def forward(self, x):
        x = self.reS(x)
        x = self.conv1(x) 
        x = self.lr(x)

        x = self.conv2(x) 
        x = self.lr(x)

        x = self.flatten(x)

        # LSTM部分
        x, h = self.lstm1(x)
        x, h = self.lstm2(x)
        x,h=self.lstm3(x)
        x, _ = h

        x = self.fc(x.reshape(-1, ))
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc3(x)

        return x

Net3主要是一维卷积,CNN加入了LSTM结构。至于名字,是随便取的…跟内容并无关系。


def Train(model,train_data,seed=1):
    device="cuda" if torch.cuda.is_available() else "cpu"
    model=model.to(device)
    Mloss=100000
    path=r"YourPath\%s.pth"%seed
    # 设置损失函数,这里使用的是均方误差损失
    criterion = nn.MSELoss()
    # 设置优化函数和学习率lr
    optimizer=torch.optim.Adam(model.parameters(),lr=1e-5,betas=(0.9,0.99),
                               eps=1e-07,weight_decay=0)
    # 设置训练周期
    epochs =3000
    criterion=criterion.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss=0

        for i in range(len(x)-ws):
            # 每次更新参数前都梯度归零和初始化
            seq,y_train=train_data.get(i) # 从我们的数据集中拿出数据
            seq,y_train=torch.FloatTensor(seq),torch.FloatTensor([y_train])
            seq=seq.unsqueeze(dim=0)
            seq,y_train=seq.to(device),y_train.to(device)

            optimizer.zero_grad()
            # 注意这里要对样本进行reshape,
            # 转换成conv1d的input size(batch size, channel, series length)
            y_pred = model(seq)
            loss = criterion(y_pred, y_train)
            loss.backward()
            train_data.set(y_pred.to("cpu").item()) # 再放入预测数据
            optimizer.step()
            total_loss+=loss

        train_data.reset()
        if total_loss.tolist()<Mloss:
            Mloss=total_loss.tolist()
            torch.save(model.state_dict(),path)
            print("Saving")
        print(f'Epoch: {epoch+1:2} Mean Loss: {total_loss.tolist()/len(train_data):10.8f}')
    return model

正常训练就OK

d=CNN(ws=ws)
Train(d,train_data,4)

在这里插入图片描述

平均损失在10点左右,还有很大优化空间。当然我们这里只是举个非常简单的例子,就是个baseline

checkpoint=torch.load(r"YourPath\4.pth")
d.load_state_dict(checkpoint) # 加载最佳参数
d.to("cpu")

四、结果可视化

我们这里用到Pyechart进行可视化。

from pyecharts.charts import *
from pyecharts import options as opts
from pyecharts.globals import CurrentConfig
pre,ppre=[i.item() for i in x[:ws]],[]
# pre 是用原始数据做预测
# ppre 用预测数据做预测
for i in range(len(x)-ws+1):
    ppre.append(d(torch.FloatTensor(x[i:i+ws]).unsqueeze(dim=0)))
    pre.append(d(torch.FloatTensor(pre[-ws:]).unsqueeze(dim=0)).item())
l=Line()
l.add_xaxis([i for i in range(len(x))])
l.add_yaxis("Original Data",x.tolist())
l.add_yaxis("Pred Data(Using Raw Datas)",x[:ws].tolist()+[i.item() for i in ppre])
l.add_yaxis("Pred Data(Using Pred Datas)",pre)
l.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
l.set_global_opts(title_opts=opts.TitleOpts(title='LSTM CNN'))

l.render_notebook()

根据时间窗口的不同,可以得到不同的结果。

ws=4

在这里插入图片描述

ws=5

在这里插入图片描述

ws=6

在这里插入图片描述

从结果上来看,时间窗口越大越好。但是这里我们只能到六了,再大就不礼貌了。(高粱只有六个节点的数据)。

至于验证,我们可以选Rice做验证:

x=torch.FloatTensor(pp_c[pp_c['Item']=="Rice"]['Value'].tolist())
pre,ppre=[i.item() for i in x[:ws]],[]
for i in range(len(x)-ws+1):
    ppre.append(d(torch.FloatTensor(x[i:i+ws]).unsqueeze(dim=0)))
    pre.append(d(torch.FloatTensor(pre[-ws:]).unsqueeze(dim=0)).item())
l=Line()
l.add_xaxis([i for i in range(len(x))])
l.add_yaxis("Original Data",x.tolist())
l.add_yaxis("Pred Data(Using Raw Datas)",x[:ws].tolist()+[i.item() for i in ppre])
l.add_yaxis("Pred Data(Using Pred Datas)",pre)
l.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
l.set_global_opts(title_opts=opts.TitleOpts(title='LSTM CNN'))

l.render_notebook()

在这里插入图片描述

可以发现,用预测做预测的结果,基本上不会差太多,那也就意味着,我们可以对高粱进行预测啦!不过在这之前,我们可以看看用原始数据做训练的结果:

在这里插入图片描述

时间窗口一样为6,可以看到在黑线贴合的非常好,但是面对大量缺失的数据,精度就远不如用预测数据做预测的结果了。

此外,这是用CNN做的结果

在这里插入图片描述

我们可以发现LSTM的波动要比CNN好,CNN后面死水一潭,应该是梯度消失导致的,前面信息没有了,后面信息又是自个构造的,这就导致了到后面变成了线性情况。

那么最后的最后,就是预测高粱产量了:

pre_data=pp_c[pp_c['Item']=='Sorghum']['Value'].tolist()
l=pre_data[:]
for i in range(len(x)-ws+1):
    l.append(d(torch.FloatTensor(l[-ws:]).unsqueeze(dim=0)).item())
L=Line()
L.add_xaxis([i for i in range(len(x))])
L.add_yaxis("Pred",l)
L.set_series_opts(label_opts=opts.LabelOpts(is_show=False))
L.set_global_opts(title_opts=opts.TitleOpts(title='sorghum production forecasts')
                            
                             )

L.render_notebook()
l.to_csv("path")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/895675.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2011年下半年 软件设计师 上午试卷2

博主介绍&#xff1a;✌全网粉丝3W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

新宝马M5谍照曝光,侵略感十足,将与奥迪、梅赛德斯-AMG正面竞争

报道称&#xff0c;宝马即将推出全新一代M5&#xff0c;该车的谍照最近再次曝光。早先&#xff0c;宝马 M5 Touring 旅行汽车的赛道测试图片已经在网络上流传开来&#xff0c;预计该车将与奥迪的RS6 Avant和梅赛德斯-AMG E63 Estate展开正面竞争。 从最新曝光的照片来看&#x…

区块链碎碎念

现在的区块链早已过了跑马圈地的时代&#xff0c;现在还按照以前承接项目的方式做区块链只能是越来越艰难。经过几年的技术沉淀&#xff0c;做区块链项目的公司都已经没落的七七八八了。 区块链不是一个能够快速显现盈利能力的行业&#xff0c;相反这个行业目前的模式还是处于…

【Datawhale暑期实践第三期】用户新增预测挑战赛

文章目录 赛题背景赛事任务赛题数据集评价指标赛题思路Baseline导入包并读取数据特征工程决策树模型训练和预测保存预测文件 赛题名称&#xff1a;用户新增预测挑战赛 赛题类型&#xff1a;数据挖掘、二分类 赛题链接&#x1f447;&#xff1a; https://challenge.xfyun.cn/top…

[Machine Learning] decision tree 决策树

&#xff08;为了节约时间&#xff0c;后面关于机器学习和有关内容哦就是用中文进行书写了&#xff0c;如果有需要的话&#xff0c;我在目前手头项目交工以后&#xff0c;用英文重写一遍&#xff09; &#xff08;祝&#xff0c;本文同时用于比赛学习笔记和机器学习基础课程&a…

【学习FreeRTOS】第12章——FreeRTOS时间管理

1.FreeRTOS系统时钟节拍 FreeRTOS的系统时钟节拍计数器是全局变量xTickCount&#xff0c;一般来源于系统的SysTick。在STM32F1中&#xff0c;SysTick的时钟源是72MHz/89MHz&#xff0c;如下代码&#xff0c;RELOAD 9MHz/1000-1 8999&#xff0c;所以时钟节拍是1ms。 portNV…

事物有哪些特性 ?MySQL 如何保证事物的四大特性 ?

目录 1.事物有哪些特性 2. MySQL 如何保证事物的四大特性 3. 事物的隔离级别 1.事物有哪些特性 1.1 何为事物 &#xff1f; 事物就是把一件事情的多个步骤&#xff0c;多个操作&#xff0c;打包成一个步骤&#xff0c;一个操作。其中任意一个步骤执行失败&#xff0c;都会进…

隧道广播平面波扬声器的应用

隧道广播平面波扬声器是一款高清晰定向扬声器&#xff0c;采用稀土永磁磁性材料与声波相控阵技术&#xff0c;有效的解决了声音定向问题。是远距离定向声波发射装置是一种革命性的技术&#xff0c;它具有大功、率高清晰、远距离传声特点&#xff0c;可以将声音信息清晰地传输到…

【数据结构】 链表简介与单链表的实现

文章目录 ArrayList的缺陷链表链表的概念及结构链表的分类单向或者双向带头或者不带头循环或者非循环 单链表的实现创建单链表遍历链表得到单链表的长度查找是否包含关键字头插法尾插法任意位置插入删除第一次出现关键字为key的节点删除所有值为key的节点回收链表 总结 ArrayLi…

Aurora 8B/10B

目录 1. Overview2. Feature List2. Block Diagram3. Ports Description3.1. User InterfaceFraming InterfaceStreaming InterfaceUser Flow Control&#xff08;UFC&#xff09;Native Flow Control&#xff08;NFC&#xff09; 3.2. Status and Control Ports3.3. Transceiv…

基于python+django+mysql的校园影院售票系统(可做计算机毕设)

开发柚子校园影院&#xff0c;不仅可以改善用户查看信息难的局面&#xff0c;还可以提高管理效率&#xff0c;同时也可以增强系统的竞争力。利用柚子校园影院的可以有效地提高系统的人事的效率和信息化水平&#xff0c;快速了解信息更新及服务的进度。这既可以确保系统服务的品…

RuoYi 云服务器部署系统

一.为什么要部署 关于RuoYi-Vue是一个前后端分离的Web后台管理系统。部署在云服务器上让所有人都可以访问这是Web网站很正常的一个需求,只要我们将前端静态文件暴露在公网中,自然就部署好了。当然,要求是前端的静态资源可以访问到后端的接口,网站才会正常运行。 二.云服务器…

.netcore windows app启动webserver

创建controller: using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Text.Json.Serialization; using System.Threading.Tasks;namespace MyWorker.…

【AI】《动手学-深度学习-PyTorch版》笔记(十九):卷积神经网络模型(GoogLeNet、ResNet、DenseNet)

AI学习目录汇总 1、GoogLeNet 1.1 介绍 发布时间:2014年 GoogLeNet的贡献是如何选择合适大小的卷积核,并将不同大小的卷积核组合使用。 之前介绍的网络结构都是串行的,GoogLeNet使用并行的网络块,称为“Inception块” “Inception块”前后进化了四次,论文链接: [1]ht…

x.view(a,b)及x = x.view(x.size(0), -1) 的理解说明

x.view()就是对tensor进行reshape&#xff1a; 我们在创建一个网络的时候&#xff0c;会在Foward函数内看到view的使用。 首先这里是一个简单的网络&#xff0c;有卷积和全连接组成。它的foward函数如下&#xff1a; class NET(nn.Module):def __init__(self,batch_size):sup…

大数据之几分钟处理完30亿个数据

写在前面 假定现在我们有一个10G的文件&#xff0c;存储的是17~70岁的年龄&#xff0c;每个年龄使用,分割&#xff0c;现需要找出出现次数最多的年龄&#xff0c;以及其出现的次数。 源码 。 1&#xff1a;数据准备 我们首先来准备一个10G大小的存储年龄信息的数据文件&#…

(五)、深度学习框架源码编译

1、源码构建与预构建&#xff1a; 源码构建&#xff1a; 源码构建是通过获取软件的源代码&#xff0c;然后在本地编译生成可执行程序或库文件的过程。这种方法允许根据特定需求进行配置和优化&#xff0c;但可能需要较长的时间和较大的资源来编译源代码。 预构建&#xff1a; 预…

JSP-学习笔记

文章目录 1.JSP介绍2 JSP快速入门3 JSP 脚本3.1 JSP脚本案例3.2 JSP缺点 4 EL表达式4.1 快速入门案例 5. JSTL标签6. MVC模式和三层架构6.1 MVC6.2 三层架构 7. 案例-基于MVC和三层架构实现商品表的增删改查 1.JSP介绍 概念 JSP&#xff08;JavaServer Pages&#xff09;是一种…

JVM——引言+JVM内存结构

引言 什么是JVM 定义: Java VirtualMachine -java 程序的运行环境 (ava 二进制字节码的运行环境) 好处: 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收功能数组下标越界检查&#xff0c;多态 比较: jvm jre jdk 学习jvm的作用 面试理解底层实现原理中…

idea设置忽略大小写

1.点击file 2.点击settings 3.点击Editor选项 4.点击general选项 5.点击code completion 6.点击左上角match case