Reid训练代码之数据集处理

news2024/12/29 9:38:14

本篇文章是对yolov5_reid这篇文章训练部分的详解。

该项目目录为:

.
|-- config # reid输入大小,数据集名称,损失函数等配置
|-- configs # 训练时期超参数定义
|-- data # 存储数据集和数据处理等代码,以及yolov5类别名称等
|-- engine # 训练和测试mAP,rank等相关代码
|-- layers # loss定义
|-- logs # 训练好的权重将存储在这
|-- modeling # 定义的网络
|-- output  # 输出
|-- person_search  # 人员查找
|-- readme.md # readme
|-- solver # 优化器相关代码
|-- tests
|-- tools  # 训练和测试代码
|-- utils  # logger等相关代码
`-- weights  # 存放预权重

数据集加载:

数据集加载与处理,需要调用头文件:

from data import make_data_loader

 make_data_loader

传入参数为cfg,训练中的相关配置文件。

build_transforms函数

这个函数传入函数有两个,cfg是配置文件,is_train=True表示训练。normalize_transform是计算数据集的均值和方差。均值为[0.485, 0.456, 0.406],方差为[0.229, 0.224, 0.225](可以看配置文件)。

如果is_train=True的时候,对数据集进行处理。

T.Resize:将图像调整为[256,128]大小;

T.RandomHorizontalFlip(p=cfg.INPUT.PROB):随机水平翻转,设置为0.5;

T.Pad:padding值,10;

T.ToTensor():转为tensor;

normalize_transform:图像的均值和方差;

RandomErasing:数据增强(随机擦除),将图片内的某块区域填充相同的像素值,从而将该区域的图片信息遮盖,强迫模型学习该区域外的特征进行识别,在一定程度上避免模型陷入局部最优,从而提高模型的泛化能力。

将多个变换组合在一起。

如果测试的时候,is_train=False,不用数据增强。

def build_transforms(cfg, is_train=True):
    normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    if is_train:
        transform = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TRAIN),
            T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
            T.Pad(cfg.INPUT.PADDING),
            T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
            T.ToTensor(),
            normalize_transform,
            RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
        ])
    else:
        transform = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TEST),
            T.ToTensor(),
            normalize_transform
        ])

    return transform

 继续返回make_data_loader函数。

通过build_transforms仅仅返回的是训练和测试需要用的一些数据处理方面的"规则"。

train_transforms = build_transforms(cfg, is_train=True)
val_transforms = build_transforms(cfg, is_train=False)

 num_workers:获取进程数量,我这里是4.

num_workers = cfg.DATALOADER.NUM_WORKERS

init_dataset函数 

传入参数name:数据集的名称,我这里是mark1501;

还传入了数据集的路径:我这里是./data

该函数主要是判断支持的数据集格式。

def init_dataset(name, *args, **kwargs):
    if name not in __factory.keys():
        raise KeyError("Unknown datasets: {}".format(name))
    return __factory[name](*args, **kwargs)

继续看make_data_loader函数。

训练时分类的数量,这里是751。注意!在训练的时候是751,在测试的是1501.

num_classes = dataset.num_train_pids

 ImageDataset函数

该类基础Dataset,因此说明该类是做数据集处理的。上面我们说到的build_transforms仅仅是一些数据集处理的"规则"。

在调用该类的时候,传入两个参数,一个是dataset.train[训练数据集的图片路径],另一个就是train_transforms[处理的规则]。所以这个类就知道了,是用上面定义的“规则”来处理我们的数据集。

在下面这段代码中self.dataset[index]就是对数据集遍历(__getitem__就是迭代器),加入此时index=0,此时获得为:('./data\\Market1501\\bounding_box_train\\0002_c1s1_000451_03.jpg', 0, 0)。img_path就为数据集的路径,pid为类,camid为相机id[这个需要了解markt1501数据集]。

read_image函数就是通过PIL读取的图像。然后用transform处理。返回值有四个,img[数据增强后的图像],pid[类别],camid[相机id],img_path[图像路径]。

class ImageDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid = self.dataset[index]
        img = read_image(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, pid, camid, img_path

下面的两个图就是增强后的效果 

 

 


接下来再回到make_data_loader函数。

下面一段代码是对处理后的数据集进行加载,这里调用的torch中DataLoader函数。传入的参数有batch,我这里是8,shuffle表示打乱,collate_fn这个很重要,就是把这些按batch处理。

    if cfg.DATALOADER.SAMPLER == 'softmax':
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
            num_workers=num_workers, collate_fn=train_collate_fn
        )

同理,验证集也是这样处理。

最终返回训练集,验证集,数据集长度(数量),类别:751

完整代码

def make_data_loader(cfg):
    train_transforms = build_transforms(cfg, is_train=True)
    val_transforms = build_transforms(cfg, is_train=False)
    num_workers = cfg.DATALOADER.NUM_WORKERS
    if len(cfg.DATASETS.NAMES) == 1:
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
    else:
        # TODO: add multi dataset to train
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)

    num_classes = dataset.num_train_pids
    train_set = ImageDataset(dataset.train, train_transforms)
    if cfg.DATALOADER.SAMPLER == 'softmax':
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
            num_workers=num_workers, collate_fn=train_collate_fn
        )

    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    return train_loader, val_loader, len(dataset.query), num_classes

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/461502.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【高分论文密码】大尺度空间模拟预测与数字制图技术

大尺度空间模拟预测和数字制图技术和不确定性分析广泛应用于高分SCI论文之中,号称高分论文密码。 大尺度模拟技术可以从不同时空尺度阐明农业生态环境领域的内在机理和时空变化规律,又可以为复杂的机理过程模型大尺度模拟提供技术基础。 在本次&#x…

cocosLua 之 RichText(1)

结构 富文本主要通过RichText来实现, 其继承结构: #mermaid-svg-AHbMrHe3zp3q1wTZ {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-AHbMrHe3zp3q1wTZ .error-icon{fill:#552222;}#mermaid-svg-AHbMrHe3z…

Linux下ds18b20驱动开发获取温度

文章目录 一、修改并且编译设备树(1)修改设备树(2)修改开发板设备树进行reboot 二、硬件连接三、驱动开发与测试(1)编写设备驱动(2)编写测试代码(3)Makefile&…

第四章——数学知识1

质数 质数&#xff1a;在大于1的整数中&#xff0c;如果只包含1和本身这俩个约束&#xff0c;就被叫质数或素数。 质数判定试除法 质数的判定——试除法&#xff1a;如果d能整除n&#xff0c;则n/d再除n&#xff0c;结果是一个整数。 d≤n/d。 bool is_prime(int x) {if (x <…

【大数据之Hadoop】二十、Yarn基础框架及工作机制

1、Yarn基础框架 Yarn是一个资源调度平台&#xff0c;负责为运算程序提供服务器运算资源&#xff0c;相当于一个分布式的操作系统平台&#xff0c;而MapReduce等运算程序则相当于运行于操作系统之上的应用程序。 YARN主要由ResourceManager、NodeManager、ApplicationMaster和…

202303-1 田地丈量

代码 #include<iostream> #include<vector> #include<string> #include<cmath> #include<algorithm> #include<stack> using namespace std; int n, a, b;int main() {cin >> n >> a >> b;int x1, y1, x2, y2;int x, y;…

科学计算NumPy之Ndarray数组对象的创建、切片、索引、修改等操作汇总

NumPy的操作汇总 NumPy概述Ndarray对象基本使用Ndarray的属性Ndarray的类型Ndarray的形状 创建数组创建数组创建全1数组创建全1数组从已有数组创建新数组从现有数组生成创建等差数列数组创建等比数列数组创建等间隔数列数组创建随机数数组创建正态分布创建创建均匀分布 数组切片…

【JUC高并发编程】—— 再见JUC

一、读写锁 读写锁概述 1️⃣ 什么是读写锁&#xff1f; 读写锁是一种多线程同步机制&#xff0c;用于在多线程环境中保护共享资源的访问 与互斥锁不同的是&#xff0c;读写锁允许多个线程同时读取共享资源&#xff0c;但在有线程请求写操作时&#xff0c;必须将其他读写锁…

windows10 ubuntu子系统安装perf工具

文章目录 1&#xff0c;ubuntu子系统中perf工具安装不了1.1&#xff0c;查看perf版本如下所示1.2&#xff0c;网上找不到对应的版本的内核源码&#xff0c;下载别的版本后&#xff0c;编译各种报错 2&#xff0c;百度查到说是WSL1不支持perf2.1&#xff0c;查看WSL版本 2.2&…

MySQ基础知识整合

目录 模糊查询 排序 单行函数 多行函数 分组函数 having 单表查询执行顺序总结 distinct 连接查询 子查询 union limit DQL语句执行顺序 DDL语句 日期化 date和date_format区别 update table 的快速创建以及删除&#xff08;及回滚&#xff09; 约束 事务 …

HTTP基础知识汇总

伴随着云原生(Cloud Native)的兴起&#xff0c;面向服务架构(Service-Oriented Architecture&#xff0c;SOA)、微服务(Microservice)、容器(Container)等相关概念与技术正在逐渐影响CAx(CAD/CAE/CAM)软件的架构设计与开发。 在云原生CAx软件中&#xff0c;首先需要把系统按照…

MySQL锁详解及案例分析

MySQL锁详解及案例分析 一、一条update语句二、MySQL锁介绍三、全局锁全局锁演示1.环境准备2.全局锁演示 四、MySQL表级锁&#xff08;都是Server层实现&#xff09;1、表级锁介绍2、表读S、写锁X1&#xff09;表锁相关命令2&#xff09;表锁演示1、表级的共享锁(读锁)2、表级的…

VLAN实验

SW1 [sw1]int g0/0/2 [sw1-GigabitEthernet0/0/2]dis this interface GigabitEthernet0/0/2 port link-type access port default vlan 2 pc1划分到vlan 2 [sw1-GigabitEthernet0/0/3]dis t…

【C++STL】set

前言 前面的CSTL的博客&#xff0c;我们介绍了string&#xff0c;vector&#xff0c;list&#xff0c;deque&#xff0c;priority_queue还有stack和queue。 这些容器统称为序列式容器&#xff0c;因为其底层为线性序列的数据结构&#xff0c;里面存储的是元素本身。 而从本节开…

【正点原子STM32精英V2开发板体验】体验LVGL的SD NAND文件系统

目的 验证基于SD NAND卡在正点原子STM32精英V2开发板上的兼容效果 实验材料 正点原子STM32精英V2开发板 TF 卡一片 SD NAND卡一片 实验步骤 1、打开例程【正点原子】精英STM32F103开发板 V2-资料盘(A盘)\4&#xff0c;程序源码\3&#xff0c;扩展例程\4&#xff0c;LVGL…

数据库系统概论(二)关系数据库,SQL概述和数据库安全性

作者的话 前言&#xff1a;总结下知识点&#xff0c;自己偶尔看一看。 目录 一、关系模型概述 1.1关系数据结构及形式化定义 1.1.1域&#xff08;Domain&#xff09; 1.1.2笛卡尔积&#xff08;Cartesian Product&#xff09; 1.1.3关系&#xff08;Relation&#xff09; …

Android之 WebView的使用

一 简介 1.1 WebView是用来展示网页的控件&#xff0c;底层是google的WebKit的引擎。 比起苹果的WebView&#xff0c;webkit一些不足地方&#xff1a; 不能支持word等文件的预览 纯标签加载&#xff0c;并不支持所有标签的加载 不支持文件的下载&#xff0c;图片的放大&#…

MATLAB连续时间信号的实现和时域基本运算(八)更新中...

1、实验目的&#xff1a; 1&#xff09;熟悉常用连续时间信号的实现方法&#xff1b; 2&#xff09;掌握连续时间信号的时域基本运算&#xff1b; 3&#xff09;掌握实现基本函数及其运算的函数的使用方法&#xff1b; 4&#xff09;加深对信号基本运算的理解。 2、实验内容&am…

Python Selenium 关键字驱动

目录 项目目录结构 action目录 config目录 exceptionpictures目录 log目录 testCases目录 testData目录 util目录 总结 之前写过一篇Java版的关键字驱动&#xff0c;现在来写一篇Python版本的&#xff0c;网上好多教程都是虎头蛇尾的不完整~ 说下思路&#xff0c;这边没…

十、ElasticSearch 实战 - 源码运行

一、概述 想深入理解 Elasticsearch&#xff0c;了解其报错机制&#xff0c;并有针对性的调整参数&#xff0c;阅读其源码是很有必要的。此外&#xff0c;了解优秀开源项目的代码架构&#xff0c;能够提高个人的代码架构能力 阅读 Elasticsearch 源码的第一步是搭建调试环境&…