用于视觉对象跟踪的序列到序列学习

news2024/11/24 11:01:29

阅读完此论文后,对着代码过一遍思路

原文地址:https://arxiv.org/abs/2304.14394

本文将视觉跟踪建模为一个序列生成问题,以自回归的方式预测目标边界框。抛弃了设计复杂的头网络,采用encoder-decoder transformer architecture,编码器用ViT提取视觉特征(≈OSTrack),而解码器用因果转换器自回归生成一个边界框值序列。

图像表示

编码器输入模板和搜索图像。现有的跟踪器中模板图像的分辨率通常小于搜索图像的分辨率,SeqTrack使用相同的尺寸,发现在模板中添加更多的背景有助于提高跟踪性能(在其他工作里使用小尺寸模板图像都解释的是为了减少背景干扰,这里说法不是很一致)。

DATA:
  MAX_SAMPLE_INTERVAL: 400
  MEAN:
  - 0.485
  - 0.456
  - 0.406
  SEARCH:
    CENTER_JITTER: 3.5
    FACTOR: 4.0
    SCALE_JITTER: 0.5
    SIZE: 384
    NUMBER: 1
  STD:
  - 0.229
  - 0.224
  - 0.225
  TEMPLATE:
    CENTER_JITTER: 0
    FACTOR: 4.0
    SCALE_JITTER: 0
    SIZE: 384
    NUMBER: 2
  TRAIN:
    DATASETS_NAME:
    - GOT10K_train_full
    DATASETS_RATIO:
    - 1
    SAMPLE_PER_EPOCH: 30000
序列表示

边界框转换为离散序列[ x , y , w , h ] [ x , y , w , h][x,y,w,h],将每个连续坐标统一离散为[ 1 , nbins]之间的整数。使用共享词汇表V(4000),V中的每个单词对应一个可学习的嵌入,在训练过程中进行优化。如下代码所示:

class DecoderEmbeddings(nn.Module):
    def __init__(self, vocab_size, hidden_dim, max_position_embeddings, dropout):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            vocab_size, hidden_dim)
        self.position_embeddings = nn.Embedding(
            max_position_embeddings, hidden_dim
        )

        self.LayerNorm = torch.nn.LayerNorm(
            hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        input_embeds = self.word_embeddings(x)
        embeddings = input_embeds

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

最终使用一个带softmax的多层感知器,根据输出嵌入对V中的单词进行采样来将嵌入映射回单词。

模型架构

它的基本架构是这样:

SeqTrack的架构,a:左边编码器拿的ViT,右边解码器用的transformer里的。编码器提取视觉特征,解码器利用特征自回归生成边界框序列。b:解码器结构,最下层输入目标序列,先自注意力再与视觉特征做注意力,自回归输出生成目标序列。

解码时加入一个因果注意力掩码(NLP那边用的差不多,防止偷看后边的果)
使用了两个特殊的标记:start和end。开始令牌告诉模型开始生成,而结束令牌则表示生成的完成。

训练时,解码器的输入序列为[start,x,y,w,h] [start,x,y,w,h][start,x,y,w,h],目标序列为[ x , y , w , h , e n d ] [ x , y , w , h , end][x,y,w,h,end]。(NLP里的)

编码器
  • 去掉了分类用的cls token。
  • 在最后一层附加一个线性投影来对齐编码器和解码器的特征维度。
self.bottleneck = nn.Linear(encoder.num_channels, hidden_dim) 
  • 只有搜索图像的特征被送入解码器。
    def forward_features(self, images_list):
        num_template = self.num_template
        template_list = images_list[0:num_template]
        search_list = images_list[num_template:]
        num_search = len(search_list)

        z_list = []
        for i in range(num_template):
            z = template_list[i]
            z = self.patch_embed(z)
            z = z + self.pos_embed[:, self.num_patches_search:, :]
            z_list.append(z)
        z_feat = torch.cat(z_list, dim=1)

        x_list = []
        for i in range(num_search):
            x = search_list[i]
            x = self.patch_embed(x)
            x = x + self.pos_embed[:, :self.num_patches_search, :]
            x_list.append(x)
        x_feat = torch.cat(x_list, dim=1)
        xz_feat = torch.cat([x_feat, z_feat], dim=1)

        xz = self.pos_drop(xz_feat)

        for blk in self.blocks:   #batch is the first dimension.
            if self.use_checkpoint:
                xz = checkpoint.checkpoint(blk, xz)
            else:
                xz = blk(xz)

        xz = self.norm(xz) # B,N,C
        return xz
解码器
  • 接收来自前一个block的词嵌入并利用一个因果关系掩码保证每个序列元素的输出只依赖于其前面的序列元素。

生成因果掩码代码如下:

def generate_square_subsequent_mask(sz):
    r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
    """

    #each token only can see tokens before them
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float(
        '-inf')).masked_fill(mask == 1, float(0.0))
    return mask

训练和推理
训练

损失函数:通过交叉熵损失最大化target tokens在前一个子序列和输入视频帧上的对数似然。

推理

引入在线模板更新和窗口惩罚,在推理过程中融合先验知识,进一步提高了模型精度和鲁棒性,使用generated tokens的似然来自动选择可靠的动态模板。引入了一种新的窗口惩罚策略,当前搜索区域中心点的离散坐标为[ n_bins / 2 , n_bins / 2],即为上一帧目标中心点位置。在生成x和y时,我们根据整数(即词)与nbins / 2的差来惩罚V中整数(即词)的可能性。差值越大惩罚越大。

实验

创建seqtrack虚拟环境并且激活

conda create -n seqtrack python=3.8
conda activate seqtrack

所需要的安装包如下所示

echo "****************** Installing pytorch ******************"
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
#conda install -y pytorch=1.11 torchvision torchaudio cudatoolkit=11.3 -c pytorch

echo ""
echo ""
echo "****************** Installing yaml ******************"
pip install PyYAML

echo ""
echo ""
echo "****************** Installing easydict ******************"
pip install easydict

echo ""
echo ""
echo "****************** Installing cython ******************"
pip install cython

echo ""
echo ""
echo "****************** Installing opencv-python ******************"
pip install opencv-python

echo ""
echo ""
echo "****************** Installing pandas ******************"
pip install pandas

echo ""
echo ""
echo "****************** Installing tqdm ******************"
conda install -y tqdm

echo ""
echo ""
echo "****************** Installing coco toolkit ******************"
pip install pycocotools

echo ""
echo ""
echo "****************** Installing jpeg4py python wrapper ******************"
pip install jpeg4py

echo ""
echo ""
echo "****************** Installing tensorboard ******************"
pip install tb-nightly

echo ""
echo ""
echo "****************** Installing tikzplotlib ******************"
pip install tikzplotlib

echo ""
echo ""
echo "****************** Installing thop tool for FLOPs and Params computing ******************"
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

echo ""
echo ""
echo "****************** Installing colorama ******************"
pip install colorama

echo ""
echo ""
echo "****************** Installing lmdb ******************"
pip install lmdb

echo ""
echo ""
echo "****************** Installing scipy ******************"
pip install scipy

echo ""
echo ""
echo "****************** Installing visdom ******************"
pip install visdom

echo ""
echo ""
echo "****************** Installing vot-toolkit python ******************"
pip install git+https://github.com/votchallenge/vot-toolkit-python

echo ""
echo ""
echo "****************** Installing timm ******************"
pip install timm==0.5.4

echo ""
echo ""
echo "****************** Installing yacs ******************"
pip install yacs

echo ""
echo ""


echo "****************** Installation complete! ******************"

运行以下命令安装

bash install.sh

将项目路径添加到环境变量

export PYTHONPATH=<absolute_path_of_SeqTrack>:$PYTHONPATH

跟踪数据格式如下所示

${SeqTrack_ROOT}
 -- data
     -- lasot
         |-- airplane
         |-- basketball
         |-- bear
         ...
     -- got10k
         |-- test
         |-- train
         |-- val
     -- coco
         |-- annotations
         |-- images
     -- trackingnet
         |-- TRAIN_0
         |-- TRAIN_1
         ...
         |-- TRAIN_11
         |-- TEST

运行以下命令来设置此项目的路径

python tracking/create_default_local_file.py --workspace_dir . --data_dir ./data --save_dir .

训练SeqTrack

python -m torch.distributed.launch --nproc_per_node 8 lib/train/run_training.py --script seqtrack --config seqtrack_b256 --save_dir .

根据基准进行测试和评估这个部分还未全部完成,后续会持续跟进

code

SeqTrack:(SeqTrack-L256为例)

class SEQTRACK(nn.Module):
    """ This is the base class for SeqTrack """
    def __init__(self, encoder, decoder, hidden_dim,
                 bins=1000, feature_type='x', num_frames=1, num_template=1):
        """ Initializes the model.
        Parameters:
            encoder: torch module of the encoder to be used. See encoder.py
            decoder: torch module of the decoder architecture. See decoder.py
        """
        super().__init__()
        self.encoder = encoder
        self.num_patch_x = self.encoder.body.num_patches_search
        self.num_patch_z = self.encoder.body.num_patches_template
        self.side_fx = int(math.sqrt(self.num_patch_x))
        self.side_fz = int(math.sqrt(self.num_patch_z))
        self.hidden_dim = hidden_dim
        self.bottleneck = nn.Linear(encoder.num_channels, hidden_dim) # the bottleneck layer, which aligns the dimmension of encoder and decoder
        self.decoder = decoder
        self.vocab_embed = MLP(hidden_dim, hidden_dim, bins+2, 3)

        self.num_frames = num_frames
        self.num_template = num_template
        self.feature_type = feature_type

        # Different type of visual features for decoder.
        # Since we only use one search image for now, the 'x' is same with 'x_last' here.
        if self.feature_type == 'x':
            num_patches = self.num_patch_x * self.num_frames
        elif self.feature_type == 'xz':
            num_patches = self.num_patch_x * self.num_frames + self.num_patch_z * self.num_template
        elif self.feature_type == 'token':
            num_patches = 1
        else:
            raise ValueError('illegal feature type')

        # position embeding for the decocder
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_dim))
        pos_embed = get_sinusoid_encoding_table(num_patches, self.pos_embed.shape[-1], cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

encoder:(ViT)

@register_model
def vit_large_patch16(pretrained=False, pretrain_type='default',
                              search_size=384, template_size=192, **kwargs):
    patch_size = 16
    model = VisionTransformer(
        search_size=search_size, template_size=template_size,
        patch_size=patch_size, num_classes=0,
        embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    cfg_type = 'vit_large_patch16_224_' + pretrain_type
    if pretrain_type == 'scratch':
        pretrained = False
        return model
    model.default_cfg = default_cfgs[cfg_type]
    if pretrained:
        load_pretrained(model, pretrain_type, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
    return model

decoder:(DETR)

class SeqTrackDecoder(nn.Module):

    def __init__(self, d_model=512, nhead=8,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False, bins=1000, num_frames=9):
        super().__init__()
        self.bins = bins
        self.num_frames = num_frames
        self.num_coordinates = 4 # [x,y,w,h]
        max_position_embeddings = (self.num_coordinates+1) * num_frames
        self.embedding = DecoderEmbeddings(bins+2, d_model, max_position_embeddings, dropout)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.body = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1846329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手机数据如何恢复?11 款最佳安卓手机恢复软件

媒体可能由于各种原因而从您的设备中删除&#xff0c;可能是意外或病毒攻击。 在这些情况下&#xff0c;照片恢复应用程序是唯一的解决方案。理想的照片恢复应用程序取决于各种因素&#xff0c;例如存储设备的损坏程度、删除照片后的持续时间以及应用程序使用的恢复算法的有效性…

玩玩大模型:总结归纳可以,策划创新拉垮

最近身边的人都在研究大模型。太深入的理解不了&#xff0c;有一些人会讲讲promt提示&#xff0c;学了几招。 比如&#xff1a; #角色 你是一个美食博主 #条件 我只有xxx元&#xff0c;在xxx.... #任务 找一家好吃的当地特色餐馆... 多试几次&#xff0c;有些结果很有参考价值…

前端学习-day10

文章目录 01-体验平面转换02-平移效果03-绝对定位元素居中04-案例-双开门06-转换旋转中心点07-案例-时钟-转换原点08-平面转换-多重转换09-缩放效果10-案例-按钮缩放11-倾斜效果12-渐变-线性13-案例-产品展示14-渐变-径向15-综合案例-喜马拉雅 01-体验平面转换 <!DOCTYPE h…

PCAtools|主成分分析

library(PCAtools) library(tidyverse) ls(package:PCAtools) iris <- as.data.frame(iris) iris <- iris %>% mutate(class str_c("a",1:dim(iris)[1],sep "")) rownames(iris) <- iris$class iris <- iris[,-6] head(iris) # 构建矩阵 …

国有企业数字化转型常见思考框架与路线图

一、国有企业数字化转型思考框架 在之前的十九届四中全会《关于坚持和完善中国特色社会主义制度推进国家治理体系和治理能力现代化若干重大问题的决定》中明确“推进国有经济布局优化和结构调整&#xff0c;发展混合所有制经济&#xff0c;增强国有经济竞争力、创新力、控制力…

麒麟移动运行环境(KMRE)——国内首个开源的商用移固融合“Android生态兼容环境”正式开源

近日&#xff0c;由麒麟软件研发的KMRE&#xff08;Kylin Mobile Runtime Environment&#xff0c;麒麟移动运行环境&#xff09;在openKylin&#xff08;开放麒麟&#xff09;社区正式发布&#xff0c;为Linux桌面操作系统产品提供了高效的Android运行环境解决方案。这也是国内…

【免费API推荐】:各类API资源免费获取【11】

欢迎来到各类API资源的免费获取世界&#xff01;幂简集成为您提供了一个集合了各种免费API接口的平台。无论您是开发者、数据分析师还是创业者&#xff0c;都可以通过我们的平台轻松免费获取所需的API资源。幂简精心整理了各类API接口&#xff0c;涵盖了不同领域的需求&#xf…

如何避免在React中的回调函数中使用箭头函数可能引起的内存泄漏?

在React中&#xff0c;箭头函数在回调函数中的使用确实可能引发性能问题&#xff0c;尤其是当这些函数在渲染方法或者组件内部被定义时。每次组件重新渲染时&#xff0c;都会创建这些函数的新实例&#xff0c;这可能导致不必要的计算和内存使用&#xff0c;甚至在某些情况下引发…

大模型RAG应用优化实战

之前体验OpenAI GPT-4o模型的时候&#xff0c;感觉到大语言模型进化的太快&#xff0c;基于AI应用做出的努力可能很快就被新一代模型降维打击&#xff0c;变得没有价值了&#xff0c;“人生苦短&#xff0c;终归尘土”&#xff0c;最终都化为虚无&#xff0c;还有什么意义呢&am…

多维表格/业务库表格大数据量性能瓶颈

先说最终结论&#xff1a;Angular 组件创建性能损耗是当下主要的性能瓶颈 理由&#xff1a; 基于以往编辑器性能优化的经验&#xff0c;编辑器在动态渲染内容时会创建很多壳子组件&#xff08;也就是Angular 组件&#xff09;&#xff0c;排查的时候就发现如果略这些壳子组件性…

探索Lazada商品数据宝库——一键获取商品详细数据信息

一、引言 在电商领域&#xff0c;Lazada凭借其广泛的商品种类和便捷的购物体验&#xff0c;成为东南亚地区备受欢迎的电商平台。然而&#xff0c;对于许多商家和数据分析师来说&#xff0c;获取商品详细数据信息却是一项繁琐而重要的任务。为了解决这个问题&#xff0c;我们精…

内容安全复习 9 - 身份认证系统攻击与防御

文章目录 基于生物特征的身份认证系统概述基于生物特征的身份认证 人脸活体检测检测方法未解决问题 基于生物特征的身份认证系统概述 作用&#xff1a;判别用户的身份、保障信息系统安全。 是识别操作者身份的过程&#xff0c;要保证其**物理身份&#xff08;现实&#xff0…

openppp2 命令行接口详解

openppp2 是一个工作在 OSI/3 Layer 网络通信层的虚拟以太网工具链的开源软件&#xff0c;在查阅本文之前&#xff0c;人们可以查阅以下资料。 开源仓库&#xff1a; liulilittle/openppp2: PPP PRIVATE NETWORK™ 2 VPN Next Generation Reliable and Secure Virtual Etherne…

一次压测引发的数据库 CPU 飙升

作者&#xff1a;昀鹤 一次压测过程中&#xff0c;当数据库的 qps 和 tps 都正常时&#xff0c;如果 cpu 利用率异常的高&#xff0c;应该如何排查&#xff1f;希望通过这篇文章&#xff0c;给你一些启发... 一、业务背景 业务需要控制频道内兑换现金的数量&#xff0c;于是在…

恭喜行云绽放,24年再度荣获国家鼓励的企业软件证书

在刚刚过去的五月份&#xff0c;行云绽放再次传来一个好消息&#xff0c;那就是2024年行云绽放再度荣获国家鼓励的企业软件证书。 什么是国家鼓励的企业软件证书&#xff1f; 国家鼓励的企业软件证书被称为“国家鼓励的软件企业证书”&#xff0c;这一证书由中国软件行业协会…

网站https逐渐普及,选择合适的SSL证书

目录 为什么实现https访问逐渐成为主流&#xff1a; 为什么要选择合适的SSL证书&#xff1a; 目前主流的三种域名证书及IP证书&#xff1a; 怎样申请SSL证书&#xff1a; 随着国内网络安全信息的逐渐普及&#xff0c;绝大部分的网站目前都配置了SSL证书用于实现https访问&a…

ModelScope联手OpenDataLab:直接调用7000+开源数据集,赋能AI模型加速研发

在人工智能的演进历程中&#xff0c;数据和模型的整合是推动技术发展的核心动力。随着AI技术的不断进步&#xff0c;整合各类关键资源&#xff0c;构建一个高效、协同的开发环境&#xff0c;已成为加速创新应用发展的关键。 基于这一理念&#xff0c;OpenDataLab浦数与ModelSc…

解锁私域电商潜力:构建与维护强大私域生态

大家好&#xff0c;我是专注于私域电商领域的技术专家&#xff0c;拥有丰富的行业经验。在今天的分享中&#xff0c;我将带大家深入理解私域流量的精髓&#xff0c;并探讨如何构建一个充满活力且高效的私域生态。在数字化浪潮下&#xff0c;如何深化用户关系并挖掘其潜在价值&a…

mybatis动态传参pgsql日期Interval

在navicat16中&#xff0c;标准写法 SELECT * FROM business_status_info WHERE create_time > (NOW() - INTERVAL 5 minutes) 在mybatis中&#xff0c;错误写法 SELECT * FROM business_status_info WHERE create_time > (NOW() - INTERVAL #{monitorTimeInterval,jdbc…

git 配置私人令牌

这里写自定义目录标题 获取私人令牌配置个人令牌 获取私人令牌 在个人设置里点击私人令牌选型&#xff0c;之后生成令牌即可。注意&#xff1a;令牌只会出现一次&#xff0c;务必保存好。 配置个人令牌 个人令牌&#xff1a;3c15c866fa61066212a83c66fd8133ba # 进入项目文…