【Detectron2】代码库学习-4. LazyConfig 配置文件

news2024/12/27 11:39:28

(https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html)

目录

    • 1. 配置文件
    • 2. LazyConfig 导入导出
    • 3. 递归实例化
    • 4. 基于LazyConfig的训练步骤
      • 4.1 导入依赖库
      • 4.2 日志初始化
      • 4.3 训练
      • 4.4 评估
      • 4.5 训练流程
      • 4.6 主函数入口
      • 5. Tips

Detectron2是Facebook AI Research(FAIR)推出的基于Pytorch的视觉算法开源框架,主要聚焦于目标检测和分割任务等视觉算法,此外还支持全景分割,关键点检测,旋转框检测等任务。Detectron2继承自Detectron 和mask-rcnn。
Detectron2具有较强的灵活性和可扩展性,支持快速的单GPU训练,多GPU并行训练和多节点分布式训练。

1. 配置文件

Detectron2 原本采用的是基于一种 key-value的基础config 系统, 采用 YAML格式。但是YAML是一种非常受限制的语言,不能存储复杂的数据结构,因此转而使用 一种更强大的 配置文件系统 LazyConfig system。

YAML-维基百科 是一种人类可读、数据可序列化(可保持成文件和重新加载恢复)的语言, JSON 格式也是一种合法的YAML。原始 的YAML只支持编码 标量(字符串,整数,浮点数)和关系数组(map, 字典,hash表)。YAML推荐的后缀名为.yaml

2. LazyConfig 导入导出

直接采用 python 脚本作为 配置文件载体,可以通过 python代码快速操作。支持丰富的数据类型。可以运行简单的函数。通过python的import语法导入导出。
config_test.py

inputs = [1024, 960]  # 输入大小
batch_size = 128
train_dict = {"input": inputs, "batch_size": batch_size}

通过 detectron2 提供的API 加载配置文件。方便获取属性和配置, 但是代码无法补全

from detectron2.config import LazyConfig
cfg=LazyConfig.load("config_test.py")
print(cfg.train_dict.batch_size)  # 方便获取属性和配置, 但是代码无法补全,
LazyConfig.save(cfg, "test.yaml") # 导出配置到yaml文件, 部分无法序列化的数据类型不能保存,如numpy 数组

test.yaml

train_dict:
  batch_size: 128
  input: [1024, 960]

3. 递归实例化

LazyConfig 采用递归实例化 特性,将函数和类的调用表示为字典。在调用时并不会立即执行 对应的函数,只返回一个字典 描述这个 call, 只有在实例化时才真正执行。

from detectron2.config import instantiate, LazyCall
import torch.nn as nn
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)  # 调用nn.Conv2d, 并配置参数
layer_cfg.out_channels = 64   # can edit it afterwards , 修改 参数
layer = instantiate(layer_cfg)  # 实例化对象,创建一个2维卷积层

LazyCall

class LazyCall:
    def __init__(self, target):
        self._target = target
    def __call__(self, **kwargs):
        if is_dataclass(self._target):
            # omegaconf object cannot hold dataclass type
            # https://github.com/omry/omegaconf/issues/784
            target = _convert_target_to_string(self._target)
        else:
            target = self._target
        kwargs["_target_"] = target
        return DictConfig(content=kwargs, flags={"allow_objects": True})

instantiate

def instantiate(cfg):
    """
    Recursively instantiate objects defined in dictionaries 
    """
    from omegaconf import ListConfig, DictConfig, OmegaConf
    if isinstance(cfg, ListConfig):
        lst = [instantiate(x) for x in cfg]  # 递归调用
        return ListConfig(lst, flags={"allow_objects": True})
    if isinstance(cfg, list):
        # Specialize for list, because many classes take
        # list[objects] as arguments, such as ResNet, DatasetMapper
        return [instantiate(x) for x in cfg]

    if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
        return OmegaConf.to_object(cfg)

    if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
        # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
        # but faster: https://github.com/facebookresearch/hydra/issues/1200
        cfg = {k: instantiate(v) for k, v in cfg.items()}
        cls = cfg.pop("_target_")
        cls = instantiate(cls)
        if isinstance(cls, str):
            cls_name = cls
            cls = locate(cls_name)
            assert cls is not None, cls_name
        else:
            try:
                cls_name = cls.__module__ + "." + cls.__qualname__
            except Exception:
                # target could be anything, so the above could fail
                cls_name = str(cls)
        assert callable(cls), f"_target_ {cls} does not define a callable object"
        try:
            return cls(**cfg)  ## 根据c
        except TypeError:
            logger = logging.getLogger(__name__)
            logger.error(f"Error when instantiating {cls_name}!")
            raise
    return cfg  # return as-is if don't know what to do

4. 基于LazyConfig的训练步骤

4.1 导入依赖库

import logging
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
    AMPTrainer, # 自动混合精度训练
    SimpleTrainer, 
    default_argument_parser,
    default_setup, # 默认配置参数
    default_writers,
    hooks,
    launch, # 分布式训练启动器
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm

4.2 日志初始化

logger = logging.getLogger("detectron2")

4.3 训练

def do_train(args, cfg):
    model = instantiate(cfg.model)  # 获取模型
    logger = logging.getLogger("detectron2") 
    logger.info("Model:\n{}".format(model)) 
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer) # 获取优化器

    train_loader = instantiate(cfg.dataloader.train)# 获取训练dataloader

    model = create_ddp_model(model, **cfg.train.ddp) # 并行模型
    # 混合精度训练
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
    checkpointer = DetectionCheckpointer(  # checkpoint 管理
        model,
        cfg.train.output_dir,
        trainer=trainer,
    )
    trainer.register_hooks(  # 注册回调函数
        [
            hooks.IterationTimer(), # 计时器
            hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
            hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
            if comm.is_main_process() # 主进程 周期保存 checkpoint
            else None,
            hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), # 评估
            hooks.PeriodicWriter( # 保存训练日志
                default_writers(cfg.train.output_dir, cfg.train.max_iter),
                period=cfg.train.log_period,
            )
            if comm.is_main_process()
            else None,
        ]
    )

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)  # 初始化或者恢复训练
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)

4.4 评估

def do_test(cfg, model):
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
        )
        print_csv_format(ret)
        return ret

4.5 训练流程

def main(args):
    cfg = LazyConfig.load(args.config_file) 
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args) # 默认日志,日志记录基础信息,备份配置文件

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint) # 加载权重
        print(do_test(cfg, model))
    else:
        do_train(args, cfg)

4.6 主函数入口

if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    launch( # 启动多GPU训练
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank, # 当前节点ID 
        dist_url=args.dist_url,
        args=(args,),
    )

5. Tips

  • 像python代码一样操作配置文件,将相同的配置独立出来,导入进来,而不是复制多份
  • 尽可能的保存配置文件的简洁,不需要的不写

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/14262.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣160 - 相交链表【双指针妙解】

链表也能相交~一、题目描述二、思路分析与罗列三、整体代码展示四、总结与提炼一、题目描述 原题传送门 示例 1: 输入:intersectVal 8, listA [4,1,8,4,5], listB [5,6,1,8,4,5], skipA 2, skipB 3 输出:Intersected at ‘8’ 解释&…

MySQL索引

索引索引的相关概念索引分类索引的底层数据结构及其原理主键索引&二级索引聚集和非聚集索引哈西索引&&自适应哈西索引索引和慢查询日志索引优化索引的相关概念 什么是索引?索引其实就是一个数据结构。当表中的数据量到达几十万甚至上百万的时候&#x…

每个 Flutter 开发者都应该知道的一些原则

“仅仅让代码起作用是不够的。有效的代码经常被严重破坏。仅满足于工作代码的程序员表现得不专业。他们可能担心没有时间改进代码的结构和设计,但我不同意。没有什么比糟糕的代码对开发项目产生更深远、更长期的影响了。” ― Robert C. Martin,Clean Code:敏捷软件工艺手册…

fpga nvme 寄存器

图1所示的NVMe多队列,每个队列支持64K命令,最多支持64K队列。这些队列的设计使得IO命令和对命令的处理不仅可以在同一处理器内核上运行,也可以充分利用多核处理器的并行处理能力。每个应用程序或线程可以有自己的独立队列,因此不需…

基于Nacos的注册中心与配置中心

基于Nacos的注册中心与配置中心 Nacos简介 概述 Nacos全称是动态命名和配置服务,Nacos是一个更易于构建云原生应用的动态服务发现、配置管理和服务管理平台。Nacos主要用于发现、配置和管理微服务。 什么是Nacos Nacos支持几乎所有主流类型的服务的发现、配置和…

同花顺_代码解析_技术指标_A

本文通过对同花顺中现成代码进行解析,用以了解同花顺相关策略设计的思想 目录 ABI AD ADL ADR ADTM ADVOL AMV ARBR ARMS ASI ATR ABI 绝对幅度指标 算法:上涨家数减去下跌家数所得的差的绝对值。 该指标只适用于大盘日线。 行号 1 aa…

题目7飞机票订票系统

题目7飞机票订票系统问题描述:某公司每天有10航班(航班号、价格),每个航班的飞机,共有80个座位, 20排,每排4个位子。编号为A,BCD。如座位号:10D表示10排D座。 运行界面如下: 1)能从键盘录入订票信息:乘客的…

[Games 101] Lecture 13-16 Ray Tracing

Ray Tracing Why Ray Tracing 光栅化不能得到很好的全局光照效果 软阴影光线弹射超过一次(间接光照) 光栅化是一个快速的近似,但是质量较低 光线追踪是准确的,但是较慢 Rasterization: real-time, ray tracing: offline生成一帧…

狗屎一样的面试官,你遇到过几个?

做了几年软件开发,我们都或多或少面试过别人,或者被别人面试过。大家最常吐槽的就是面试造火箭,进厂拧螺丝。今天就来吐槽一下那些奇葩(gou)一样的面试官 A 那是在我刚工作1年的时候,出去面试前端开发。 那…

分布式开源存储架构Ceph概述

概述 k8s的后端存储中ceph应用较为广泛,当前的存储市场仍然是由一些行业巨头垄断,但在开源市场还是有一些不错的分布式存储,其中包括了Ceph、Swift、sheepdog、glusterfs等 什么是ceph? Ceph需要具有可靠性(reliab…

C++11标准模板(STL)- 算法(std::partition_point)

定义于头文件 <algorithm> 算法库提供大量用途的函数&#xff08;例如查找、排序、计数、操作&#xff09;&#xff0c;它们在元素范围上操作。注意范围定义为 [first, last) &#xff0c;其中 last 指代要查询或修改的最后元素的后一个元素。 定位已划分范围的划分点 …

线上崩了?一招教你快速定位问题。

&#x1f44f; 背景 正浏览着下班后去哪家店撸串&#xff0c;结果隔壁组同事囧着脸过来问我&#xff1a;大哥&#xff0c;赶紧过去帮忙看个问题&#xff01;客户反馈很多次了&#xff0c;一直找不出问题出在哪里&#xff01;&#xff01;&#xff01; 我&#xff1a;能不能有…

利用WPS功能破解及本地恢复密码

利用WPS功能破解及本地恢复密码 认识WPS功能 ​ WPS&#xff08;Wi-Fi Protected Setup&#xff09;是Wi-Fi保护设置的英文缩写。WPS是由Wi-Fi联盟组织实施的认证项目&#xff0c;主要致力于简化无线局域网安装及安全性能的配置工作。WPS并不是一项新增的安全性能&#xff0c;它…

数据结构之链表(单链表)

文章目录前言一、链表二、链表的八种结构1.单向或者双向2.带头或者不带头&#xff08;头&#xff1a;哨兵位&#xff09;3.循环或者不循环三、单链表1.接口2.接口的实现1.开辟一个新的节点1.打印单链表2.头插3.尾插4.头删5.尾删6.单链表的查找7.在pos位置之前插入数据8.在pos位…

MySQL8.0概述及新特性

文章目录学习资料常见的数据库管理系统排名&#xff08;DBMS&#xff09;SQL的分类DDL&#xff1a;数据定义语言DML&#xff1a;数据操作语言DCL&#xff1a;数据控制语言MySQL8.0新特性性能优化默认字符集DDL的原子化计算列宽度属性窗口函数公用表表达式索引新特性支持降序索引…

面试了20+前端大厂,整理出的面试题

事件是什么&#xff1f;事件模型&#xff1f; 事件是用户操作网页时发生的交互动作&#xff0c;比如 click/move&#xff0c; 事件除了用户触发的动作外&#xff0c;还可以是文档加载&#xff0c;窗口滚动和大小调整。事件被封装成一个 event 对象&#xff0c;包含了该事件发生…

RabbitMQ Windows 安装、配置、使用 - 小白教程

1、配套文件 下载erlang&#xff1a;http://www.erlang.org/downloads/ 下载RabbitMQ&#xff1a;http://www.rabbitmq.com/download.html 2、RabbitMQ服务端代码是使用并发式语言Erlang编写的&#xff0c;安装Rabbit MQ的前提是安装Erlang&#xff0c;双击otp_win64_21.1.ex…

计算机毕业设计springboot+vue+elementUI汽车车辆充电桩管理系统

项目介绍 随着我国汽车行业的不断发展&#xff0c;电动汽车已经开始逐步的领导整个汽车行业&#xff0c;越来越多的人在追求环保和经济实惠的同时开始使用电动汽车&#xff0c;电动汽车和燃油汽车最大的而不同就是 需要充电&#xff0c;同时我国的基础充电桩也开始遍及了大多数…

Java 异常处理

目录 一、异常的基本概念 二 、为何需要异常处理 三 、异常的处理 四 、异常类的继承架构 五 、抛出异常 5.1、程序中抛出异常 5.2、指定方法抛出异常 六 、自定义异常 不管使用的那种语言进行程序设计&#xff0c;都会产生各种各样的错误。 Java 提供有强大的异常处理…

商业银行普惠金融可持续发展综合能力呈现梯队化,专项领域各有所长

易观分析&#xff1a;普惠金融有别于传统的金融体系&#xff0c;强调构建包容性、公平性的金融服务生态&#xff0c;商业银行提升可持续发展的综合能力需关注五个方面的因素&#xff1a;获客能力上以普惠客群的金融需求为锚点&#xff0c;增强银行服务生态的多样性&#xff0c;…