【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型

news2024/10/6 20:26:51

前言

Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到,大大减少了目标在遮挡后,追踪失败的可能。

一、数据集简介

Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低清摄像头)拍摄到的 1501 个行人、32668 个检测到的行人矩形框。每个行人至少由2个摄像头捕获到,并且在一个摄像头中可能具有多张图像。训练集有 751 人,包含 12,936 张图像,平均每个人有 17.2 张训练数据;测试集有 750 人,包含 19,732 张图像,平均每个人有 26.3 张测试数据。3368 张查询图像的行人检测矩形框是人工绘制的,而 gallery 中的行人检测矩形框则是使用DPM检测器检测得到的。

该数据集提供的固定数量的训练集和测试集均可以在single-shot或multi-shot测试设置下使用。

目录结构

Market-1501-v15.09.15

  ├── bounding_box_test

       ├── 0000_c1s1_000151_01.jpg

       ├── 0000_c1s1_000376_03.jpg

       ├── 0000_c1s1_001051_02.jpg

  ├── bounding_box_train

       ├── 0002_c1s1_000451_03.jpg

       ├── 0002_c1s1_000551_01.jpg

       ├── 0002_c1s1_000801_01.jpg

  ├── gt_bbox

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c1s1_009376_00.jpg

       ├── 0001_c2s1_001976_00.jpg

  ├── gt_query

       ├── 0001_c1s1_001051_00_good.mat

       ├── 0001_c1s1_001051_00_junk.mat

  ├── query

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c2s1_000301_00.jpg

       ├── 0001_c3s1_000551_00.jpg

  └── readme.txt

目录介绍

(1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)

(2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像

(3) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

(4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)

(5) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box

命名规则

以 0001_c1s1_000151_01.jpg 为例

1) 0001 表示每个人的标签编号,从0001到1501;

2) c1 表示第一个摄像头(camera1),共有6个摄像头;

3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;

4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;

5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

二、跟踪模型介绍

特征提取的模型有很多,可以替换特征提取模型网络。

下述给出的是 deep_sort/deep/model.py 里面的模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    def __init__(self, c_in, c_out, is_downsample=False):
        super(BasicBlock, self).__init__()
        self.is_downsample = is_downsample
        if is_downsample:
            self.conv1 = nn.Conv2d(
                c_in, c_out, 3, stride=2, padding=1, bias=False)
        else:
            self.conv1 = nn.Conv2d(
                c_in, c_out, 3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(c_out)
        self.relu = nn.ReLU(True)
        self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(c_out)
        if is_downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
                nn.BatchNorm2d(c_out)
            )
        elif c_in != c_out:
            self.downsample = nn.Sequential(
                nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
                nn.BatchNorm2d(c_out)
            )
            self.is_downsample = True

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.is_downsample:
            x = self.downsample(x)
        return F.relu(x.add(y), True)


def make_layers(c_in, c_out, repeat_times, is_downsample=False):
    blocks = []
    for i in range(repeat_times):
        if i == 0:
            blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
        else:
            blocks += [BasicBlock(c_out, c_out), ]
    return nn.Sequential(*blocks)


class Net(nn.Module):
    def __init__(self, num_classes=751, reid=False):
        super(Net, self).__init__()
        # 3 128 64
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # nn.Conv2d(32,32,3,stride=1,padding=1),
            # nn.BatchNorm2d(32),
            # nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, padding=1),
        )
        # 32 64 32
        self.layer1 = make_layers(64, 64, 2, False)
        # 32 64 32
        self.layer2 = make_layers(64, 128, 2, True)
        # 64 32 16
        self.layer3 = make_layers(128, 256, 2, True)
        # 128 16 8
        self.layer4 = make_layers(256, 512, 2, True)
        # 256 8 4
        self.avgpool = nn.AvgPool2d((8, 4), 1)
        # 256 1 1
        self.reid = reid
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # B x 128
        if self.reid:
            x = x.div(x.norm(p=2, dim=1, keepdim=True))
            return x
        # classifier
        x = self.classifier(x)
        return x


if __name__ == '__main__':
    net = Net()
    x = torch.randn(4, 3, 128, 64)
    y = net(x)
    import ipdb
    ipdb.set_trace()

三、数据集处理

splitDataset.py

用于存放训练的图片

# -*- coding:utf-8 -*-
# @author: 牧锦程
# @微信公众号: AI算法与电子竞赛
# @Email: m21z50c71@163.com
# @VX:fylaicai

import os
from shutil import copyfile

# You only need to change this line to your dataset download path
download_path = 'Market-1501-v15.09.15'

if not os.path.isdir(download_path):
    print('please change the download_path')

save_path = 'pytorch'
if not os.path.isdir(save_path):
    os.mkdir(save_path)

# ------------------- query ----------------------
query_path = download_path + '/query'
query_save_path = save_path + '/query'
print("process: ", query_path)
if not os.path.isdir(query_save_path):
    os.mkdir(query_save_path)

for root, dirs, files in os.walk(query_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = query_path + '/' + name
        dst_path = query_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

# ----------------- multi-query ------------------------
query_path = download_path + '/gt_bbox'
print("process: ", query_path)
# for dukemtmc-reid, we do not need multi-query
if os.path.isdir(query_path):
    query_save_path = save_path + '/multi-query'
    if not os.path.isdir(query_save_path):
        os.mkdir(query_save_path)

    for root, dirs, files in os.walk(query_path, topdown=True):
        for name in files:
            if not name[-3:] == 'jpg':
                continue
            ID = name.split('_')
            src_path = query_path + '/' + name
            dst_path = query_save_path + '/' + ID[0]
            if not os.path.isdir(dst_path):
                os.mkdir(dst_path)
            copyfile(src_path, dst_path + '/' + name)

# ------------------- gallery ----------------------
gallery_path = download_path + '/bounding_box_test'
gallery_save_path = save_path + '/gallery'
print("process: ", gallery_path)
if not os.path.isdir(gallery_save_path):
    os.mkdir(gallery_save_path)

for root, dirs, files in os.walk(gallery_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = gallery_path + '/' + name
        dst_path = gallery_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

# ------------------ train ---------------------
train_path = download_path + '/bounding_box_train'
train_save_path = save_path + '/train'
val_save_path = save_path + '/test'
if not os.path.isdir(train_save_path):
    os.mkdir(train_save_path)
    os.mkdir(val_save_path)

print("process: ", train_path)
for root, dirs, files in os.walk(train_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = train_path + '/' + name
        dst_path = train_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
            # first image is used as val image
            dst_path = val_save_path + '/' + ID[0]
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

四、模型训练

修改数据集路径

修改 data-dir 参数为自己的数据集路径

修改数据增强

增加一个尺寸修改

修改类别数量

代码中通过dataloader来获取,因此可以不进行修改

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

修改保存模型名称

这里的修改是为例与原始的模型进行区分,可以做对比

训练结果

查看精度

先运行test.py,生成 features.pth

在运行 evaluate.py,得到如下精度:

五、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1806322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【BUG】已解决:ModuleNotFoundError: No module named ‘transformers‘

已解决:ModuleNotFoundError: No module named ‘transformers‘ 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司…

360数字安全:2024年4月勒索软件流行态势分析报告

勒索软件传播至今,360 反勒索服务已累计接收到数万勒索软件感染求助。随着新型勒索软件的快速蔓延,企业数据泄露风险不断上升,勒索金额在数百万到近亿美元的勒索案件不断出现。勒索软件给企业和个人带来的影响范围越来越广,危害性…

OrangePi AIpro小试牛刀-目标检测(YoloV5s)

非常高兴参加本次香橙派AI Pro,香橙派联合华为昇腾打造的一款AI推理开发板评测活动,以前使用树莓派Raspberry Pi4B 8G版本,这次有幸使用国产嵌入式开发板。 一窥芳容 这款开发板搭载的芯片是和华为昇腾的Atlas 200I DK A2同款的处理器&#…

插卡式仪器模块:数字万用表模块(插卡式)

• 6 位数字表显示 • 24 位分辨率 • 250 KSPS 采样率 • 电源和数字 I/O 均采用隔离抗噪技术 • 电压、电流、电阻、电感、电容的高精度测量 • 二极管/三极管测试 通道122输入 阻抗 电压10 MΩHigh-Z, 10 MΩ电流10 Ω50 mΩ / 2 Ω / 2 KΩ输入范围电压 5 V0–60 V电流…

钉钉统一授权登录第三方网站

开发流程 配置回调域名。 进入已创建的应用详情页,在基础信息页面可以查看到应用的SuiteKey/SuiteSecret(第三方企业应用)或AppKey/AppSecret(企业内部应用)。 在应用详情页,然后单击钉钉登录与分享,添加应用回调的URL,以http或…

一文学习yolov5 实例分割:从训练到部署

一文学习yolov5 实例分割:从训练到部署 1.模型介绍1.1 YOLOv5结构1.2 YOLOv5 推理时间 2.构建数据集2.1 使用labelme标注数据集2.2 生成coco格式label2.3 coco格式转yolo格式 3.训练3.1 整理数据集3.2 修改配置文件3.3 执行代码进行训练 4.使用OpenCV进行c部署参考文…

安全专业的硬件远控方案 设备无网也能远程运维

在很多行业中,企业的运维工作不仅仅局限在可以联网的IT设备,不能连接外网的特种设备也需要专业的远程运维手段。 这种特种设备在能源、医疗等行业尤其常见,那么我们究竟如何通过远程控制,对这些无网设备实施远程运维,…

单灯双控开关原理

什么是单灯双控?顾名思义,指的是一个灯具可以通过两个不同的开关或控制器进行控制。 例如客厅的主灯可能会设置成单灯双控,一个开关位于门口,另一个位于房间内的另一侧,这样无论你是从门口进入还是从房间内出来&#x…

Java数据结构准备工作---常用类

文章目录 前言1.包装类1.1.包装类基本知识1.2.包装类的用途1.3.装箱和拆箱1.3.1.装箱:1.3.2.拆箱 1.4 包装类的缓存问题 2.时间处理类2.1.Date 时间类(java.util.Date)2.2.DateFormat 类和 SimpleDateFormat 类2.3.Calendar 日历类 3.其他常用类3.1.Math类3.2.Rando…

(论文翻译)Coordinate Attention for Efficient Mobile Network Design(坐标注意力 CVPR2021)

Coordinate Attention for Efficient Mobile Network Design(CVPR2021) 文章目录 Coordinate Attention for Efficient Mobile Network Design(CVPR2021)摘要1.引言2.相关工作3.方法:Coordinate Attention3.1.Revisit …

(二)JSX基础

什么是JSX 概念:JSX是JavaScript和XML(HTML)的缩写,表示在JS代码中编写HTML模版结构,它是React中编写UI模板的方式。 优势:1.HTML的声明式模版方法;2.JS的可编程能力 JSX的本质 JSX并不是标准…

RabbitMQ-topic exchange使用方法

RabbitMQ-默认读、写方式介绍 RabbitMQ-发布/订阅模式 RabbitMQ-直连交换机(direct)使用方法 目录 1、概述 2、topic交换机使用方法 2.1 适用场景 2.2 解决方案 3、代码实现 3.1 源代码实现 3.2 运行记录 4、小结 1、概述 topic 交换机是比直连交换机功能更加强大的…

强!推荐一款开源接口自动化测试平台:AutoMeter-API !

在当今软件开发的快速迭代中,接口自动化测试已成为确保代码质量和服务稳定性的关键步骤。 随着微服务架构和分布式系统的广泛应用,对接口自动化测试平台的需求也日益增长。 今天,我将为大家推荐一款强大的开源接口自动化测试平台: AutoMete…

备战 清华大学 上机编程考试-冲刺前50%,倒数第5天

T1:多项式求和 小K最近刚刚习得了一种非常酷炫的多项式求和技巧,可以对某几类特殊的多项式进行运算。非常不幸的是,小K发现老师在布置作业时抄错了数据,导致一道题并不能用刚学的方法来解,于是希望你能帮忙写一个程序…

SpringBoot+Vue体育馆管理系统(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 角色对应功能 学生管理员 功能截图

计算机网络 期末复习(谢希仁版本)第4章

路由器:查找转发表,转发分组。 IP网的意义:当互联网上的主机进行通信时,就好像在一个网络上通信一样,看不见互连的各具体的网络异构细节。如果在这种覆盖全球的 IP 网的上层使用 TCP 协议,那么就…

144、二叉树的前序递归遍历

题解: 递归书写三要素: 1)确定递归函数的参数和返回值。要确定每次递归所要用到的参数以及需要返回的值 2)确定终止条件。操作系统也是用栈的方式实现递归,那么如果不写终止条件或者终止条件写的不对,都…

从哲学层面谈稳定性建设

背景 我(姓名:黄凯,花名:兮之)在阿里工作了五年,一直在一个小团队从事电商的稳定性工作。看了很多稳定性相关的文档,很少有能把稳定性说明白的文档。也有一些文档也能把涉及的方方面面说清楚&a…

办理公司诉讼记录删除行政处罚记录删除

企业行政处罚记录是可以做到撤销消除的,一直被大多数企业忽略,如果相关诉讼记录得不到及时删除,不仅影响企业招投标,还影响企业的贷款申请,严重的让企业资金链断裂,影响企业长远发展和企业形象。行政处罚是…

配置网页版的SQL Developer : Oracle Database Actions

我们知道SQL Developer有三种形式: 桌面版,这个最常用命令行版,即SQLcl网页版,即SQL Developer Web,最新的名字叫Oracle Database Actions, 本文讲述3,如何配置SQL Developer网页版。 第一步…