[拆轮子] PaddleDetection中__shared__、__inject__ 和 from_config 三者分别做了什么

news2024/7/11 11:15:08

在上一篇中,PaddleDetection Register装饰器到底做了什么
https://blog.csdn.net/HaoZiHuang/article/details/128668393

已经介绍了 __shared____inject__ 的作用:

  • __inject__ 表示引入全局字典中已经封装好的模块。如loss等。
  • __shared__为了实现一些参数的配置全局共享,这些参数可以被backbone, neck,head,loss等所有注册模块共享。

PaddleDetection 文档是这么说的,可是我还是不太懂。于是看了下源码,建议先看上边那篇文章,里边写了在哪部分 __inject__ 列表 和 __shared__列表被读取的。

标题中的三者都是在 ppdet/core/workspace.pycreate 函数使用的,create 函数用于创建已经被 Register装饰的注册过的类

1. __shared__ 部分

在 create 函数中先进行有效性检验, cls_or_name 可以是类别名称的字符串,也可以是已经写好的类,但在 PaddleDetection 当前版本内容,大概率只是字符串

    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
        	# 如果 cls_or_name 这个类已经注册,则 global_config.values 元素是 SchemaDict
            pass
            
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            # 如果有 __dict__ 则直接返回hhhh( 当前版本用的不多 )
            return global_config[name]
            
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

之后解析 __shared__ 列表中的内容

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and \
                   not isinstance(target_key, SharedConfig):
                continue  # 如果当前当前 target_key 不是SharedConfig, 那么参数已被传入
			
			# 
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

而之后的几行如果在全局配置过,比如这样:
在这里插入图片描述
则读取全局配置的内容

2. from_config 部分

之后执行:

    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

由于 backbone neck head 之间的配置可能存在耦合,于是部分类实例化时,可能需要之前模块的配置,所以要在 architecture 初始化时,创建 neck head 之类的

给个例子看吧,transformer 和 detr_head 创建时除了读取之前 config 的内容,也传入了来自前置模块的内容

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])
        # transformer
        kwargs = {'input_shape': backbone.out_shape}
        transformer = create(cfg['transformer'], **kwargs)
        # head
        kwargs = {
            'hidden_dim': transformer.hidden_dim,
            'nhead': transformer.nhead,
            'input_shape': backbone.out_shape
        }
        detr_head = create(cfg['detr_head'], **kwargs)

        return {
            'backbone': backbone,
            'transformer': transformer,
            "detr_head": detr_head,
        }

3. __inject__ 部分

__inject__ 部分其实与 from_config 很像,都是将类实例化为对象,来看一小部分

在这里插入图片描述

k'loss',之前在 __inject__ 列表中
target_key'DETRLoss' 是一个字符串

	target_key = config[k]
	......
	
    elif isinstance(target_key, str):
        if target_key not in global_config:
            raise ValueError("Missing injection config:", target_key)
        target = global_config[target_key]
        if isinstance(target, SchemaDict):
            cls_kwargs[k] = create(target_key)   # 在此处将类实例化
        elif hasattr(target, '__dict__'):  # serialized object
            cls_kwargs[k] = target

可以看到 from_config 是由于组件之间存在参数耦合,要在前者创建完毕后,将部分参数传给后者,所以要借助 create API 手动实例化

__inject__ 的使用很简单,只许在 __inject__ 中指定对应的参数即可,如上图中指定了 loss 部分,而 loss 参数是 DETRLoss,于是 loss 传入后是一个 实例化的 DETRLoss 对象

4. 附录 create 函数源码

def create(cls_or_name, **kwargs):
    """
    Create an instance of given module class.

    Args:
        cls_or_name (type or str): Class of which to create instance.

    Returns: instance of type `cls_or_name`
    """
    assert type(cls_or_name) in [type, str
                                 ], "should be a class or name of a class"
    name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
    if name in global_config:
        if isinstance(global_config[name], SchemaDict):
            pass
        elif hasattr(global_config[name], "__dict__"):
            # support instance return directly
            return global_config[name]
        else:
            raise ValueError("The module {} is not registered".format(name))
    else:
        raise ValueError("The module {} is not registered".format(name))

    config = global_config[name]
    cls = getattr(config.pymodule, name)
    cls_kwargs = {}
    cls_kwargs.update(global_config[name])

    # parse `shared` annoation of registered modules
    if getattr(config, 'shared', None):
        for k in config.shared:
            target_key = config[k]
            shared_conf = config.schema[k].default
            assert isinstance(shared_conf, SharedConfig)
            if target_key is not None and not isinstance(target_key,
                                                         SharedConfig): # 如果我指定则就传入指定的
                continue  # value is given for the module
            elif shared_conf.key in global_config:
                # `key` is present in config
                cls_kwargs[k] = global_config[shared_conf.key]  # 必须在全局设置! __shared__ (num_classes之类的)
            else:
                cls_kwargs[k] = shared_conf.default_value       # 否则就搞默认的

    # parse `inject` annoation of registered modules
    if getattr(cls, 'from_config', None):
        cls_kwargs.update(cls.from_config(config, **kwargs))

    if getattr(config, 'inject', None):
        for k in config.inject:
            target_key = config[k]
            # optional dependency
            if target_key is None:
                continue

            if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
                if 'name' not in target_key.keys():
                    continue
                inject_name = str(target_key['name'])
                if inject_name not in global_config:
                    raise ValueError(
                        "Missing injection name {} and check it's name in cfg file".
                        format(k))
                target = global_config[inject_name]
                for i, v in target_key.items():
                    if i == 'name':
                        continue
                    target[i] = v
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(inject_name)
            elif isinstance(target_key, str):
                if target_key not in global_config:
                    raise ValueError("Missing injection config:", target_key)
                target = global_config[target_key]
                if isinstance(target, SchemaDict):
                    cls_kwargs[k] = create(target_key)
                elif hasattr(target, '__dict__'):  # serialized object
                    cls_kwargs[k] = target
            else:
                raise ValueError("Unsupported injection type:", target_key)
    # prevent modification of global config values of reference types
    # (e.g., list, dict) from within the created module instances
    #kwargs = copy.deepcopy(kwargs)
    return cls(**cls_kwargs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/165974.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

excel函数技巧:函数TEXT七助数据大变身

如果函数有职业,那各函数的职业会是什么呢?别的先不说,就拿TEXT而言,它可以让日期变数字、数字变日期、阿拉伯数字变大写中文数字、金额元变万元,连IF的条件判断它也可以变出来…这简直就是当之无愧的变装女皇啊&#…

从0到1完成一个Node后端(express)项目(三、写接口、发起请求)

往期 从0到1完成一个Node后端(express)项目(一、初始化项目、安装nodemon) 从0到1完成一个Node后端(express)项目(二、下载数据库、navicat、express连接数据库) 写接口 我们看ex…

关于Linux部署Tomcat的访问问题

文章目录1.问题2.排除问题2.1检查Tomcat是否启动2.2检查防火墙&端口3.其他可能的问题3.1java的配置问题3.2可能出现了端口占用问题1.问题 在CentOS7系统的主机中配置好了Tomcat后发现通过默认端口无法访问到(http://xx:xx:xx:xx:8080) 2.排除问题 …

C语言在杨氏矩阵中找一个数

这道题大家都会做,使用暴力算法遍历整个数组。但是题目要求时间复杂度小于O(n),这样做显然不合题意,所以,通过分析杨氏矩阵的特点,我们发现矩阵右上角的那个数为一行中最大的,一列中…

SAP MM 新建移动类型(Movement Type)

一、概念 物料的移动类型(Movement Type)代表了货物的移动,当一个物料做某种移动时,便开始了如下一系列事件: 1、一个物料凭证会被创建,可以被用来作为移动的证明及作为其它任何相关应用的一个信息来源&am…

Jetson nano 入手系列之6—使用qt creator 开发c++ opencv+CSI摄像头人脸检测

Jetson nano 入手系列之6—使用qt creator 开发c opencvCSI摄像头人脸检测1.创建摄像头人脸检测项目1.1 创建并配置项目1.2 编辑文件1.2.1 main.cpp1.2.2 CMakeLists.txt2.构建及编译2.1 直接使用qt creator完成2.2 使用命令行参考文献本系列针对亚博科技jetson nano开发板。 …

一篇文章带你学会MySQL数据库的基本管理

目录 前言 一、数据库的介绍 二、mariadb的安装 三、数据库的开启及安全初始化 四、数据库的基本管理 五、数据库密码更改及破解 六、用户授权 七、数据库的备份 八、phpmyadmin的安装 总结 前言 什么是数据库? 每个人家里都会有衣柜,衣柜是…

前端效果积累 | 酷炫、实用3D地球路径飞行效果实现

📌个人主页:个人主页 ​🧀 推荐专栏:前端开发成神之路 --【这是一个为想要入门和进阶前端开发专门开启的精品专栏!从个人到商业的全套开发教程,实打实的干货分享,确定不来看看? &…

【C语言进阶】自定义类型之结构体

目录一:结构体1.1:结构的基础知识: 1.2:结构的声明: 1.3:特殊声明(匿名结构体): 1.4:结构的自引用: 1.5:结构体变量的定义和初始化&am…

springboot 项目自定义log日志文件提示系统找不到指定的文件

自己尝试搭建了一个springboot项目,自定义了log日志文件,启动后报错 Logging system failed to initialize using configuration from logback-spring.xml java.io.FileNotFoundException: E:\code_demo\xxxx\logback-spring.xml (系统找不到指定的文件…

Elasticsearch(二)--Elasticsearch客户端讲解

一、前言 在上一章我们大致了解了下elasticsearch,虽说上次的内容全是八股文,但是很多东西还是非常有用的,这些哪怕往小说作为面试,往大说是可以帮你很快的理解es是个什么玩意儿,所以还是非常推荐大家去看一下上一章内容。 这一章…

【C++】map和set的使用

​🌠 作者:阿亮joy. 🎆专栏:《吃透西嘎嘎》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录👉关联式容…

码二哥的技术专栏 总入口

已发表的技术专栏(订阅即可观看所有专栏) 0  grpc-go、protobuf、multus-cni 技术专栏 总入口 1  grpc-go 源码剖析与实战  文章目录 2  Protobuf介绍与实战 图文专栏  文章目录 3  multus-cni   文章目录(k8s多网络实现方案) 4  gr…

JVM整理笔记之测试工具JCStress的使用及其注解的应用

文章目录前言如何使用JCStress测试代码JCStress 注解说明前言 如果要研究高并发,一般会借助高并发工具来进行测试。JCStress(Java Concurrency Stress)它是OpenJDK中的一个高并发测试工具,它可以帮助我们研究在高并发场景下JVM&a…

RecyclerView 倒计时和正计时方案

本章内容一.方案制定二.设计三.编码相信不少同学都会在这里栽跟头,在思考这个问题设计了两套方案,而我的项目需求中需要根据业务是否反馈来进行倒计时和正计时的操作。一.方案制定 1.在Adapter中使用CountDownTimer 2.修改数据源更新数据 3.只修改页面展…

leetcode--各种数据结构相关的题

数据结构1.数组(1)找到所有数组中消失的数字(448)(2)旋转图像(48)(3)搜索二维矩阵 II(240)(4)最多能完成排序的块(769)2.栈和队列(1)用栈实现队列(232)&#…

“链引擎”(PBC)计划 | 太保集团长安链应用展示

引言 长安链“链引擎”计划(Powered by Chainmaker)(简称:PBC计划)是由长安链生态联盟发起的一项应用赋能计划,旨在以长安链技术体系为核心支撑,汇聚产业各方力量,为应用方提供技术、品牌、生态等支持&…

面试干货!初级软件测试面试题及答案题库一起奉上

软件测试工程师面试通常要经历技术面以及HR面,HR面一般都是日常问题,面试人可以临场发挥过去,但关乎岗位职责的技术面,可就没那么容易了,尤其是对于很多初次去面试测试岗位的没有任何测试岗位面试经验的转行人员&#…

Java并发面试题

基础知识 并发编程的优缺点 为什么要使用并发编程(并发编程的优点) 充分利用多核CPU的计算能力:通过并发编程的形式可以将多核CPU的计算能力发挥到极致,性能得到提升方便进行业务拆分,提升系统并发能力和性能&#…

【网络安全】内网介绍+windows信息收集(含命令)

目录 前言 一、内网渗透测试是什么? 1.介绍 2.内外网区别 3.工作组是什么? 4.域是什么? 5.域的知识点 6.活动目录 7.活动目录主要功能 8.域权限 二、windows信息收集 (1)系统信息 (2&#xff0…