temporal shift module(TSM)

news2025/1/10 11:41:54

【官方】Paddle2.1实现视频理解经典模型 — TSM - 飞桨AI Studio本项目将带大家深入理解视频理解领域经典模型TSM。从模型理论讲解入手,深入到代码实践。实践部分基于TSM模型在UCF101数据集上从训练到推理全流程实现行为识别任务。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/2310889?channelType=0&channel=0视频理解:基于TSM实现UCF101视频理解 - 飞桨AI Studio基于飞桨开源框架构建TSM,并实现对数据集UCF101的视频理解。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/4114499?channelType=0&channel=0

最近一直在做视频相关的项目,后续会陆续出一些视频理解和视频场景运动的案例,视频这块主推paddlevideo,里面应用层面的东西很丰富,paddle在应用侧一直做的比较好,模型训练这块可以结合mmaction2来,其实从实际应用角度来说,我觉得用paddle和pytorch训练都无所谓,部署的话可能以往我的经验更多是onnx,tensort服务侧的,目前来看,主要也就是服务器,端侧和页面侧的部署这三块,我看paddle分别有paddle inference、lite、js,国产框架中确实是首屈一指的,但是我自己的感觉是从我以前训练gan的结果看,paddle貌似要比pytorch的结果,一样的数据,一样的参数配置,好像要差一点。本文主要介绍一下tsm模块,利用2dcnn来模拟时序信息。视频中核心是视频动作识别,本质就是视频分类,可以用作特征提取,视频时序提取是输入一段长视频获取其中的时序片段,时空定位是同时获取视频中的人物物体的空间位置,核心三大任务,除此之外视频特征提取embedding,这块主要是结合多模态去做,视频,音频和文本侧特征的综合利用和提取。

1.时序信息维度 

上述这个视频序列从左向右播放和从右向左播放表达的意思是不同的,视频理解对视频顺序是强依赖的。

2.temporal shift module

这个模块是核心,其实tsm是可插拔模块,是可以很好的嵌入到resnet等模型中,上述图中,一种颜色是一帧,按照时序T上,一共是四帧,同一帧横向是一个channel,在cnn中channel是统一做cnn的,在a图中是没有shift的,在b中是离线shift操作,可见将channel中第一个向下移动,第二个向上移动,其实至于上下移动几个channel并没有很严的的限制,通常是分成几等分去移动,这样上下移动之后,则第一个channel会向下突出一帧,第二个channel会向上突出一帧,突出帧直接截断,空缺帧直接补0,这样在横向做cnn时,统一channel维度变引入不同色的帧,tsm正是通过这种平移的方式,TSM在特征图中引入 temporal 维度上的上下文交互,通过通道移动操作可以使得在当前帧中包含了前后两帧的通道信息,这样再进2D卷积操作就能像3D卷积一样直接提取视频的时空信息,提高了模型在时间维度上的建模能力。而online模式用于对视频类型的实时预测,在这种情况下,无法预知下一秒的图像,因此只能将channel维度由过去向现在移动,而不能从未来向现在移动。

3.缺点和改进

虽然时间位移的原理很简单,但作者发现直接将空间位移策略应用于时间维度并不能提供高性能和效率。具体来说,如果简单的转移所有通道,则会带来两个问题:

  1. 由于大量数据移动而导致的效率下降问题。位移操作不需要计算但是会涉及数据移动,数据移动增加了硬件上的内存占用和推理延迟,作者观察到在视频理解网络中,当使用naive shift策略时,CPU延迟增加13.7%,GPU延迟增加12.4%,使整体推理变慢。
  2. 空间建模能力变差导致性能下降,由于部分通道被转移到相邻帧,当前帧不能再访问通道中包含的信息,这可能会损失2D CNN主干的空间建模能力。与TSN基线相比,使用naive shift会降低2.6%的准确率。

为了解决naive shift的两个问题,TSM给出了相应的解决方法。

  1. 减少数据移动。 为了研究数据移动的影响,作者测量了TSM模型在不同硬件设备上的推理延迟,作者移动了不同比例的通道数并测量了延迟,位移方式分为无位移、部分位移(位移1/8、1/4、1/2的通道)和全部位移,使用ResNet-50主干和8帧输入测量模型。作者观察到,如果移动所有的通道,那么延迟开销将占CPU推理时间的13.7%,如果只移动一小部分通道,如1/8,则可将开销限制在3%左右。       
  2. 保持空间特征学习能力。 一种简单的TSM使用方法是将其直接插入到每个卷基层或残差模块前,如 所示,这种方法被称为 in-place shift,但是它会损失主干模型的空间特征学习能力,尤其当我们移动大量通道时,存储在通道中的当前帧信息会随着通道移动而丢失。为解决这个问题,作者提出了另一种方法,即将TSM放在残差模块的残差分支中,这种方法被称为 residual TSM,如所示,它可以解决退化的空间特征学习问题,因为原始的激活信息在时间转移后仍可通过identity映射访问。

 4.mmaction2中的代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import NonLocal3d
from torch.nn.modules.utils import _ntuple

from ..builder import BACKBONES
from .resnet import ResNet


class NL3DWrapper(nn.Module):
    """3D Non-local wrapper for ResNet50.

    Wrap ResNet layers with 3D NonLocal modules.

    Args:
        block (nn.Module): Residual blocks to be built.
        num_segments (int): Number of frame segments.
        non_local_cfg (dict): Config for non-local layers. Default: ``dict()``.
    """

    def __init__(self, block, num_segments, non_local_cfg=dict()):
        super(NL3DWrapper, self).__init__()
        self.block = block
        self.non_local_cfg = non_local_cfg
        self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,
                                          **self.non_local_cfg)
        self.num_segments = num_segments

    def forward(self, x):
        x = self.block(x)

        n, c, h, w = x.size()
        x = x.view(n // self.num_segments, self.num_segments, c, h,
                   w).transpose(1, 2).contiguous()
        x = self.non_local_block(x)
        x = x.transpose(1, 2).contiguous().view(n, c, h, w)
        return x


class TemporalShift(nn.Module):
    """Temporal shift module.

    This module is proposed in
    `TSM: Temporal Shift Module for Efficient Video Understanding
    <https://arxiv.org/abs/1811.08383>`_

    Args:
        net (nn.module): Module to make temporal shift.
        num_segments (int): Number of frame segments. Default: 3.
        shift_div (int): Number of divisions for shift. Default: 8.
    """

    def __init__(self, net, num_segments=3, shift_div=8):
        super().__init__()
        self.net = net
        self.num_segments = num_segments
        self.shift_div = shift_div

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        x = self.shift(x, self.num_segments, shift_div=self.shift_div)
        return self.net(x)

    @staticmethod
    def shift(x, num_segments, shift_div=3):
        """Perform temporal shift operation on the feature.

        Args:
            x (torch.Tensor): The input feature to be shifted.
            num_segments (int): Number of frame segments.
            shift_div (int): Number of divisions for shift. Default: 3.

        Returns:
            torch.Tensor: The shifted feature.
        """
        # 假设当前feature map的通道是256,shift_div=3,
        # 那么就有256/3的特征进行shift left,256/3的特征进行shift right,其他一部分特征不动
        # num_segments每个视频采样的帧数
        # 每帧有c个通道,
        # [
        # [0_1,0_2,0_3,1_1,1_2,3_5,3_6,3_7]  第一帧,8个通道,但是shift_div表示这个通道维度被切分成3个等分
        # []  第二帧
        # []  第三帧
        # ]
        # [N, C, H, W]
        n, c, h, w = x.size()

        # [N // num_segments, num_segments, C, H*W]
        # can't use 5 dimensional array on PPL2D backend for caffe
        x = x.view(-1, num_segments, c, h * w)

        # get shift fold
        fold = c // shift_div

        # split c channel into three parts:
        # left_split, mid_split, right_split
        left_split = x[:, :, :fold, :]
        mid_split = x[:, :, fold:2 * fold, :]
        right_split = x[:, :, 2 * fold:, :]

        # can't use torch.zeros(*A.shape) or torch.zeros_like(A)
        # because array on caffe inference must be got by computing

        # shift left on num_segments channel in `left_split`
        zeros = left_split - left_split
        blank = zeros[:, :1, :, :]
        left_split = left_split[:, 1:, :, :]
        left_split = torch.cat((left_split, blank), 1)

        # shift right on num_segments channel in `mid_split`
        zeros = mid_split - mid_split
        blank = zeros[:, :1, :, :]
        mid_split = mid_split[:, :-1, :, :]
        mid_split = torch.cat((blank, mid_split), 1)

        # right_split: no shift

        # concatenate
        out = torch.cat((left_split, mid_split, right_split), 2)

        # [N, C, H, W]
        # restore the original dimension
        return out.view(n, c, h, w)


@BACKBONES.register_module()
class ResNetTSM(ResNet):
    """ResNet backbone for TSM.

    Args:
        num_segments (int): Number of frame segments. Default: 8.
        is_shift (bool): Whether to make temporal shift in reset layers.
            Default: True.
        non_local (Sequence[int]): Determine whether to apply non-local module
            in the corresponding block of each stages. Default: (0, 0, 0, 0).
        non_local_cfg (dict): Config for non-local module. Default: ``dict()``.
        shift_div (int): Number of div for shift. Default: 8.
        shift_place (str): Places in resnet layers for shift, which is chosen
            from ['block', 'blockres'].
            If set to 'block', it will apply temporal shift to all child blocks
            in each resnet layer.
            If set to 'blockres', it will apply temporal shift to each `conv1`
            layer of all child blocks in each resnet layer.
            Default: 'blockres'.
        temporal_pool (bool): Whether to add temporal pooling. Default: False.
        **kwargs (keyword arguments, optional): Arguments for ResNet.
    """

    def __init__(self,
                 depth,
                 num_segments=8,
                 is_shift=True,
                 non_local=(0, 0, 0, 0),
                 non_local_cfg=dict(),
                 shift_div=8,
                 shift_place='blockres',
                 temporal_pool=False,
                 **kwargs):
        super().__init__(depth, **kwargs)
        self.num_segments = num_segments
        self.is_shift = is_shift
        self.shift_div = shift_div
        self.shift_place = shift_place
        self.temporal_pool = temporal_pool
        self.non_local = non_local
        self.non_local_stages = _ntuple(self.num_stages)(non_local)
        self.non_local_cfg = non_local_cfg

    def make_temporal_shift(self):
        """Make temporal shift for some layers."""
        if self.temporal_pool:
            num_segment_list = [
                self.num_segments, self.num_segments // 2,
                                   self.num_segments // 2, self.num_segments // 2
            ]
        else:
            num_segment_list = [self.num_segments] * 4
        if num_segment_list[-1] <= 0:
            raise ValueError('num_segment_list[-1] must be positive')

        if self.shift_place == 'block':

            def make_block_temporal(stage, num_segments):
                """Make temporal shift on some blocks.

                Args:
                    stage (nn.Module): Model layers to be shifted.
                    num_segments (int): Number of frame segments.

                Returns:
                    nn.Module: The shifted blocks.
                """
                blocks = list(stage.children())
                for i, b in enumerate(blocks):
                    blocks[i] = TemporalShift(
                        b, num_segments=num_segments, shift_div=self.shift_div)
                return nn.Sequential(*blocks)

            self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
            self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
            self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
            self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])

        elif 'blockres' in self.shift_place:
            n_round = 1
            if len(list(self.layer3.children())) >= 23:
                n_round = 2

            def make_block_temporal(stage, num_segments):
                """Make temporal shift on some blocks.

                Args:
                    stage (nn.Module): Model layers to be shifted.
                    num_segments (int): Number of frame segments.

                Returns:
                    nn.Module: The shifted blocks.
                """
                blocks = list(stage.children())
                for i, b in enumerate(blocks):
                    if i % n_round == 0:
                        blocks[i].conv1.conv = TemporalShift(
                            b.conv1.conv,
                            num_segments=num_segments,
                            shift_div=self.shift_div)
                return nn.Sequential(*blocks)

            self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
            self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
            self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
            self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])

        else:
            raise NotImplementedError

    def make_temporal_pool(self):
        """Make temporal pooling between layer1 and layer2, using a 3D max
        pooling layer."""

        class TemporalPool(nn.Module):
            """Temporal pool module.

            Wrap layer2 in ResNet50 with a 3D max pooling layer.

            Args:
                net (nn.Module): Module to make temporal pool.
                num_segments (int): Number of frame segments.
            """

            def __init__(self, net, num_segments):
                super().__init__()
                self.net = net
                self.num_segments = num_segments
                self.max_pool3d = nn.MaxPool3d(
                    kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))

            def forward(self, x):
                # [N, C, H, W]
                n, c, h, w = x.size()
                # [N // num_segments, C, num_segments, H, W]
                x = x.view(n // self.num_segments, self.num_segments, c, h,
                           w).transpose(1, 2)
                # [N // num_segmnets, C, num_segments // 2, H, W]
                x = self.max_pool3d(x)
                # [N // 2, C, H, W]
                x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)
                return self.net(x)

        self.layer2 = TemporalPool(self.layer2, self.num_segments)

    def make_non_local(self):
        # This part is for ResNet50
        for i in range(self.num_stages):
            non_local_stage = self.non_local_stages[i]
            if sum(non_local_stage) == 0:
                continue

            layer_name = f'layer{i + 1}'
            res_layer = getattr(self, layer_name)

            for idx, non_local in enumerate(non_local_stage):
                if non_local:
                    res_layer[idx] = NL3DWrapper(res_layer[idx],
                                                 self.num_segments,
                                                 self.non_local_cfg)

    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        super().init_weights()
        if self.is_shift:
            self.make_temporal_shift()
        if len(self.non_local_cfg) != 0:
            self.make_non_local()
        if self.temporal_pool:
            self.make_temporal_pool()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/25283.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2. Object中equals和toStirng 源码分析

文章目录1.equals方法2.重写equals方法为何一定要重写hashCode方法&#xff1f;2.1 反例演示3.toString方法4. 整型转二进制我们都知道Object是所有类的父类&#xff0c;那么它里面的一些方法你是否真的理解了呢&#xff1f; 下面我们就以源码为基础来学习这些看似简单的方法吧…

谷歌浏览器无法使用翻译功能的解决方案,谷歌浏览器无法翻译怎么办?谷歌浏览器右键翻译失效了?

如果你发现网站别的方案无效&#xff0c;请参考我的方案&#xff0c; 绝对有效&#xff01; 2022年起&#xff0c;突然发现谷歌浏览器的翻译功能无法使用了&#xff0c;既然发现问题&#xff0c;就要解决问题&#xff0c;按照下面的步骤一步一步来操作 首先下载最新版谷歌浏览…

[附源码]java毕业设计校园出入管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

【Lua基础 第4章】Lua的流程控制、#的作用、table的创建方式、table表常用方法、函数、多返回值、可变长参数

文章目录&#x1f4a8;更多相关知识&#x1f447;一、Lua 的流程控制&#x1f538;if语句&#x1f31f;代码演示&#x1f538;if...else 语句&#x1f31f;代码演示&#x1f538;if...elseif...else 语句二、#的作用三、table的创建方式四、table表的常用方法使用&#x1f539;…

标记肽Suc-AAPI-pNA、72682-77-0

标记肽Suc-AAPI-对硝基苯胺编号: 184433 中文名称: 标记肽Suc-AAPI-对硝基苯胺 英文名: Suc-Ala-Ala-Pro-Ile-pNA CAS号: 72682-77-0 单字母: Suc-AAPI-pNA 三字母: Suc-Ala-Ala-Pro-Ile-pNA 氨基酸个数: 4 分子式: C27H38O9N6 平均分子量: 590.63 精确分子量: 590.27 等电点(P…

使用MobaXterm tunneling访问集群(服务器)jupyter notebook

应用场景 想要在本地计算机C上使用高性能服务器上的计算节点运行jupyter notebook相关的代码。 高性能服务器上通常只有一个公网ip用于账户登陆管理&#xff0c;但有多个计算节点&#xff0c;需要使用公网IP通过SSH方式登入管理节点A&#xff0c;并使用SSH二次登陆计算节点B&…

CKKS同态加密方案初步学习

如论文标题所示&#xff0c;CKKS允许复数和实数运算&#xff0c;是一个近似精度计算的方案&#xff0c;也就是解密出来的明文和加密之前的明文不会完全一致。也就是采用丢失部分精度来换取较高的效率。 CKKS的核心是把加密噪声视为近似计算误差的一部分&#xff0c;也就是解密出…

Python项目一:pygname

1.安装pip install pygame 2.加载模块初始化&#xff1a;开始 import sys import pygamepygame.init() #初始化3.创建窗口 3.1pygame .display模块 作用&#xff1a;创建游戏窗口 常见的内置方法&#xff1a; 方法作用 pygame。display.init() 初始化display模块p…

C++11标准模板(STL)- 算法(std::partial_sort)

定义于头文件 <algorithm> 算法库提供大量用途的函数&#xff08;例如查找、排序、计数、操作&#xff09;&#xff0c;它们在元素范围上操作。注意范围定义为 [first, last) &#xff0c;其中 last 指代要查询或修改的最后元素的后一个元素。 排序一个范围的前 N 个元素…

阿里巴巴最新总结「百亿级别并发设计手册」GitHub收获70K标星

随着淘宝购物节和抖音直播平台带货的火热&#xff0c;大批促销活动涌现&#xff0c;「秒杀」这个词也越来越频繁地出现在我们的生活里。 除了那些头部的电商公司&#xff0c;某多、某东&#xff0c;还有各种街、某会、某品等&#xff0c;甚至是一些老牌的传统企业&#xff0c;…

Android持久化技术,好内存不如烂存储

Android持久化技术&#xff0c;好内存不如烂存储前言六、Android持久化技术&#xff0c;好内存不如烂存储6.1 持久化技术介绍6.2 简单文件存储方案6.3 SharedPreferences存储方案6.3.1 获取SharedPreferences对象的三种方式6.3.2 使用SharedPreferences对象存储和读取数据6.4 S…

Model Fusion of Heterogeneous Neural Networks via Cross-Layer Alignment论文阅读

论文地址点这里 一. 介绍 本文是针对异构的网络融合技术&#xff0c;是基于上一篇OTFusion的论文进行的工作&#xff0c;解决了神经元关联问题。当所有的网络都具有相同的架构时&#xff0c;OTFusion比普通平均算法有明显的改进。与其他基于平均的模型融合方法相比&#xff0…

如何制作一个实时在线显示评论

通过循环容器及数据表功能&#xff0c;制作一个发送评论实时显示的功能 效果展示 具体步骤 制作评论背景 制作评论样式 制作一个发送评论输入框 制作一个发送按钮 创建评论数据表 添加获取评论事件 创建发送评论触发器 数据绑定与设置 步骤分解 制作评论背景 拖拽 循环容器 到…

Go基础学习【2】

文章目录一&#xff1a;数组二&#xff1a;map集合三&#xff1a;包四&#xff1a;结构体一&#xff1a;数组 1.命名 var arrAge [5]int{1,2,3,4,5} var arrAge […]int{1,2,4,5,6} var arrAge [5]string{3:“sfd”,5:“asdf”} 2.传递 通过传递数组的指针 和 使用数组的切片…

[go学习笔记.第十六章.TCP编程] 2.项目-海量用户即时通讯系统

一.项目介绍 1.项目开发流程 需求分析->设计阶段->编码实现->测试阶段->实施阶段 2.需求分析 (1).用户注册 (2).用户登录 (3).显示在线用户列表 (4).群聊(广播) (5).点对点聊天 (6).离线留言 3.示意图 4.项目开发前技术准备 项目要保存用户信息和消息数据,因此需…

【Vue】vue项目用qrcodejs2生成带log的二维码图片,vue生成二维码图片中间带log,自定义log

系列文章目录 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录系列文章目录前言一、安装qrcodejs2二、在页面中使用1.引入…

英国Paper写作思路和精髓如何了解?

第一学期即将结束&#xff0c;为了能帮助更多英国留学生顺利完成Paper&#xff0c;增加对英国Paper写作的理解&#xff0c;取得高分。本文小编为大家分享英国Paper写作的思路和精髓&#xff0c;帮助自己修改提升Paper质量。 The first semester is coming to an end.In order t…

flutter AnimatedSwitcher 动画切换过渡组件 跑马灯动画封装

flutter AnimatedSwitcher 动画切换过渡组件前言一、AnimatedSwitcher 简介二、AnimatedSwitcher 的简单使用三、AnimatedSwitcher 自定义跑马灯动画四、SlideTransitionX 的封装总结前言 本篇文章将记录 AnimatedSwitcher 过渡组件&#xff0c;这个组件动画是一个新的小部件来…

制作一个简单HTML宠物猫网页(HTML+CSS)

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

在 Spring Boot中配置日志

Spring Boot 在引擎盖下使用Apache Commons Logging。但是&#xff0c;它允许您选择所需的日志记录库。让我们来看看使用 Spring Boot 时的一些配置和最佳实践。 目录 概述简单日志记录示例配置日志记录 更改日志级别将日志写入文件在 Spring 引导中更改日志记录模式对日志条…