自监督行为识别-时空线索解耦(论文复现)

news2025/1/13 13:29:45

自监督行为识别-时空线索解耦(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 自监督行为识别-时空线索解耦(论文复现)
      • 引言
      • 论文概述
      • 核心创新点
      • 双向解耦编码器
      • 跨域对比损失的构建
      • 结构化数据增强
      • 项目部署
        • 准备工作
        • 数据准备
      • 生成数据
        • 训练&测试
          • 训练
          • 测试
      • bug修改

引言

自监督骨架行为识别是一种利用未标记的骨架数据进行行为识别的方法。传统的行为识别方法通常需要大量标记好的数据进行训练,但标记数据的获取成本高昂。自监督学习通过设计自动生成标签的任务,可以在缺乏标记数据的情况下进行训练

在这里插入图片描述

在自监督骨架行为识别中,骨架数据可以通过传感器或深度摄像头等设备获取。这些数据包含了人体关节的位置和运动信息。自监督学习任务的关键是设计一种能够从未标记的骨架数据中自动生成标签的方法。

在训练过程中,使用未标记的骨架数据进行自监督学习,生成伪标签。然后,将生成的伪标签用于监督骨架行为识别模型的训练。通过这种方式,自监督学习可以在缺乏标记数据的情况下,提供一种有效的方法进行骨架行为识别。

那么目前自监督骨架行为还面临哪些挑战呢?

  • 挑战1. 时空信息的混淆

编码器负责将输入映射到可以进行对比的潜在空间。而之前的大多数方法专注于通过常用的时空建模网络获得统一的信息。他们的设计导致了时间、空间信息的纠缠,无法为随后的对比措施提供明确的指示。

  • 挑战2.数据增强的局限性

此外,现有技术往往局限于规模转换(常见的增强策略,比如裁剪、旋转),这导致无法充分利用数据增强的潜力。

  • 挑战3. 未考虑方法的可迁移性

优化过程中,大多数方法都专注于在相同的表示水平上构建对比对;忽略域之间的差距(同一任务下或数据集中)

论文概述

SCD-NET(SCD-Net: Spatio temporal Clues Disentanglement Network for
Self-Supervised Skeleton-Based Action Recognition AAAI2024)引入了一种新的对比学习框架,即时空线索解耦网络(SCD-Net)。
  具体来说,将解耦模块与特征提取器相结合,分别从空间和时间域获得明确的线索。对于SCD-Net的训练,构建了一个全局锚点,鼓励锚点与提取的线索相互作用。此外,本文提出了一种具有结构约束的新的掩码策略,以加强上下文关联,利用掩码图像建模到所提出的SCD-Net。
  从实验结果来看,在NTU-RGB+D(60&120)和PKUMMD (I&II)数据集进行了广泛的评估,涵盖了各种下游任务,如动作识别、动作检索、迁移学习和半监督学习。实验结果证明了该方法的有效性,显著优于现有的最先进(SOTA)方法

核心创新点

为了解决自监督在面临的三个挑战,该文分别提出三种方法分别应对。首先在时空信息混淆的问题上,作者提出双向接口编码器;数据增强方面,分别在时间、空间上分设置不同的数据增强策略;方法的可迁移性方面设置了跨越对比损失,详细架构可见下文。
  SCD-NET整体架构如下所示:骨架数据->数据增强(data augmentation)后,分别送入编码器层(encoder)以及动量编码器层(Momentum encoder).每个编码器都使用了双向解耦编码器,在经过特征抽取器(feature extractor)后,分别对空间解耦(spatial decoupling)、时间解耦(temporal decoupling)操作,获取不同维度的特征。动量编码器得到的输出作为键向量,正常编码器得到的输出作为查询向量,最后将键向量、查询向量进行对比学习

在这里插入图片描述

双向解耦编码器

一般来说,从骨架序列中提取的特征被描述为描述动作的复杂时空关联。然而,本文认为这种范式并不适用于对比学习。由于信息的纠缠性很大,很难为后续的比较提供明确的指导。在SCD-Net中,本文提倡一种双路解耦编码器,从复杂的序列信息中分别提取出时间、空间信息以获得更好的判别性表示。
  双向解耦编码器构造如下图:分为建模(projection)和细化(refinement)阶段,空间部分对CT维度进行合并,保留V(代表骨骼关节)维度,而后进行嵌入操作得到骨架图->序列化–>transformer 编码器->空间池化->空间特征;时间部分对CC维度进行合并,保留T(代表视频帧)维度,而后进行嵌入操作得到关节序列->序列化–>transformer 编码器->时间池化->时间特征

在这里插入图片描述

# 双向解耦编码器
        vt = self.gcn_t(x)
        vt = rearrange(vt, '(B M) C T V -> B T (M V C)', M=2)
        vt = self.channel_t(vt)
        vs = self.gcn_s(x)
        vs = rearrange(vs, '(B M) C T V -> B (M V) (T C)', M=2)
        vs = self.channel_s(vs)
        vt = self.t_encoder(vt) # B T C
        vs = self.s_encoder(vs)
        # implementation using amax for the TMP runs faster than using MaxPool1D
        # not support pytorch < 1.7.0
        vt = vt.amax(dim=1)
        vs = vs.amax(dim=1)
        return vt, vs

跨域对比损失的构建

在这里插入图片描述

 # 正负样本以及损失函数的设计
    def forward(self, q_input, k_input):
        三种查询向量的定义
        qt, qs, qi = self.encoder_q(q_input)  # queries: NxC

        qt = nn.functional.normalize(qt, dim=1)
        qs = nn.functional.normalize(qs, dim=1)
        qi = nn.functional.normalize(qi, dim=1)

        # 计算key特征
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            kt, ks, ki = self.encoder_k(k_input)  # keys: NxC

            kt = nn.functional.normalize(kt, dim=1)
            ks = nn.functional.normalize(ks, dim=1)
            ki = nn.functional.normalize(ki, dim=1)
        # 正负样本
        l_pos_ti = torch.einsum('nc,nc->n', [qt, ki]).unsqueeze(1)
        l_pos_si = torch.einsum('nc,nc->n', [qs, ki]).unsqueeze(1)
        l_pos_it = torch.einsum('nc,nc->n', [qi, kt]).unsqueeze(1)
        l_pos_is = torch.einsum('nc,nc->n', [qi, ks]).unsqueeze(1)

        l_neg_ti = torch.einsum('nc,ck->nk', [qt, self.i_queue.clone().detach()])
        l_neg_si = torch.einsum('nc,ck->nk', [qs, self.i_queue.clone().detach()])
        l_neg_it = torch.einsum('nc,ck->nk', [qi, self.t_queue.clone().detach()])
        l_neg_is = torch.einsum('nc,ck->nk', [qi, self.s_queue.clone().detach()])
        # 损失函数
        logits_ti = torch.cat([l_pos_ti, l_neg_ti], dim=1)
        logits_si = torch.cat([l_pos_si, l_neg_si], dim=1)
        logits_it = torch.cat([l_pos_it, l_neg_it], dim=1)
        logits_is = torch.cat([l_pos_is, l_neg_is], dim=1)

        logits_ti /= self.T
        logits_si /= self.T
        logits_it /= self.T
        logits_is /= self.T

结构化数据增强

本位在空间、时间部分分别提出了不同的增强策略,空间部分提出结构引导的空间掩码,时间部分提出基于管道的时间掩码

  • 结构引导的空间掩码

考虑到骨架的物理结构,当选择某个关节进行掩码时,模型可能通过周围的点学习到相关信息,掩码效果不佳。通过施加结构约束,本文的方法在当前随机选择的关节或框架周围的局部区域内应用掩码操作,而不是仅依赖于孤立的点本文同时对其相邻区域的进行掩码。让本文用矩阵p来表示邻接关系,如果关节i和j连通,则Pij = 1,否则Pij= 0。令D = Pn。为了施加结构约束,当节点i被选中时,本文对Dij != 0的所有节点j执行相同的增强操作。

  • 基于管道的时间掩码

基于管道的时间掩码的核心思想是通过将时间序列数据分为多个管道,并为每个管道生成对应的时间掩码,来提取关键的行为特征。具体而言,时间掩码是一种二进制序列,用于指示时间序列中的重要时间段。通过对时间序列数据进行分割,并根据具体的行为任务和特征需求,选择性地将时间掩码应用于每个管道。结构图图下:
  这种方法的优势在于,它可以将注意力集中在对行为识别最有用的时间段上,从而提高模型对关键动作的感知能力。时间掩码的生成可以根据不同的策略进行,如基于阈值、基于能量或基于模式识别等方法。生成的时间掩码可以作为输入数据的权重,用于调整模型对不同时间段的重视程度

在这里插入图片描述

项目部署

准备工作
  • Pytorch环境
  • 安装依赖包,运行以下命令即可
pip install -r requirements.txt
数据准备

生成数据

下载数据集

  • NTU-RGB+D
  • PKU-MMD

数据处理

  • 用以下代码处理数据: python ntu_gendata.py
训练&测试
训练

训练 NTU-RGB+D 60数据集 在 Cross-Subject 评价标准下的预训练模型, 运行以下命令

    python ./pretraining.py --lr 0.01 --batch-size 64 --encoder-t 0.2   --encoder-k 8192 \
                --checkpoint-path ./checkpoints/pretrain/ \
                --schedule 351  --epochs 451  --pre-dataset ntu60 \
                --protocol cross_subject --skeleton-representation joint
测试

测试论文给出模型在行为识别任务下NTU-RGB+D 60 数据集 Cross-Subject 评价标准上的结果, 运行以下命令

    python ./action_classification.py --lr 2 --batch-size 1024 \
                --pretrained ./checkpoints/pretrain/checkpoint.pth.tar \
                --finetune-dataset ntu60 --protocol cross_subject --finetune_skeleton_representation joint

测试论文给出模型在行为检索任务下NTU-RGB+D 60 数据集 Cross-Subject 评价标准上的结果, 运行以下命令

    python ./action_retrieval.py --knn-neighbours 1 \
                --pretrained ./checkpoints/pretrain/checkpoint.pth.tar \
                --finetune-dataset ntu60 --protocol cross_subject --finetune-skeleton-representation joint

bug修改

在复现原文代码的时候,直接运行运行./pretraining文件,会出现如下错误

在这里插入图片描述

  • 错误如下:

TypeError: init() got an unexpected keyword argument ‘batch_first’

  • 原因分析:

init方法没有参数‘batch_first’,所以将batch_first删去即可。在更改batch_first参数的请务必同时输入数据进行同步调整。具体来说,batch_first=True时的输入维度为 (batch, seq, feature),否则对应的输入维度需要调整为(seq, batch, feature)

  • 解决方法:

将报错代码中encoder_layer部分替换为如下代码,即可正常运行

        encoder_layer = TransformerEncoderLayer(self.d_model, num_head, self.d_model)
        self.t_encoder = TransformerEncoder(encoder_layer, num_layer)
        self.s_encoder = TransformerEncoder(encoder_layer, num_layer)

    def forward(self, x):
        
        vt = self.gcn_t(x)

        vt = rearrange(vt, '(B M) C T V -> B T (M V C)', M=2)
        vt = self.channel_t(vt)

        vs = self.gcn_s(x)
        
        vs = rearrange(vs, '(B M) C T V -> B (M V) (T C)', M=2)
        vs = self.channel_s(vs)

        # batch_first=True时的输入维度为 (batch, seq, feature),
        # 否则对应的输入维度需要调整为(seq, batch, feature)
        # 通过transpose函数调整维度
        vt = vt.transpose(0, 1)  # 调整为 (T, B, M*V*C)
        vs = vs.transpose(0, 1)  # 调整为 (T, B*M*V, C)

        vt = self.t_encoder(vt) # 
        vs = self.s_encoder(vs) #

        vt = vt.amax(dim=1)
        vs = vs.amax(dim=1)

        return vt, vs

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2215136.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GDB基本使用指南

什么是 GDB&#xff1f; GDB&#xff08;GNU Debugger&#xff09;是一个强大的调试工具&#xff0c;主要用于调试 C、C 和其他语言编写的程序。 它让我们可以监控、控制程序的执行&#xff0c;从而查找并修复错误。 安装 GDB ubuntu上一条命令可以搞定&#xff1a; sudo …

STM32-ADC模数转换

一、概述 ADC&#xff08;Analog-Digital Converter&#xff09;模拟-数字转换器 ADC可以将引脚上连续变化的模拟电压转换为内存中存储的数字变量&#xff0c;建立模拟电路到数字电路的桥梁12位逐次逼近型ADC&#xff0c;1us转换时间输入电压范围&#xff1a;0~3.3V&#xff…

使用Mockaroo生成测试数据

使用Mockaroo生成测试数据 最近在学习【Spring Boot & React】Spring Boot和React教程视频的P51.Generating 1000 students一课中&#xff0c;看到了https://www.mockaroo.com/网站可以用来模拟生成测试数据&#xff0c;觉得还不错&#xff0c;特此记录一下。感觉每次看老…

基于SSM+微信小程序的宠物管理系统1

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于SSM微信小程序的宠物管理系统实现了管理员、店主、用户。 管理员实现了店主管理、附件宠物店、管理员、用户管理、猫狗查询、猫狗宠物社区、商品信息等、店主实现了商品信息管理。用户…

高翔【自动驾驶与机器人中的SLAM技术】学习笔记(十一)ESKF中融合速度观测量;发散的原因;如何解决发散;以及对slam的理解

带着问题去学习: 1、slam发散的原因? 2、如何解决/限制发散? 3、如何在已经有观察值和预测值的ESKF中,再引入一个其他其他观察量? 一、多传感器融合的思考——轮速计 反思为何需要融合多个传感器? 我认为根本上的原因,是因为有些传感器在某些场景下会失灵、效果不佳…

[AWS云]kafka调用和创建

背景:因为因为公司的项目需要使用AWS的kafka&#xff0c;但是在创建和使用过程中都遇到了一些报错和麻烦&#xff0c;毕竟老外的东西&#xff0c;和阿里云、华为使用起来还是不一样。 一、创建&#xff08;创建的配置过程就略了&#xff0c;就是配置一下可用区、型号&#xff0…

1. 到底什么是架构

1. 什么是架构 定义&#xff1a;架构&#xff0c;又名软件架构&#xff0c;是有关软件整体结构与组件的抽象描述&#xff0c;用于指导大型软件系统各个方面的设计优秀架构的特点&#xff1a;优秀的性能、超强的TPS/QPS的承载能力、高可用决定了你能够支撑多少PV的流量 2. 什么…

AUTOSAR_EXP_ARAComAPI的5章笔记(12)

☞返回总目录 5.4.6 方法 骨架侧的服务方法是抽象方法&#xff0c;必须由继承骨架的服务实现子类进行重写。让我们来看一下我们服务示例中的 Adjust 方法&#xff1a; /*** 对于所有输出和非空返回参数* 生成一个包含非空返回值和/或输出参数的封装结构。*/ struct AdjustOu…

智能之眼:如何用监督学习教机器看懂世界

智能之眼&#xff1a;如何用监督学习教机器看懂世界 智能之眼&#xff1a;如何用监督学习教机器看懂世界前言什么是监督学习&#xff1f;监督学习的工作流程监督学习的类型 监督学习的常用算法1. 线性回归&#xff08;Linear Regression&#xff09;线性回归的优缺点 2. 逻辑回…

ui入门

一、QWidget类 QWidget是Qt中所有用户界面对象的基类&#xff0c;即可视化组件和窗口的基类都是此类&#xff0c;因此QWidget类内部包含了大量的与UI相关的基础特性。 最最基础的属性&#xff1a; width : const int 宽度&#xff0c;单位像素&#xff0c;不计算边框。属性在文…

房屋租赁系统(论文+源码)-kaic

摘 要 社会的发展和科学技术的进步&#xff0c;互联网技术越来越受欢迎。网络计算机的生活方式逐渐受到广大人民群众的喜爱&#xff0c;也逐渐进入了每个用户的使用。互联网具有便利性&#xff0c;速度快&#xff0c;效率高&#xff0c;成本低等优点。 因此&#xff0c;构建符…

下载Edge/Chrome浏览器主题的背景图片

当我们为Edge安装了心仪的主题后&#xff0c;希望把对应的背景图片下载保存要怎么做呢&#xff0c;以下图的“湖心小屋”主题为例。如下图&#xff0c;我们已经在应用商店中按照了该主题。 当打开新标签页后&#xff0c;可以欣赏这个主题内置的背景图片。 如果想要下载这个背景…

安装macOS Sequoia注意事项

随着macOS Sequoia的发布&#xff0c;许多Mac用户开始计划升级到这一最新版本。然而&#xff0c;升级系统并非简单点击“升级”按钮即可。在安装新系统之前&#xff0c;有一些关键的注意事项可以帮助你避免潜在的问题&#xff0c;确保顺利过渡到macOS Sequoia。本文将详细介绍在…

《深度学习》【项目】自然语言处理——情感分析 <上>

目录 一、项目介绍 1、项目任务 2、评论信息内容 3、待思考问题 1&#xff09;目标 2&#xff09;输入字词格式 3&#xff09;每一次传入的词/字的个数是否就是评论的长度 4&#xff09;一条评论如果超过32个词/字怎么处理&#xff1f; 5&#xff09;一条评论如果…

源码编译 FunASR for windows on arm

源码编译 FunASR for windows on arm 这里有编译好的&#xff0c;直接下载使用 https://github.com/turingevo/FunASR-build/releases 编译 1 下载 onnxruntime-win-arm64&#xff1a; https://github.com/microsoft/onnxruntime/releases/download/v1.16.1/onnxruntime-win…

最优化方法-Goldstein准则学习记录(matlab代码实现)

目录 一、前言 二、定义 三、代码实现 四、改良后 五、总结 一、前言 作为非精确线性搜索方法的一种&#xff0c;旨在降低计算量&#xff0c;提高算法效率。在迭代过程中没有必要把线性搜索搞得十分精确&#xff0c;因此我们可以放松对的精度要求&#xff0c;只要求每一步…

葵花卫星影像数据NC转tif

数据介绍 葵花8号卫星(Himawari-8)是日本发射的静止轨道气象卫星,由日本气象厅(JMA)运营。该卫星自2015年7月7日开始正式启用,主要用于观测东亚和西太平洋区域的天气情况。葵花8号卫星搭载了先进的光学仪器,能够提供高分辨率的气象数据。 卫星分辨率 葵花8号卫星的主要…

Python学习-注释,输入,运算符

python中的注释 单行注释以#开头多行注释 这是一段多行注释。 你可以在这里写很多行注释&#xff0c; 这些内容都不会被Python解释器执行。 中文编码注释#coding:utf-8按住ctrl\ 多行注释 输入函数 input() 输入值的类型为str 基本使用 presentinput(输入的提示) print(pre…

STL.string(中)

string 迭代器findswapsubstrrfindfind_first_of&#xff08;用的很少&#xff09;find_last_of&#xff08;用的很少&#xff09;find_first_not_of&#xff08;用的很少&#xff09; 迭代器 int main() {//正向迭代器string s1("hello world!");string::iterator i…

PCL 渐进式形态学滤波

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 如果不太了解点云数学形态学的基本理论,可以先阅读这篇文章:https://blog.csdn.net/dayuhaitang1/article/details/123172437。形态学中的窗口结构一直存在着这样的问题:如果窗口结构元尺寸过小,则无法去除一些…