之前跑了一下mmdetection 3.x自带的一些算法, 但是具体的代码细节总是看了就忘, 所以想做一些笔记, 方便初学者参考. 其实比较不能忍的是, 官网的文档还是空的…
这次想写其中的数据流是如何运作的, 包括从读取数据集的样本与真值, 到数据增强, 再到模型的forward当中.
0. MMDetection整体组成部分
让我们首先回顾一下C++的标准模板库(STL)是怎样设计的. STL的三个核心组件是容器, 算法与迭代器. 容器, 例如vector, queue等等, 他们是负责存储数据的, 算法是负责进行一些操作, 例如排序, 查找等等. 而迭代器是容器与算法之间的桥梁, 也就是算法可以通过迭代器去访问容器, 使得算法可以独立于容器的类型进行操作. 三个部分相辅相成, 就达到了泛型编程的理念.
再让我们回顾一下一套深度学习的代码包含什么部分. 从大的方面来说, 需要有数据的读取与增强(DataLoader), 模型的定义, 损失函数的计算, 负责梯度传播的优化器, 在验证(测试)集上的评估等. 同理, MMDetection也是按照这种方式来的, 并且每个部分接口相通, 就可以实现更广义的模型定义和训练方式.
在mmengine/registry/__init__.py
中, 我们可以看到, MMEngine(或者说MMDetection)总体有这些类型的模块:
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, HOOKS,
INFERENCERS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
WEIGHT_INITIALIZERS)
那么以上这么多模块可以分成几类, 分别负责什么呢? 按照我个人的理解, MMDetection的整体组成部分可以表示为下图:
为了节省空间, 优化器相关并未画出
1. 认识config文件
mmdetection设计的核心思想是通过字典来配置整个的训练过程和模型定义, 这些字典放在一个.py的config文件中. 一般来说,config文件最重要的就是数据加载(train_dataloader, val_dataloader和test_dataloader
), 模型定义(model
)和训练与测试过程(train_pipeline, test_pipeline
). 除此之外, 还有一些训练, 测试配置(train_cfg, test_cfg
)等等. 具体config的例子可以参照官网Learn about configs.
需要注意的是, mmdetection中字典定义class的方式, 往往是键type
表示类的名字, 之后的其他键都是类初始化需要的参数. 例如, 如果我想自定义一个模型, 叫做MyModel
, 定义在当前目录下的./models/my_model.py
中, 定义方式如下:
from mmdet.registry import MODELS # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel # 一个模型基类
@MODELS.register_module() # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
def __init__(self,
arg1=...,
arg2=...,
arg3=...):
...
def loss(self, inputs, data_samples): # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
...
如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:
# 必须将自定义类的py文件导入 这样可以自动register自定义模型 否则模型初始化时找不到
custom_imports = dict(
imports=['models.my_model'],
allow_failed_imports=False)
# 现在就可以愉快的传参了
models=dict(
type='MyModel',
arg1=1,
arg2=[16, 128],
arg3=dict(channel=256),
...
)
同样, 我们可以自定义DataLoader, Loss, 等等.
此外, dict是可以嵌套的, 例如mmdetection将检测模型分成了backbone, neck和head
三部分, 那么如果我们又自定义了一个Head, 叫MyHead
:
from mmdet.registry import MODELS # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmengine.model import BaseModule # 一个模型基类
@MODELS.register_module() # 装饰器 在模型库中"注册"
class MyHead(BaseModule):
def __init__(self,
arg4=...):
...
这样, 如果MyModel
的前向传播过程中需要一个head, 则代码大致是这个样子:
from mmdet.registry import MODELS # 自定义模型, 需要在模型库中"注册", 初始化时才能找到定义
from mmdet.models.mot.base import BaseMOTModel # 一个模型基类
@MODELS.register_module() # 装饰器 在模型库中"注册"
class MyModel(BaseMOTModel):
def __init__(self,
arg1=...,
arg2=...,
arg3=...,
head=...):
self.head = MODELS.build(head) # 建立Head的模型, 类型是nn.Module
...
def loss(self, inputs, data_samples): # 前向传播, inputs是输入tensor, data_samples是包含标签的列表
... # 一些其他过程
ret = self.head(inputs) # forward
... # 后处理
配置文件中对应更改为:
如果按上述方式定义了模型, 那么在我们的配置文件中, 就是这个样子:
custom_imports = dict(
imports=['models.my_model', '自定义HEAD所在的py文件'],
allow_failed_imports=False)
models=dict(
type='MyModel',
arg1=1,
arg2=[16, 128],
arg3=dict(channel=256),
head=dict( # 定义head
type='MyHead',
arg4=256,
...
)
...
)
篇幅所限, 自定义损失函数, 数据增强之类的就不一一列举了.
2. 数据流
我们接下来以检测与跟踪任务为例, 看看数据到底是如何被读入的. 我们以训练过程说明.
在训练过程中, 我们会初始化一个RUNNER类, 其读入我们的config文件并依次完成各种(模型, 数据加载, 优化器, 钩子等等)的初始化. 我们以官方提供的train.py
为例:
runner = Runner.from_cfg(cfg)
from_cfg()
是一个类方法(classmethod), 在其中我们实例化了Runner
类.
随后, 我们调用Runner
的train()
方法进行训练. 首先, 我们实例化训练循环:
self._train_loop = self.build_train_loop(
self._train_loop) # type: ignore
训练循环就属于LOOP类型.
在这里, 我们以最常用的EpochBasedTrainLoop
为例. 在EpochBasedTrainLoop
的初始化函数中, 根据config文件中的train_dataloader
字典实例化出torch
的DataLoader
类():
data_loader = DataLoader(
dataset=dataset,
sampler=sampler if batch_sampler is None else None,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
worker_init_fn=init_fn,
**dataloader_cfg)
return data_loader
当然, 我们知道torch的DataLoader类在调用的时候, 会调用到dataset(类别是torch.utils.data.Dataset
)的__getitem__
方法. 因此, 我们从__getitem__
入手来探索数据流.
在MMDetection的设计中, 数据集的类都是继承于MMengine中的BaseDataset
, 其中的__getitem__
是这样写的:
def __getitem__(self, idx: int) -> dict:
if not self._fully_initialized:
print_log(
'Please call `full_init()` method manually to accelerate '
'the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
if self.test_mode:
data = self.prepare_data(idx)
if data is None:
raise Exception('Test time pipline should not get `None` '
'data_sample')
return data
for _ in range(self.max_refetch + 1):
data = self.prepare_data(idx)
# Broken images or random augmentations may cause the returned data
# to be None
if data is None:
idx = self._rand_another()
continue
return data
raise Exception(f'Cannot find valid image after {self.max_refetch}! '
'Please check your image path and pipeline')
我们可以看到, 在__getitem__
中最核心的是self.prepare_data(idx)
. 按照这种思路一级一级向上查找, 我们就可以总结出如下图的数据读取流程:
其中, 数据增强pipeline是一系列类型为TRANSFORMS
类的列表, 再每经过一次数据增强时, 字典都会被更新.
我们以较为常用的随机便宜(RandomShift)来说, 其是这样定义的:
@TRANSFORMS.register_module()
class RandomShift(BaseTransform):
def __init__(self,
...
@autocast_box_type()
def transform(self, results: dict) -> dict: # transform方法, 更新字典, 图像与对应的边界框等都需要被更新
"""Transform function to random shift images, bounding boxes.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Shift results.
"""
if self._random_prob() < self.prob:
img_shape = results['img'].shape[:2]
random_shift_x = random.randint(-self.max_shift_px,
self.max_shift_px)
random_shift_y = random.randint(-self.max_shift_px,
self.max_shift_px)
new_x = max(0, random_shift_x)
ori_x = max(0, -random_shift_x)
new_y = max(0, random_shift_y)
ori_y = max(0, -random_shift_y)
# TODO: support mask and semantic segmentation maps.
bboxes = results['gt_bboxes'].clone()
bboxes.translate_([random_shift_x, random_shift_y])
# clip border
bboxes.clip_(img_shape)
# remove invalid bboxes
valid_inds = (bboxes.widths > self.filter_thr_px).numpy() & (
bboxes.heights > self.filter_thr_px).numpy()
# If the shift does not contain any gt-bbox area, skip this
# image.
if not valid_inds.any():
return results
bboxes = bboxes[valid_inds]
results['gt_bboxes'] = bboxes
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
valid_inds]
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = \
results['gt_ignore_flags'][valid_inds]
# shift img
img = results['img']
new_img = np.zeros_like(img)
img_h, img_w = img.shape[:2]
new_h = img_h - np.abs(random_shift_y)
new_w = img_w - np.abs(random_shift_x)
new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
= img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
results['img'] = new_img
return results
需要注意的是, 经过pipeline后, 字典最终会被更新成如下形式:
dict = {'inputs': torch.Tensor,
'data_samples': DetDataSample或TrackDataSample等}
其中'inputs'
键对应的值就是转换为tensor的图片, 而'data_samples'
键对应的值是表示样本的类, 在检测任务中, 是DetDataSample
, 跟踪任务中, 是TrackDataSample
. DetDataSample
类有许多成员, 包括该样本(图片)的目标的边界框真值, 分割真值等:
class DetDataSample(BaseDataElement):
"""A data structure interface of MMDetection. They are used as interfaces
between different components.
The attributes in ``DetDataSample`` are divided into several parts:
- ``proposals``(InstanceData): Region proposals used in two-stage
detectors.
- ``gt_instances``(InstanceData): Ground truth of instance annotations.
- ``pred_instances``(InstanceData): Instances of detection predictions.
- ``pred_track_instances``(InstanceData): Instances of tracking
predictions.
- ``ignored_instances``(InstanceData): Instances to be ignored during
training/testing.
- ``gt_panoptic_seg``(PixelData): Ground truth of panoptic
segmentation.
- ``pred_panoptic_seg``(PixelData): Prediction of panoptic
segmentation.
- ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation.
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
以上过程可以借用MMEngine文档里的一个图说明:
最终, 模型的forward, loss, predict等方法都是接收inputs: torch.Tensor
与data_samples
作为输入, 例如:
def loss(self, inputs: Tensor, data_samples: TrackSampleList,
**kwargs) -> Union[dict, tuple]: