[杂记]mmdetection3.x中的数据流与基本流程详解(数据集读取, 数据增强, 训练)

news2024/11/24 20:53:14

之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的…

在这里插入图片描述

这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中.


0. MMDetection整体组成部分

让我们首先回顾一下C++的标准模板库(STL)是怎样设计的. STL的三个核心组件是容器, 算法与迭代器. 容器, 例如vector, queue等等, 他们是负责存储数据的, 算法是负责进行一些操作, 例如排序, 查找等等. 而迭代器是容器与算法之间的桥梁, 也就是算法可以通过迭代器去访问容器, 使得算法可以独立于容器的类型进行操作. 三个部分相辅相成, 就达到了泛型编程的理念.

再让我们回顾一下一套深度学习的代码包含什么部分. 从大的方面来说, 需要有数据的读取与增强(DataLoader), 模型的定义, 损失函数的计算, 负责梯度传播的优化器, 在验证(测试)集上的评估等. 同理, MMDetection也是按照这种方式来的, 并且每个部分接口相通, 就可以实现更广义的模型定义和训练方式.

mmengine/registry/__init__.py中, 我们可以看到, MMEngine(或者说MMDetection)总体有这些类型的模块:

from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS,
                   INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS,
                   MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
                   OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
                   TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
                   WEIGHT_INITIALIZERS)

那么以上这么多模块可以分成几类, 分别负责什么呢? 按照我个人的理解, MMDetection的整体组成部分可以表示为下图:

在这里插入图片描述

为了节省空间, 优化器相关并未画出

1. 认识config文件

mmdetection设计的核心思想是通过字典来配置整个的训练过程和模型定义, 这些字典放在一个.py的config文件中. 一般来说,config文件最重要的就是数据加载(train_dataloader, val_dataloader和test_dataloader), 模型定义(model)和训练与测试过程(train_pipeline, test_pipeline). 除此之外, 还有一些训练, 测试配置(train_cfg, test_cfg)等等. 具体config的例子可以参照官网Learn about configs.

需要注意的是, mmdetection中字典定义class的方式, 往往是键type表示类的名字, 之后的其他键都是类初始化需要的参数. 例如, 如果我想自定义一个模型, 叫做MyModel, 定义在当前目录下的./models/my_model.py中, 定义方式如下:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
    def __init__(self, 
                 arg1=..., 
                 arg2=..., 
                 arg3=...):
        ...
	def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
		...

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


# 必须将自定义类的py文件导入 这样可以自动register自定义模型 否则模型初始化时找不到

custom_imports = dict(
    imports=['models.my_model'],
    allow_failed_imports=False)

# 现在就可以愉快的传参了
models=dict(
    type='MyModel', 
    arg1=1, 
    arg2=[16, 128], 
    arg3=dict(channel=256), 
    ...
)

同样, 我们可以自定义DataLoader, Loss, 等等.

此外, dict是可以嵌套的, 例如mmdetection将检测模型分成了backbone, neck和head三部分, 那么如果我们又自定义了一个Head, 叫MyHead:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmengine.model import BaseModule  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyHead(BaseModule):
    def __init__(self, 
                 arg4=...):
        ...

这样, 如果MyModel的前向传播过程中需要一个head, 则代码大致是这个样子:


from mmdet.registry import MODELS  # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel  # 一个模型基类

@MODELS.register_module()  # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
    def __init__(self, 
                 arg1=..., 
                 arg2=..., 
                 arg3=...,
                 head=...):
        self.head = MODELS.build(head)  # 建立Head的模型, 类型是nn.Module
        ...
	def loss(self, inputs, data_samples):  # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
		...  # 一些其他过程
		ret = self.head(inputs)  # forward
		...  # 后处理

配置文件中对应更改为:

如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:


custom_imports = dict(
    imports=['models.my_model', '自定义HEAD所在的py文件'],
    allow_failed_imports=False)

models=dict(
    type='MyModel', 
    arg1=1, 
    arg2=[16, 128], 
    arg3=dict(channel=256), 
    head=dict(  # 定义head
        type='MyHead',
        arg4=256,
        ...
	)
    ...
)

篇幅所限, 自定义损失函数, 数据增强之类的就不一一列举了.

2. 数据流

我们接下来以检测与跟踪任务为例, 看看数据到底是如何被读入的. 我们以训练过程说明.

在训练过程中, 我们会初始化一个RUNNER类, 其读入我们的config文件并依次完成各种(模型, 数据加载, 优化器, 钩子等等)的初始化. 我们以官方提供的train.py为例:

runner = Runner.from_cfg(cfg)

from_cfg()是一个类方法(classmethod), 在其中我们实例化了Runner类.

随后, 我们调用Runnertrain()方法进行训练. 首先, 我们实例化训练循环:

        self._train_loop = self.build_train_loop(
            self._train_loop)  # type: ignore

训练循环就属于LOOP类型.

在这里, 我们以最常用的EpochBasedTrainLoop为例. 在EpochBasedTrainLoop的初始化函数中, 根据config文件中的train_dataloader字典实例化出torchDataLoader类():
在这里插入图片描述

        data_loader = DataLoader(
            dataset=dataset,
            sampler=sampler if batch_sampler is None else None,
            batch_sampler=batch_sampler,
            collate_fn=collate_fn,
            worker_init_fn=init_fn,
            **dataloader_cfg)
        return data_loader

当然, 我们知道torch的DataLoader类在调用的时候, 会调用到dataset(类别是torch.utils.data.Dataset)的__getitem__方法. 因此, 我们从__getitem__入手来探索数据流.

在MMDetection的设计中, 数据集的类都是继承于MMengine中的BaseDataset, 其中的__getitem__是这样写的:
在这里插入图片描述

    def __getitem__(self, idx: int) -> dict:
        if not self._fully_initialized:
            print_log(
                'Please call `full_init()` method manually to accelerate '
                'the speed.',
                logger='current',
                level=logging.WARNING)
            self.full_init()

        if self.test_mode:
            data = self.prepare_data(idx)
            if data is None:
                raise Exception('Test time pipline should not get `None` '
                                'data_sample')
            return data

        for _ in range(self.max_refetch + 1):
            data = self.prepare_data(idx)
            # Broken images or random augmentations may cause the returned data
            # to be None
            if data is None:
                idx = self._rand_another()
                continue
            return data

        raise Exception(f'Cannot find valid image after {self.max_refetch}! '
                        'Please check your image path and pipeline')

我们可以看到, 在__getitem__中最核心的是self.prepare_data(idx). 按照这种思路一级一级向上查找, 我们就可以总结出如下图的数据读取流程:

在这里插入图片描述
其中, 数据增强pipeline是一系列类型为TRANSFORMS类的列表, 再每经过一次数据增强时, 字典都会被更新.

我们以较为常用的随机便宜(RandomShift)来说, 其是这样定义的:


@TRANSFORMS.register_module()
class RandomShift(BaseTransform):

    def __init__(self,
        ...

    @autocast_box_type()
    def transform(self, results: dict) -> dict:  # transform方法, 更新字典, 图像与对应的边界框等都需要被更新
        """Transform function to random shift images, bounding boxes.

        Args:
            results (dict): Result dict from loading pipeline.

        Returns:
            dict: Shift results.
        """
        if self._random_prob() < self.prob:
            img_shape = results['img'].shape[:2]

            random_shift_x = random.randint(-self.max_shift_px,
                                            self.max_shift_px)
            random_shift_y = random.randint(-self.max_shift_px,
                                            self.max_shift_px)
            new_x = max(0, random_shift_x)
            ori_x = max(0, -random_shift_x)
            new_y = max(0, random_shift_y)
            ori_y = max(0, -random_shift_y)

            # TODO: support mask and semantic segmentation maps.
            bboxes = results['gt_bboxes'].clone()
            bboxes.translate_([random_shift_x, random_shift_y])

            # clip border
            bboxes.clip_(img_shape)

            # remove invalid bboxes
            valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (
                bboxes.heights > self.filter_thr_px).numpy()
            # If the shift does not contain any gt-bbox area, skip this
            # image.
            if not valid_inds.any():
                return results
            bboxes = bboxes[valid_inds]
            results['gt_bboxes'] = bboxes
            results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
                valid_inds]

            if results.get('gt_ignore_flags', None) is not None:
                results['gt_ignore_flags'] = \
                    results['gt_ignore_flags'][valid_inds]

            # shift img
            img = results['img']
            new_img = np.zeros_like(img)
            img_h, img_w = img.shape[:2]
            new_h = img_h - np.abs(random_shift_y)
            new_w = img_w - np.abs(random_shift_x)
            new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
                = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
            results['img'] = new_img

        return results

需要注意的是, 经过pipeline后, 字典最终会被更新成如下形式:

dict = {'inputs': torch.Tensor, 
        'data_samples': DetDataSample或TrackDataSample等}

其中'inputs'键对应的值就是转换为tensor的图片, 而'data_samples'键对应的值是表示样本的类, 在检测任务中, 是DetDataSample, 跟踪任务中, 是TrackDataSample. DetDataSample类有许多成员, 包括该样本(图片)的目标的边界框真值, 分割真值等:

在这里插入图片描述

class DetDataSample(BaseDataElement):
    """A data structure interface of MMDetection. They are used as interfaces
    between different components.

    The attributes in ``DetDataSample`` are divided into several parts:

        - ``proposals``(InstanceData): Region proposals used in two-stage
            detectors.
        - ``gt_instances``(InstanceData): Ground truth of instance annotations.
        - ``pred_instances``(InstanceData): Instances of detection predictions.
        - ``pred_track_instances``(InstanceData): Instances of tracking
            predictions.
        - ``ignored_instances``(InstanceData): Instances to be ignored during
            training/testing.
        - ``gt_panoptic_seg``(PixelData): Ground truth of panoptic
            segmentation.
        - ``pred_panoptic_seg``(PixelData): Prediction of panoptic
           segmentation.
        - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
        - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.

以上过程可以借用MMEngine文档里的一个图说明:

在这里插入图片描述

最终, 模型的forward, loss, predict等方法都是接收inputs: torch.Tensordata_samples作为输入, 例如:

在这里插入图片描述

def loss(self, inputs: Tensor, data_samples: TrackSampleList,
             **kwargs) -> Union[dict, tuple]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1457396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线性规划单纯形法原理及实现

欢迎关注更多精彩 关注我&#xff0c;学习常用算法与数据结构&#xff0c;一题多解&#xff0c;降维打击。 本期话题&#xff1a;线性规划单纯形法原理及实现 标准化及单纯形方法 相关学习资料 https://www.bilibili.com/video/BV168411j7XL/?spm_id_from333.788&vd_so…

用于将Grafana默认数据库sqlite3迁移到MySQL数据库

以下是一个方案&#xff0c;用于将Grafana数据迁移到MySQL数据库。 背景: grafana 默认采用的是sqlite3&#xff0c;当我们要以集群形式部署的时使用mysql较为方便&#xff0c;试了很多sqlite转mysql的方法要么收费,最后放弃。选择自己动手风衣足食。 目标: 迁移sqlite3切换…

【深圳游戏业:腾讯引领小型公司创新求发展】

深圳游戏业&#xff1a; 腾讯引领小型公司创新求发展 一 深圳游戏公司主要类型 腾讯集团 作为中国最大的游戏公司&#xff0c;腾讯在游戏领域可以说是第一强者。2022年&#xff0c;腾讯的游戏业务营收高达1707亿元&#xff0c;约占了中国整个游戏市场总收入的64%。 刚开始时&…

【机器学习算法】KNN鸢尾花种类预测案例和特征预处理。全md文档笔记(已分享,附代码)

本系列文章md笔记&#xff08;已分享&#xff09;主要讨论机器学习算法相关知识。机器学习算法文章笔记以算法、案例为驱动的学习&#xff0c;伴随浅显易懂的数学知识&#xff0c;让大家掌握机器学习常见算法原理&#xff0c;应用Scikit-learn实现机器学习算法的应用&#xff0…

基于四足机器人和机械臂的运动控制系统(一)

文章目录 一、项目框架二、设计内容与功能需求1. 导航与路径规划2. 视觉感知3. 运动控制4. 精准遥控5. 环境探测6. 云端监控与数据分析7. 人机协同8. 充电桩9. 紧急响应与救援 三、硬件设计1. 四足机器人2. 机械臂3. 机器主控板4. 遥控器板5. 舵机驱动板 四、软件设计1. 环境2.…

【机器学习笔记】14 关联规则

关联规则概述 关联规则&#xff08;Association Rules&#xff09;反映一个事物与其他事物之间的相互依存性和关联性。如果两个或者多个事物之间存在一定的关联关系&#xff0c;那么&#xff0c;其中一个事物就能够通过其他事物预测到。 关联规则可以看作是一种IF-THEN关系。…

Sora:最强文生视频工具

Sora是什么 Sora&#xff0c;是一款能够根据文本创建出逼真的、富有想象力场景的AI模型。Sora能够娴熟地创造出高达一分钟的高清视频&#xff0c;其视觉内容丰富多样&#xff0c;分辨率精准无误。Sora的强大之处在于&#xff0c;它通过在视频和图像的压缩潜在空间中进行训练&a…

[ai笔记10] 关于sora火爆的反思

欢迎来到文思源想的ai空间&#xff0c;这是技术老兵重学ai以及成长思考的第10篇分享&#xff01; 最近sora还持续在技术圈、博客、抖音发酵&#xff0c;许多人都在纷纷发表对它的看法&#xff0c;这是一个既让人惊喜也感到焦虑的事件。openai从2023年开始&#xff0c;每隔几个…

SpringSecurity + OAuth2 详解

SpringSecurity入门到精通 ************************************************************************** SpringSecurity 介绍 **************************************************************************一、入门1.简介与选择2.入门案例-默认的登录和登出接口3.登录经过了…

笑营宝课后延时服务选课报名管理系统简介

课后延时服务是在“双减”政策背景下推向全国的校园服务。开展丰富多彩的课后服务&#xff0c;既解决家长负担&#xff0c;又能在校内提供作业辅导及素质提升课程&#xff0c;实现教育公平。是解决孩子三点半放学之后的校园服务&#xff0c;但也需要最大限度的降低学校老师的工…

基于java的企业校园招聘平台的设计与实现

分享一个自己的毕业设计&#xff0c;想要获取源码的同学加V&#xff1a;qq2056908377 链接&#xff1a;https://pan.baidu.com/s/1It0CnXUvc9KVr1kDcHWvEw 提取码&#xff1a;1234 摘要&#xff1a; 摘要&#xff1a;本毕业设计旨在设计和实现一个企业校园招聘平台&#xf…

【详细流程】vue+Element UI项目中使用echarts绘制圆环图 折线图 饼图 柱状图

vueElement UI项目中数据分析功能需要用到圆环图 折线图 饼图 柱状图等&#xff0c;可视化图形分析 安装流程及示例 1.安装依赖 npm install echarts --save2.在main.js中引入并挂载echarts import echarts from echarts Vue.prototype.$echarts echarts3.在需要使用echart…

代码随想录刷题笔记-Day20

1. 二叉树的最近公共祖先 236. 二叉树的最近公共祖先https://leetcode.cn/problems/lowest-common-ancestor-of-a-binary-tree/ 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#x…

RecombiMAb anti-mouse CD40,FGK4.5-CP133单克隆抗体

FGK4.5-CP133单克隆抗体是原始FGK4.5单克隆抗体的重组嵌合型抗体。可变结构域序列与原始FGK4.5克隆号相同&#xff0c;但是恒定区序列已经从大鼠IgG2a变为小鼠IgG2a。FGK4.5-CP133抗体像原始大鼠IgG2a抗体一样&#xff0c;不包含Fc突变。 FGK4.5-CP133单克隆抗体能与小鼠CD40(也…

压缩感知(Compressed Sensing,CS)的基础知识

压缩感知&#xff08;Compressed Sensing&#xff0c;CS&#xff09;是一种用于信号处理的技术&#xff0c;旨在以少于奈奎斯特采样定理所要求的样本频率来重构信号。该技术利用信号的稀疏性&#xff0c;即信号可以用较少的非零系数表示。压缩感知在图像获取中的应用使得在采集…

阿里云个人建站笔记

导航 一、购买ECS服务器二、配置mysql&#xff08;一&#xff09;安装Mysql步骤一&#xff1a;安装mysql步骤二&#xff1a;配置MySQL步骤三&#xff1a;远程访问MySQL数据库 &#xff08;二&#xff09;给实例配置安全组策略&#xff08;三&#xff09;设置防火墙 一、购买ECS…

防御保护——综合实验

拓扑图 实验需求&#xff1a; 1.Fw1和Fw2组成主备模式的双机热备 2.DMZ区存在两台服务器&#xff0c;现在要求生产区的设备仅能在办公时间&#xff08;9:00-18:00&#xff09;访问&#xff0c;办公区的设备全天都可以访问。 3.办公区设备可以通过电信链路和移动链路上网(多对多…

Linux 实例常用内核参数介绍—容器访问外部网络之ip_forward数据包转发

文章目录 1 问题解决1.1 问题1.2 原因1.3 解决临时打开永久打开 下面为扩展内容Linux 实例常用内核参数介绍:[https://cloud.tencent.com/document/product/213/46400](https://cloud.tencent.com/document/product/213/46400) 2 net.ipv4.ip_forward内核参数通俗解释3 在Linux…

[office] EXCEL怎么制作大事记图表- #学习方法#其他

EXCEL怎么制作大事记图表? 在宣传方面&#xff0c;经常会看到一些记录历史事件、成长历程的图&#xff0c;非常的直观、好看(如下图所示)。那么是怎么做到呢呢?这里我们介绍一下用EXCEL表格快速做出事件记录图的方法。 1、首先&#xff0c;做出基础表格(如下图一所示)。表格…

nacos部署

简介 Nacos 阿里巴巴推出来的开源项目&#xff0c;是更易于构建云原生应用的动态服务发现、配置管理和服务管理平台 Nacos 致力于发现、配置和管理微服务&#xff0c;并提供简单易用的特性集&#xff0c;能够快速实现动态服务发现、服务配置、服务元数据及流量管理。 Nacos 更…