【Code Reading】Transformer in vision and video

news2024/10/6 14:29:51

文章目录

  • 1. vit
  • 2. Swin-t
  • 3. vit_3D
  • 4. TimeSformer First🚀🚀
  • 5. vivit

1. vit

详细解释

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16的外还有32x32的。其中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。

2. Swin-t

在这里插入图片描述

在这里插入图片描述

3. vit_3D

您将需要传递两个额外的超参数:
(1) 帧数frames 和(2) 沿帧维度的patch大小frame_patch_size

class ViT3D(nn.Module):
    def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size) # 128 128 
        patch_height, patch_width = pair(image_patch_size) # 16 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' # 16 2

        num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) # 每一个frame块中有多少块patch
        patch_dim = channels * patch_height * patch_width * frame_patch_size # 3*16*16*2=1536

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim), # -->1024 token dim(Hidden Size)
            nn.LayerNorm(dim),
        )
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

只有在num_patches和patch_embedding部分,多了一些关于frame的操作。将视频输入数据经过patch_embedding转化为token后,就没什么区别了。(包括后面加【cls】和position embedding)

4. TimeSformer First🚀🚀

第一篇使用纯Transformer结构在视频识别上的文章。

初探Video Transformer(一):抛弃CNN的纯Transformer视频理解框架—TimeSformer

在这里插入图片描述

在这里插入图片描述
5种方法的可视化:
在这里插入图片描述

self-attention和rnn计算复杂度的对比

自注意力的计算复杂度和 输入token的数量平方 成正比。
在这里插入图片描述

5. vivit

初探Video Transformer(二):谷歌开源更全面、高效的无卷积视频分类模型ViViT

为了让模型表现力更强,ViViT讨论了两方面的设计和思考(TimeSformer只考虑了模型结构设计):
Embedding video clips 和 Transformer Models for Video

在这里插入图片描述
TimeSformer:通过reshape,可以分别对2D和时间维度分别进行embedding。

  • Embedding video clips:

在这里插入图片描述
如何将输入的Video数据转化为tokens。(token_dim=N*d)
在这里插入图片描述
在这里插入图片描述


在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
与"Uniform frame sampling"相比,这种方法融合了时空信息。


  • Transformer Models for Video

在这里插入图片描述
这种实现方法可以理解为vit-3D。和TimeSformer的joint attention相似。
在这里插入图片描述

  • Model 2: Factorised encoder

在这里插入图片描述
在这里插入图片描述

  • Model 3: Factorised self-attention
    在这里插入图片描述
    在这里插入图片描述

在这里插入图片描述

  • Model 4: Factorised dot-product attention
    在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


在这里插入图片描述


  • Model-1
    model1就是标准的VIT结构,除了patchembeeding以外没有任何的改变,直接看vit-3d代码就可以了。

  • Model-2
    第一个Transformer:先对同一帧下的token进行交互interaction,产生每个时间索引下的latent representation。
    第二个Transformer:对time stepinx交互。相当于时间空间信息后融合。

self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
      nn.LayerNorm(patch_dim),
      nn.Linear(patch_dim, dim),
      nn.LayerNorm(dim)
  )

相比于VIT 3D的b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c),这里把时间维度单独抽出来。

在forward中:相比于之前的patch_embedding、cls_tokens 、pos_embedding。

    def forward(self, video):
        x = self.to_patch_embedding(video)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

model2中:
如果Transformer输入有【cls】,则分类。若无,则全局平均池化输出。

def forward(self, video):
    x = self.to_patch_embedding(video) # b f (h w) (p1 p2 pf c)
    b, f, n, _ = x.shape # n为一帧内有多少个patches,_为token—dim

    x = x + self.pos_embedding[:, :f, :n] # 先pos_embedding

    if exists(self.spatial_cls_token):
        spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
        x = torch.cat((spatial_cls_tokens, x), dim = 2) # 后对spatial_cls_tokens 进行添加

    x = self.dropout(x)

    x = rearrange(x, 'b f n d -> (b f) n d') # 融合时间维度

    # attend across space

    x = self.spatial_transformer(x)  # 进行空间的变换 输入输出维度不变(b f) n d

    x = rearrange(x, '(b f) n d -> b f n d', b = b)

    # excise out the spatial cls tokens or average pool for temporal attention

    x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')# [4, 8, 1024] 提取第一个cls,当作当前帧所有token的全局表示

    # append temporal CLS tokens

    if exists(self.temporal_cls_token):
        temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)

        x = torch.cat((temporal_cls_tokens, x), dim = 1)

    # attend across time

    x = self.temporal_transformer(x)

    # excise out temporal cls token or average pool

    x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean') # [4, 1024] 提取第一个cls,当作跨帧所有token的全局表示

    x = self.to_latent(x)
    return self.mlp_head(x)

  • Model-3
    model3的实现和TimeSformer的实现是一样的,去掉cls-token即可,可以参考TimeSformer的文章。

  • Model-4
    model4的实现与model1不同之处在于,transformer是有两个不同维度的attention 来进行计算的。
    代码以后填吧。


效果对比:

在这里插入图片描述

在这里插入图片描述
比较模型性能,这里Model2的temporal-transformer设置4层。model1的性能最好,但是FLOPs最大,运行时间最长。Model4没有额外的参数量,计算量比model1少很多,但是性能不高。Model3相比与其他的变体,需要更高的计算量和参数量。Model2表现最佳,精度尚可,计算量和运行时比较低。最后一行是单独做的实验,去掉了Model2的temporal transformer,直接在帧上做了pooling,EK上的精度下降很多,对于时序强的数据集需要用temporal transformer来做时序信息交互。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164640.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于MIMO通信系统的球形译码算法matlab性能仿真,对比PSK检测,SDR检测

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022A 3.部分核心程序 ................................................................ for i1:length(SNR) Bit…

SpringCloud(六) Nacos配置管理

Nacos除了可以做注册中心,同样可以做配置管理来使用; Nacos做注册中心的使用和注意事项详看:SpringCloud(四) Nacos注册中心-CSDN博客 目录 一, 统一配置管理 1. 在Nacos中添加配置文件 2. 从微服务拉取配置 (1) 引入nacos-config依赖 (2) 添加bootstrap.yaml (3) 读取n…

相机存储卡被格式化了怎么恢复?数据恢复办法分享!

随着时代的发展,相机被越来越多的用户所使用,这也意味着更多的用户面临着相机数据丢失的问题,很多用户在使用相机的过程中,都出现过不小心格式化相机存储卡的情况,里面的数据也将一并消失,相机存储卡被格式…

Android13充电动画实现

充电动画 以MTK平台为例,实现充电动画 效果图 修改文件清单 system/vendor/mediatek/proprietary/packages/apps/SystemUI/src/com/android/systemui/charging/BubbleBean.javasystem/vendor/mediatek/proprietary/packages/apps/SystemUI/src/com/android/system…

2023年【道路运输企业安全生产管理人员】考试资料及道路运输企业安全生产管理人员考试技巧

题库来源:安全生产模拟考试一点通公众号小程序 道路运输企业安全生产管理人员考试资料根据新道路运输企业安全生产管理人员考试大纲要求,安全生产模拟考试一点通将道路运输企业安全生产管理人员模拟考试试题进行汇编,组成一套道路运输企业安…

浅谈新能源汽车充电桩的选型与安装

叶根胜 安科瑞电气股份有限公司 上海嘉定201801 摘要:电动汽车的大力发展和推广是国家为应对日益突出的燃油供需矛盾和环境污染,加强生态环境保护和治理而开发新能源和清洁能源的措施之一,加快了电动汽车的发展。如今,电动汽车已…

Pure-Pursuit 跟踪双移线 Gazebo 仿真

Pure-Pursuit 跟踪双移线 Gazebo 仿真 主要参考学习下面的博客和开源项目 自动驾驶规划控制(A*、pure pursuit、LQR算法,使用c在ubuntu和ros环境下实现) https://github.com/NeXTzhao/planning Pure-Pursuit 的理论基础见今年六月…

如何在 Endless OS 上安装 ONLYOFFICE 桌面编辑器 7.5

ONLYOFFICE 桌面编辑器是一款基于依据 AGPL v.3 许可进行分发的开源办公套件。使用这款应用,您无需保持网络连接状态即可处理存储在计算机上的文档。本指南会向您介绍,如何在 Endless OS 上安装 ONLYOFFICE 桌面编辑器 7.5。 ONLYOFFICE 桌面版是什么 O…

Ansible中的任务执行控制

循环 简单循环 {{item}} 迭代变量名称 loop: - value1 - value2 - ... //赋值列表{{item}} //迭代变量名称循环散列或字典列表 - name: create filehosts: host1tasks:- name: file moudleservice:name: "{{ item.name }}"state: "{{…

实验记录之——git push

平时做开发的时候经常push代码不成功,如下图 经好友传授经验,有如下方法 Win cmd使用Clash(端口是7890)代理操作,在cmd中输入: set http_proxy127.0.0.1:7890 set https_proxy127.0.0.1:7890Linux export …

防火墙日志记录和分析

防火墙监控进出网络的流量,并保护部署防火墙的网络免受恶意流量的侵害。它是一个网络安全系统,它根据一些预定义的规则监控传入和传出的流量,它以日志的形式记录有关如何管理流量的信息,日志数据包含流量的源和目标 IP 地址、端口…

可视化流程编排(Bpmn.js)介绍及实践

作者:罗强 为将内部系统打通并规范流程定义,基于统一的平台实现工单自动化流转,从而使用无界山工单系统帮助公司内部人员统一管理和处理来自企业内部提交的工单需求。而在系统中流程编排及节点设计主要是使用bpmn.js实现精细化配置。从而满足各种复杂的业务需求。目…

MapStruct的用法

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、依赖导入二、简单使用2.1、定义转换的接口2.2、创建实体类2.3、测试2.4、底层实现 三、常用的注解3.1、Mapping(target "xxx1",source "x…

你还在用System.currentTimeMillis()打印代码执行时间?

文章目录 前言一、开发环境二、使用步骤1. 创建Springboot项目2. 引入hutool3. 使用TimeInterval 总结 前言 Hutool是一个小而全的Java工具类库,里面集成了很多实用的工具类,比如文件、流、加密解密、转码、正则、线程、XML等,通过这些工具类…

探索 Java 8 中的 Stream 流:构建流的多种方式

人嘛,要懂得避嫌… 开篇引入 Java 8引入了Stream流作为一项新的特性,它是用来处理集合数据的一种函数式编程方式。Stream流提供了一种更简洁、高效和易于理解的方法来操作集合数据,同时也能够实现并行处理,以提高性能。 以下是St…

在 GORM 中定义模型

为实现与数据库的无缝交互而打造有效模型的全面指南 在使用 GORM 进行数据库管理时,定义模型是基础。模型是您的应用程序的面向对象结构与数据库的关系世界之间的桥梁。本文深入探讨了在 GORM 中打造有效模型的艺术,探讨如何创建结构化的 Go 结构体&…

第十二章,集合类例题

例题1 package 例题;import java.util.*;public class 例题 {public static void main(String[] args) {// TODO Auto-generated method stub//实例化集合类对象Collection<String> list new ArrayList<>();//调用方法&#xff0c;向集合添加数据list.add("…

Java——java.time包使用方法详解

Java——time包使用方法详解 java.time 包是 Java 8 引入的新日期和时间 API&#xff08;JSR 310&#xff09;&#xff0c;用于替代旧的 java.util.Date 和 java.util.Calendar 类。它提供了一组全新的类来处理日期、时间、时间间隔、时区等&#xff0c;具有更好的设计和易用性…

什么是消息队列

什么是消息队列 消息队列是一种通信机制&#xff0c;用于在不同的应用程序或组件之间传递消息。它允许应用程序之间异步地发送和接收消息&#xff0c;而无需直接依赖彼此的可用性或性能。消息队列通常用于解耦不同组件&#xff0c;提高系统的可伸缩性和可维护性&#xff0c;以…

【Python入门一】Python及PyCharm安装教程

Python及PyCharm安装教程 1 Python简介1.1 Python下载及安装 2 PyCharm简介2.1 PyCharm下载及安装 参考 1 Python简介 Python是一种开源的高级编程语言&#xff0c;由Guido van Rossum于1991年创建。Python易于学习、阅读和编写&#xff0c;具有丰富的标准库和第三方模块&…