ConvGRU原理与开源代码

news2025/1/21 15:21:09

ConvGRU

    • 1. 算法简介与应用场景
    • 2. 算法原理
      • 2.1 GRU基础
      • 2.2 ConvGRU原理
        • 2.2.1 ConvGRU的结构
        • 2.2.2 卷积操作的优点
      • 2.3 GRU与ConvGRU的对比分析
      • 2.4 ConvGRU的应用
    • 3. PyTorch代码

仅需要网络源码的可以直接跳到末尾即可
需要ConvLSTM的可以参考我的另外一篇博客:小白也能读懂的ConvLSTM!(开源pytorch代码)

1. 算法简介与应用场景

ConvGRU(卷积门控循环单元)是一种结合了卷积神经网络(CNN)和门控循环单元(GRU)的深度学习模型。与ConvLSTM类似,ConvGRU也主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如视频分析、气象预测和交通流量预测等。

在视频分析中,ConvGRU可以帮助识别和预测视频中的动态行为,利用时间序列的连续性和空间信息进行更准确的分析。在气象预测中,ConvGRU能够根据过去的气象数据(如降水、云图等)预测未来的天气情况。

2. 算法原理

2.1 GRU基础

在介绍ConvGRU之前,首先让我们回顾一下什么是门控循环单元(GRU)。GRU是一种特殊的循环神经网络(RNN),它通过引入门控机制来解决传统RNN在长序列训练中面临的梯度消失和爆炸问题。GRU单元主要包含两个门:重置门和更新门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

GRU的核心公式如下:

  • 重置门
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

  • 更新门
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)

  • 候选状态
    h ~ t = tanh ⁡ ( W h ⋅ [ r t ∗ h t − 1 , x t ] + b h ) \tilde{h}_t = \tanh(W_h \cdot [r_t * h_{t-1}, x_t] + b_h) h~t=tanh(Wh[rtht1,xt]+bh)

  • 最终状态
    h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_t = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_t ht=(1zt)ht1+zth~t

这里, h t h_t ht 是当前的隐藏状态, x t x_t xt 是当前的输入。

2.2 ConvGRU原理

ConvGRU在GRU的基础上引入了卷积操作。与ConvLSTM类似,ConvGRU使用卷积层来处理空间数据,从而能够更好地捕捉输入数据中的空间特征。

ConvGRU结构图

没找到ConvGRU的图,和LSTM道理一样的

2.2.1 ConvGRU的结构

ConvGRU的单元结构与GRU非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvGRU的每个门的公式可以表示为:

z t = σ ( W z ∗ X t + U z ∗ H t − 1 + b z ) z_t = \sigma (W_{z} * X_t + U_{z} * H_{t-1} + b_z) zt=σ(WzXt+UzHt1+bz)
r t = σ ( W r ∗ X t + U r ∗ H t − 1 + b r ) r_t = \sigma (W_{r} * X_t + U_{r} * H_{t-1} + b_r) rt=σ(WrXt+UrHt1+br)
h ~ t = tanh ⁡ ( W h ∗ X t + U h ∗ ( r t ∗ H t − 1 ) + b h ) \tilde{h}_t = \tanh(W_{h} * X_t + U_{h} * (r_t * H_{t-1}) + b_h) h~t=tanh(WhXt+Uh(rtHt1)+bh)
h t = ( 1 − z t ) ∗ H t − 1 + z t ∗ h ~ t h_t = (1 - z_t) * H_{t-1} + z_t * \tilde{h}_t ht=(1zt)Ht1+zth~t

这里的所有 W W W U U U都是卷积权重, b b b是偏置项, σ \sigma σ 是 sigmoid 函数, tanh ⁡ \tanh tanh 是双曲正切函数。

ConvGRU结构图

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。

  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。

  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 GRU与ConvGRU的对比分析

特性GRUConvGRU
输入类型一维序列三维数据(时序的图像数据)
处理方式全连接层卷积操作
空间特征捕捉较弱较强
应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvGRU的应用

ConvGRU在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是一个简单的ConvGRU的网络完整代码:

import os
import torch
from torch import nn
from torch.autograd import Variable


class ConvGRUCell(nn.Module):
    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias, dtype):
        """
        初始化卷积 GRU 单元。

        :param input_size: (int, int)
            输入张量的高度和宽度作为 (height, width)。
        :param input_dim: int
            输入张量的通道数。
        :param hidden_dim: int
            隐藏状态的通道数。
        :param kernel_size: (int, int)
            卷积核的大小。
        :param bias: bool
            是否添加偏置项。
        :param dtype: torch.cuda.FloatTensor 或 torch.FloatTensor
            是否使用 CUDA。
        """
        super(ConvGRUCell, self).__init__()
        self.height, self.width = input_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.hidden_dim = hidden_dim
        self.bias = bias
        self.dtype = dtype

        # 定义用于计算更新门和重置门的卷积层
        self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim,
                                    out_channels=2 * self.hidden_dim,  # 用于更新门和重置门
                                    kernel_size=kernel_size,
                                    padding=self.padding,
                                    bias=self.bias)

        # 定义用于计算候选神经记忆的卷积层
        self.conv_can = nn.Conv2d(in_channels=input_dim + hidden_dim,
                                  out_channels=self.hidden_dim,  # 用于候选神经记忆
                                  kernel_size=kernel_size,
                                  padding=self.padding,
                                  bias=self.bias)

    def init_hidden(self, batch_size):
        """
        初始化隐藏状态。

        :param batch_size: int
            批次大小。
        :return: Variable
            隐藏状态。
        """
        return Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).type(self.dtype)

    def forward(self, input_tensor, h_cur):
        """
        前向传播函数。

        :param input_tensor: (b, c, h, w)
            输入张量实际上是目标模型。
        :param h_cur: (b, c_hidden, h, w)
            当前的隐藏状态。
        :return: h_next
            下一个隐藏状态。
        """
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv_gates(combined)

        # 分割卷积输出以获取更新门和重置门
        gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1)
        reset_gate = torch.sigmoid(gamma)
        update_gate = torch.sigmoid(beta)

        # 使用重置门乘以当前隐藏状态
        combined = torch.cat([input_tensor, reset_gate * h_cur], dim=1)
        cc_cnm = self.conv_can(combined)
        cnm = torch.tanh(cc_cnm)

        # 更新隐藏状态
        h_next = (1 - update_gate) * h_cur + update_gate * cnm
        return h_next


class ConvGRU(nn.Module):
    def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
                 dtype, batch_first=False, bias=True, return_all_layers=False):
        """
        初始化卷积 GRU 模型。

        :param input_size: (int, int)
            输入张量的高度和宽度作为 (height, width)。
        :param input_dim: int
            输入张量的通道数。
        :param hidden_dim: int
            隐藏状态的通道数。
        :param kernel_size: (int, int)
            卷积核的大小。
        :param num_layers: int
            卷积 GRU 层的数量。
        :param dtype: torch.cuda.FloatTensor 或 torch.FloatTensor
            是否使用 CUDA。
        :param batch_first: bool
            如果数组的第一个位置是批次。
        :param bias: bool
            是否添加偏置项。
        :param return_all_layers: bool
            是否返回所有层的隐藏状态。
        """
        super(ConvGRU, self).__init__()

        # 确保 kernel_size 和 hidden_dim 的长度与层数一致
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('不一致的列表长度。')

        self.height, self.width = input_size
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.dtype = dtype
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            # 确定当前层的输入维度
            cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
            # 创建并添加卷积 GRU 单元到列表
            cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
                                         input_dim=cur_input_dim,
                                         hidden_dim=self.hidden_dim[i],
                                         kernel_size=self.kernel_size[i],
                                         bias=self.bias,
                                         dtype=self.dtype))

        # 将 Python 列表转换为 PyTorch 模块
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        前向传播函数。

        :param input_tensor: (b, t, c, h, w) 或 (t, b, c, h, w)
            从 AlexNet 提取的特征。
        :param hidden_state:
            初始隐藏状态。
        :return: layer_output_list, last_state_list
            各个层的输出列表以及最后一个状态列表。
        """
        if not self.batch_first:
            # 如果不是按批次优先,则重新排列维度
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # 实现状态化的卷积 GRU
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # 初始化隐藏状态
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            h = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                # 计算当前层的下一个隐藏状态
                h = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
                                              h_cur=h)
                output_inner.append(h)

            # 将序列内的隐藏状态堆叠起来
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h])

        if not self.return_all_layers:
            # 如果不需要返回所有层,则只返回最后一层的输出和状态
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size):
        """
        初始化隐藏状态。

        :param batch_size: int
            批次大小。
        :return: list
            每一层的初始化隐藏状态列表。
        """
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        """
        检查 kernel_size 的一致性。

        :param kernel_size: tuple 或 list of tuples
            卷积核大小。
        """
        if not (isinstance(kernel_size, tuple) or
                    (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` 必须是 tuple 或 list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        """
        扩展参数以适应多层结构。

        :param param: int 或 list
            参数。
        :param num_layers: int
            层数。
        :return: list
            扩展后的参数列表。
        """
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1957648.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Halcon Blob分析

斑点分析的思路:在图像中,相关对象的像素可以通过其灰度值来识别。例如下图的组织颗粒。这些颗粒是凉的,而液体是暗的,通过选择明亮像素(阈值),可以很容易地检测到颗粒。在需要应用中,这种简单的暗像素和亮…

成像光学:LCD的工作原理与结构图解

一、主流显示面板技术:LCD,OLED,MicroLED 二、主流显示屏的发展趋势 三、LCD堆叠结构(以比较流行的TFT-LCD为例) 沿光路方向介绍:背光,下偏光片(polarizer),…

python实现图像分割算法2

python实现随机步行算法 随机步行算法数学模型Python 实现详细解释优缺点应用领域随机步行算法是一种常用于图像分割和图像分析的算法。它通过模拟随机游走来确定图像中每个像素的标签或类别。随机步行算法特别适合用于解决有种子标记的图像分割问题,其中用户提供一些初始标记…

【Python】基础语法(上)

本篇文章讲解以下知识: (1)初始编码 (2)输出 (3)初识数据类型 一:初识编码 在计算机中所有的数据本质上都是以0和1的组合来存储。 比如:在一个文件中有以下内容&am…

力扣SQL50 上级经理已离职的公司员工 一题双解

Problem: 1978. 上级经理已离职的公司员工 Code -- 方法 1 -- select e1.employee_id -- from employees e1 -- left join employees e2 -- on e1.manager_id e2.employee_id -- where e1.salary < 30000 -- and e1.manager_id is not null -- and e2.employee_id is…

SpringBoot 整合 Redis 实现验证码登录功能

一、整合Redis 在pom.xml中添加Redis相关依赖&#xff1b; <!--Spring Data Redis依赖配置--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>…

103.qt qml-最全Table新增下拉复制功能

在上篇文章102.qt qml-最全Table交互之多列固定、行列拖拽、自定义委托、标题交互使用教程_qt 表格控件 拖动列-CSDN博客 我们实现了大部分功能,所以本章实现下拉复制功能。 demo截图如下所示: 支持跨界复制,如果下拉的位置大于Table则会动画向下移动,具体可以参考视频链接…

颠覆未来计算!CRAM技术摒弃冯·诺依曼模型,20年研究终迎突破

未来科技&#xff1a;AI计算需求激增&#xff0c;数据中心耗电量堪比派对狂饮&#xff01;明尼苏达大学研究团队或携革命性设备&#xff0c;以惊人能效解决AI能耗难题&#xff01; 研究人员设计了一种新型的"计算随机存取存储器"&#xff08;CRAM&#xff09;原型芯…

查看路由表 netstat -r

“Kernel IP routing table” 是Linux系统中用于展示和配置IP路由的表。它告诉操作系统如何将数据包从一个网络接口发送到另一个网络或主机。下面是对您给出的路由表条目的解释&#xff1a; Destination&#xff1a;目的地地址&#xff0c;可以是具体的IP地址&#xff0c;也可…

Codeforces 962 div3 A-F

A 题目分析 签到 C代码 #include<iostream> using namespace std; int main(){int t;cin>>t;while(t--){int n;cin>>n;cout<<n/4n%4/2<<endl;} } B 题目分析 将n*n的方格分成若干个k*k的方格&#xff0c;每个k*k的方格中所有的数都相同 遍历…

小主机SSD固态硬盘选购攻略,希捷酷鱼 530 SSD固态硬盘表现优秀【附系统无损迁移教程】

小主机SSD固态硬盘选购攻略&#xff0c;希捷酷鱼 530 SSD固态硬盘表现优秀【附系统无损迁移教程】 哈喽小伙伴们好&#xff0c;我是Stark-C~ 这几年随着以零刻为首的小主机市场的兴起&#xff0c;小主机相关的配置周边需求也是越来越大&#xff0c;就比如说SSD固态硬盘就是其…

爬虫程序在采集亚马逊站点数据时如何绕过验证码限制?

引言 在电商数据分析中&#xff0c;爬虫技术的应用日益广泛。通过爬虫技术&#xff0c;我们可以高效地获取大量的电商平台数据&#xff0c;这些数据对于市场分析、竞争情报、价格监控等有着极其重要的意义。亚马逊作为全球最大的电商平台之一&#xff0c;是数据采集的重要目标…

Nacos-微服务注册中⼼(Nacos简介 Nacos配置管理)

目录 一、 微服务的注册中⼼ 1. 注册中⼼的主要作⽤ 2. 常⻅的注册中⼼ 二、Nacos简介 nacos实战⼊⻔ 1. 搭建nacos环境 2.将订单微服务注册到nacos 2.1 在pom.xml中添加nacos的依赖 2.2 在主类上添加EnableDiscoveryClient注解 2.3 在application.yml中添加nacos服…

如何在Linux上构建Raspberry Pi虚拟环境

目录 前置环境需求 Older Version 新版本启动 下面我们来讲讲如何使用QEMU来仿照树莓派环境。这里首先先分成两大类。第一类是跑比较老的&#xff0c;安全性较低的老树莓派&#xff0c;主要指代的是22年4月份发布之前的版本&#xff0c;这个版本当中&#xff0c;树莓派镜像自…

Layui表格合并、表格折叠树

1、核心代码&#xff1a; let tableMerge layui.tableMerge; // 引入合并的插件&#xff0c;插件源文件在最后let tableData [{pid: 0,cid: 111,sortNum: 1, // 序号pName: 数据父元素1,name: 数据1,val: 20,open: true, // 子树是否展开hasChild: true, // 有子数据opt: 数据…

昇思25天学习打卡营第1天 | 快速入门教程

昇思大模型平台&#xff0c;就像是AI学习者和开发者的超级基地&#xff0c;这里不仅提供丰富的项目、模型和大模型体验&#xff0c;还有一大堆经典数据集任你挑。 AI学习有时候就像找不到高质量数据集的捉迷藏游戏&#xff0c;而且本地跑大数据集训练模型简直是个折磨&#xf…

react css module 不生效问题记录

背景&#xff1a;自己使用webpackreactcssless配置的项目框架&#xff0c;在使用过程中发现css module引入不生效。 import React from react import styles from ./index.module.less console.log(styles)//输出 undefinedwebpack配置了css-loader,less-loader,webpack默认cs…

Linux系统之dns服务配置

要求&#xff1a;DNS服务器域解析 www. 11zzj.com为192.168.11.1; ftp.11zzj.com 为192.168.11.2; mail.11zzj.com 为172.16.11.20; 1.打开Linux6&#xff08;服务器&#xff09;和Linux5&#xff08;客户端&#xff09; 配置IP地址和DNS 地址&#xff0c;并ping通。…

PSINS工具箱函数介绍——kfinit

kfinit是kf的参数初始化函数&#xff0c;用于初始化滤波参数 本文所述的代码需要基于PSINS工具箱&#xff0c;工具箱的讲解&#xff1a; PSINS初学指导基于PSINS的相关程序设计&#xff08;付费专题&#xff09; 使用方法 kfinit这个函数的字面意思是&#xff1a;kf的初始化…

游戏制作中没想明白的事情

当一个备忘录&#xff0c;有的是还没有时间去深入研究&#xff0c;或者没有从头了解 什么是建模绑定&#xff1f;为什么人物建模&#xff0c;初始化都是双手打开的&#xff1f;平着放武器&#xff0c;但运行的时候武器会自动竖起来&#xff0c;这是怎么做到的&#xff1f; 思…