python中的装饰函数
所谓的函数修饰符,就是在函数外面再套一层函数,
装饰函数可以接受不同的参数类型的函数传入,对其进行装饰加工;
在需要装饰的函数前面@装饰函数即可;
这样子做的好处在于代码的复用,比如不同的算法都需要相同的前处理或者后处理,所以装饰函数可以减少代码的复制粘贴;
个人感觉跟C++中的虚函数的概念和目的是有区别的,虚函数的目的是相同的函数名,不同的对象会调用不同的虚函数实现;
但是python中的装饰函数是对不同的函数传入进行相同的封装操作;
def tankFactory(dec:"function")->"function":
def powerTank(*params:tuple,**kvParams:dict):
dec(*params,**kvParams)
print("已加装反应装甲")
print("已加装反导弹装置")
print("已加装红外干扰仪")
return powerTank
@tankFactory
def tank(player1:'驾驶员',player2:'装填手',player3:'车长',player4:'炮长'):
print("这是一个裸奔的坦克")
print("该坦克有四名成员")
@tankFactory
def mordenTank(player1,player2,player3):
print("这是一个3成员的现代坦克")
tank("tom","jerry","robot1","robot2")
#输出start-----
"""
这是一个裸奔的坦克
该坦克有四名成员
已加装反应装甲
已加装反导弹装置
已加装红外干扰仪
"""
#输出end-----
mordenTank("tom","jerry","robot")
#输出start-----
"""
这是一个3成员的现代坦克
已加装反应装甲
已加装反导弹装置
已加装红外干扰仪
"""
#输出end-----
此外,因为Python中所有函数都是对象,而作为函数修饰符的函数也不例外,而解释器在处理这种特殊的函数时,有时候会忘记这是一个函数修饰符。所以我们需要显示地告诉Python解释器,这是一个作为函数修饰符的特殊函数,而非普通货色。
至于如何做,很简单:
从functools模块引入一个函数wraps,并在函数修饰符中调用这个函数。如下示例:
from functools import wraps
def tankFactory(dec:"function")->"function":
@wraps(dec)
def powerTank(*params:tuple,**kvParams:dict):
dec(*params,**kvParams)
print("已加装反应装甲")
print("已加装反导弹装置")
print("已加装红外干扰仪")
return powerTank
@tankFactory
def tank(player1:'驾驶员',player2:'装填手',player3:'车长',player4:'炮长'):
print("这是一个裸奔的坦克")
print("该坦克有四名成员")
@tankFactory
def mordenTank(player1,player2,player3):
print("这是一个3成员的现代坦克")
tank("tom","jerry","robot1","robot2")
mordenTank("tom","jerry","robot")
mmdetection的注册机机制
mmdet的注册机机制是基于装饰器实现的,Register这个类中有一个装饰器函数register_module(),通过这个装饰器函数可以把类/函数装入到Register这个类的成员变量_module_dict中,_module_dict的目的是维护每个注册器中的类和函数到字符串的映射;构建目标实例的时候,就可以根据config中指定的type字符串构建对应的类或字符串;
整体注册器机制的好处在于可以把同一功能属性的函数/类归属不同的模块(也就是归属于不同的注册器),构建整体pipeline的时候,config中指定type就可以方便地构建整体模型,也就是更改config就可以改动整体pipeline的某个小模块;
build_func 要么从自定义,要么从父类继承,要么就是build_from_cfg;
其中build_from_cfg就是从cfg中读取type字段,其他字段传入class,然后根据_module_dict去读取对应的类或者函数返回实例化的对象;
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
实例化部分代码如下:
obj_type = args.pop('type') #取出type字段
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None: #判断在不在注册列表里
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args) #使用类
Register源代码如下:
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import warnings
from functools import partial
from .misc import deprecated_api_warning, is_seq_of
[docs]def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict when it is a class configuration, or
call a function from config dict when it is a function configuration.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
>>> # Returns an instantiated object
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
>>> # Return a result of the calling function
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
[docs]class Registry:
"""A registry to map strings to classes or functions.
Registered object could be built from registry. Meanwhile, registered
functions could be called from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = MODELS.build(dict(type='resnet50'))
Please refer to
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
advanced usage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
[docs] @staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
>>> # in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
str: The inferred scope name.
"""
# We access the caller using inspect.currentframe() instead of
# inspect.stack() for performance reasons. See details in PR #1844
frame = inspect.currentframe()
# get the frame where `infer_scope()` is called
infer_scope_caller = frame.f_back.f_back
filename = inspect.getmodule(infer_scope_caller).__name__
split_filename = filename.split('.')
return split_filename[0]
[docs] @staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
tuple[str | None, str]: The former element is the first scope of
the key, which can be ``None``. The latter is the remaining key.
"""
split_index = key.find('.')
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
[docs] def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(type='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert registry.scope not in self.children, \
f'scope {registry.scope} exists in {self.name} registry'
self.children[registry.scope] = registry
@deprecated_api_warning(name_dict=dict(module_class='module'))
def _register_module(self, module, module_name=None, force=False):
if not inspect.isclass(module) and not inspect.isfunction(module):
raise TypeError('module must be a class or a function, '
f'but got {type(module)}')
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
[docs] def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class or function to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
return module
return _register
mmdet中的register.py中声明了所有需要用到的注册机:
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry(
'runner', parent=MMENGINE_RUNNERS, locations=['mmdet.engine.runner'])
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor',
parent=MMENGINE_RUNNER_CONSTRUCTORS,
locations=['mmdet.engine.runner'])
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry(
'loop', parent=MMENGINE_LOOPS, locations=['mmdet.engine.runner'])
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmdet.engine.hooks'])
# manage data-related modules
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmdet.datasets'])
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmdet.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmdet.datasets.transforms'])
# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmdet.models'])
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry(
'model_wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmdet.models'])
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmdet.models'])
# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
locations=['mmdet.engine.optimizers'])
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry(
'optim_wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
locations=['mmdet.engine.optimizers'])
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
locations=['mmdet.engine.optimizers'])
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
locations=['mmdet.engine.schedulers'])
# manage all kinds of metrics
METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmdet.evaluation'])
# manage evaluator
EVALUATOR = Registry(
'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmdet.evaluation'])
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmdet.models'])
# manage visualizer
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmdet.visualization'])
# manage visualizer backend
VISBACKENDS = Registry(
'vis_backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmdet.visualization'])
# manage logprocessor
LOG_PROCESSORS = Registry(
'log_processor',
parent=MMENGINE_LOG_PROCESSORS,
# TODO: update the location when mmdet has its own log processor
locations=['mmdet.engine'])
然后我们就可以把需要的函数/类注册到对应的register的_module_dict里面;
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
"""Implementation of `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
init_cfg=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
然后注册器调用build的函数就可以构建对应的实例,只需要cfg作为参数传入,在该注册机的注册表里面的函数/类会进行实例化:
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
总结而言:
注册器机制其实本质就是:
1.把 各个模块按功能归类到不同注册器;
2.维护注册机的函数/类=》字符串的列表,让根据字符串就可以实例化对象,所以更改config文件就可以轻松更改pipeline中的某个小模块;
mmengine的runner机制
深度学习算法的训练、验证和测试通常都拥有相似的流程,因此, MMEngine 抽象出了执行器来负责通用的算法模型的训练、测试、推理任务。用户一般可以直接使用 MMEngine 中的默认执行器,也可以对执行器进行修改以满足定制化需求。
mmengine认为深度学习的执行流程都可以归纳为构建模型、读取数据、循环迭代等步骤;
MMengine中包含训练模型所需要的各个模块(比如模型/数据集构建加载,优化器设置,学习率调整策略,Log日志和可视化等等)、以及钩子和循环Loop;其中把钩子挂在各循环Loop中,这里的钩子挂的一般是前后处理的步骤以及模型的iter,所以整体的流程如下:创建工作目录和工作环境=》准备好各个模块=》挂钩子=》运行循环:
mmengine定义了常用的四种循环的Loop:基于iter的训练Loop,基于epoch的训练Loop,val Loop及Test的Loop:
用户也可以继承循环类定义自己的循环,但是需要有runner的实例和dataloader的迭代器作为输入,还可以有其他额外的输入;
之所以要有这两个基本的输入,原因在于Loop其实是为runner实例服务的,然后有数据集的dataloader才能执行Loop,所以以上两个是必不可少的;可以在Loop中定义任何需要的逻辑(定义钩子、挂上钩子或者删减钩子即可),以下是一个官方的示例代码:
from mmengine.registry import LOOPS, HOOKS
from mmengine.runner import BaseLoop
from mmengine.hooks import Hook
# 自定义验证循环
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator, dataloader2):
super().__init__(runner, dataloader, evaluator)
self.dataloader2 = runner.build_dataloader(dataloader2)
def run(self):
self.runner.call_hooks('before_val_epoch')
for idx, data_batch in enumerate(self.dataloader):
self.runner.call_hooks(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
outputs = self.run_iter(idx, data_batch)
self.runner.call_hooks(
'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
metric = self.evaluator.evaluate()
# 增加额外的验证循环
for idx, data_batch in enumerate(self.dataloader2):
# 增加额外的钩子点位
self.runner.call_hooks(
'before_valloader2_iter', batch_idx=idx, data_batch=data_batch)
self.run_iter(idx, data_batch)
# 增加额外的钩子点位
self.runner.call_hooks(
'after_valloader2_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
metric2 = self.evaluator.evaluate()
...
self.runner.call_hooks('after_val_epoch')
# 定义额外点位的钩子类
@HOOKS.register_module()
class CustomValHook(Hook):
def before_valloader2_iter(self, batch_idx, data_batch):
...
def after_valloader2_iter(self, batch_idx, data_batch, outputs):
...
然后在配置文件中添加自定义循环的字段就可以了:
# 自定义验证循环
val_cfg = dict(type='CustomValLoop', dataloader2=dict(dataset=dict(type='ValDataset2'), ...))
# 额外点位的钩子
custom_hooks = [dict(type='CustomValHook')]
同理,runner也可以执行自定义:
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
@RUNNERS.register_module()
class CustomRunner(Runner):
def setup_env(self):
...
mmenigine的Loop其实是由Hook组成的,每个Hook会继承Hook基类,然后重写定义其before_train/before_train_iter或者其他部分,Hook会有优先级priority的设置,同一个优先级会根据钩子的注册顺序执行;
我们如果要往训练的Loop中添加执行逻辑,可以继承Hook基类或者其他Hook类(比如EMAHook)然后定义优先级(挂载点);
主题结构
apis:里面定义了推理的封装接口;
configs:
其中_base_文件夹下定义了datasets、models、schedules三个配置文件,datasets的配置文件主要包含train/val的数据加载dataloader和形变pipeline;models主要是包含了基础主流检测模型的模型定义、训练配置、测试配置(cascade rcnn\faster RCNN\ssd等等);schedules是学习率和优化器配置;然后其他文件夹下一般都是继承了_base_文件夹里的某个配置,在上面增删修改得到新的模型;
接下来说engine模块:
engines下面有四个子文件夹:hooks,optimizers,runner,schedules