Time-distributed 的理解

news2024/9/23 11:28:32

前言

今天看到论文中用到 Time-distributed CNN,第一次见到 Time-distributed,不理解是什么含义,看到代码实现也很懵。不管什么网络结构,外面都能套一个TimeDistributed。看了几个博客,还是不明白,问了问ChatGPT,终于明白了一点。

解释

直观理解

以下是ChatGPT的解释

“Time-distributed” 是一种用于深度学习处理序列数据的技术,它将神经网络中的层或网络独立地应用于序列的每个时间步长。在典型的前馈神经网络中,输入数据会被馈送到网络中,并且相同的权重会被应用于所有的输入特征。但是,当处理序列数据,如时间序列或自然语言时,我们需要在每个时间步长上应用相同的权重来捕捉时间信息。

“Time-distributed” 技术允许我们独立地应用每个时间步长上的层或网络。这可以通过在 Keras 或 PyTorch 中将层或网络包装在 “TimeDistributed” 层中来实现。

例如,如果我们有一个形状为 (batch_size, timesteps, input_dim) 的三维张量作为输入,应用一个具有 10 个单位的 “TimeDistributed” 密集层将产生一个形状为 (batch_size, timesteps, 10) 的三维张量作为输出。这个包装器可以用于任何模块,例如卷积层、循环神经网络层、全连接层等。 “Time-distributed” 层将相同的密集层应用于每个时间步长,从而使网络能够学习数据中的时间模式。

“Time-distributed” 层通常用于序列到序列模型中,如语言翻译或语音识别,其中输入和输出都是序列。

代码实现角度理解

考虑这样一个问题,将原来代码中的 TimeDistributed 去掉会发生什么?

全连接层

对于全连接层,如果没有 TimeDistributed,代码照样能跑。

import torch
import torch.nn as nn

input = torch.randn(5, 3, 10)  # 时间步数是5,batch_size是3,每个时间步的特征维度是10
model = nn.Linear(10, 5)
output = model(input)
print(output.shape)

输出:torch.Size([5, 3, 5])

如果将输入改为 input = torch.randn(5, 3, 2, 2, 10)
输出 torch.Size([5, 3, 2, 2, 5])

可以看到,不管输入有多少维度,都能正常输出。

在这里插入图片描述
从官方文档也可以看到,输入 * 可以是任意维度。

卷积层

对于全连接层,如果没有 TimeDistributed,代码就会报错。

import torch
import torch.nn as nn

input = torch.randn(5, 3, 3, 256, 256)  # 时间步数是5,batch_size是3,通道数是3,图片高宽都是256
model = nn.Conv2d(3, 16, kernel_size=3)  # 输入通道是3,输出通道是16,kernel_size=3
output = model(input)
print(output.shape)

报错信息

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [5, 3, 3, 256, 256]

可以看到维度不匹配。如果把时间维度去掉,则可以正常输出。

import torch
import torch.nn as nn

input = torch.randn(3, 3, 256, 256)  # batch_size是3,通道数是3,图片高宽都是256
model = nn.Conv2d(3, 16, kernel_size=3)  # 输入通道是3,输出通道是16,kernel_size=3
output = model(input)
print(output.shape)

输出:torch.Size([3, 16, 254, 254])

因此如果我想用带时间步数的图片做卷积,那就无法实现了,如何解决这个问题呢?就要用到 Time-distributed。

增加 TimeDistributed 的代码

import torch
import torch.nn as nn

input = torch.randn(5, 3, 3, 256, 256)  # 时间步数是5,batch_size是3,通道数是3,图片高宽都是256
model = TimeDistributed(nn.Conv2d(3, 16, kernel_size=3))  # 输入通道是3,输出通道是16,kernel_size=3
output = model(input)
print(output.shape)

输出:torch.Size([5, 3, 16, 254, 254])

具体做法就是让每个时间步的数据也做卷积。

TimeDistributed 代码实现

在 keras 中,自带 TimeDistributed。在 pytorch 中没有实现,需要自己手动实现。

以下是 pytorch 版本的 TimeDistributed 代码实现。

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def __multi_time(self, size):
        size_temp = list(size)
        size_temp = [size_temp[0] * size_temp[1]] + size_temp[2:]
        return tuple(size_temp)

    def __dist_time(self, size, batch, time_dim):
        size_temp = list(size)
        size_temp = [batch, time_dim] + size_temp[1:]
        return tuple(size_temp)

    def forward(self, x):
        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(self.__multi_time(x.size()))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        y = y.contiguous().view(self.__dist_time(y.size(), x.size(0), x.size(1)))  # (samples, timesteps, output_size)

        return y

chatGPT 对上面代码的解释

这段代码实现了一个 PyTorch 的自定义模块 TimeDistributed,可以将一个其它模块(module)作为输入,将其应用于输入张量(x)的每个时间步上,输出一个相同形状的张量。

具体来说,该模块的 forward 方法接收一个形状为 (samples, timesteps, input_size) 的输入张量 x,其中 samples 表示样本数量,timesteps 表示时间步数,input_size 表示每个时间步的特征维度。然后,该方法将输入张量进行形状变换,将 samples 和 timesteps 两个维度合并为一个维度,变成形状为 (samples * timesteps, input_size) 的张量,即将样本和时间步的信息压缩在了一个维度中。接下来,该方法将这个压缩后的张量输入到 module 模块中,并获取输出张量 y。最后,该方法再将输出张量进行形状变换,将第一个维度还原成 samples,将第二个维度还原成 timesteps,即将样本和时间步的信息重新拆分开,最终输出形状为 (samples, timesteps, output_size) 的张量,其中 output_size 表示每个时间步的输出特征维度,应该与 module 模块的输出维度相同。

此外,该模块还包含两个辅助方法 __multi_time 和 __dist_time,分别用于实现输入和输出张量形状的转换。其中,__multi_time 方法将形状为 (samples, timesteps, input_size) 的张量转换为形状为 (samples * timesteps, input_size) 的张量,即将样本和时间步的信息压缩在一个维度中;__dist_time 方法则将形状为 (samples * timesteps, output_size) 的张量转换为形状为 (samples, timesteps, output_size) 的张量,即将样本和时间步的信息重新拆分开。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/357060.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python数据挖掘基础

一、Matplotlib 画二维图表的python库,实现数据可视化 , 帮助理解数据,方便选择更合适的分析方法1、折线图1.1引入matplotlibimport matplotlib.pyplot as plt %matplotlib inlineplt.figure() plt.plot([1, 0, 9], [4, 5, 6]) plt.show()1.2…

知识探索项目测试报告

⭐️前言⭐️ 本篇文章是博主基于知识探索项目所做的测试报告,主要涉及到的测试知识有设计测试用例、自动化测试等测试知识。 🍉欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍉博主将持续更新学习记录收获…

基于springboot+vue的药物咨询平台

基于springbootvue的药物咨询平台 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍&…

二阶段提交事务的实现和缺点

背景 说起分布式事务,我们最绕不开的一个话题就是该不该使用分布式事务,而要理解为什么做出使用与否的决定,就必须要提到分布式事务中的最经典的实现:两阶段提交事务,本文我们就简答介绍下这个两阶段提交事务以及它的优缺点 技术…

【Opencv 系列】 第6章 人脸检测(Haar/dlib) 关键点检测

本章内容 1.人脸检测,分别用Haar 和 dlib 目标:确定图片中人脸的位置,并画出矩形框 Haar Cascade 哈尔级联 核心原理 (1)使用Haar-like特征做检测 (2)Integral Image : 积分图加速特征计算 …

SpringSecurity的权限校验详解说明(附完整代码)

说明 SpringSecurity的权限校是基于SpringSecurity的安全认证的详解说明(附完整代码) (https://blog.csdn.net/qq_51076413/article/details/129102660)的讲解,如果不了解SpringSecurity是怎么认证,请先看下【SpringSecurity的安…

【1792. 最大平均通过率】

来源:力扣(LeetCode) 描述: 一所学校里有一些班级,每个班级里有一些学生,现在每个班都会进行一场期末考试。给你一个二维数组 classes ,其中 classes[i] [passi, totali] ,表示你…

0xL4ugh 2023

这回跟着个队伍跑,不过还是2X以后的成绩,前边太卷了。自己会的部分,有些是别人已经提交了的。记录一下。Cryptocrypto 1给了一些数据,像这样就没有别的了ct [0, 1, 1, 2, 5, 10, 20, 40, 79, 159, 317, 635, 1269, 2538, 5077, 1…

2023.02.19 学习周报

文章目录摘要文献阅读1.题目2.摘要3.介绍4.本文贡献5.方法5.1 Local Representation Learning5.2 Global Representation Learning5.3 Item Similarity Gating6.实验6.1 数据集6.2 结果7.结论深度学习1.对偶问题1.1 拉格朗日乘数法1.2 强对偶性2.SVM优化3.软间隔3.1 解决问题3.…

尚医通 (十八)微信登录

目录一、生成微信登录二维码1、准备工作2、后端开发service_user3、前端显示登录二维码4、二维码出现不了进行调试二、开发微信扫描回调1、准备工作2、后台开发3、前台开发三、分析代码四、bug一、生成微信登录二维码 1、准备工作 1、注册 2、邮箱激活 3、完善开发者资料 4、…

JSP中http与内置对象学习笔记

本博文讲述jsp客户端与服务器端的http、jsp内置对象与控制流和数据流实现 1.HTTP请求响应机制 HTTP协议是TCP/IP协议中的一个应用层协议,用于定义客户端与服务器之间交换数据的过程 1.1 HTTP请求 HTTP请求由请求行、消息报头、空行和请求数据4部分组成。 请求行…

ThreeJS 之界面控制

文章目录参考描述界面自适应问题resize 事件修改画布大小修改视锥体的宽高比全屏显示dblclick 事件检测全屏显示状态进入全屏显示状态退出全屏显示状态尾声参考 项目描述ThreeJS官方文档哔哩哔哩老陈打码搜索引擎BingMDN 文档document.mozFullScreenElementMDN 文档Element.re…

LeetCode题目笔记——6359. 替换一个数字后的最大差值

文章目录题目描述题目链接题目难度——简单方法一:替换代码/Python代码优化总结题目描述 给你一个整数 num 。你知道 Danny Mittal 会偷偷将 0 到 9 中的一个数字 替换 成另一个数字。 请你返回将 num 中 恰好一个 数字进行替换后,得到的最大值和最小值…

CTK学习:(一)编译CTK

CTK插件框架简介 CTK Plugin Framework是用于C++的动态组件系统,以OSGi规范为模型。在此框架下,应用程序由不同的组件组成,遵循面向服务的方法。 ctk是一个开源项目,Github 地址:https://github.com/commontk。 源码地址commontk/CTK: A set of common support code for…

信小程序点击按钮绘制定制转发分享图

1. 说明 先上代码片断分享链接: https://developers.weixin.qq.com/s/vl3ws9mA72GG 使用 painter 画图 按钮传递定制化信息 效果如下: 2. 关键代码说明 文件列表如下: {"usingComponents": {"painter": "/com…

基于springboot的停车场管理系统(程序+文档)

大家好✌!我是CZ淡陌。将再这里为大家分享优质的实战项目,本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目,希望你能有所收获,少走一些弯路。 🍅更多优质项目👇&#x1f…

Android实例仿真之二

目录 三 从无入手 第一阶段 第二阶段 第三阶段 第四阶段 第五阶段 第六阶段 第七阶段 八 举两个典型例子: 九 逆向工程 三 从无入手 这节标题叫从无入手,什么意思呢?如果没有Android这个实例存在,你要做一个类似Android…

Mysql数据库事务

数据库事务 数据库事务由一组sql语句组成。 所有sql语句执行成功则事务整体成功;任一条sql语句失败则事务整体失败,数据恢复到事务之前的状态。 Mysql 事务操作 开始事务 start transaction;- 或 begin;事务开始后,对数据的增删改操作不…

MySQL最佳实践

一、MySQL查询执行过程 1.MySQL分层结构 MySQL8.0没有查询缓存的功能了,如果频繁修改缓存,将会损耗性能查询流程就按照分层结构就可以清楚,只要了解各个组件的各自功能就行分析器主要分析语法和词法是否正确优化器主要优化SQL语句 二、MySQL更新执行过程 更新主要涉及两个重…

SpringCloud - Ribbon负载均衡

目录 负载均衡流程 负载均衡策略 Ribbon加载策略 负载均衡流程 Ribbon将http://userservice/user/1请求拦截下来,帮忙找到真实地址http://localhost:8081LoadBalancerInterceptor类对RestTemplate的请求进行拦截,然后从Eureka根据服务id获取服务列表&…