【Python/Pytorch - 网络模型】-- 手把手搭建E3D LSTM网络

news2025/4/25 10:46:42

在这里插入图片描述
文章目录

文章目录

  • 00 写在前面
  • 01 基于Pytorch版本的E3D LSTM代码
  • 02 论文下载

00 写在前面

测试代码,比较重要,它可以大概判断tensor维度在网络传播过程中,各个维度的变化情况,方便改成适合自己的数据集。

需要github上的数据集以及可运行的代码,可以私聊!

01 基于Pytorch版本的E3D LSTM代码

# 库函数调用
from functools import reduce
from src.utils import nice_print, mem_report, cpu_stats
import copy
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F

# E3DLSTM模型代码
class E3DLSTM(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau):
        super().__init__()

        self._tau = tau
        self._cells = []

        input_shape = list(input_shape)
        for i in range(num_layers):
            cell = E3DLSTMCell(input_shape, hidden_size, kernel_size)
            # NOTE hidden state becomes input to the next cell
            input_shape[0] = hidden_size
            self._cells.append(cell)
            # Hook to register submodule
            setattr(self, "cell{}".format(i), cell)

    def forward(self, input):
        # NOTE (seq_len, batch, input_shape)

        batch_size = input.size(1)
        c_history_states = []
        h_states = []
        outputs = []

        for step, x in enumerate(input):
            for cell_idx, cell in enumerate(self._cells):
                if step == 0:
                    c_history, m, h = self._cells[cell_idx].init_hidden(
                        batch_size, self._tau, input.device
                    )
                    c_history_states.append(c_history)
                    h_states.append(h)

                # NOTE c_history and h are coming from the previous time stamp, but we iterate over cells
                c_history, m, h = cell(
                    x, c_history_states[cell_idx], m, h_states[cell_idx]
                )
                c_history_states[cell_idx] = c_history
                h_states[cell_idx] = h
                # NOTE hidden state of previous LSTM is passed as input to the next one
                x = h

            outputs.append(h)

        # NOTE Concat along the channels
        return torch.cat(outputs, dim=1)


class E3DLSTMCell(nn.Module):
    def __init__(self, input_shape, hidden_size, kernel_size):
        super().__init__()

        in_channels = input_shape[0]
        self._input_shape = input_shape
        self._hidden_size = hidden_size

        # memory gates: input, cell(input modulation), forget
        self.weight_xi = ConvDeconv3d(in_channels, hidden_size, kernel_size)
        self.weight_hi = ConvDeconv3d(hidden_size, hidden_size, kernel_size, bias=False)

        self.weight_xg = copy.deepcopy(self.weight_xi)
        self.weight_hg = copy.deepcopy(self.weight_hi)

        self.weight_xr = copy.deepcopy(self.weight_xi)
        self.weight_hr = copy.deepcopy(self.weight_hi)

        memory_shape = list(input_shape)
        memory_shape[0] = hidden_size

        # self.layer_norm = nn.LayerNorm(memory_shape)
        self.group_norm = nn.GroupNorm(1, hidden_size) # wzj

        # for spatiotemporal memory
        self.weight_xi_prime = copy.deepcopy(self.weight_xi)
        self.weight_mi_prime = copy.deepcopy(self.weight_hi)

        self.weight_xg_prime = copy.deepcopy(self.weight_xi)
        self.weight_mg_prime = copy.deepcopy(self.weight_hi)

        self.weight_xf_prime = copy.deepcopy(self.weight_xi)
        self.weight_mf_prime = copy.deepcopy(self.weight_hi)

        self.weight_xo = copy.deepcopy(self.weight_xi)
        self.weight_ho = copy.deepcopy(self.weight_hi)
        self.weight_co = copy.deepcopy(self.weight_hi)
        self.weight_mo = copy.deepcopy(self.weight_hi)

        self.weight_111 = nn.Conv3d(hidden_size + hidden_size, hidden_size, 1)

    def self_attention(self, r, c_history):
        batch_size = r.size(0)
        channels = r.size(1)
        r_flatten = r.view(batch_size, -1, channels)
        # BxtaoTHWxC
        c_history_flatten = c_history.view(batch_size, -1, channels)

        # Attention mechanism
        # BxTHWxC x BxtaoTHWxC' = B x THW x taoTHW
        scores = torch.einsum("bxc,byc->bxy", r_flatten, c_history_flatten)
        attention = F.softmax(scores, dim=2)

        return torch.einsum("bxy,byc->bxc", attention, c_history_flatten).view(*r.shape)

    def self_attention_fast(self, r, c_history):
        # Scaled Dot-Product but for tensors
        # instead of dot-product we do matrix contraction on twh dimensions
        scaling_factor = 1 / (reduce(operator.mul, r.shape[-3:], 1) ** 0.5)
        scores = torch.einsum("bctwh,lbctwh->bl", r, c_history) * scaling_factor

        attention = F.softmax(scores, dim=0)
        return torch.einsum("bl,lbctwh->bctwh", attention, c_history)

    def forward(self, x, c_history, m, h):
        # Normalized shape for LayerNorm is CxT×H×W
        normalized_shape = list(h.shape[-3:])

        def LR(input):
            # return F.layer_norm(input, normalized_shape)
            return self.group_norm(input, normalized_shape) # wzj

        # R is CxT×H×W
        r = torch.sigmoid(LR(self.weight_xr(x) + self.weight_hr(h)))
        i = torch.sigmoid(LR(self.weight_xi(x) + self.weight_hi(h)))
        g = torch.tanh(LR(self.weight_xg(x) + self.weight_hg(h)))

        recall = self.self_attention_fast(r, c_history)
        # nice_print(**locals())
        # mem_report()
        # cpu_stats()

        c = i * g + self.group_norm(c_history[-1] + recall) # wzj

        i_prime = torch.sigmoid(LR(self.weight_xi_prime(x) + self.weight_mi_prime(m)))
        g_prime = torch.tanh(LR(self.weight_xg_prime(x) + self.weight_mg_prime(m)))
        f_prime = torch.sigmoid(LR(self.weight_xf_prime(x) + self.weight_mf_prime(m)))

        m = i_prime * g_prime + f_prime * m
        o = torch.sigmoid(
            LR(
                self.weight_xo(x)
                + self.weight_ho(h)
                + self.weight_co(c)
                + self.weight_mo(m)
            )
        )
        h = o * torch.tanh(self.weight_111(torch.cat([c, m], dim=1)))

        # TODO is it correct FIFO?
        c_history = torch.cat([c_history[1:], c[None, :]], dim=0)
        # nice_print(**locals())

        return (c_history, m, h)

    def init_hidden(self, batch_size, tau, device=None):
        memory_shape = list(self._input_shape)
        memory_shape[0] = self._hidden_size
        c_history = torch.zeros(tau, batch_size, *memory_shape, device=device)
        m = torch.zeros(batch_size, *memory_shape, device=device)
        h = torch.zeros(batch_size, *memory_shape, device=device)

        return (c_history, m, h)


class ConvDeconv3d(nn.Module):
    def __init__(self, in_channels, out_channels, *vargs, **kwargs):
        super().__init__()

        self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)
        # self.conv_transpose3d = nn.ConvTranspose3d(out_channels, out_channels, *vargs, **kwargs)

    def forward(self, input):
        # print(self.conv3d(input).shape, input.shape)
        # return self.conv_transpose3d(self.conv3d(input))
        return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")

class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

class E3DLSTM_NET(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau, time_steps, output_shape):
        super().__init__()

        self.input_shape = input_shape
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.kernel_size = kernel_size
        self.tau = tau
        self.time_steps = time_steps
        self.output_shape = output_shape
        self.dtype = torch.float32

        self.encoder = E3DLSTM(
            input_shape, hidden_size, num_layers, kernel_size, tau
        ).type(self.dtype)
        self.decoder = nn.Conv3d(
            hidden_size * time_steps, output_shape[0], kernel_size, padding=(0, 2, 2)
        ).type(self.dtype)
        self.out = Out(4, 1)

    def forward(self, input_seq):
        return self.out(self.decoder(self.encoder(input_seq)))

# 测试代码
if __name__ == '__main__':
    input_shape = (16, 4, 16, 16)
    output_shape = (16, 1, 16, 16)

    tau = 2
    hidden_size = 64
    kernel = (3, 5, 5)
    lstm_layers = 4
    time_steps = 29

    x = torch.ones([29, 2, 16, 4, 16, 16])
    model = E3DLSTM_NET(input_shape, hidden_size, lstm_layers, kernel, tau, time_steps, output_shape)
    print('finished!')
    f = model(x)
    print(f)

02 论文下载

Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Github链接:e3d_lstm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1823275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue37-非单文件组件

一、组件的两种编写形式: 非单文件组件;单文件组件。 二、创建一个组件 2-1、组件中的el 组件中不写el,不说为谁服务。 2-2、组件中的data 因为对象形式,多处复用的话,有引用关系,改一处,另一…

6月14日 Qtday2

#include "widget.h" #include "ui_widget.h" #include <QTimer> using namespace std; Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget), lab1(new QLabel(this)) //初始化一个标签显示登录状态 {//设置华清远见的标签图…

基于Django、Bootstrap的电影推荐系统,算法基于用户的协同过滤算法,有爬虫有可视化后台

背景 基于Django和Bootstrap的电影推荐系统结合了用户协同过滤算法&#xff0c;通过爬虫技术获取电影数据&#xff0c;并在可视化后台展示推荐结果。该系统旨在提供个性化的电影推荐服务&#xff0c;帮助用户发现符合其喜好的电影。 用户协同过滤算法是一种常用的推荐算法&am…

JavaSE---类和对象(上)

1. 面向对象的初步认知 1.1 什么是面向对象 Java是一门纯面向对象的语言(Object Oriented Program&#xff0c;简称OOP)&#xff0c;在面向对象的世界里&#xff0c;一切皆为对象。 面向对象是解决问题的一种思想&#xff0c;主要依靠对象之间的交互完成一件事情。用面向对象…

linux的UDP广播测试:C语言代码

测试代码 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/types.h> #include <sys/socket.h> #include <netinet/in.h> #include <arpa/inet.h> #include <netdb.h>#…

信息系统架构风格-系统架构师(十)

1、信息系统架构风格是描述特定应用领域中系统组织方式的惯用模式。架构风格定义了一个系统家族&#xff0c;即一个架构定义&#xff08;&#xff09;。 A一组设计原则 B一组模式 C一个词汇表和一组约束 D一组最佳实践 解析&#xff1a; 信息系统架构风格是描述某一特定 应…

014基于SSM+Jsp的网络视频播放器

开发语言&#xff1a;Java框架&#xff1a;ssm技术&#xff1a;JSPJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包…

【教学类-36-08】20240612动物面具(通义万相)-A4大小2图扇子

背景需求&#xff1a; 【教学类-36-07】20240608动物面具&#xff08;通义万相&#xff09;-A4大小7图&15手工纸1图-CSDN博客文章浏览阅读1.1k次&#xff0c;点赞45次&#xff0c;收藏27次。【教学类-36-07】20240608动物面具&#xff08;通义万相&#xff09;-A4大小7图&…

【TB作品】MSP430 G2553 单片机 口袋板 日历 时钟 闹钟 万年历 电子时钟 秒表显示

文章目录 功能介绍操作方法部分流程图代码录制了一个演示视频可以下载观看 功能介绍 时间与日期显示&#xff1a; 实时显示当前时间&#xff08;小时、分钟、秒&#xff09;和日期&#xff08;年、月、日&#xff09;。 闹钟功能&#xff1a; 设置闹钟时间&#xff08;小时、分…

全面解说Facebook代投菲律宾真金游戏pwa广告全流程

全面解说Facebook代投菲律宾真金游戏pwa广告全流程 随着数字营销的不断发展&#xff0c;社交媒体平台如Facebook已成为广告主们争相投放的热门渠道。对于希望拓展菲律宾市场的真金游戏企业来说&#xff0c;了解并掌握在Facebook上投放广告的具体流程显得尤为重要。本文将详细介…

每天五分钟深度学习:逻辑回归算法完成m个样本的梯度下降

本文重点 上节课程我们学习了单样本逻辑回归算法的梯度下降,实际使用中我们肯定是m个样本的梯度下降,那么m个样本的如何完成梯度下降呢? m个样本的损失函数定义为: 我们定义第i个样本的dw、db为: dw和db为损失J对w和b的偏导数,因为m个样本的代价函数J是1到m个样本总损失…

Office 2021 mac/win版:智慧升级,办公新风尚

Office 2021是微软公司推出的一款高效、智能且功能丰富的办公软件套件。它集成了Word、Excel、PowerPoint等多个经典应用程序&#xff0c;旨在为用户提供更出色的办公体验。 Office 2021 mac/win版获取 Office 2021在继承了前代版本优点的基础上&#xff0c;进行了大量的优化…

VirtualBox、Centos7下安装docker后pull镜像问题

Docker安装篇(CentOS7安装)_docker 安装 centos7-CSDN博客 首先&#xff0c;安装docker可以根据这篇文章进行安装&#xff0c;安装完之后&#xff0c;我们就需要去通过docker拉取相关的服务镜像&#xff0c;然后安装相应的服务容器&#xff0c;比如我们通过docker来安装mysql,…

MySQL中的正排/倒排索引和DoubleWriteBuffer

正排/倒排索引 正排索引 文档1&#xff1a;词条A&#xff0c;词条B&#xff0c;词条C 文档2&#xff1a;词条A&#xff0c;词条D 文档3&#xff1a;词条B&#xff0c;词条C&#xff0c;词条E正排表是以文档的ID为关键字&#xff0c;表中记录文档中的每个字的位置信息&#xff…

RC滤波器及截止频率推导

1、RC滤波器 低通滤波电路 高通滤波电路 电容容抗公式 宏观分析&#xff1a;不管是高通还是低通滤波电路&#xff0c;说白了就是一个电容和一个电阻构成了一个分压电路&#xff0c;区别在于电容位置不同。低通滤波电阻在前电容在后&#xff0c;信号频率越高&#xff0c;容抗越…

观测云产品更新 | BPF 网络日志、智能监控、告警策略等

观测云更新 网络 BPF 网络日志&#xff1a;优化 BPF 网络功能&#xff0c;增强 L4/L7 网络联动。 APM/RUM APM/RUM&#xff1a;新增 【Issue 自动发现】功能。启用该配置后&#xff0c;观测云会将符合配置项规则的错误数据记录自动创建 Issue。 监控 1、智能监控&#xff1…

【数据结构初阶】--- 双链表

双链表你要了解的 双链表&#xff1a; 带头双链表循环双链表带头循环双链表 今天我们学习的&#xff0c;就是最强双链表 — 带头循环双链表 下图就是带头循环双链表&#xff0c;所谓带头&#xff0c;就是有哨兵位&#xff0c;那么最左边的节点就是哨兵位&#xff0c;是不计入…

银河麒麟系统项目部署

使用服务器信息 软件&#xff1a;VMware Workstation Pro 虚拟机&#xff1a;ubtun 内存&#xff1a;20G 虚拟机连接工具&#xff1a; MobaXterm Redis连接工具&#xff1a; RedisDesktopManager 镜像&#xff1a;F:\Kylin-Server-10-8.2-Release-Build09-20211104-X86_64…

企业都在用的镭速是如何做到高效的大文件传输的

在当前这个数字化快速发展的时代&#xff0c;文件传输速度和安全性已经成为企业运营的关键。镭速&#xff0c;作为众多企业都在用的专业大文件传输工具&#xff0c;是如何实现高效、安全的文件传输呢&#xff1f;接下来&#xff0c;我们就来深入探讨一下镭速的传输策略及其在不…

超级签名源码/超级签/ios分发/签名端本地linux服务器完成签名

该系统完全在linux下运行&#xff0c;不存在使用第三方收费工具&#xff0c;市面上很多系统都是使用的是第三方收费系统&#xff0c;例如&#xff1a;某心签名工具&#xff0c;某测侠等&#xff0c;不开源而且需要每年交费&#xff0c;这种系统只是在这些工具的基础上套了一层壳…