Reid strong baseline 代码详解

news2025/1/18 20:10:07

本项目是对Reid strong baseline代码的详解。项目暂未加入目标检测部分,后期会不定时更新,请持续关注。

本相比Reid所用数据集为Markt1501,支持Resnet系列作为训练的baseline网络。训练采用表征学习+度量学习的方式。

目录

训练参数

训练代码

create_supervised_trainer(创建训练函数)

create_supervised_evaluator(创建测试函数)

do_train代码 

训练期权重的保存

获得Loss和acc

获取初始epoch 

学习率的调整

loss和acc的打印

时间函数的打印

测试结果的打印以及权重的保存

完整代码 

测试

Reid相关资料学习链接

项目代码:

后期计划更新


训练参数

--last stride:作为Resnet 最后一层layer的步长,默认为1;

--model_path:预训练权重

--model_name:模型名称,支持Resnet系列【详见readme】

--neck:bnneck

--neck_feat:after

--INPUT_SIZE:[256,128],输入大小

--INPUT_MEAN:

--INPUT_STD:

--PROB: 默认0.5

--padding:默认10

--num_workers:默认4【根据自己电脑配置来】

--DATASET_NAME:markt1501,数据集名称

--DATASET_ROOT_DIR:数据集根目录路径

--SAMPLER: 现仅支持softmax_triplet

--IMS_PER_BATCH:训练时的batch size

--TEST_IMS_PER_BATCH:测试时的batch size

--NUM_INSTANCE:一个batch中每个ID用多少图像,默认为4

--OPTIMIZER_NAME:优化器名称,默认为Adam,支持SGD

--BASE_LR:初始学习率,默认0.00035

--WEIGHT_DECAY:权重衰减

--MARGIN:用于tripletloss,默认0.3

--IF_LABELSMOOTH:标签平滑

--OUTPUT_DIR:权重输出路径

--DEVICE:cuda or cpu

--MAX_EPOCHS:训练迭代次数,默认120

训练代码

def train(args):
    # 数据集
    train_loader, val_loader, num_query, num_classes = make_data_loader(args)
    # model
    model = build_model(args, num_classes)
    # 优化器
    optimizer = make_optimizer(args, model)
    # loss
    loss_func = make_loss(args, num_classes)
    start_epoch = 0
    scheduler = WarmupMultiStepLR(optimizer, args.STEPS, args.GAMMA, args.WARMUP_FACTOR,
                                  args.WARMUP_ITERS, args.WARMUP_METHOD)
    print('ready train~')
    do_train(args,
             model,
             train_loader,
             val_loader,
             optimizer,
             scheduler,
             loss_func,
             num_query,
             start_epoch)

上述代码中所用处理数据集函数make_data_loader可以参考我另一篇文章:

Reid数据集处理代码详解

在看do_train前需要先看以下内容

log_period表示为打印Log的周期,默认为1;

checkpoint_period:表示为保存权重周期,默认为1;

output_dir:输出路径

device:cuda or cpu

epochs:训练迭代轮数

代码中的create_supervised_trainercreate_supervised_evaluator两个函数,是分别是用来创建监督训练和测试的,是对ignite.engine内训练和测试方法的重写。

create_supervised_trainer(创建训练函数)

规则是在内部实现一个def _update(engine,batch)方法,最后返回Engine(_update)。代码如下。

'''
ignite是一个高级的封装训练和测试库
'''
def create_supervised_trainer(model, optimizer, loss_fn, device=None):
    """
    :param model:  (nn.Module) reid model to train
    :param optimizer:Adam or SGD
    :param loss_fn: loss function
    :param device: gpu or cpu
    :return: Engine
    """
    if device:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        img, target = batch
        img = img.to(device) if torch.cuda.device_count() >= 1 else img
        target = target.to(device) if torch.cuda.device_count() >= 1 else target
        score, feat = model(img)  # 采用表征+度量
        loss = loss_fn(score, feat, target)  # 传入三个值,score是fc层后的(hard),feat是池化后的特征,target是标签
        loss.backward()
        optimizer.step()
        # compute acc
        acc = (score.max(1)[1] == target).float().mean()
        return loss.item(), acc.item()
    return Engine(_update)

create_supervised_evaluator(创建测试函数)

同理,测试代码也是一样,如下,其中metrics是我们需要计算的评价指标

# 重写create_supervised_evaluator,传入model和metrics,metrics是一个字典用来存储需要度量的指标
def create_supervised_evaluator(model, metrics, device=None):
    if device:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            data, pids, camids = batch
            data = data.to(device) if torch.cuda.is_available() else data
            feat = model(data)
            return feat, pids, camids
    engine = Engine(_inference)
    for name, metric in metrics.items():
        metric.attach(engine, name)
    return engine

do_train代码 

然后看一下do_train中的代码。

训练期权重的保存

这里的trainer就是我们前面创建的监督训练的函数,给该实例添加事件,事件为在每次epoch结束的时候保存一次权重[注意这里保存的权重是将模型的完整结构以及优化器权重都保存下来了]

trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                     'optimizer': optimizer})

获得Loss和acc

# average metric to attach on trainer
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

获取初始epoch 

训练前获取开始的epoch,默认为0;

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

学习率的调整

在训练期间每个epoch开始的时候,会调整学习率

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

loss和acc的打印

该事件发生在每个iteration完成时,而不是epoch完成时。

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                        .format(engine.state.epoch, ITER, len(train_loader),
                                engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
                                scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

时间函数的打印

该函数是用来在每个epoch完成的时候打印一下用了多长时间

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

测试结果的打印以及权重的保存

该函数用来打印测试结果,比如mAP,Rank,测试后的权重会保存在logs下。命名形式为mAP_xx.pth。【注意这里保存我权重和上面保存的权重是不一样的,这里仅仅保存权重,不包含网络结构和优化器权重】

     @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
            text = "mAP:{:.1%}".format(mAP)
            # logger.info("mAP: {:.1%}".format(mAP))
            logger.info(text)
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
            torch.save(state_dict, 'logs/mAP_{:.1%}.pth'.format(mAP))
            return cmc, mAP

完整代码 

def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        loss_fn,
        num_query,
        start_epoch
):
    log_period = 1
    checkpoint_period = 1
    eval_period = 1
    output_dir = cfg.OUTPUT_DIR
    device = cfg.DEVICE
    epochs = cfg.MAX_EPOCHS
    print("Start training~")
    trainer = create_supervised_trainer(model, optimizer, loss_fn, device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm='yes')},
                                            device=device)
    checkpointer = ModelCheckpoint(output_dir, cfg.model_name, checkpoint_period, n_saved=10, require_empty=False)
    state_dict = model.state_dict()
    timer = Timer(average=True)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                     'optimizer': optimizer})
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')


    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                        .format(engine.state.epoch, ITER, len(train_loader),
                                engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
                                scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
            text = "mAP:{:.1%}".format(mAP)
            # logger.info("mAP: {:.1%}".format(mAP))
            logger.info(text)
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
            torch.save(state_dict, 'logs/mAP_{:.1%}.pth'.format(mAP))
            return cmc, mAP

    trainer.run(train_loader, max_epochs=epochs)

训练命令如下: 

python tools/train.py --model_name resnet50_ibn_a --model_path weights/ReID_resnet50_ibn_a.pth --IMS_PER_BATCH 8 --TEST_IMS_PER_BATCH 4 --MAX_EPOCHS 120

会出现如下形式: 

 

=> Market1501 loaded
Dataset statistics:
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |   751 |    12936 |         6
  query    |   750 |     3368 |         6
  gallery  |   751 |    15913 |         6
  ----------------------------------------
  
2023-05-15 14:30:55.603 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[227/1484] Loss: 6.767, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:55.774 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[228/1484] Loss: 6.761, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:55.946 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[229/1484] Loss: 6.757, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:56.134 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[230/1484] Loss: 6.760, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:56.305 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[231/1484] Loss: 6.764, Acc: 0.000, Base Lr: 3.82e-05

每个epoch训练完成后会测试一次mAP:

我这里第一个epoch的mAP达到75.1%,Rank-1:91.7%, Rank-5:97.2%, Rank-10:98.2%。

测试完成后会在log文件下保存一个pth权重,名称为mAPxx.pth,也是用该权重进行测试。

 

2023-05-15 14:35:59.753 | INFO     | engine.trainer:print_times:128 - Epoch 1 done. Time per batch: 261.820[s] Speed: 45.4[samples/s]
2023-05-15 14:35:59.755 | INFO     | engine.trainer:print_times:129 - ----------
The test feature is normalized
2023-05-15 14:39:51.025 | INFO     | engine.trainer:log_validation_results:137 - Validation Results - Epoch: 1
2023-05-15 14:39:51.048 | INFO     | engine.trainer:log_validation_results:140 - mAP:75.1%
2023-05-15 14:39:51.051 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-1  :91.7%
2023-05-15 14:39:51.051 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-5  :97.2%
2023-05-15 14:39:51.052 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-10 :98.2%

测试

测试代码在tools/test.py中,代码和train.py差不多,这里不再细说,该代码是可对评价指标进行测试复现。

命令如下:其中TEST_IMS_PER_BATCH是测试时候的batch size,model_name是网络名称,model_path是你训练好的权重路径。

python tools/test.py --TEST_IMS_PER_BATCH 4 --model_name [your model name] --model_path [your weight path]

Reid相关资料学习链接

Reid损失函数理论讲解:Reid之损失函数理论学习讲解_爱吃肉的鹏的博客-CSDN博客

Reid度量学习Triplet loss代码讲解:Reid度量学习Triplet loss代码解析。_爱吃肉的鹏的博客-CSDN博客

yolov5 reid项目(支持跨视频检索):yolov5_reid【附代码,行人重识别,可做跨视频人员检测】_yolov5行人重识别_爱吃肉的鹏的博客-CSDN博客

yolov3 reid项目(支持跨视频检索):ReID行人重识别(训练+检测,附代码),可做图像检索,陌生人检索等项目_爱吃肉的鹏的博客-CSDN博客


预权重链接:

链接:百度网盘 请输入提取码 提取码:yypn


项目代码:

GitHub - YINYIPENG-EN/reid_strong_baselineContribute to YINYIPENG-EN/reid_strong_baseline development by creating an account on GitHub.https://github.com/YINYIPENG-EN/reid_strong_baseline

后期计划更新

​ 1.引入知识蒸馏训练

​ 2.加入YOLOX进行跨视频检测

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/532550.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端开发之this.$options.data的使用

前端开发之this.$options.data的使用 前言效果图vue2中使用vue3中使用 前言 this.$options.data:初始化对象 效果图 vue2中使用 this.$options这是一个Vue的特性&#xff0c;它可以让你访问组件的选项对象。你可以使用this.$options.data.call(this)来获取组件的初始数据&am…

pytorch模型转ONNX

目录 1. ONNX 2. pytorch 转 ONNX 3. 加载 ONNX 文件 4. Netron 1. ONNX 一般来说&#xff0c;pytorch训练好的模型是不能够直接用于生产环境&#xff0c;有很多的地方没有优化 而ONNX 格式可以兼顾不同框架的模型&#xff0c;相当于一个中间人的角色。这样部署到不同的环…

msvcr120.dll丢失怎样修复,学这三招就可以修复好

年前才买的新电脑&#xff0c;今天在打开软件ps软件的时候&#xff0c;电脑就提升msvcr120.dll文件丢失&#xff0c;无法执行此代码。刚刚开始以为是电脑的系统没有装好&#xff0c;经过我一下午时间的研究&#xff0c;原来是电脑msvcr120.dll文件丢失一般都是下载到垃圾软件&a…

软件测试后浪太强了,前浪有点顶不住啊,真难受...

想和大家说的话 8年前军哥刚进入到IT行业&#xff0c;现在发现学习软件测试的人越来越多&#xff0c;今天想根据军哥的行业经验再结合自己的一些看法给大家提一些建议。 最近聊到软件测试的行业内卷&#xff0c;越来越多的转行和大学生进入测试行业&#xff0c;导致软件测试已…

证件照片如何换背景底色,3个免费制作证件照的方法,简单易学

在日常生活中&#xff0c;我们经常需要用到证件照&#xff0c;比如&#xff1a;找工作需要简历上附带有证件照&#xff0c;还有办理学生证、身份证也需要提交证件照。 不同的平台有时候提交的要求&#xff08;背景底色、大小等&#xff09;也不一样&#xff0c;如果你不想每次…

缺少dll文件怎么办?修复dll文件的多种方法

缺少dll文件怎么办&#xff1f;当您试图启动某个应用程序或游戏时&#xff0c;可能会遇到“缺少DLL文件”的错误提示。DLL文件是动态链接库文件的缩写&#xff0c;它们包含在计算机上的许多应用程序和游戏中&#xff0c;并且是确保这些应用程序和游戏正常运行的重要部分。当出现…

秒懂!项目安全问题-SM4加解密

项目安全问题一直被人们研究&#xff0c;当前端路径上通过?status这种拼接参数时&#xff0c;参数的值在浏览器路径栏上非常醒目&#xff0c;是很容易被人恶意修改的&#xff0c;比如该用户并没有编辑权限&#xff0c;但有心之人却可以通过修改参数status的值把see改成edit&am…

【DataX】将hive表数据导入ES

目录 一、环境 二、创建hive测试表 三、Es写入插件包 四、配置json 五、数据同步 1、执行命令 2、查看es结果 一、环境 DataX&#xff1a;windows安装 Es版本&#xff1a;7.9.0 二、创建hive测试表 CREATE TABLE teacher(name string,age int )row format del…

抖音未来的发展趋势|成都欢蓬信息

抖音未来的发展趋势&#xff0c;近年来随着互联网技术的发展&#xff0c;小视频app也逐渐走入大家的日常生活中&#xff0c;闲着的时候打开手机抖音APP&#xff0c;就可以刷到世界各地人们分享的视频和直播&#xff0c;下面一起看看抖音未来的发展趋势 一、抖音的现状   据权…

【 五子棋对战平台(java_gobang) 】

文章目录 一、核心功能及技术二、效果演示三、创建项目扩展&#xff1a;WebSocket 框架知识 四、需求分析和概要设计五、数据库设计与配置 Mybatis六、实现用户模块功能6.1 数据库代码编写6.2 前后端交互接口6.3 服务器开发6.4 客户端开发 七、实现匹配模块功能7.1 前后端交互接…

构建新一代智慧园区移动应用以推动数字转型

随着智慧城市的建设和智慧园区的崛起&#xff0c;智慧园区数字一体化建设成为园区发展的重心&#xff0c;当然数字转型离不开移动应用的整合服务。 在过去的几年中&#xff0c;智慧园区移动应用已经发展成为园区管理和服务的重要手段之一&#xff0c;为企业和员工提供了更加便…

知行之桥EDI系统2023版功能介绍——概览页面

登录知行之桥EDI系统2023版&#xff0c;即可看到概览页面。默认情况下&#xff0c;会显示过去7天的各项指标。用户可以在页面右上角&#xff1a;显示过去的数据 下拉列表中手动选择需要的时间段&#xff0c;如&#xff1a;24小时、3天、7天等。 关键指标的自定义配置 概览页面…

如何将exe注册为windows服务,直接从后台运行

如何将exe注册为windows服务&#xff0c;直接从后台运行 使用instsrvsrvanywindow64位系统安装配置 window32位系统安装 使用instsrvsrvany 这是地址&#xff1a;链接: 网盘地址 提取码: h2za 复制这段内容后打开百度网盘手机App&#xff0c;操作更方便哦 window64位系统 安…

[OOD设计] - 电梯系统设计

明确主要需求 首先需要设计电梯系统的基本工作流程&#xff0c;一个简单电梯系统主要就是两个主要功能&#xff1a; 乘客在电梯外按下按钮时&#xff0c;电梯系统会驱动一个电梯来接人乘客在电梯内部按下楼层按钮时&#xff0c;电梯系统会驱动该电梯到达指定楼层 根据需求来…

泰克Tektronix AFG31021 任意波函数发生器产品资料

AFG31021是一款高质量、多功能的任意波形发生器&#xff0c;可以生成高精度、高分辨率的波形信号。该产品的主要特点包括&#xff1a; 可以生成任意波形信号&#xff0c;内置多种标准波形&#xff0c;如正弦波、方波、三角波、锯齿波等&#xff0c;也可以通过用户自定义来生成…

支付从业者转型路在何方?

近来&#xff0c;整个支付行业&#xff0c;已经“卷”出了新高度。 营销上电销卷地推&#xff0c;工单卷电销&#xff0c;POS机具则是退押金卷不退押金&#xff0c;无押金卷退押金”&#xff0c;互相“卷”得不亦乐乎。 与此同时&#xff0c;支付圈子里聊的永远是“成本上升”…

微信小程序是怎么做的?

微信小程序是一种轻量级的应用&#xff0c;它可以在微信内部直接使用&#xff0c;无需下载和安装。那么&#xff0c;微信小程序是怎么做的呢&#xff1f; 微信小程序制作的大概步骤 微信小程序制作主要包括以下几个步骤&#xff1a; ①注册小程序账号 ②在小程序制作工具创…

5th-Generation Mobile Communication Technology(一)

目录 一、5G/NR 1、 快速参考&#xff08;Quick Reference&#xff09; 2、5G Success 3、5G Challenges 4、Qualcomm Videos 二、PHY and Protocol 1、Frame Structure 2、Numerology 3、Waveform 4、Frequency Band 5、BWP 6、Synchronization 7、Beam Management 8、CSI Fra…

matmul/mm 函数用法介绍

介绍torch.matmul之前先介绍torch.mm函数, mm和matmul都是torch中矩阵乘法函数&#xff0c;mm只能作用于二维矩阵&#xff0c;matmul可以作用于二维也能作用于高维矩阵 mm函数使用 x torch.rand(4, 9) y torch.rand(9, 8) print(torch.mm(x,y).shape)torch.Size([4, 8]) m…

Linux Shell 实现一键部署VMware Workstation

VMware Workstation 前言 VMware Workstation Pro 是业界标准的桌面 Hypervisor&#xff0c;用于在 Linux 或 Windows PC 上运行虚拟机 download VMware_Workstation VMware_Workstation WindowsVMware_Workstation linux文档downloaddownload参考 Linux 各系统下载使用参…