记录MMDetection研究过程
0.前言
参考:
1.MMDetection框架入门教程(完全版)
2.
1.框架概述
MMDetection是商汤和港中文大学针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块,通过模块调用的方式,我们能够以很少的代码量实现一个新算法,大大提高了代码复用率
整个MMLab家族除了MMDetection,还包含针对目标跟踪任务的MMTracking,针对3D目标检测任务的MMDetection3D等开源项目,他们都是以Pytorch和MMCV以基础。
Pytorch不需要过多介绍,
MMCV是一个面向计算机视觉的基础库,最主要作用是提供了基于Pytorch的通用训练框架,比如我们常提到的Registry、Runner、Hook等功能都是在MMCV中支持的。
另外,MMCV还提供了通用IO接口、多种CNN网络结构、高质量实现的常见CUDA算子。
2. 框架流程
2.1 Pytorch算法开发流程
我们使用Pytorch构建一个新算法时,通常包含如下几步:
构建数据集:新建一个类,并继承Dataset类,重写__getitem__()方法实现数据和标签的加载和遍历功能,并以pipeline的方式定义数据预处理流程
构建数据加载器:传入相应的参数实例化DataLoader
构建模型:新建一个类,并继承Module类,重写forward()函数定义模型的前向过程
定义损失函数和优化器:根据算法选择合适和损失函数和优化器
训练和验证:循环从DataLoader中获取数据和标签,送入网络模型,计算loss,根据反传的梯度使用优化器进行迭代优化
其他操作:在主调函数里可以任意穿插训练Tricks、日志打印、检查点保存等操作
2.2 MMDetection算法开发流程
使用Pytorch构建一个新算法时,通常包含如下几步:
- 注册数据集:CustomDataset是MMDetection在原始的Dataset基础上的再次封装,其__getitem__()方法会根据训练和测试模式分别重定向到prepare_train_img()和prepare_test_img()函数。用户以继承CustomDataset类的方式构建自己的数据集时,需要重写load_annotations()和get_ann_info()函数,定义数据和标签的加载及遍历方式。完成数据集类的定义后,还需要使用DATASETS.register_module()进行模块注册。
- 注册模型:模型构建的方式和Pytorch类似,都是新建一个Module的子类然后重写forward()函数。唯一的区别在于MMDetection中需要继承BaseModule而不是Module,BaseModule是Module的子类,MMLab中的任何模型都必须继承此类。另外,MMDetection将一个完整的模型拆分为backbone、neck和head三部分进行管理,所以用户需要按照这种方式,将算法模型拆解成3个类,分别使用BACKBONES.register_module()、NECKS.register_module()和HEADS.register_module()完成模块注册。
- 构建配置文件:配置文件用于配置算法各个组件的运行参数,大体上可以包含四个部分:datasets、models、schedules和runtime。完成相应模块的定义和注册后,在配置文件中配置好相应的运行参数,然后MMDetection就会通过Registry类读取并解析配置文件,完成模块的实例化。另外,配置文件可以通过_base_字段实现继承功能,以提高代码复用率。
- 训练和验证:在完成各模块的代码实现、模块的注册、配置文件的编写后,就可以使用./tools/train.py和./tools/test.py对模型进行训练和验证,不需要用户编写额外的代码。
2.3 流程对比
虽然从步骤上看MMDetection相比Pytorch的算法实现步骤存在挺大差异,但底层的逻辑实现和Pytorch本质上还是一样的,可以参考下图对照着进行理解,其中蓝色部分表示Pytorch流程,橙色部分表示MMDetection流程,绿色部分表示和算法框架无关的通用流程。
在开始接触MMDetection的算法实现流程之前,必须要先对注册机制和Hook机制有一个大致的了解,推荐先快速阅读,对注册机制和Hook机制先有一个大体上的了解,看完第五章后再回过头来看注册机制和Hook机制的细节部分会有更深的体会。
3.注册机制
3.1 Registry类
MMDetection作为MMCV的下游项目,继承了MMCV的模块管理方式——注册机制。简单来说,注册机制就是维护几张查询表,key是模块的名称,value是模块的句柄,每张查询表都管理一批功能相似的不同模块。我们每新建一个模块,都要根据模块实现的功能将对应的key-value查询对保存到对应的查询表中,这个保存的过程就称为“注册”。当我们想要调用某个模块时,只需要根据模块名称从查询表中找到对应的模块句柄,然后就能完成模块初始化或方法调用等操作。MMCV通过Registry类来实现字符串(key)到类(value)的映射。
Registry的构造函数如下所示,变量self._module_dict就是上面提到的“查询表”,注册的模块都会存到这个字典类型的变量里,新建一个Registry实例就是新建一张查询表。另外,Registry还支持继承机制。
from mmcv.utils import Registry
class Registry:
# 构造函数
def __init__(self, name, build_func=None, parent=None, scope=None):
# 注册器的名称
self._name = name
# 使用module_dict管理字符串到类的映射
self._module_dict = dict()
# 使用children管理注册器的子类
self._children = dict()
# build_func按照如下优先级初始化:
# 1. build_func: 优先使用指定的函数
# 2. parent.build_func: 其次使用父类的build_func
# 3. build_from_cfg: 默认从config dict中实例化对象
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
# 设置父类-子类的从属关系
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
模块的注册通过Registry的成员函数register_module()来实现,register_module()内部又会调用另一个私有函数_register_module(),模块注册的核心功能其实是在_register_module()中实现的。核心代码也很简单,就是将传入的module_name和module_class保存到字典self._module_dict中。
def _register_module(self, module_class, module_name=None, force=False):
# 如果未指定模块名称则使用默认名称
if module_name is None:
module_name = module_class.__name__
# 为了支持在nn.Sequentail中构建pytorch模块, module_name为list形式
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
# 如果force=False, 则不允许注册相同名称的模块
# 如果force=True, 则用后一次的注册覆盖前一次
if not force and name in self._module_dict:
raise KeyError(f'{
name} is already registered in {
self.name}')
# 将当前注册的模块加入到查询表中
self._module_dict[name] = module_class
在我们通过字符串获取到一个模块的句柄后,可以通过self.build_func函数句柄来实例化这个模块。build_func可以人为指定,也可以从父类继承,一般来说都是默认使用build_from_cfg()函数,即使用配置参数cfg来初始化该模块。配置参数cfg是一个字典,里面的type字段是模块名称的字符串,其他字段则对应模块构造函数的输入参数。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
# 将cfg以外的外部传入参数也合并到args中
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# 获取模块名称
obj_type = args.pop('type')
if isinstance(obj_type, str):
# get函数返回registry._module_dict中obj_type对应的模块句柄
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f'{
obj_type} is not in the {
registry.name} registry')
elif inspect.isclass(obj_type):
# type值是模块本身
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {
type(obj_type)}')
# 模块初始化, 返回模块实例
try:
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{
obj_cls.__name__}: {
e}')
考虑到registry参数需要指向当前注册器本身,我们一般是调用Registry类的build()方法而不是self.build_func。
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
下面是一个小例子,模拟了网络模型的注册和调用过程。注意一下,我们打印Registry对象时,实际上打印的是self._module_dict中的values。
# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')
# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
def __init__(self, depth):
self.depth = depth
print('Initialize ResNet{}'.format(depth))
# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)
class FPN(object):
def __init__(self, in_channel):
self.in_channel= in_channel
print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)
print