目录
- 1. 配置文件
- 2. LazyConfig 导入导出
- 3. 递归实例化
- 4. 基于LazyConfig的训练步骤
- 4.1 导入依赖库
- 4.2 日志初始化
- 4.3 训练
- 4.4 评估
- 4.5 训练流程
- 4.6 主函数入口
- 5. Tips
Detectron2是Facebook AI Research(FAIR)推出的基于Pytorch的视觉算法开源框架,主要聚焦于目标检测和分割任务等视觉算法,此外还支持全景分割,关键点检测,旋转框检测等任务。Detectron2继承自Detectron 和mask-rcnn。
Detectron2具有较强的灵活性和可扩展性,支持快速的单GPU训练,多GPU并行训练和多节点分布式训练。
1. 配置文件
Detectron2 原本采用的是基于一种 key-value的基础config 系统, 采用 YAML格式。但是YAML是一种非常受限制的语言,不能存储复杂的数据结构,因此转而使用 一种更强大的 配置文件系统 LazyConfig system。
YAML-维基百科 是一种人类可读、数据可序列化(可保持成文件和重新加载恢复)的语言, JSON 格式也是一种合法的YAML。原始 的YAML只支持编码 标量(字符串,整数,浮点数)和关系数组(map, 字典,hash表)。YAML推荐的后缀名为
.yaml
2. LazyConfig 导入导出
直接采用 python 脚本作为 配置文件载体,可以通过 python代码快速操作。支持丰富的数据类型。可以运行简单的函数。通过python的import语法导入导出。
config_test.py
inputs = [1024, 960] # 输入大小
batch_size = 128
train_dict = {"input": inputs, "batch_size": batch_size}
通过 detectron2 提供的API 加载配置文件
。方便获取属性和配置, 但是代码无法补全
from detectron2.config import LazyConfig
cfg=LazyConfig.load("config_test.py")
print(cfg.train_dict.batch_size) # 方便获取属性和配置, 但是代码无法补全,
LazyConfig.save(cfg, "test.yaml") # 导出配置到yaml文件, 部分无法序列化的数据类型不能保存,如numpy 数组
test.yaml
train_dict:
batch_size: 128
input: [1024, 960]
3. 递归实例化
LazyConfig 采用递归实例化 特性,将函数和类的调用
表示为字典。在调用时并不会立即执行 对应的函数,只返回一个字典 描述这个 call
, 只有在实例化时才真正执行。
from detectron2.config import instantiate, LazyCall
import torch.nn as nn
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) # 调用nn.Conv2d, 并配置参数
layer_cfg.out_channels = 64 # can edit it afterwards , 修改 参数
layer = instantiate(layer_cfg) # 实例化对象,创建一个2维卷积层
LazyCall
class LazyCall:
def __init__(self, target):
self._target = target
def __call__(self, **kwargs):
if is_dataclass(self._target):
# omegaconf object cannot hold dataclass type
# https://github.com/omry/omegaconf/issues/784
target = _convert_target_to_string(self._target)
else:
target = self._target
kwargs["_target_"] = target
return DictConfig(content=kwargs, flags={"allow_objects": True})
instantiate
def instantiate(cfg):
"""
Recursively instantiate objects defined in dictionaries
"""
from omegaconf import ListConfig, DictConfig, OmegaConf
if isinstance(cfg, ListConfig):
lst = [instantiate(x) for x in cfg] # 递归调用
return ListConfig(lst, flags={"allow_objects": True})
if isinstance(cfg, list):
# Specialize for list, because many classes take
# list[objects] as arguments, such as ResNet, DatasetMapper
return [instantiate(x) for x in cfg]
if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
return OmegaConf.to_object(cfg)
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
# but faster: https://github.com/facebookresearch/hydra/issues/1200
cfg = {k: instantiate(v) for k, v in cfg.items()}
cls = cfg.pop("_target_")
cls = instantiate(cls)
if isinstance(cls, str):
cls_name = cls
cls = locate(cls_name)
assert cls is not None, cls_name
else:
try:
cls_name = cls.__module__ + "." + cls.__qualname__
except Exception:
# target could be anything, so the above could fail
cls_name = str(cls)
assert callable(cls), f"_target_ {cls} does not define a callable object"
try:
return cls(**cfg) ## 根据c
except TypeError:
logger = logging.getLogger(__name__)
logger.error(f"Error when instantiating {cls_name}!")
raise
return cfg # return as-is if don't know what to do
4. 基于LazyConfig的训练步骤
4.1 导入依赖库
import logging
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
AMPTrainer, # 自动混合精度训练
SimpleTrainer,
default_argument_parser,
default_setup, # 默认配置参数
default_writers,
hooks,
launch, # 分布式训练启动器
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm
4.2 日志初始化
logger = logging.getLogger("detectron2")
4.3 训练
def do_train(args, cfg):
model = instantiate(cfg.model) # 获取模型
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
cfg.optimizer.params.model = model
optim = instantiate(cfg.optimizer) # 获取优化器
train_loader = instantiate(cfg.dataloader.train)# 获取训练dataloader
model = create_ddp_model(model, **cfg.train.ddp) # 并行模型
# 混合精度训练
trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
checkpointer = DetectionCheckpointer( # checkpoint 管理
model,
cfg.train.output_dir,
trainer=trainer,
)
trainer.register_hooks( # 注册回调函数
[
hooks.IterationTimer(), # 计时器
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process() # 主进程 周期保存 checkpoint
else None,
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), # 评估
hooks.PeriodicWriter( # 保存训练日志
default_writers(cfg.train.output_dir, cfg.train.max_iter),
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) # 初始化或者恢复训练
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
4.4 评估
def do_test(cfg, model):
if "evaluator" in cfg.dataloader:
ret = inference_on_dataset(
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
)
print_csv_format(ret)
return ret
4.5 训练流程
def main(args):
cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.apply_overrides(cfg, args.opts)
default_setup(cfg, args) # 默认日志,日志记录基础信息,备份配置文件
if args.eval_only:
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint) # 加载权重
print(do_test(cfg, model))
else:
do_train(args, cfg)
4.6 主函数入口
if __name__ == "__main__":
args = default_argument_parser().parse_args()
launch( # 启动多GPU训练
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank, # 当前节点ID
dist_url=args.dist_url,
args=(args,),
)
5. Tips
- 像python代码一样操作配置文件,将相同的配置独立出来,导入进来,而不是复制多份
- 尽可能的保存配置文件的简洁,不需要的不写