注册器设计模式
from detectron2.utils.registry import Registry
这里面的Registry实际上就是注册器设计模式
1.举一个简单版的小例子来理解注册器设计模式
参考:Python注册器设计模式_python 注册类-CSDN博客
# 这一行代码是从Python的 typing 模块中导入了一些类型提示(type hint)相关的工具。
# Python的 typing 模块提供了静态类型检查的支持,这对提高代码的可读性和可靠性很有帮助
from typing import Any, Callable, Dict, List, Optional, Type, Union
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-ignore-all-errors[2,3]
from typing import Any, Dict, Iterable, Iterator, Tuple
from tabulate import tabulate
class Registry(Iterable[Tuple[str, Any]]):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, Any] = {}
def _do_register(self, name: str, obj: Any) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: Any = None) -> Any:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: Any) -> Any:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name: str) -> Any:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(name, self._name)
)
return ret
def __contains__(self, name: str) -> bool:
return name in self._obj_map
def __repr__(self) -> str:
table_headers = ["Names", "Objects"]
table = tabulate(
self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
)
return "Registry of {}:\n".format(self._name) + table
def __iter__(self) -> Iterator[Tuple[str, Any]]:
return iter(self._obj_map.items())
def build(self, cfg: Dict[str, Any]) -> Any:
if not isinstance(cfg, dict) or 'type' not in cfg:#检查字典是否符合规定,规定中,键值要写为 type
raise TypeError('cfg must be a dict and contain the key "type"')
module_type = cfg.pop('type')#获取模块类型
module_cls = self.get(module_type)#获取模块类
if module_cls is None:#如果模块不存在,报错
raise KeyError(f'{module_type} is not registered in {self._name}')
return module_cls(**cfg)#实例化模块并且返回
# pyre-fixme[4]: Attribute must be annotated.
__str__ = __repr__
# 用法示例
if __name__ == '__main__':
# 第二步、定义一个注册表
MODELS = Registry("SPARSE_INST_ENCODER")
# 第三步、使用装饰器在注册表注册模块,比如一个类
@MODELS.register()
class ResNet:
def __init__(self, depth):
self.depth = depth
# 使用装饰器注册模块,比如一个函数
@MODELS.register()
def resnet50():
return ResNet(depth=50)
# 使用普通函数注册模块,比如一个类,先声明一个类
class MobileNet:
def __init__(self, width_multiplier):
self.width_multiplier = width_multiplier
#注册到注册表中
MODELS.register(MobileNet)#或者直接在MobileNet类声明上面用装饰器@MODELS.register_module()
# 第四步、使用build函数,构建模型(其实是各个已注册模块的实例化)
resnet = MODELS.build(dict(type='ResNet', depth=18))
print(f'ResNet: depth = {resnet.depth}') # 输出: ResNet: depth = 18
mobilenet = MODELS.build(dict(type='MobileNet', width_multiplier=1.0))
print(f'MobileNet: width_multiplier = {mobilenet.width_multiplier}') # 输出: MobileNet: width_multiplier = 1.0
resnet_instance = MODELS.build(dict(type='resnet50'))
print(
f'ResNet instance from function: depth = {resnet_instance.depth}') # 输出: ResNet instance from function: depth = 50
注册器模式实现的四个步骤:
1、定义注册器类
注册器类比较关键,需要实现了好几个功能,各种模块的
注册:内部函数_do_register负责具体注册的实现;外部函数register暴漏给编码人员,写代码的时候用
储存:将被注册的模块(类、函数、等)存在注册器类的字典中。所以一般__init__() 里会初始化一个字典
获取:使用函数get,获取已注册对象,传入类的名称,返回这个类的实际实现的引用
实例化:创建build函数,实例化被注册的模块
2、初始化注册器
注册器又称注册表,创建一个注册器:用注册器类新建一个对象。
MODELS = Registry(‘models’)
3、注册可调用对象
将模块(类、函数、等)注册到注册器中去
可以用隐式的方法:装饰器挂在类的声明实现的头上,就可以完成注册了
@MODELS.register_module()
class ResNet:
如以上代码,能把ResNet类注册到注册表中。
也可以使用显示的注册方式,如
# 使用普通函数注册模块,比如一个类,先声明一个类
class MobileNet:
def __init__(self, width_multiplier):
self.width_multiplier = width_multiplier
#注册到注册表中
MODELS.register(MobileNet)#或者直接在MobileNet类声明上面用装饰器@MODELS.register_module()
这两种写法都是可以的,底层实现是一样的。
4、使用注册器构建对象
也就是已经被注册到注册表中的模块(类、函数、等)的实例化,实例化用的是build函数
比如可以这么写:
resnet_cfg = {'type': 'ResNet', 'depth': 18}
resnet = MODELS.build(resnet_cfg)
也可
resnet = MODELS.build(dict(type='ResNet', depth=18))
2.detectron2里的注册器设计模式
detectron2第一步定义注册器类的代码:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-ignore-all-errors[2,3]
from typing import Any, Dict, Iterable, Iterator, Tuple
from tabulate import tabulate
class Registry(Iterable[Tuple[str, Any]]):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, Any] = {}
def _do_register(self, name: str, obj: Any) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: Any = None) -> Any:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: Any) -> Any:
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name: str) -> Any:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(name, self._name)
)
return ret
def __contains__(self, name: str) -> bool:
return name in self._obj_map
def __repr__(self) -> str:
table_headers = ["Names", "Objects"]
table = tabulate(
self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
)
return "Registry of {}:\n".format(self._name) + table
def __iter__(self) -> Iterator[Tuple[str, Any]]:
return iter(self._obj_map.items())
# pyre-fixme[4]: Attribute must be annotated.
__str__ = __repr__
第二步初始化注册器的代码:
以初始化BACKBONE_REGISTRY注册器为例
第三步注册可调用对象:
第四步使用注册器构建对象:
实际上构建对象就只有一句话
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
传入参数meta_arch的值为字字符串'SparseInst'。所以META_ARCH_REGISTRY.get(meta_arch)得到了SparseInst类名,META_ARCH_REGISTRY.get(meta_arch)(cfg)实际上就是SparseInst(cfg),也就是上面这句代码跟这句代码等价
model=SparseInst(cfg)