TadTR(TIP 2022)视频动作检测方法详解

news2024/10/11 5:01:41

前言

论文:End-to-end Temporal Action Detection with Transformer
代码:TadTR

  从论文题目可以看出 TadTR 是基于 Transformer 的端到端的方法,TAD 在视频动作分类任务上更进一步,不仅对动作分类,还要检测动作发生的时间段。本文使用的代码为 OpenTAD,其中包含了多种 TAD 方法,在正式解析网络细节之前,看一下总框架。

原代码仓库框架图(与论文原图略有不同)

在这里插入图片描述
  图像序列会经过 CNN 提取特征,并用 Transformer Encoder 编码,Encoder 的输出经过 Decoder 和 FFN 得到 Segment 分割时间段和 Class Prob 动作分类得分。时间段和 Encoder 特征再经过 Actionness Regression 计算时间段是否存在动作。

  整体的思路与目标检测类似,Segment 对应检测框 box,Class Prob 对应目标分类 class,Actionness 对应目标框是否存在目标 conf。

Backbone

在这里插入图片描述
  TadTR 使用 SlowFast 作为 Backbone。SlowFast 本身做视频动作分类,相关工作请看 视频理解调研笔记 | 2021年前视频动作分类发展脉络(类似使用图像分类的网络作为目标检测的 Backbone 来抽特征)。下面对输入数据、模型结构和参数的细节进行说明。

输入

  • Input fast 是从原始视频中每隔 3 帧取 1 帧共 256 帧图像,通过间隔较短的大量帧作为运动动态信息;
  • Input slowInput fast 每隔 8 帧取 1 帧共 32 帧,通过间隔较长的少量帧作为静态图像信息;
  • 4 - BatchSize,3 - 图像 RGB 三通道,32/256 - 图像帧数, 96 × 96 96\times96 96×96 - 图像分辨率(保比缩放+中心裁剪)。
  • 对于一个较长的视频来说会取出多段的 256 帧最终基本覆盖整个视频,具体的取帧逻辑就略过了,直接以某个视频的划分结果来感受一下,这里一个视频划分出了 3 个 Batch 的数据,对应原始视频帧的索引分别为 0~765、192~957、246~1011。

结构和参数

  • SlowFast 包含 2 个 3D-ResNet 对应了快慢分支;
    • 慢分支的图像特征会多次融合快分支的运动特征,具体方式是用 3D 卷积对运动特征在时序上做下采样,然后和图像特征做 Concat;
    • 空间维度上,模型下采样 32 倍后通过一个平均池化变为 1 × 1 1\times1 1×1
    • 时间维度上全程保持不变,末尾处以线性插值的方式将两个分支的维度对齐为 128;
    • 最后两个分支的特征 Concat 到一起作为 Backbone 的输出。
  • Input fast 后的第一个 CBR 为例对结构图进行说明,CBR 代表 Conv + BatchNorm + RuLU 的模块,上方的两组参数分别为时序和空间上的 size + stride + pad,如果只有一组参数代表在时序维度的参数为 1,1,0
  • res_layer 上方参数代表 Bottleneck 的数量,Bottleneck 的结构以 slow_res_layer1 中的第一个为例进行说明,具体如下图。每个 res_layer 的第一个 Bottleneck 的残差结构都需要一个 downsample 卷积对特征的维度或空间分辨率做调整,空间参数为 1,1,01,2,0,时间参数为 1,1,0;后续的 Bottleneck 因为输入输出形状相同,残差结构去掉 downsample 卷积直接和输入相加。

在这里插入图片描述

  • Bottleneck 每个卷积核的参数可以看下表,空间上都保持一致,3 个卷积核的参数分别为 1,1,03,1,11,1,0;时间上,conv1 中除了慢分支的 layer1layer2 的为 1,1,0 其余都是 3,1,1conv2conv3 全是 1,1,0

在这里插入图片描述

Projection

在这里插入图片描述

  Projection 的输入为 Backbone 的输出,用一维卷积做特征降维,torch.nn.GroupNorm 按特征维度 32 一组划分为 8 组做 Normalization。这里的 Mask 是在对原始视频取帧时得到的,大致意味着是否是真实的帧(例如视频帧不够了需要 padding 到 256 帧那么对应 Mask 的值为 False),本文默认 Mask 全为 True 进行说明。

Transformer

1. Encoder

  Encoder 包含 TDA 和 FFN,先看 FFN 的计算流程。
FFN

  难点在于 TDA,出自 Deformable-DETR,原本用于做目标检测,受可变形卷积 DCN 启发对 Self-Attention 做改进,核心在于不再计算所有特征之间的注意力权重,而是选取几个位置的特征,且位置由网络学习得到。

在这里插入图片描述

输入

(1)query & value

  对 Projection 的输出做维度调整,即 query = value = Projection_output.permute(0, 2, 1)

(2)enPos

  图中位置编码 enPos 按下方代码由两个部分组成。其中 position 和原版 Transformer 类似,由三角函数生成;level 是可学习参数,对应特征的层级(此处只有 1 个 level),在原版 Deformable-DETR 会取出不同层级的特征图,position 体现的是空间上的位置差异,level 则体现层级的差异。

self.level_embeds = nn.Parameter(torch.Tensor(self.encoder.num_feature_levels, self.encoder.embed_dim))
pos_embed = self.position_embedding(masks) + self.level_embeds[0].view(1, 1, -1)

(3)enPoints

  其数值为 0.5 ~ 127.5 除以 128 做归一化,对应时序的位置坐标。

计算流程

B = 1 = BatchSize
N = 128
H = 8 = Head
L = 1 = Level
P = 4 = Point
D = 32 = Dimension

(1)attention

  与原始 Transformer 不同,注意力权重直接通过一个全连接层加 Softmax 得到。

(2)offset

  用一个全连接层得到坐标偏移量,与原坐标 enPoints 相加得到坐标 Locations(范围大致0~1), × 2 − 1 \times2-1 ×21 得到 Grids (范围大致-1~1),Stack 的数值全为 -1,具体意义在 value 分支一同说明。

(3)value

sampling_value_l_ = F.grid_sample(
            value_l_.unsqueeze(-1),
            sampling_grid_l_,
            mode="bilinear",
            padding_mode="zeros",
            align_corners=False,
        )

  直接看 grid_sample 操作,原图为 32 × 32 × 128 × 1 32\times32\times128\times1 32×32×128×1,第一个 32 = B × H 32=B\times H 32=B×H,第二个 32 32 32 为特征维度, 128 128 128 对应了时序的长度也可以看作词向量的个数,在这里可以把 128 × 1 128\times1 128×1 看作 h × w h\times w h×w 的图像分辨率,grid 中最后的维度 2 = ( x , y ) 2=(x,y) 2=(x,y) 坐标,其中 x = − 1 x=-1 x=1 y y y 在 offset 分支计算得到。输出可以看作 128 个时间点,每个时间点关注 4 个特定时间点的特征,特征的维度是 32。
  grid_sample 后的 Stack 操作是用于合并每个 level,后续就是常规的加权求和。这里返回看 Softmax 的维度是 L × P L\times P L×P,与原版 Transformer 对比,原版输入是每个单词的特征,经过 self-attention 输出每个单词新的特征,而新的特征是对所有单词特征加权求和所得;在 TDA 中输入是 128 个时间点的特征,输出的每个特征由 4 个时间点的特征加权求和所得。

2. Decoder

在这里插入图片描述

  输入 query 先经过经典多头注意力 nn.MultiheadAttention,其输出和 Encoder 的输出分别作为 TDA 的输入 queryvalue,得到输出 Output

  Output 通过 bbox embeddePoints 一同更新 dePointsbbox embed 具体细节如下图;最后一个 Decoder 的输出 dePoints 代表检测动作时间段 box,Class 代表动作分类。
在这里插入图片描述
  最后,box 与 Encoder 输出特征做 RoIAlign 与 FFN 得到每个片段是否存在动作的置信度,对应总框架图右上角部分。
  box 中的 2 个数值分别代表时间的中心点和时间长度,RoI 的 3 个数值分别代表 batch 索引、时间的起始和结束点(尺度为128)。

在这里插入图片描述

输入

  Decoder 的输入由 nn.Embedding 生成,query 对应起始的输入,dePosdePoints 会代替 Encoder 中 TDA 的 enPosenPoints(注 enPoints 会不断更新,而 enPos 保持不变)。40 代表 proposals 个数,意味着这段视频内最多检测出 40 个动作。

self.query_embedding = nn.Embedding(two_stage_num_proposals, self.encoder.embed_dim * 2)

在这里插入图片描述

Postprocess

  后处理根据分类得分和动作置信度综合排序得到结果,最终输出时间段(秒)、动作类别、置信度。每个 Batch 输出 200 个结果,此处 3 个 Batch 属于同一个视频,那么一个视频就得到 600 个结果。

prob = sigmoid(Class) * actionness	# [4,40,20]
score, index = topk(prob.view(bs, -1), 200, dim=1)	# [4,200]

在这里插入图片描述

Train

1. 输入

return self.losses(output, masks, gt_segments, gt_labels)

  Loss 计算使用的 output 如下所示,前三项和推理阶段相同,最后一项 aux_outputs 是 Decoder 阶段的中间输出。Decoder 包含 4 个相同的模块,前三个输出的 dePoints 和 Class 构成了 aux_outputs

在这里插入图片描述

2. 正样本选取

indices = self.matcher(outputs_without_aux, gt_segments, gt_labels)

在这里插入图片描述
  此处针对 Outputs 的前三项,即最终输出进行计算,4 个元组对应 4 个 batch,每个 batch 中第一个张量对应 proposal 的索引,第二个张量对应 gt 索引。

3. Loss 计算

  这里简单介绍每个部分 Loss 的核心计算函数,具体计算细节涉及不少代码和各种参数意义不大,如下图所示,前四项为最终输出对应的 Loss,后续为 aux 部分的 Loss,计算方式是一样的。

在这里插入图片描述

(1)class

  F.binary_cross_entropy_with_logits 计算交叉熵,正样本 → \to 1,负样本 → \to 0

(2)bbox

  F.l1_loss 计算 box 数值的 L1Loss

(3)iou

   1 − G I o U 1-\mathrm{GIoU} 1GIoU

(4)actionness

  F.l1_loss 计算 actionness 和 IoU 的 L1Loss

代码中的细枝末节

(1)dePoints 迭代与 Locations 计算

  在 Decoder 中 TDA 的 dePoints 初始形状为 4 × 40 × 1 4\times 40\times 1 4×40×1 而后续为 4 × 40 × 2 4\times 40\times 2 4×40×2,结合源码看 Locations 的计算方式。

"当 points 为 [4,40,1] 时直接与 offsets 相加"
"normalizer=128=时序长度, 对 offsets 缩放"
sampling_locations = (
	reference_points[:, :, None, :, None]
	+ sampling_offsets / offset_normalizer[None, None, None, :, None]
)

"当 points 为 [4,40,2] 时第二个值用于对 offsets 缩放, num_points=4"
sampling_locations = (
	reference_points[:, :, None, :, None, 0]
	+ sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 1] * 0.5
)

(2)Mask

  在 Projection 部分提到了 Mask,实际上 Mask 在许多地方都会使用,但本文默认 Mask 全为 True 而将其忽略了。Mask 随 DataLoader 获取,记录序列帧中每一帧是否是 padding 得到的,对比图像就是记录图像的像素是否是 padding 的。代码中会根据真实帧所占的比例 ratio = sum(Mask) / len(Mask) 作为一个系数在一些地方使用。

(3)正样本选取的代码实现

  这里的代码实现比较有意思,直接看最后一行,c[i] 代表每个 Batch 的 proposals 和对应 gt 的 cost 矩阵大小为 [40,n],n 为 gt 的数量。利用匈牙利算法 scipy.optimize.linear_sum_assignment 直接得到正样本的分配,在此任务中就是为每个 gt 分配一个不重复的 proposal,最终 cost 总和最低。这里的 cost 矩阵由分类得分、检测框的绝对值与 IoU 综合构成,cost 越低代表与 gt 越匹配。

# alpha=0.25, gamma=2.0
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())	# [160,20]
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())		# [160,20]
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]				# [160,6]
cost_bbox = torch.cdist(out_bbox, tgt_segment, p=1)									# [160,6]
cost_giou = -compute_iou_torch(proposal_cw_to_se(tgt_segment), proposal_cw_to_se(out_bbox))		# [160,6]

# self.cost_bbox=5.0, self.cost_class=6.0, self.cost_giou=2.0
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou		# [160,6]
C = C.view(bs, num_queries, -1).cpu()	# [4,40,6]

indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2204124.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣21~30题

21题(简单): 分析: 按要求照做就好了,这种链表基本操作适合用c写,python用起来真的很奇怪 python代码: # Definition for singly-linked list. # class ListNode: # def __init__(self, v…

每日一题:单例模式

每日一题:单例模式 ❝ 单例模式是确保一个类只有一个实例,并提供一个全局访问点 1.饿汉式(静态常量) 特点:在类加载时就创建了实例。优点:简单易懂,线程安全。缺点:无论是否使用&…

uni-app如何搭建项目(一步一步教程)

来来来,看这里 uni-app新建项目教程uni-app项目结构 首先我们要有一个HBuilder这个软件,然后我们来搭建uni-app项目 uni-app新建项目教程 首先我们打开这个HBuilder软件,好我们就出现这个界面,我们点击新建项目   然后我们选择…

Github优质项目推荐 - 第六期

文章目录 Github优质项目推荐 - 第六期一、【WiFiAnalyzer】,3.4k stars - WiFi 网络分析工具二、【penpot】,33k stars - UI 设计与原型制作平台三、【Inpaint-Anything】,6.4k stars - 修复图像、视频和3D 场景中的任何内容四、【Malware-P…

小猿口算APP脚本(协议版)

小猿口算是一款专注于数学学习的教育应用,主要面向小学阶段的学生。它提供多种数学练习和测试,包括口算、速算、应用题等。通过智能化的题目生成和实时批改功能,帮助学生提高数学计算能力。此外,它还提供详细的学习报告和分析,帮助家长和教师了解学生的学习进度和薄弱环节…

批量处理vue2中文硬编码转i18n国际化(保姆级)

文章目录 背景技术选型使用软件与插件插件使用补充 背景 公司的项目需要适应国际化的需求,但是因为代码是一个成品的项目,也就导致,代码量巨大,连带着需要转国际化的硬编码中文也很多,如果一点点纯手工改动&#xff0…

RelationGraph实现工单进度图——js技能提升

直接上图: 从上图中可以看到整个工单的进度是从【开始】指向【PCB判责】【完善客诉】【PCBA列表】,同时【完善客诉】又可以同时指向【PCB判责】【PCBA列表】,后续各自指向自己的进度。 直接上代码: 1.安装 1.1 Npm 方式 npm …

“探索端智能,加速大模型应用” 火山引擎边缘智能x扣子技术沙龙圆满落幕!

9月21日,火山引擎边缘智能扣子技术沙龙在上海圆满落地,沙龙以“探索端智能,加速大模型应用”为主题,边缘智能、扣子、地瓜机器人以及上海交通大学等多位重磅嘉宾出席,从多维视角探讨 AI、 AIoT、端侧大模型等技术与发展…

嵌入式数据结构中线性表的具体实现

大家好,今天主要给大家分享一下,如何使用数据结构中的线性表以及具体的实现。 第一:线性表的定义和表示方法 线性表的定义 – 线性表就是零个或多个相同数据元素的有限序列。 • 线性表的表示方法 – 线性表记为: L=(a0,∙∙∙∙∙∙∙∙ai-1aiai+1 ∙∙∙∙∙∙an-1) •…

HTTP的工作原理

HTTP(Hypertext Transfer Protocol)是一种用于在计算机网络上传输超文本数据的应用层协议。它是构成万维网的基础之一,被广泛用于万维网上的数据通信。(超文本(Hypertext)是用超链接的方法,将各种不同空间的文字信息组…

数据交换的金钟罩:合理利用安全数据交换系统,确保信息安全

政府单位为了保护网络不受外部威胁和内部误操作的影响,通常会进行网络隔离,隔离成内网和外网。安全数据交换系统是专门设计用于在不同的网络环境(如内部不同网络,内部网络和外部网络)之间安全传输数据的解决方案。 使用…

Redis 其他类型 渐进式遍历

我们之前已经学过了Redis最常用的五个类型了,然而Redis还有一些在特定场景下比较好用的类型 Redis最关键的五个数据类型: 上面的类型是非常常用,很重要的类型。 除此之外的其他类型不常用,只是在特定的场景能够发挥用处&#…

澳鹏干货 | 大语言模型的上下文窗口 (Context Windows)

大语言模型(LLMs)极大地提升了人工智能在理解和生成文本方面的能力。其中一个影响其效用的重要方面是“上下文窗口”(Context Windows)—— 这个概念直接影响着模型接收和生成语言的有效性。 本期澳鹏干货将深入探讨上下文窗口对…

微软确认Word离奇Bug 命名不当会导致文件被删

微软近日确认Word应用中存在一个Bug,该漏洞可能导致用户在特定情况下错误地删除文件。该问题主要出现在文件命名过程中,如果用户在保存Word文件时采用特定的命名方式,文件可能会被移动到回收站。 根据微软支持中心的消息,如果用户…

MVS海康工业相机达不到标称最大帧率

文章目录 一、相机参数设置1、取消相机帧率限制2、修改相机图像格式3、调整相机曝光时间4、检查相机数据包大小(网口相机特有参数)5、 恢复相机默认参数6、 相机 ADC 输出位深调整 二、系统环境设置1、 网口相机设置2、 USB 相机设置 一、相机参数设置 …

java对接GPT 快速入门

统一对接GPT服务的Java说明 当前,OpenAI等GPT服务厂商主要提供HTTP接口,这使得大部分Java开发者在接入GPT时缺乏标准化的方法。 为解决这一问题,Spring团队推出了Spring AI ,它提供了统一且标准化的接口来对接不同的AI服务提供商…

毕设分享 大数据用户画像分析系统(源码分享)

文章目录 0 前言2 用户画像分析概述2.1 用户画像构建的相关技术2.2 标签体系2.3 标签优先级 3 实站 - 百货商场用户画像描述与价值分析3.1 数据格式3.2 数据预处理3.3 会员年龄构成3.4 订单占比 消费画像3.5 季度偏好画像3.6 会员用户画像与特征3.6.1 构建会员用户业务特征标签…

Linux查看下nginx及使用的配置文件

1、查到nginx进程 ps -aef | grep nginx2、通过进行pid查到nginx路径 pwdx <pid>3、根据路径得到配置文件 path***/nginx -t如下&#xff1a;

Unity网络开发基础 —— 实践小项目

概述 接Unity网络开发基础 导入基础知识中的代码 需求分析 手动写Handler类 手动书写消息池 using GamePlayer; using System; using System.Collections; using System.Collections.Generic; using UnityEngine;/// <summary> /// 消息池中 主要是用于 注册 ID和消息类…

(五)Proteus仿真STM32单片机串口数据流收发

&#xff08;五&#xff09;Protues仿真STM32单片机串口数据流收发 – ARMFUN 1&#xff0c;打开STM32CubeMX&#xff0c;找到USART1,配置模式Asynchronous&#xff0c;此时PA9、PA10自动变成串口模式 串口默认参数:115200bps 8bit None 1stop 2&#xff0c;NVIC Settings使能…