PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测

news2024/11/28 10:32:15

前言

前面已经写过不少时间序列预测的文章:

  1. 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
  2. PyTorch搭建LSTM实现时间序列预测(负荷预测)
  3. PyTorch中利用LSTMCell搭建多层LSTM实现时间序列预测
  4. PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
  5. PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
  6. PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  7. PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  8. PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  9. PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  10. PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  11. PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
  12. PyTorch-LSTM时间序列预测中如何预测真正的未来值
  13. PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  14. PyTorch搭建ANN实现时间序列预测(风速预测)
  15. PyTorch搭建CNN实现时间序列预测(风速预测)
  16. PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
  17. PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
  18. PyTorch时间序列预测系列文章总结(代码使用方法)
  19. TensorFlow搭建LSTM实现时间序列预测(负荷预测)
  20. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)
  21. TensorFlow搭建双向LSTM实现时间序列预测(负荷预测)
  22. TensorFlow搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
  23. TensorFlow搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
  24. TensorFlow搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
  25. TensorFlow搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
  26. TensorFlow搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
  27. TensorFlow搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
  28. TensorFlow搭建ANN实现时间序列预测(风速预测)
  29. TensorFlow搭建CNN实现时间序列预测(风速预测)
  30. TensorFlow搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)

这些文章中LSTM的模型都采用以下方法搭建:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1 # 单向LSTM
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        batch_size, seq_len = input_seq.shape[0], input_seq.shape[1]
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
        pred = self.linear(output)  # (5, 30, 1)
        pred = pred[:, -1, :]  # (5, 1)
        return pred

其中LSTM模型的定义语句为:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True, dropout=0.5)

如果num_layers=2, hidden_size=64,那么两层LSTM的hidden_size都为64,并且最后一层也就是第二层结束后不会执行dropout策略。

如果我们需要让两层LSTM的hidden_size不一样,并且每一层后都执行dropout,就可以采用LSTMCell来实现多层的LSTM。

LSTMCell

关于nn.LSTMCell的参数,官方文档给出的解释为:
在这里插入图片描述
参数一共三个,意义和之前文章讲的一样,不再重复。

利用LSTMCell搭建一个两层的LSTM如下所示:

class LSTM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.input_size = args.input_size
        self.output_size = args.output_size
        self.num_directions = 1
        self.batch_size = args.batch_size
        self.lstm0 = nn.LSTMCell(args.input_size, hidden_size=128)
        self.lstm1 = nn.LSTMCell(input_size=128, hidden_size=32)
        self.dropout = nn.Dropout(p=0.4)
        self.linear = nn.Linear(32, self.output_size)

    def forward(self, input_seq):
        batch_size, seq_len = input_seq.shape[0], input_seq.shape[1]
        # batch_size, hidden_size
        h_l0 = torch.zeros(batch_size, 128).to(device)
        c_l0 = torch.zeros(batch_size, 128).to(device)
        h_l1 = torch.zeros(batch_size, 32).to(device)
        c_l1 = torch.zeros(batch_size, 32).to(device)
        output = []
        for t in range(seq_len):
            h_l0, c_l0 = self.lstm0(input_seq[:, t, :], (h_l0, c_l0))
            h_l0, c_l0 = self.dropout(h_l0), self.dropout(c_l0)
            h_l1, c_l1 = self.lstm1(h_l0, (h_l1, c_l1))
            h_l1, c_l1 = self.dropout(h_l1), self.dropout(c_l1)
            output.append(h_l1)

        pred = self.linear(output[-1])

        return pred

可以发现,我们定义了两个LSTMCell,分别对应两层:

self.lstm0 = nn.LSTMCell(args.input_size, hidden_size=128)
self.lstm1 = nn.LSTMCell(input_size=128, hidden_size=32)

第一层的input_size就为初始数据的input_size,第二层的input_size应当为第一层的hidden_size,这样才能实现数据传递。

使用LSTMCell时我们需要手动对每个时间步进行计算与传递:

for t in range(seq_len):
    h_l0, c_l0 = self.lstm0(input_seq[:, t, :], (h_l0, c_l0))
    h_l0, c_l0 = self.dropout(h_l0), self.dropout(c_l0)
    h_l1, c_l1 = self.lstm1(h_l0, (h_l1, c_l1))
    h_l1, c_l1 = self.dropout(h_l1), self.dropout(c_l1)
    output.append(h_l1)

input_seq的维度为:

input_seq(batch_size, seq_len, input_size)

每次取出其中一个步长参与运算:

h_l0, c_l0 = self.lstm0(input_seq[:, t, :], (h_l0, c_l0))

第一个LSTMCell的结果将被送入第二个LSTMCell:

h_l1, c_l1 = self.lstm1(h_l0, (h_l1, c_l1))

此时得到的是一个时间步的输出,维度大小为(batch_size, hidden_size)。重复执行多次,就可以得到所有步长的输出。最后,我们再取最后一个时间步(这里不懂请看第一篇文章)的输出进行映射以得到最终的输出:

pred = self.linear(output[-1])

可以发现,在每一个LSTMCell执行结束后,我们都可以手动添加dropout层:

h_l0, c_l0 = self.dropout(h_l0), self.dropout(c_l0)

反观LSTM的执行过程:

output, _ = self.lstm(input_seq, (h_0, c_0))

此时output的shape为:

output(batch_size, seq_len, hidden_size)

实际上就是一步到位,直接得到所有seq_len(batch_size, hidden_size)

训练/测试

这里没啥可说的,与前面一模一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/88235.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么AI距离智能越来越远?

2021年讨论了人机混合智能里的深度态势感知和人的算计与机器的计算如何结合的问题。之后有一位朋友问了我五个问题。第一,关于数学和逻辑的关系问题。这个问题是百年来数学的基础问题,迄今为止似乎没有定论。从实用主义角度说,“把数学等同于…

企业在项目中采用工时管理系统的好处

在如今疫情的影响下,不少企业面对经济形势愈发严峻的情况下,对项目员工工时的管理也是越来越注重。如何在确保企业正常运转的前提下提升企业发展空间,人员降低工作成本呢?根据目前研究表明,很多企业都选择使用项目工时…

Android Kotlin使用AspectJ进行AOP面向切面编程

前言 什么是面向切面编程?首先我们来了解下两个概念: OOP(面向对象编程):针对业务处理过程的实体及其属性和行为进行抽象封装,以获得更加清晰高效的逻辑单元划分。 AOP(面向切面编程):则是针对业务处理过程…

html好看的生日祝福,生日表白(源码)

文章目录1.设计来源1.1 主界面1.2 秘密基地1.3 甜言蜜语2.效果和源码2.1 动态效果2.2 源代码2.3 自定义背景图片代码2.4 自定义每次生日记录代码2.5 自定义背景音乐代码源码下载作者:xcLeigh 文章地址:https://blog.csdn.net/weixin_43151418/article/de…

Java实现Google第三方登录

文章目录前言一、了解OAuth2.0二、注册开发者账号1.登录开发者平台2.创建应用三、代码实现1.实现流程1.点击登录2.接受回调中的code获取accessToken3.获取用户信息2.注意事项前言 Google API 使用 OAuth 2.0 协议进行身份验证和授权。Google 支持常见的 OAuth 2.0 场景&#x…

高分子点击试剂DBCO-PEG-Hydrazide,二苯并环辛炔-聚乙二醇-酰基

一、试剂基团反应特点(Reagent group reaction characteristics): DBCO-PEG-Hydrazide属于高分子点击试剂,“点击化学"一般由叠氮化物(azide)和炔烃(alkyne)作用形共价键&#…

老港综合填埋场二期配套渗滤液工程电能管理系统的设计和应用-Susie 周

1、概述 本项目为老港综合填埋场二期配套渗滤液工程电能管理系统。根据配电系统管理的要求,需要对(老港综合填埋场二期配套渗滤液工程电能管理系统项目的配电柜进行电能管理,以保证用电的安全、可靠。 Acrel-3000电能管理系统充分利用了现代…

Mybatis源码分析(一)Mybatis 基本使用

目录一 知识回顾1.1 简介1.2 其他二 基本使用官网:mybatis – MyBatis 3 | 简介 一 知识回顾 1.1 简介 MyBatis 是一款优秀的持久层框架,它支持自定义 SQL、存储过程以及高级映射。MyBatis 免除了几乎所有的 JDBC 代码以及设置参数和获取结果集的工作…

图片怎么转换成excel文档?

当我们创建excel文档中,里面无疑是需要各种表格内容,而如果是我们一个一个编辑起来,这就会比较繁琐。而现在许多需求可以通过网络很容易地得到满足。比如有把图片转换成excel表格的需求。下载一个小工具,这就相当方便了&#xff0…

不愧是阿里资深架构师,这本“分布式架构笔记”写得如此透彻明了

前言: Mybatis 是一款优秀的持久层框架。其封装了 JDBC 操作, 免去了开发人员编写 JDBC 代码以及设置参数和获取结果集的重复性工作。通过编写简单的 XML 或 Java 注解即可映射数据库 CRUD 操作。本文介绍的是阿里资深架构师十年经验整理,My…

JAVA 中的注解可以继承吗?

前言 注解想必大家都用过,也叫元数据,是一种代码级别的注释,可以对类或者方法等元素做标记说明,比如 Spring 框架中的Service,Component等。那么今天我想问大家的是类被继承了,注解能否继承呢?…

基于springboot在线答疑系统

教师权限:首页、个人中心、疑难解答管理、试卷管理、试题管理、考试管理。 学生权限;首页、个人中心、问题发布管理、疑难解答管理、考试管理等功能模块的管理维护等操作,系统结构图如下图4-1所示。 图4-1 系统功能图 截图 目 录 摘 要 I …

[附源码]Node.js计算机毕业设计扶贫产品展销平台小程序Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

matlab 的help没了

前两天还正常用,今天输入help 关键字 回复是没有相关的内容。 解决办法: 按照如下选择就行了 然后输入 help help 就会有显示了 help - Help for functions in Command Window This MATLAB function displays the help text for the functionalit…

大数据MapReduce学习案例:倒排索引

文章目录一,案例分析(一)倒排索引介绍(二)案例需求二,案例实施(一)准备数据文件(1)启动hadoop服务(2)虚拟机上创建文本文件&#xff0…

数据结构双向链表

双向链表也叫双链表,是链表的一种,它的每个数据结点中都有两个指针,分别指向直接后继和直接前驱。所以,从双向链表中的任意一个结点开始,都可以很方便地访问它的前驱结点和后继结点。一般我们都构造双向循环链表。 那…

WPF入门第三篇 ControlTemplate、Trigger与Storyboard

ControlTemplate、Trigger与Storyboard ControlTemplate通常用在Style中,Trigger通常作为ControlTemplate的一部分,StoryBoard表示动画效果,下面将通过对Button按钮设置这几项来简单说明这几项的用法。 在MainWindow中添加一个Button按钮&am…

Prometheus技术分享——如何监控宿主机和容器

这一期主要来跟大家聊一下,使用node_exporter工具来暴露主机和因公程序上的指标,利用prometheus来监控宿主机;以及通过通过Cadvisor监控docker容器。 一、部署node_exporter监控宿主机 1 下载软件包 wget https://github.com/prometheus/n…

分布式链路追踪SkyWalking

文章目录目录介绍服务端搭建注册中心启动注册中心修改持久化配置UI服务配置启动服务客户端搭建目录介绍 重要的目录结构分析如下: agent:客户端需要指定的目录,其中有一个jar,就是负责和客户端整合收集日志bin:服务端…

深入理解Linux网络技术内幕(十三)——协议处理函数

文章目录前言网络协议栈概论大蓝图Ethernet的链路层的选择(LLC和SNAP)网络协议栈的操作方式执行正确的协议处理函数特殊媒介封装协议处理函数的组织协议处理函数的注册Ethernet与IEEE 802.3帧设置封包类型设置Ethernet协议及长度逻辑链接控制&#xff08…