基于Pytorch框架的LSTM算法(一)——单维度单步滚动预测(2)

news2025/1/26 15:51:01

#项目说明:
在这里插入图片描述

说明:1time_steps=

滚动预测代码

y_norm = scaler.fit_transform(y.reshape(-1, 1))
y_norm = torch.FloatTensor(y_norm).view(-1)

# 重新预测
window_size = 12
future = 12
L = len(y)

首先对模型进行训练;
然后选择所有数据的后window_size个数据,通过训练,每次通过前window_size个数据预测未来一个数据,之后新预测的一个数据append加到preds中,这样经过循环future次后,通过滚动循环的方式,预测未来future=12天的数据,完成滚动预测。
最后通过可视化查看预测结果。
preds = y_norm[-window_size:].tolist() 

model.eval()
for i in range(future):  
    seq = torch.FloatTensor(preds[-window_size:])
    with torch.no_grad():

        model.hidden = (torch.zeros(1,1,model.hidden_size),
                        torch.zeros(1,1,model.hidden_size))  
        preds.append(model(seq).item())


true_predictions = scaler.inverse_transform(np.array(preds).reshape(-1, 1))


x = np.arange('2019-02-01', '2020-02-01', dtype='datetime64[M]').astype('datetime64[D]')

plt.figure(figsize=(12,4))
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
plt.plot(x,true_predictions[window_size:])
plt.show()

完整代码解读:

import torch
import torch.nn as nn

from sklearn.preprocessing import MinMaxScaler
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()


# 导入酒精销售数据
df = pd.read_csv('data\Alcohol_Sales.csv',index_col=0,parse_dates=True)
len(df)

df.head()  # 观察数据集,这是一个单变量时间序列

plt.figure(figsize=(12,4))
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
plt.show()


y = df['S4248SM144NCEN'].values.astype(float)

# print(len(y))  #325条数据

test_size = 12

# 划分训练和测试集,最后12个值作为测试集
train_set = y[:-test_size]  #323条数据
test_set = y[-test_size:] #12条数据

# print(train_set.shape) #(313,) 一位数组

# 归一化至[-1,1]区间,为了获得更好的训练效果
scaler = MinMaxScaler(feature_range=(-1, 1))
#scaler.fit_transform输入必须是二维的,但是train_set却是一个一维,所有实验reshape(-1,1)
train_norm = scaler.fit_transform(train_set.reshape(-1, 1)) #np.reshape(-1, 1)=1,行未知

# print(train_norm.shape) #(313, 1) 这里将一维数据转化为二维

# 转换成 tensor
train_norm = torch.FloatTensor(train_norm).view(-1) 
print(train_norm.shape) #torch.Size([313])


# 定义时间窗口,注意和前面的test size不是一个概念
window_size = 12

# 这个函数的目的是为了从原时间序列中抽取出训练样本,也就是用第一个值到第十二个值作为X输入,预测第十三个值作为y输出,这是一个用于训练的数据点,时间窗口向后滑动以此类推
def input_data(seq,ws):  
    out = []
    L = len(seq)
    for i in range(L-ws):
        window = seq[i:i+ws]
        label = seq[i+ws:i+ws+1]
        out.append((window,label))  #将x和y以tensor格式放入到out列表当中,    
    return out

train_data = input_data(train_norm,window_size)
len(train_data)  # 等于325(原始数据集长度)-12(测试集长度)-12(时间窗口)
   
class LSTMnetwork(nn.Module):
    def __init__(self,input_size=1,hidden_size=100,output_size=1):
        super().__init__()
        self.hidden_size = hidden_size
        # 定义LSTM层
        self.lstm = nn.LSTM(input_size,hidden_size)
        # 定义全连接层
        self.linear = nn.Linear(hidden_size,output_size)
        # 初始化h0,c0
        self.hidden = (torch.zeros(1,1,self.hidden_size),
                       torch.zeros(1,1,self.hidden_size))

    def forward(self,seq):
        # 前向传播的过程是输入->LSTM层->全连接层->输出

        # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM
        # 在观察查看LSTM输入的维度,LSTM的第一个输入input_size维度是(L, N, H_in), L是序列长度,N是batch size,H_in是输入尺寸,也就是变量个数
        # LSTM的第二个输入是一个元组,包含了h0,c0两个元素,这两个元素的维度都是(D∗num_layers,N,H_out),D=1表示单向网络,num_layers表示多少个LSTM层叠加,N是batch size,H_out表示隐层神经元个数
        
        '''
        pytorch中LSTM输入为[time_step,batch,feature],这里窗口time_step=12,feature=1[1维数据],batch我们这里设置为1
        所以使用seq.view(len(seq),1,-1)将tensor[12]数据转化为tensor[12,1,1]
        '''
        lstm_out, self.hidden = self.lstm(seq.view(len(seq),1,-1), self.hidden)        
        # print(lstm_out) #torch.Size([12, 1, 100]) [time_step,batch,hidden]        
        # print(lstm_out.view(len(seq),-1))  #[12,100]
        pred = self.linear(lstm_out.view(len(seq),-1)) 
        
        # print(pred) #torch.Size([12, 1])
        return pred[-1]  # 输出只用取最后一个值
    
torch.manual_seed(101)
model = LSTMnetwork()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 100
start_time = time.time()
for epoch in range(epochs):

    for seq, y_train in train_data:
        
        # 每次更新参数前都梯度归零和初始化
        optimizer.zero_grad()
        model.hidden = (torch.zeros(1,1,model.hidden_size),
                        torch.zeros(1,1,model.hidden_size))

        y_pred = model(seq)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch+1:2} Loss: {loss.item():10.8f}')

print(f'\nDuration: {time.time() - start_time:.0f} seconds')

future = 12

# 选取序列最后12个值开始预测
preds = train_norm[-window_size:].tolist()

# 设置成eval模式
model.eval()
# 循环的每一步表示向时间序列向后滑动一格
for i in range(future):
    
    seq = torch.FloatTensor(preds[-window_size:]) #第下一次循环的时候,seq总是能取到后12个数据,因此及时后面用pred.append()也还是每次用到最新的预测数据完成下一次的预测。
    
    with torch.no_grad():
        model.hidden = (torch.zeros(1,1,model.hidden_size),
                        torch.zeros(1,1,model.hidden_size))
        """
        item理解:取出张量具体位置的元素元素值,并且返回的是该位置元素值的高精度值,保持原元素类型不变;必须指定位置
            即:原张量元素为整形,则返回整形,原张量元素为浮点型则返回浮点型,etc.
        """
        # print(model(seq),model(seq).item())  #tensor([0.1027]), tensor([0.1026])
        preds.append(model(seq).item())  #每循环一次,这里会将新的预测值添加到pred中,

# 逆归一化还原真实值
true_predictions = scaler.inverse_transform(np.array(preds[window_size:]).reshape(-1, 1))

# 对比真实值和预测值
plt.figure(figsize=(12,4))
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
x = np.arange('2018-02-01', '2019-02-01', dtype='datetime64[M]').astype('datetime64[D]')

plt.plot(x,true_predictions)
plt.show()

# 放大看
fig = plt.figure(figsize=(12,4))
plt.grid(True)
fig.autofmt_xdate()

plt.plot(df['S4248SM144NCEN']['2017-01-01':])
plt.plot(x,true_predictions)
plt.show()


# 重新开始训练
epochs = 100
# 切回到训练模式
model.train()
y_norm = scaler.fit_transform(y.reshape(-1, 1))
y_norm = torch.FloatTensor(y_norm).view(-1)
all_data = input_data(y_norm,window_size)


start_time = time.time()

for epoch in range(epochs):

    for seq, y_train in all_data:  

        optimizer.zero_grad()
        model.hidden = (torch.zeros(1,1,model.hidden_size),
                        torch.zeros(1,1,model.hidden_size))

        y_pred = model(seq)

        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch+1:2} Loss: {loss.item():10.8f}')

print(f'\nDuration: {time.time() - start_time:.0f} seconds')

# 重新预测
window_size = 12
future = 12
L = len(y)

preds = y_norm[-window_size:].tolist()

model.eval()
for i in range(future):  
    seq = torch.FloatTensor(preds[-window_size:])
    with torch.no_grad():

        model.hidden = (torch.zeros(1,1,model.hidden_size),
                        torch.zeros(1,1,model.hidden_size))  
        preds.append(model(seq).item())


true_predictions = scaler.inverse_transform(np.array(preds).reshape(-1, 1))


x = np.arange('2019-02-01', '2020-02-01', dtype='datetime64[M]').astype('datetime64[D]')

plt.figure(figsize=(12,4))
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
plt.plot(x,true_predictions[window_size:])
plt.show()

代码说明:代码中包含了训练、测试和预测。但没有对该模型进行评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1176509.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“利用义乌购API揭秘跨境贸易商机:一键获取海量优质商品列表!“

义乌购API可以根据关键词取商品列表。通过调用义乌购API的item_search接口,传入关键词参数,可以获取到符合该关键词的商品列表。 以下是使用义乌购API根据关键词取商品列表的步骤: 注册义乌购开发者账号并获取授权码和密钥。在代码中导入义…

SAP-MM-定义计量单位组

业务场景: 有些物料的计量单位是相同的,为了快速维护物料的计量单位的转换关系,可以创建计量单位组,输入转换关系时,输入组就可以直接转换,不需要单个维护 SPRO-后勤常规-物料主数据-设置关键字段-定义计…

享搭低代码平台:快速构建符合需求的进销存管理系统应用

本文介绍了享搭低代码平台如何赋予用户快速构建进销存管理系统应用的能力。通过在应用商店安装费用进销存管理模板,并通过拖拽方式对模板进行自定义扩充,用户可以快速搭建符合自身需求的进销存管理系统,从而提高管理效率和优化运营。 介绍低代…

shopee、亚马逊卖家如何安全给自己店铺测评?稳定测评环境是关键

大家都知道通过测评可以提升产品的转化率,提升产品的销量,那么做跨境平台的卖家如何安全的给自己店铺测评呢? 无论是亚马逊、拼多多Temu、shopee、Lazada、wish、速卖通、敦煌网、Wayfair、雅虎、eBay、Newegg、乐天、美客多、阿里国际、沃尔…

进销存管理系统如何提高供应链效率?

供应链和进销存系统之间有着密切的联系。进销存系统是供应链管理的一部分,用于跟踪和管理产品的采购、库存和销售。进销存管理是供应链管理的核心流程之一,它有助于提高效率、降低成本、增加盈利,同时确保客户满意度,这对于企业的…

HackTheBox-Starting Point--Tier 1---Ignition

文章目录 一 题目二 实验过程 一 题目 Tags Web、Common Applications、Magento、Reconnaissance、Web Site Structure Discovery、Weak Credentials译文:Web、常见应用、Magento、侦察、网站结构发现、凭证薄弱Connect To attack the target machine, you must …

Docker安装Minio(稳定版)

1、安装 docker pull minio/minio:RELEASE.2021-06-17T00-10-46Z docker run -p 9000:9000 minio/minio:RELEASE.2021-06-17T00-10-46Z server /data 2、访问测试 3、MinIO自定义Access和Secret密钥 要覆盖MinIO的自动生成的密钥,您可以将Access和Secret密钥设为…

干货分享 | 3D WEB轻量化引擎HOOPS Communicator如何读取复杂大模型文件?

HOOPS Communicator是一款简单而强大的工业级高性能3D Web可视化开发包,其主要应用于Web领域,主要加载其专有的SCS、SC、SCZ格式文件;HOOPS还拥有另一个桌面端开发包HOOPS Visualize,主要加载HSF、HMF轻量化格式文件。 两者虽然同…

面向有连接型和面向无连接型

文章目录 面向有连接型和面向无连接型面向有连接型面向无连接型 面向有连接型和面向无连接型 通过网络发送数据,大致可以分为面向有连接与面向无连接两种类型,如下图: 面向有连接型 面向有连接型中,在发送数据之前&#xff0c…

Python Slice函数:数据处理利器详解

引言: 在Python编程中,处理数据是一个非常常见且重要的任务。为了更高效地处理数据,Python提供了许多内置函数和方法。其中,slice()函数是一个非常强大且常用的工具,它可以帮助我们轻松地提取、操作和处理数据。无论是…

无效的标记: --release

一、错误提示: 无效的标记: --release 二、原因 使用的 jdk 版本与所需 jdk 版本不符 三、解决: 1、先排除是否是JDK与SpringBoot的版本不一致导致的:如JDK1.8和SpringBoot3.1.5冲突; 2、检查调整Java编译版本 3、检查Maven环…

MTK 拨打紧急电话接通时间过长问题分析

1、问题分析 从Log视频来看,通话接通时间过长,但是Modem Log来看,进行多两次拨号。 查看AP代码确实进行了两次拨号 AP界面查看确实只有一路通话 查看MTK原始代码,发现当紧急拨号失败后,上层换卡重试,界面不…

yolov8模型训练、目标跟踪

一、准备条件 1.下载yolov8 https://github.com/ultralytics/ultralytics2.安装python https://www.python.org/ftp/python/3.8.0/python-3.8.0-amd64.exe3.安装依赖 进入ultralytics-main,执行: pip install -r requirements.txt pip install -U ul…

chrome 防止http自动转https的方法

1. 左上角,单击地址栏左边 2. 然后点击网站设置 3. 不安全内容改为【允许】 4. 然后以后访问此网站时,就不会再自动跳转为https了

比特币生态的下一个千倍币——CHAX

这两天突然火起来的chax大家关注到了吗? 比特币生态的 如果你羡慕别人ordi赚万倍的收入,你羡慕0风险参与sats铭刻获得10倍以上收益的人,请耐心看完的接下来的内容,如果你不想赚钱,请滑走,原创不易&#x…

Windows如何搭建WebDAV服务并实现公网远程访问内网本地文件?

文章目录 1. 安装IIS必要WebDav组件2. 客户端测试3. cpolar内网穿透3.1 打开Web-UI管理界面3.2 创建隧道3.3 查看在线隧道列表3.4 浏览器访问测试 4. 安装Raidrive客户端4.1 连接WebDav服务器4.2 连接成功4.2 连接成功总结: 自己用Windows Server搭建了家用NAS主机&…

pytorch与cudatoolkit,cudnn对应关系及安装相应的版本

文章目录 一.cuda安装二、nvidia 驱动和cuda runtime 版本对应关系三、安装cudatoolkit,cudnn对应版本四、cuda11.2版本的对应安装的pytorch版本及安装五、相关参考 一.cuda安装 1.确定当前平台cuda可以安装的版本 安装好显卡驱动后,使用nvidia-smi命令可以查看这个…

antv/x6 自定义html节点并且支持动态更新节点内容

antv/x6 自定义html节点 效果图定义一个连接桩公共方法注册图形节点创建html节点动态更新节点内容 效果图 定义一个连接桩公共方法 const ports {groups: {top: {position: top,attrs: {circle: {r: 4,magnet: true,stroke: #cf1322,strokeWidth: 1,fill: #fff,style: {visib…

算法通关村第十四关白银挑战——堆的经典算法题

关注微信公众号:怒码少年。 回复关键词:【电子书】,领取多本计算机相关电子书 大家好,我是怒码少年小码。 今天开始进入新的篇章——堆!这里我默认了大家都知道堆的基本知识了,我们来看看关于堆的两道高频…

基于Matlab的策动点阻抗快速综合库函数-微带线综合

基于Matlab的策动点阻抗快速综合库函数-微带线综合 参考书籍: MICROWAVE AMPLIFIER AND ACTIVE CIRCUIT DESIGN USING THE REAL FREQUENCY TECHNIQUE 1、环境安装 下载RFPLSynth包,链接:https://github.com/Grant-Giesbrecht/RFPLSynth。在…