小白也能读懂的ConvLSTM!(开源pytorch代码)

news2024/11/16 20:22:44

ConvLSTM

    • 1. 算法简介与应用场景
    • 2. 算法原理
      • 2.1 LSTM基础
      • 2.2 ConvLSTM原理
        • 2.2.1 ConvLSTM的结构
        • 2.2.2 卷积操作的优点
      • 2.3 LSTM与ConvLSTM的对比分析
      • 2.4 ConvLSTM的应用
    • 3. PyTorch代码
    • 参考文献

仅需要网络源码的可以直接跳到末尾即可

1. 算法简介与应用场景

ConvLSTM(卷积长短期记忆网络)是一种结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)优势的深度学习模型。它主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如气象预测、视频分析、交通流量预测等。

在气象预测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)预测未来的天气情况。在视频分析中,它可以帮助识别视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。

2. 算法原理

2.1 LSTM基础

在介绍ConvLSTM之前,先让我们来回归一下什么是长短期记忆网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制解决了传统RNN在长序列训练中面临的梯度消失和爆炸问题。LSTM单元主要包含三个门:输入门、遗忘门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

LSTM的核心公式如下:

  • 遗忘门
    f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

  • 输入门
    i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
    C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

  • 单元状态更新
    C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct=ftCt1+itC~t

  • 输出门
    o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
    h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t \ast \tanh(C_t) ht=ottanh(Ct)

这里, C t C_t Ct 是当前的单元状态, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。

2.2 ConvLSTM原理

ConvLSTM在LSTM的基础上引入了卷积操作。传统的LSTM使用全连接层处理输入数据,而ConvLSTM则采用卷积层来处理空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。
在这里插入图片描述

2.2.1 ConvLSTM的结构

ConvLSTM的单元结构与LSTM非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表示为:

i t = σ ( W x i ∗ X t + W h i ∗ H t − 1 + W c i ∘ C t − 1 + b i ) i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i) it=σ(WxiXt+WhiHt1+WciCt1+bi)
f t = σ ( W x f ∗ X t + W h f ∗ H t − 1 + W c f ∘ C t − 1 + b f ) f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f) ft=σ(WxfXt+WhfHt1+WcfCt1+bf)
C t = f t ∘ C t − 1 + i t ∘ t a n h ( W x c ∗ X t + W h c ∗ H t − 1 + b c ) C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c) Ct=ftCt1+ittanh(WxcXt+WhcHt1+bc)
o t = σ ( W x o ∗ X t + W h o ∗ H t − 1 + W c o ∘ C t + b o ) o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o) ot=σ(WxoXt+WhoHt1+WcoCt+bo)
H t = o t ∘ t a n h ( C t ) H_t = o_t \circ tanh(C_t) Ht=ottanh(Ct)

这里的 所有 W W W都是是卷积权重, b b b是偏置项, σ \sigma σ 是 sigmoid 函数, tanh ⁡ \tanh tanh 是双曲正切函数。。
在这里插入图片描述

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。

  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。

  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 LSTM与ConvLSTM的对比分析

特性LSTMConvLSTM
输入类型一维序列三维数据(时序的图像数据)
处理方式全连接层卷积操作
空间特征捕捉较弱较强
应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvLSTM的应用

ConvLSTM在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是ConvLSTM的完整代码,可以直接拿来用:

import torch.nn as nn
import torch


class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        初始化卷积 LSTM 单元。

        参数:
        ----------
        input_dim: int
            输入张量的通道数。
        hidden_dim: int
            隐藏状态的通道数。
        kernel_size: (int, int)
            卷积核的大小。
        bias: bool
            是否添加偏置项。
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        # 计算填充大小以保持输入和输出尺寸一致
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # 沿着通道轴进行拼接
        combined = torch.cat([input_tensor, h_cur], dim=1)

        combined_conv = self.conv(combined)
        # 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        # 更新单元状态
        c_next = f * c_cur + i * g
        # 更新隐藏状态
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        # 初始化隐藏状态和单元状态为零
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

    """
    卷积 LSTM 层。

    参数:
    ----------
    input_dim: 输入通道数
    hidden_dim: 隐藏通道数
    kernel_size: 卷积核大小
    num_layers: LSTM 层的数量
    batch_first: 批次是否在第一维
    bias: 卷积中是否有偏置项
    return_all_layers: 是否返回所有层的计算结果

    输入:
    ------
    一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量

    输出:
    ------
    元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):
    0 - layer_output_list 是长度为 T 的每个输出的列表
    1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态

    示例:
    >>> x = torch.rand((32, 10, 64, 128, 128))
    >>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
    >>> _, last_states = convlstm(x)
    >>> h = last_states[0][0]  # 0 表示层索引,0 表示 h 索引
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        # 检查 kernel_size 的一致性
        self._check_kernel_size_consistency(kernel_size)

        # 确保 kernel_size 和 hidden_dim 的长度与层数一致
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('不一致的列表长度。')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        # 创建 ConvLSTMCell 列表
        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        前向传播函数。

        参数:
        ----------
        input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)
        hidden_state: 初始隐藏状态,默认为 None

        返回:
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # 改变输入张量的顺序,如果 batch_first 为 False
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # 实现状态化的 ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # 初始化隐藏状态
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):

            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                # 在每个时间步上更新状态
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                                 cur_state=[h, c])
                output_inner.append(h)

            # 将输出堆叠起来
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            # 如果不需要返回所有层,则只返回最后一层的输出和状态
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            # 初始化每一层的隐藏状态
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

参考文献

[1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
[3]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1956081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“手撕”MySQL的索引

目录 二、索引的作用 三、索引的缺点 四、如何使用索引 查看索引: 创建索引: ​编辑 删除索引: 五、索引的底层原理 那什么是B树,什么是B树呢? B树的好处: 总结: 一、什么是索引 索…

OpenCV 图像预处理—图像金字塔

文章目录 相关概念高斯金字塔拉普拉斯金字塔应用 构建高斯金字塔为什么要对当前层进行模糊?1. 平滑处理2. 减少混叠(Aliasing)3. 多尺度表示4. 图像降采样 举个栗子创建高斯金字塔和拉普拉斯金字塔,并用拉普拉斯金字塔恢复图像 相…

《汇编语言 基于x86处理器》- 读书笔记 - 第3章-汇编语言基础

《汇编语言 基于x86处理器》- 读书笔记 - 第3章-汇编语言基础 3.1 基本语言元素3.1.1 第一个汇编语言程序常见汇编语言调用规范 3.1.2 整数常量(基数、字面量)3.1.3 整型常量表达式3.1.4 实数常量十进制实数十六进制实数(编码实数&#xff09…

使用git命令行的方式,将本地项目上传到远程仓库

在国内的开发环境中,git的使用是必不可少的。Git 是一款分布式版本控制系统,用于有效管理和追踪文件的变更历史及协作开发。本片文章就来介绍一下怎样使用git命令行的方式,将本地项目上传到远程仓库,虽然现在的IDE中基本都配置了g…

Ubuntu安装terminator教程

Terminator 是一个高级的终端仿真器,专为 Linux 和 Unix 系统设计。它的主要特点是提供了丰富的多窗口和多标签功能,使用户能够在一个窗口中管理多个终端会话。这对于系统管理员、开发人员以及需要同时运行多个命令行任务的用户来说,极为方便。 一、安装 1、更新包 sudo a…

使用Selenium爬虫批量下载AlphaFold数据库中的PDB文件

注意:本方法使用了python,下载速度一般,如果需要更快的大批量下载可以考虑使用其他方法,例如FTP Alphafold数据库其实提供了许多物种的蛋白质组: AlphaFold Protein Structure Database 但是如果你搜索的物种不在这个…

算法面试leadcode【经典150道】

88 合并两个有序数组 方法一 使用arraycopy排序 * 思路一:将nums2合并到nums1的尾部,再直接进行排序。* 使用arraycopy(int[]nums1,int m,int[] nums2,int n)* 方法来进行排序,* 从原数组的哪个位置,移动到原数组的哪个位置&#…

xxl-job适配达梦数据库并制作镜像、源码部署xxl-job

背景:因项目需要信创,需将原本的mysql数据库,改成达梦数据库 一、部署达梦数据库 1.1 部署达梦数据库服务 可参考:Docker安装达梦数据库_达梦数据库docker镜像-CSDN博客 PS:部署达梦数据库时,需加上大小…

Java | Leetcode Java题解之第300题最长递增子序列

题目&#xff1a; 题解&#xff1a; class Solution {public int lengthOfLIS(int[] nums) {int len 1, n nums.length;if (n 0) {return 0;}int[] d new int[n 1];d[len] nums[0];for (int i 1; i < n; i) {if (nums[i] > d[len]) {d[len] nums[i];} else {int…

19. Revit API: Parameter(参数)

一、前言 我们在前面或多或少提到也用到参数了&#xff0c;这篇便细讲一下。 首先&#xff0c;我们知道好多信息都藏在参数里&#xff0c;或者说可以从参数中获取。我们还能够通过调整参数的值&#xff0c;改变模型的形态&#xff0c;即族的参变。 其次&#xff0c;有时族上…

【CAN通讯系列4】CAN通讯如何传递信号?

在【CAN通讯系列3】如何学习CAN通讯&#xff1f;中举了一个例子&#xff1a;新能源汽车要实现驱动功能&#xff0c;先需要整车控制器VCU计算目标转速或扭矩请求等信号&#xff0c;再通过CAN通讯传递给电机控制器MCU&#xff0c;就这个例子继续探讨CAN通讯的基础问题。 1 CAN数据…

入门 PyQt6 看过来(案例)08~ 页面布局

主题&#xff1a;学习页面布局控件以及布局容器的使用&#xff08;理论知识&#xff09; 1 布局控件 PyQt6的布局方式包括绝对布局、水平布局、垂直布局、网格布局和表单布局。 绝对布局&#xff1a;直接设置控件对象在参考坐标中的位置水平布局&#xff1a;对加入的控件对象从…

引用的项目“xxxx/tsconfig.node.json”可能不会禁用发出。

vue3 报错&#xff1a; 引用的项目“xxxx/tsconfig.node.json”可能不会禁用发出。 解决&#xff1a; 进入对应的 json 文件&#xff1a; 修改&#xff1a; "noEmit": false 当 noEmit 设置为 false 时&#xff0c;TypeScript 编译器将根据项目配置生成相应的输出文…

【数据结构初阶】单链表经典算法题十道(详解+图例)—得道飞升(终篇)

hi &#xff01; 目录 9、 环形链表 || 10、随机链表的复制 终章 9、 环形链表 || 【图解】 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/typedef struct ListNode ListNode; struct ListNode *detectCy…

Live800:客户服务中的情感智能,建立深厚客户关系的秘诀

在当今竞争激烈的市场环境中&#xff0c;客户服务已成为企业脱颖而出的关键因素之一。而情感智能&#xff0c;作为客户服务中的重要组成部分&#xff0c;更是建立深厚客户关系、提升客户满意度的秘诀所在。优秀的客户服务不仅关乎问题的解决&#xff0c;更在于情感的交流与共鸣…

物联网云盒多路开关量模拟量转无线MQTT钡铼技术S275

物联网云盒多路开关量模拟量转无线MQTT技术在现代工业自动化和远程监测中扮演着关键角色。钡铼第四代RTU S275作为一款先进的物联网数据监测采集控制短信报警终端&#xff0c;集成了多种先进技术和功能&#xff0c;旨在提升远程数据采集与控制的效率和可靠性。 钡铼第四代RTU …

在Ubuntu 12.10上安装和使用tmux的方法

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 关于 tmux tmux 是一个终端复用工具。它允许您使用多个虚拟终端访问 tmux 终端。 tmux 利用了客户端-服务器模型&#xff0c;这使您可…

unity3d:TabView,UGUI多标签页组件,TreeView树状展开菜单

概述 1.最外层DataForm为空壳编辑数据用。可以有多个DataForm&#xff0c;例如福利DataForm&#xff0c;抽奖DataForm 2.Menu层为左边栏层&#xff0c;每个DataForm可以使用不同样式的MenuForm预制体 3.DataForm中使用ReorderList&#xff0c;可排列配置 4.有定位功能&#xf…

网址导航系统PHP源码分享

1、采用光年全新v5模板开发后台 2、后台内置8款主题色&#xff0c;分别是简约白、炫光绿、渐变紫、活力橙、少女粉、少女紫、科幻蓝、护眼黑 3、可管理无数引导页主题并且主题内可以进行不同的自定义设置&#xff0c;目前内置16套主题 持续增加中… 4、可单独开发各种插件&a…

【OSCP系列】OSCP靶机-LemonSqueezy(原创)

【OSCP系列】OSCP靶机-LemonSqueezy 原文转载已经过授权 原文链接&#xff1a;Lusen的小窝 - 学无止尽&#xff0c;不进则退 (lusensec.github.io) 一、主机发现 二、端口扫描 1、快速扫描 2、全端口扫描 只有一个80端口 3、版本系统探测 80端口http的apache服务&#xff0…