【计算机视觉】siamfc论文复现实现目标追踪

news2024/9/22 4:17:04

什么是目标跟踪

使用视频序列第一帧的图像(包括bounding box的位置),来找出目标出现在后序帧位置的一种方法。

什么是孪生网络结构

孪生网络结构其思想是将一个训练样本(已知类别)和一个测试样本(未知类别)输入到两个CNN(这两个CNN往往是权值共享的)中,从而获得两个特征向量,然后通过计算这两个特征向量的的相似度,相似度越高表明其越可能是同一个类别。

在这里插入图片描述

给你一张我的正脸照(没有经过美颜处理的),你该如何在人群中找到我呢?一种最直观的方案就是:“谁长得最像就是谁”。但是对于计算机来说,如何衡量“长得像”,并不是个简单的问题。这就涉及一种基本的运算——互相关(cross-correlation)。互相关运算可以用来度量两个信号之间的相似性。互相关得到的响应图中每个像素的响应高低代表着每个位置相似度的高低。

在这里插入图片描述

在目标领域中,最早利用这种思想的是SiamFC,其网络结构如上图。图中的φ就是CNN编码器,上下两个分支使用的CNN不仅结构相同,参数也是完全共享的(说白了就是同一个网络,并不存在孪生兄弟那样的设定)。z和x分别是要跟踪的目标模版图像(尺寸为127x127)和新的一帧中的搜索范围(尺寸为255x255)。二者经过同样的编码器后得到各自的特征图,对二者进行互相关运算后则会同样得到一个响应图(尺寸为17x17),其每一个像素的值对应了x中与z等大的一个对应区域出现跟踪目标的概率。

互相关运算的步骤,像极了我们手里拿着一张目标的照片(模板图像),然后把这个照片按在需要寻找目标的图片上(搜索图像)进行移动,然后求重叠部分相似度,从而找到这个目标,只不过为了计算机计算的方便,使用AlexNet对图像数据进行了编码/特征提取

下面这个版本中有一些动图,还是会帮助理解的:https://github.com/rafellerc/Pytorch-SiamFC

SiamFC代码分析

我们对siamese的结构大致就讲完了,还有一些内容结合代码来讲,效果更好。

3.1 training

3.1.1图像预处理

小超up给出训练的框图如下。训练过程中,首先要获取训练数据集的所有视频序列(每个视频序列的所有帧),我采用的是GOT-10k数据集训练;获取数据集之后进行图像预处理,对每一个视频序列抽取两帧图像并作数据增强处理(包括裁剪、resize等过程),分别作为目标模板图像和搜索图像;把经过图像处理的所有图像对加载并以batch_size输入网络得到预测输出;建立标签和损失函数,损失函数的输入是预测输出,目标是标签;设置优化策略,梯度下降损失,最终得到网络模型。

在这里插入图片描述

先贴代码,再分析:

def train(data_dir, net_path=None,save_dir='pretrained'):
    #从文件中读取图像数据集
    seq_dataset = GOT10k(data_dir,subset='train',return_meta=False)
    #定义图像预处理方法
    transforms = SiamFCTransforms(  
        exemplar_sz=cfg.exemplar_sz, #127
        instance_sz=cfg.instance_sz, #255
        context=cfg.context) #0.5
    #从读取的数据集每个视频序列配对训练图像并进行预处理,裁剪等
    train_dataset = GOT10kDataset(seq_dataset,transforms)

data_dir是存放GOT-10k数据集的文件路径,GOT-10k一共有9335个训练视频序列,seq_dataset返回的是所有视频序列的图片路径列表seq_dirs及对应groundtruth列表anno_files及一些其他信息,如下:

img

接下来是定义好图像预处理方法,在GOT10kDataset方法中对每个视频序列配对两帧图像,并使用定义好的图像处理方法,接下来直接进入该方法分析代码,GOT10kDataset的代码如下:

class GOT10kDataset(Dataset): #继承了torch.utils.data的Dataset类
    def __init__(self, seqs, transforms=None,pairs_per_seq=1):
    def __getitem__(self, index): #通过_sample_pair方法得到索引返回item=(z,x,box_z,box_x),然后经过transforms处理
    def __len__(self): #返回9335*pairs_per_seq对
    def _sample_pair(self, indices): #随机挑选两个索引,这里取的间隔不超过T=100
    def _filter(self, img0, anno, vis_ratios=None): #通过该函数筛选符合条件的有效索引val_indices

这里最重要的方法就是__getitem__,该方法最终返回处理后的图像,在内部首先调用了_sample_pair方法,用于提取两帧有效图片(有效的定义是图片目标的面积和高宽等有约束条件)的索引,在得到这两帧图片和对应groundtruth之后通过定义好的transforms进行处理,transforms是SiamFCTransforms类的实例化对象,该类中主要继承了resize图片大小和各种裁剪方式等,如代码所示:

class SiamFCTransforms(object):
    def __init__(self, exemplar_sz=127, instance_sz=255, context=0.5):
        self.exemplar_sz = exemplar_sz
        self.instance_sz = instance_sz
        self.context = context
        #transforms_z/x是数据增强方法
        self.transforms_z = Compose([
            RandomStretch(),     #随机resize图片大小,变化再[1 1.05]之内
            CenterCrop(instance_sz - 8),  #中心裁剪 裁剪为255-8
            RandomCrop(instance_sz - 2 * 8),   #随机裁剪  255-8->255-8-8
            CenterCrop(exemplar_sz),   #中心裁剪 255-8-8->127
            ToTensor()])                        #图片的数据格式从numpy转换成torch张量形式
        self.transforms_x = Compose([
            RandomStretch(),                   #s随机resize图片
            CenterCrop(instance_sz - 8),      #中心裁剪 裁剪为255-8
            RandomCrop(instance_sz - 2 * 8),  #随机裁剪 255-8->255-8-8
            ToTensor()])                      #图片数据格式转化为torch张量
    
    def __call__(self, z, x, box_z, box_x): #z,x表示传进来的图像
        z = self._crop(z, box_z, self.instance_sz)       #对z(x类似)图像 1、box转换(l,t,w,h)->(y,x,h,w),并且数据格式转为float32,得到center[y,x],和target_sz[h,w]
        x = self._crop(x, box_x, self.instance_sz)       #2、得到size=((h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127
        z = self.transforms_z(z)                         #3、进入crop_and_resize:传入z作为图片img,center,size,outsize=255(instance_sz),随机选方式填充,均值填充
        x = self.transforms_x(x)                         #   以center为中心裁剪一块边长为size大小的正方形框(注意裁剪时的padd边框填充问题),再resize成out_size=255(instance_sz)
        return z, x

实例化对象后,直接从__call__开始运行代码,首先关注的应该是_crop函数,该函数将原始的两帧图片分别以目标为中心,裁剪一块包含上下文信息的patch,patch的边长定义如下:

在这里插入图片描述

式中,w、h分别表示目标的宽和高。下面具体讲里面的_crop函数:

    def _crop(self, img, box, out_size):
        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2,
            box[0] - 1 + (box[2] - 1) / 2,
            box[3], box[2]], dtype=np.float32)
        center, target_sz = box[:2], box[2:]

        context = self.context * np.sum(target_sz)
        size = np.sqrt(np.prod(target_sz + context))
        size *= out_size / self.exemplar_sz

        avg_color = np.mean(img, axis=(0, 1), dtype=float)
        interp = np.random.choice([
            cv2.INTER_LINEAR,
            cv2.INTER_CUBIC,
            cv2.INTER_AREA,
            cv2.INTER_NEAREST,
            cv2.INTER_LANCZOS4])
        patch = ops.crop_and_resize(
            img, center, size, out_size,
            border_value=avg_color, interp=interp)
        
        return patch

因为GOT-10k里面对于目标的bbox是以ltwh(即left, top, weight, height)形式给出的,上述代码一开始就先把输入的box变成center based,坐标形式变为[y, x, h, w],结合下面这幅图就非常好理解在这里插入图片描述img

crop_and_resize:

def crop_and_resize(img, center, size, out_size,
                    border_type=cv2.BORDER_CONSTANT,
                    border_value=(0, 0, 0),
                    interp=cv2.INTER_LINEAR):
    # convert box to corners (0-indexed)
    size = round(size)  # the size of square crop
    corners = np.concatenate((
        np.round(center - (size - 1) / 2),
        np.round(center - (size - 1) / 2) + size))
    corners = np.round(corners).astype(int)

    # pad image if necessary
    pads = np.concatenate((
        -corners[:2], corners[2:] - img.shape[:2]))
    npad = max(0, int(pads.max()))
    if npad > 0:
        img = cv2.copyMakeBorder(
            img, npad, npad, npad, npad,
            border_type, value=border_value)

    # crop image patch
    corners = (corners + npad).astype(int)
    patch = img[corners[0]:corners[2], corners[1]:corners[3]]

    # resize to out_size
    patch = cv2.resize(patch, (out_size, out_size),
                       interpolation=interp)

    return patch

在裁剪过程中会出现越界的情况,需要对原始图像边缘填充,填充值固定为图像的RGB均值,填充大小根据图像边缘越界最大值作为填充值,具体实现过程由以下代码完成。

# padding操作
	#corners表示目标的[ymin,xmin,ymax,xmax]
    pads = np.concatenate((
        -corners[:2], corners[2:] - img.shape[:2]))
    npad = max(0, int(pads.max())) #得到上下左右4个越界值中最大的与0对比,<0代表无越界
    if npad > 0:
        img = cv2.copyMakeBorder(
            img, npad, npad, npad, npad,
            cv2.BORDER_CONSTANT, value=img_average)

实验结果:

img

3.1.2加载训练数据、标签及损失函数

图像预处理完成后,得到了用与训练的9335对图像,将图像加载批量加载输入网络得到输出结果作为损失函数的input,损失函数的target是制定好的labels。

#加载训练数据集
    loader_dataset = DataLoader( dataset = train_dataset,
                                 batch_size=cfg.batch_size,
                                 shuffle=True,
                                 num_workers=cfg.num_workers,
                                 pin_memory=True,
                                 drop_last=True, )
    #初始化训练网络
    cuda = torch.cuda.is_available()  #支持GPU为True
    device = torch.device('cuda:0' if cuda else 'cpu')  #cuda设备号为0
    model = AlexNet(init_weight=True)
    corr = _corr()
    model = model.to(device)
    corr = corr.to(device)
    # 设置损失函数和标签
    logist_loss = BalancedLoss()
    labels = _create_labels(size=[cfg.batch_size, 1, cfg.response_sz - 2, cfg.response_sz - 2])
    labels = torch.from_numpy(labels).to(device).float()

本小节主要讲网络输出的labels和损失函数,接下来只是小超up个人的一些理解,代码与论文理论部分形式不一致,但效果一样。先上图,论文中labels以及损失函数如下图:

img

img

然而代码中的labels值却是1和0,损失函数使用的是二值交叉熵损失函数F.binary_cross_entropy_with_logits,如下图推导所示,解释了为什么代码实现部分真正使用的labels值是1和0,而理论部分使用的是1和-1。

img

利用下面代码的这个_creat_labels方法可以得到标签。

def _create_labels(size):
    def logistic_labels(x, y, r_pos):
        # x^2+y^2<4 的位置设为为1,其他为0
        dist = np.sqrt(x ** 2 + y ** 2)
        labels = np.where(dist <= r_pos,    #r_os=2
                          np.ones_like(x),  #np.ones_like(x),用1填充x
                          np.zeros_like(x)) #np.zeros_like(x),用0填充x
        return labels
    #获取标签的参数
    n, c, h, w = size  # [8,1,15,15]
    x = np.arange(w) - (w - 1) / 2  #x=[-7 -6 ....0....6 7]
    y = np.arange(h) - (h - 1) / 2  #y=[-7 -6 ....0....6 7]
    x, y = np.meshgrid(x, y)       
    #建立标签
    r_pos = cfg.r_pos / cfg.total_stride  # 16/8
    labels = logistic_labels(x, y, r_pos)
    #重复batch_size个label,因为网络输出是batch_size张response map
    labels = labels.reshape((1, 1, h, w))   #[1,1,15,15]
    labels = np.tile(labels, (n, c, 1, 1))  #将labels扩展[8,1,15,15]
    return labels

验证结果如下图,只截取了部分labels,得到的labels对应输入,大小都是[8,1,15,15]

if __name__ == '__main__':
    labels = _create_labels([8,1,15,15])  #返回的label.shape=(8,1,15,15)

其中关于np.tile、np.meshgrid、np.where函数的使用可以去看这篇博客,最后出来的一个batch下某一个通道下的label就是下面这样的

img

3.1.3 优化策略

这里主要说一下学习率lr,随着训练次数epoch增多而减小,具体值如下公式image-20240721105017162,式中,initial为初始学习率,gamma是定义的超参,epoch为训练次数。整个优化器及学习率调整实现代码如下:

#建立优化器,设置指数变化的学习率
    optimizer = optim.SGD(
        model.parameters(),
        lr=cfg.initial_lr,              #初始化的学习率,后续会不断更新
        weight_decay=cfg.weight_decay,  #λ=5e-4,正则化
        momentum=cfg.momentum)          #v(now)=dx∗lr+v(last)∗momemtum
    gamma = np.power(                   #np.power(a,b) 返回a^b
        cfg.ultimate_lr / cfg.initial_lr,
        1.0 / cfg.epoch_num)
    lr_scheduler = ExponentialLR(optimizer, gamma)  #指数形式衰减,lr=initial_lr*(gamma^epoch)=
3.1.4 模型的训练与保存

一切准备工作就绪后,就开始训练了。代码中设定epoch_num为50次,训练时密切加上model.train(),告诉网络处于训练状态,这样,网络运行时就会利用pytorch的自动求导机制求导;在测试时,改为model.eval(),关闭自动求导。模型训练的步骤如代码所示:

# loop over epochs
for epoch in range(self.cfg.epoch_num):
    # update lr at each epoch
    self.lr_scheduler.step(epoch=epoch)

    # loop over dataloader
    for it, batch in enumerate(dataloader):
        loss = self.train_step(batch, backward=True)
        print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
            epoch + 1, it + 1, len(dataloader), loss))
        sys.stdout.flush()
    
    # save checkpoint
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    net_path = os.path.join(
        save_dir, 'siamfc_alexnet_e%d.pth' % (epoch + 1))
    torch.save(self.net.state_dict(), net_path)

至此此份repo的训练应该差不多结束了

参考文档

siameseFC论文和代码解析

SiamFC 学习(论文、总结与分析)

siamfc-pytorch代码讲解(一):backbone&head

siamfc-pytorch代码讲解(二):train&siamfc

SiamFC代码分析(architecture、training、test)

http://www.360doc.com/content/19/0801/10/32196507_852333196.shtml

视频推荐

目标跟踪零基础代码入门(一):SiamFC_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SRC】小程序抓包巨详细配置,一个Burp就够了,但是可以更优雅!

小程序抓包配置 文章目录 小程序抓包配置0x00 前言0x01 直接使用BurpSuite抓包0x02 配合Proxifier 0x00 前言 其实在PC端抓微信小程序的包&#xff0c;只需要一个BurpSuite就足够了&#xff0c;但是为了避免抓一些没用的包&#xff0c;减少对小程序抓包测试过程中的干扰&#…

学生处分类型管理

在智慧校园学工管理系统中&#xff0c;"处分类型"功能扮演着至关重要的角色&#xff0c;它如同一座桥梁&#xff0c;连接着校园秩序与学生行为规范的两端。这一模块的核心精髓&#xff0c;在于它以精准的违规行为界定和适当的处分措施&#xff0c;巧妙地平衡了纪律的…

Qmi8658a姿态传感器使用心得(4)linux

1.FIFO 结构与大小 FIFO 数据可以包含陀螺仪和加速度计数据&#xff0c;通过 SPI/I2C/I3C 接口以突发读模式读取。FIFO 大小可配置为 16 样本、32 样本、64 样本或 128 样本&#xff08;每个样本为 6 字节&#xff09;。 2.FIFO 模式 Bypass 模式&#xff1a;禁用 FIFO 功能。…

SpringCloud03_loadbalancer的概述、负载均衡解析、切换、原理

文章目录 ①. Ribbon进入维护模式②. loadbalancer的概述③. loadbalancer负载均衡解析④. 负载均衡案例总结⑤. 负载均衡算法原理 ①. Ribbon进入维护模式 ①. Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具。 ②. 维护模式不再介绍,了解即可 ③.…

大语言模型-Transformer-Attention Is All You Need

一、背景信息&#xff1a; Transformer是一种由谷歌在2017年提出的深度学习模型。 主要用于自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;特别是序列到序列&#xff08;Sequence-to-Sequence&#xff09;的学习问题&#xff0c;如机器翻译、文本生成等。Transfor…

【python】Numpy运行报错分析:ValueError - 数组维度不一致

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

java中多态的用法

思维导图&#xff1a; 1. 多态的概念 多态通俗的讲就是多种形态&#xff0c;同一个动作&#xff0c;作用在不同对象上&#xff0c;所产生不同的形态。 例如下图&#xff1a; 2. 多态的实现条件 Java中&#xff0c;多态的实现必须满足以下几个条件&#xff1a; 1. 必须在继承…

动画革命:Lottie如何改变我们对移动应用交互的认知

在数字世界的浩瀚星空中&#xff0c;每一个像素都跃动着无限创意与想象的火花。当静态的界面遇上动态的魔法&#xff0c;一场视觉盛宴便悄然开启。今天&#xff0c;让我们一同揭开一位幕后英雄的神秘面纱——Lottie&#xff0c;这个在UI/UX设计界掀起波澜的动画利器&#xff0c…

[trick]使用生成器打破嵌套循环

原文 break用于结束循环。但是&#xff0c;如果有嵌套循环&#xff0c;如何跳出外层循环&#xff1f; def this_is_the_one(x):return x 3my_list [[1, 2], [3, 4], [5, 6]] for sublist in my_list:for element in sublist:print(f"Checking {element}")if this_…

农场驿站平台小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;农场资讯管理&#xff0c;卖家管理&#xff0c;用户分享管理&#xff0c;分享类型管理&#xff0c;商品信息管理&#xff0c;商品分类管理&#xff0c;系统管理&#xff0c;订单管…

天舟飞船可视化:直观体验太空任务全过程

利用图扑先进的 3D 可视化技术&#xff0c;实时展示天舟飞船的发射、对接和任务执行&#xff0c;为观众提供身临其境的太空探索体验。

奥比岛手游攻略:新手攻略大全!云手机辅助!

《奥比岛&#xff1a;梦想国度》是一款画风可爱的Q版休闲益智手游。在这个充满童话色彩的世界里&#xff0c;玩家们可以度过快乐的每一天&#xff0c;结交许多朋友&#xff0c;完成各种任务&#xff0c;体验丰富多彩的游戏玩法。下面将为大家带来详细的攻略大全。 游戏前瞻&…

Java 面试 | Redis

目录 1. 在项目中缓存是如何使用的&#xff1f;2. 为啥在项目中要用缓存&#xff1f;3. 缓存如果使用不当会造成什么后果&#xff1f;4. redis 和 memcached 有什么区别&#xff1f;5. redis 的线程模型是什么&#xff1f;6. 为什么单线程的 redis 比多线程的 memcached 效率要…

Python酷库之旅-第三方库Pandas(035)

目录 一、用法精讲 106、pandas.Series.iloc方法 106-1、语法 106-2、参数 106-3、功能 106-4、返回值 106-5、说明 106-6、用法 106-6-1、数据准备 106-6-2、代码示例 106-6-3、结果输出 107、pandas.Series.__iter__魔法方法 107-1、语法 107-2、参数 107-3、…

Science Robotics 一种使用导电嵌段共聚物弹性体和心理物理阈值来实现准确触觉效果的方法

速读&#xff1a;电触觉刺激作为感官替代的形式存在许多问题&#xff0c;如反应不一致、疼痛和脱敏等问题。加州大学Darren J. Lipomi教授团队研究了一种利用导电嵌段共聚物弹性体和心理物理阈值来实现准确触觉的方法。通过优化材料、设备布局和校准技术&#xff0c;他们在10名…

web服务器——虚拟主机配置实战

搭建静态网站 —— 基于 http 协议的静态网站 实验 1 &#xff1a;搭建一个 web 服务器&#xff0c;访问该服务器时显示 “hello world” 欢迎界面 。 实验 2 &#xff1a;建立两个基于 ip 地址访问的网站&#xff0c;要求如下 该网站 ip 地址的主机位为 100 &#xff0c;设置…

jupyter_contrib_nbextensions安装失败问题

目录 1.文件路径长度问题 2.jupyter不出现Nbextensions选项 1.文件路径长度问题 问题&#xff1a; could not create build\bdist.win-amd64\wheel\.\jupyter_contrib_nbextensions\nbextensions\contrib_nbextensions_help_item\contrib_nbextensions_help_item.yaml: No su…

【强化学习的数学原理】课程笔记--4(随机近似与随机梯度下降,时序差分方法)

目录 随机近似与随机梯度下降Mean estimationRobbins-Monro 算法用 Robbins-Monro 算法解释 Mean estimation用 Robbins-Monro 算法解释 Batch Gradient descent用 SGD 解释 Mean estimation SGD 的一个有趣的性质 时序差分方法Sarsa 算法一个例子 Expected Sarsa 算法n-step S…

LLM基础模型系列:Prefix-Tuning

------->更多内容&#xff0c;请移步“鲁班秘笈”&#xff01;&#xff01;<------ Prefix Tuning和Prompt Tuning最大的区别就是向每层的Transformer Block添加可训练的张量&#xff0c;而上一期的Prompt Tuning只是在输入的时候添加。 此外&#xff0c;通过全连接层&a…

【BUG】已解决:ModuleNotFoundError: No module named ‘sklearn‘

已解决&#xff1a;ModuleNotFoundError: No module named ‘sklearn‘ 目录 已解决&#xff1a;ModuleNotFoundError: No module named ‘sklearn‘ 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是…