RT-DETR代码详解(官方pytorch版)——参数配置(1)

news2025/1/13 2:48:06

前言

RT-DETR虽然是DETR系列,但是它的代码结构和之前的DETR系列代码不一样。

它是通过很多的yaml文件进行参数配置,和之前在train.py的parser = argparse.ArgumentParser()去配置所有参数不同,所以刚开始不熟悉代码的时候可能不知道在哪儿修改参数。

RT-DETR有官方版和ultralytics版两个版本代码,可以参考以下链接,分别使用两种方法对代码进行复现:
详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署_rt-dert-CSDN博客

下述内容主要是针对参数配置的代码实现进行解读,因为刚开始我拿着代码都不知道是怎么运行的,模型在哪儿加载参数都找不到

一、train.py文件

在RT-DETR中,train.py文件需要配置的内容很少,因为需要的参数配置全都放在了rtdetr_rxxvd_6x_coco.yml(骨干网络可选)文件中。在这个文件中又包含了其他所有的文件,可以依需修改:

左边是可以选择的backbone骨干网络,后续以ResNet18为例。

二、rtdetr_r18vd_6x_coco.yaml文件


__include__: [
  '../dataset/coco_detection.yml',  # 数据集
  '../runtime.yml', # 运行参数配置
  './include/dataloader.yml', # 定义数据加载器参数
  './include/optimizer.yml', # 定义优化器通用设置
  './include/rtdetr_r50vd.yml', # 定义 RT-DETR 模型的结构参数(如 backbone 和解码器层数等
]


output_dir: ./output/rtdetr_r18vd_6x_coco  # 输出的文件地址

PResNet:
  depth: 18
  freeze_at: -1 # 不冻结任何层(如果设置为正数,则冻结 ResNet 的前几层)
  freeze_norm: False # 不冻结归一化层(如 BatchNorm)
  pretrained: True # 加载预训练权重(通常是基于 ImageNet 数据集的权重)

HybridEncoder:
  in_channels: [128, 256, 512] # 编码器的输入特征通道数,分别对应 ResNet-18 不同尺度的特征图输出
  hidden_dim: 256
  expansion: 0.5 # 特征通道扩展比例


RTDETRTransformer:
  eval_idx: -1 # 指定在哪一层解码器输出进行评估(-1 表示最后一层)
  num_decoder_layers: 3 # 解码器的层数
  num_denoising: 100  # 去噪查询的数量



optimizer:
  type: AdamW # 该优化器改进了 Adam,支持权重衰减以减轻过拟合
  params:  # 参数分组,针对不同模块的参数设置不同的学习率和权重衰减
    - 
      params: '^(?=.*backbone)(?=.*norm).*$'      # 匹配骨干网络中的归一化层参数,设置较低学习率和无权重衰减
      lr: 0.00001
      weight_decay: 0.
    - 
      params: '^(?=.*backbone)(?!.*norm).*$'      # 匹配骨干网络中非归一化参数
      lr: 0.00001
    - 
      params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bias)).*$'   # 匹配 Transformer 中归一化层或偏置参数
      weight_decay: 0.

  lr: 0.0001
  betas: [0.9, 0.999] # Adam 优化器的 beta 参数
  weight_decay: 0.0001 # 权重衰减值

上面的注释只是为了解释各行代码意思,但是运行代码过程中,yaml文件不能有注释,否则会报错:

三、yaml_config.py文件

 在train.py文件中,实际是通过YAMLConfig()这个类读取rtdetr_r18vd_6x_coco.yaml中的配置信息。通过加载 YAML 配置文件,将不同的模型、优化器、数据加载器等组件以模块化的方式创建

 主要功能

1. 动态加载 YAML 配置文件

  • 使用 load_config 函数加载 YAML 文件,读取其中的配置数据。
  • 支持通过 merge_dict 将命令行或其他来源的参数覆盖 YAML 文件中的默认配置。

2. 组件动态创建

  • 根据 YAML 文件的配置,动态创建模型(model)、损失函数(criterion)、优化器(optimizer)、学习率调度器(lr_scheduler)和数据加载器(dataloader)等。

3. 参数分组和正则匹配

  • 支持为优化器指定不同模块的参数组,并通过正则表达式选择分组的参数。

4. 支持扩展功能

  • 支持 EMA(Exponential Moving Average,指数滑动平均) 和 AMP(Automatic Mixed Precision,自动混合精度)
  • 自动处理模型参数的冻结、梯度裁剪等功能。

5. 模块化设计

  • 配置组件通过 create 函数动态实例化,便于扩展和自定义。

3.1 类初始化与加载配置

class YAMLConfig(BaseConfig):
    def __init__(self, cfg_path: str, **kwargs) -> None:
        super().__init__()
        cfg = load_config(cfg_path)  # 加载 YAML 配置文件
        merge_dict(cfg, kwargs)  # 合并外部输入的参数(高优先级)

        self.yaml_cfg = cfg  # 保存解析后的 YAML 配置

        # 一些常见配置的提取
        self.log_step = cfg.get('log_step', 100)
        self.checkpoint_step = cfg.get('checkpoint_step', 1)
        self.epoches = cfg.get('epoches', -1)
        self.resume = cfg.get('resume', '')
        self.tuning = cfg.get('tuning', '')
        self.sync_bn = cfg.get('sync_bn', False)
        self.output_dir = cfg.get('output_dir', None)
        self.use_ema = cfg.get('use_ema', False)
        self.use_amp = cfg.get('use_amp', False)
        self.autocast = cfg.get('autocast', dict())
        self.find_unused_parameters = cfg.get('find_unused_parameters', None)
        self.clip_max_norm = cfg.get('clip_max_norm', 0.0)
  • 功能
    • 从 YAML 配置文件中加载配置,初始化训练流程中常用的参数。
    • cfg_path:YAML 配置文件路径。
    • kwargs:支持通过外部传入参数(如命令行参数)覆盖 YAML 中的默认配置
    • 使用 get 方法设置默认值,避免配置文件缺失某些字段时程序报错。

 3.1.1 yaml_config.py文件

  通过cfg = load_config(cfg_path)已经将所有的配置信息传递给cfg了

尽管传入的只有一个rtdetr_r18vd_6x_coco.yaml文件,但它里面包含了其他的配置文件地址:

load_config()函数在yaml_utils.py文件中


def load_config(file_path, cfg=dict()):
    """
    加载 YAML 配置文件,并支持递归加载包含的其他 YAML 文件。
    Args:
        file_path (str): 要加载的 YAML 文件路径。
        cfg (dict): 全局配置字典,默认为空字典。
    Returns:
        dict: 加载并合并后的配置字典。
    """
    # 获取文件扩展名并确保是 YAML 文件
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "仅支持 YAML 文件(.yml 或 .yaml)"

    # 打开并加载 YAML 文件
    with open(file_path, 'r') as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)
        if file_cfg is None:
            return {}  # 如果文件为空,则返回空字典

    # 检查是否需要加载包含的 YAML 配置(递归加载)
    if INCLUDE_KEY in file_cfg:
        # 提取 'include' 键的值,通常是其他 YAML 文件路径的列表
        base_yamls = list(file_cfg[INCLUDE_KEY])
        for base_yaml in base_yamls:
            # 将路径展开为完整路径(支持用户目录 ~ 和相对路径)
            if base_yaml.startswith('~'):
                base_yaml = os.path.expanduser(base_yaml)
            if not base_yaml.startswith('/'):  # 如果是相对路径
                base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)

            # 递归加载被包含的 YAML 文件
            base_cfg = load_config(base_yaml, cfg)
            # 合并当前加载的配置到全局配置中
            merge_config(base_cfg, cfg)

    # 最终合并当前文件的配置到全局配置中
    return merge_config(file_cfg, cfg)

  • 通过 include 字段,可以将配置拆分成多个 YAML 文件,便于管理和维护。
  • 支持递归加载多个 YAML 文件,并通过 merge_config 实现配置合并,确保最终配置完整。

  

 3.2 动态加载组件(如模型、优化器等)

 通 @property 装饰器延迟加载组件,仅在实际使用时创建对象

@property装饰器

是 Python 的一个内置装饰器,常用于定义一个类的方法,并将其伪装成“属性”。

  1. 保护类的封装特性
  2. 让开发者可以使用“对象.属性”的方式操作操作类属性

通过 @property 装饰器,可以直接通过方法名来访问方法,不需要在方法名后添加一对“()”小括号。

语法格式:

@property
def 方法名(self)
    代码块

更多@property装饰器内容可看,其中包含延时加载的应用:@property装饰器-CSDN博客

 3.2.1 模型加载

@property
def model(self) -> torch.nn.Module:
    if self._model is None and 'model' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        self._model = create(self.yaml_cfg['model'])  # 动态创建模型
    return self._model
  • 检查 _model 是否已经创建,若未创建且配置中包含 model 字段,则动态创建模型。(self.yaml_cfg已经存储了所有的配置信息,见3.1.1 图,提取model键的值)
  • 使用 create 函数按照 yaml_cfg['model'] 中的定义实例化模型。

在rtdetr_r18vd_6x_coco.yml--->./include/rtdetr_r50vd.yml中 :

3.2.2 优化器延迟加载

@property
def optimizer(self):
    if self._optimizer is None and 'optimizer' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)  # 获取参数分组
        self._optimizer = create('optimizer', params=params)  # 动态创建优化器
    return self._optimizer
  • 获取优化器参数分组(get_optim_params),根据配置动态创建优化器实例。

3.2.3  学习率调度器加载

@property
def lr_scheduler(self):
    if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
        print('Initial lr: ', self._lr_scheduler.get_last_lr())
    return self._lr_scheduler
  • 动态创建学习率调度器对象,并与优化器绑定

在rtdetr_r18vd_6x_coco.yml--->./include/optimizer.yml中 :

基于MultiStepLR生成对应的学习率调度器

  • MultiStepLR 是 PyTorch 中 torch.optim.lr_scheduler 提供的一种学习率调度器
  • 它会在指定的训练步骤(milestones)调整学习率

根据配置,初始学习率为 0.1在第 1000 步时,学习率会乘以 gamma=0.1,变为 0.01。输出如下:

Step 0: Learning Rate = 0.1
Step 500: Learning Rate = 0.1
Step 1000: Learning Rate = 0.01
Step 1500: Learning Rate = 0.01

3.3 数据加载器

@property
def train_dataloader(self):
    if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._train_dataloader = create('train_dataloader')
        self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
    return self._train_dataloader
  • 动态加载训练数据加载器,并根据配置调整 shuffle 参数

3.4 参数分组(正则表达式匹配)

@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
    '''
    E.g.:
        ^(?=.*a)(?=.*b).*$         means including a and b
        ^((?!b.)*a((?!b).)*$       means including a but not b
        ^((?!b|c).)*a((?!b|c).)*$  means including a but not (b | c)
    '''
    assert 'type' in cfg, ''
    cfg = copy.deepcopy(cfg)

    if 'params' not in cfg:
        return model.parameters()  # 如果未定义参数分组,返回默认模型参数

    assert isinstance(cfg['params'], list), ''

    param_groups = []
    visited = []
    for pg in cfg['params']:
        pattern = pg['params']
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
        pg['params'] = params.values()
        param_groups.append(pg)
        visited.extend(list(params.keys()))

    names = [k for k, v in model.named_parameters() if v.requires_grad]

    if len(visited) < len(names):
        unseen = set(names) - set(visited)
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
        param_groups.append({'params': params.values()})
        visited.extend(list(params.keys()))

    assert len(visited) == len(names), ''
    return param_groups
  • 根据正则表达式匹配模型中的参数(named_parameters 方法返回 <参数名, 参数> 的映射)。
  • 支持按模块或特定规则分组优化器参数(如设置不同学习率、权重衰减)。
  • 未匹配的参数会自动归为默认组。

  • ^(?=.*backbone)(?=.*norm).*$:匹配键名中包含 backbone 和 norm 的参数。
  • ^(?=.*encoder)(?!.*bias).*$:匹配键名中包含 encoder 且不包含 bias 的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2275763.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

细说STM32F407单片机以DMA方式读写外部SRAM的方法

目录 一、工程配置 1、时钟、DEBUG、GPIO、CodeGenerator 2、USART3 3、NVIC 4、 FSMC 5、DMA 2 &#xff08;1&#xff09;创建MemToMem类型DMA流 &#xff08;2&#xff09;开启DMA流的中断 二、软件设计 1、KEYLED 2、fsmc.h、fsmc.c、dma.h、dma.c 3、main.h…

Proteus-8086调试汇编格式的一点心得

这阵子开始做汇编的微机实验&#xff08;微机原理与接口技术题解及实验指导&#xff0c;吴宁版本13章&#xff09;&#xff0c;中间出了挺多问题&#xff0c;解决后记录下。 先上电路图 用子电路来仿真发现仿真的时候子电路这块根本没有高低电平输出&#xff0c;只好把子电路拿…

FreeROTS学习 内存管理

内存管理是一个系统基本组成部分&#xff0c;FreeRTOS 中大量使用到了内存管理&#xff0c;比如创建任务、信号量、队列等会自动从堆中申请内存&#xff0c;用户应用层代码也可以 FreeRTOS 提供的内存管理函数来申请和释放内存 FreeRTOS 内存管理简介 FreeRTOS 创建任务、队列…

【西北工业大学主办 | EI检索稳定 | 高H值专家与会报告】2025年航天航空工程与材料技术国际会议(AEMT 2025)

2025 年航天航空工程与材料技术国际会议&#xff08;AEMT 2025&#xff09;将于2025年2月28日至3月2日在中国天津召开。本届会议由西北工业大学主办&#xff0c;由北京航空航天大学、北京理工大学作为支持单位加入&#xff0c;AEIC 学术交流中心协办。 AEMT 2025 旨在汇聚来自全…

目标检测跟踪中的Siamese孪生网络与普通卷积网络(VGG、ResNet)有什么区别?

1、什么是Siamese网络&#xff1f; Siamese网络又叫孪生网络&#xff0c;是一种特殊的神经网络架构&#xff0c;由一对&#xff08;或多对&#xff09;共享参数的子网络组成&#xff0c;用于学习输入样本之间的相似性或关系。最早在 1994 年由 Bromley 等人提出&#xff0c;最…

网络攻击行为可视化分析系统【数据分析 + 可视化】

一、系统背景 随着信息技术的快速发展&#xff0c;网络已成为现代社会不可或缺的一部分。然而&#xff0c;与此同时&#xff0c;网络攻击手段也日益多样化和复杂化&#xff0c;给企业和个人的信息安全带来了极大的威胁。传统的网络攻击分析方法往往依赖于人工分析和处理大量的…

一个运行在浏览器中的开源Web操作系统Puter本地部署与远程访问

文章目录 前言1.关于Puter2.本地部署Puter3.Puter简单使用4. 安装内网穿透5.配置puter公网地址6. 配置固定公网地址 &#x1f4a1; 推荐 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击跳转到网站…

C语言 操作符_位操作符、赋值操作符、单目操作符

1.位操作符 & - 按&#xff08;2进制&#xff09;位与 | - 按&#xff08;2进制&#xff09;位或 ^ - 按&#xff08;2进制&#xff09;位异或 只适用于整型 例&#xff1a;实现交换两个变量的值&#xff0c;要求不能新建变量 //3^3 0 -> a^a 0 //011 //011 //000 …

图像处理 | 图像二值化

在图像处理领域&#xff0c;图像二值化是一个重要的操作&#xff0c;它将彩色或灰度图像转换为只有两种颜色&#xff08;通常是黑白&#xff09;的图像。二值化广泛应用于文字识别、图像分割、边缘检测等领域&#xff0c;尤其在处理简洁和高对比度的图像时非常有效。本文将深入…

IP 地址与蜜罐技术

基于IP的地址的蜜罐技术是一种主动防御策略&#xff0c;它能够通过在网络上布置的一些看似正常没问题的IP地址来吸引恶意者的注意&#xff0c;将恶意者引导到预先布置好的伪装的目标之中。 如何实现蜜罐技术 当恶意攻击者在网络中四处扫描&#xff0c;寻找可入侵的目标时&…

Web基础之什么是HTTP协议

Q&#xff1a;什么是HTTP协议&#xff1f; 概念&#xff1a;Hyper Text Transfer Protocol&#xff0c;超文本传输协议&#xff0c;规定了浏览器和服务器之间数据传输的规则。 特点&#xff1a; 1&#xff0e;基于TCP协议&#xff1a;面向连接&#xff0c;安全 2&#xff0e;基…

#渗透测试#谷歌扩展学习#编写一个属于自己的谷歌扩展

目录 一、Chrome扩展程序是什么 二、如何自己编写一个简单谷歌扩展 1. 创建项目文件夹 2. 创建 manifest.json 文件 3. 创建 popup.html 文件 4. 创建 popup.js 文件 5. 加载扩展程序到Chrome浏览器 6. 测试扩展程序 三、Chrome插件图标设计技巧 1. 简洁明了 2. 独特…

LayerNorm的思考

文章目录 1. LayerNorm2. 图解3. softmax4. python 代码 1. LayerNorm y x − E [ x ] v a r ( x ) ϵ ∗ γ β \begin{equation} y\frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{var}(x)\epsilon}}*\gamma\beta \end{equation} yvar(x)ϵ ​x−E[x]​∗γβ​​ 2. 图解 矩阵A …

ExplaineR:集成K-means聚类算法的SHAP可解释性分析 | 可视化混淆矩阵、决策曲线、模型评估与各类SHAP图

集成K-means聚类算法的SHAP可解释性分析 加载数据集并训练机器学习模型 SHAP 分析以提取特征对预测的影响 通过混淆矩阵可视化模型性能 决策曲线分析 模型评估&#xff08;多指标和ROC曲线的目视检查&#xff09; 带注释阈值的 ROC 曲线 加载 SHAP 结果以进行下游分析 与…

Kafka 会丢消息吗?

目录 01 生产者(Producer) 02 消息代理(Broker) 03 消费者(Consumer) 来源:Kafka 会丢消息吗? Kafka 会丢失信息吗? 许多开发人员普遍认为,Kafka 的设计本身就能保证不会丢失消息。然而,Kafka 架构和配置的细微差别会导致消息的丢失。我们需要了解它如何以及何时…

Open FPV VTX开源之第一次出图

Open FPV VTX开源之第一次出图 1. 源由2. 连线2.1 飞控2.2 调试 3. serial3.1 启动log - uboot3.2 登录版本 - linux3.3 获取有线IP 4. ssh - linux5. PixelPilot出图6. 总结7. 参考资料8. 补充 - 8812AU网卡 1. 源由 在《Open FPV VTX开源之硬件规格及组成》章节中&#xff0…

仓颉笔记——写一个简易的web服务并用浏览器打开

创建一个web服务端&#xff0c;同时创建一个客户端去读取这个服务端。 也满足浏览器打开web的需求。 直接上代码。 import net.http.* import std.time.* import std.sync.* import std.log.LogLevel// 1. 构建 Server 实例 let server ServerBuilder().addr("127.0.0.1&…

Trie树算法

Trie树&#xff0c;也称为前缀树或字典树&#xff0c;是一种特殊的树型数据结构。它用于存储一组字符串&#xff0c;使得查找、插入和删除字符串的操作非常高效。类似这种&#xff0c; 模板&#xff1a; 这是用数组来模拟上图中的树的结构&#xff0c;逻辑上和上图结构一致。 …

03-51单片机定时器和串口通信

一、51单片机定时器 1.定时器介绍 1.1为什么要使用定时器 在前面的学习中&#xff0c;用到了 Delay 函数延时&#xff0c;这里学习定时器以后&#xff0c;就可以通过定时器来完成&#xff0c;当然定时器的功能远不止这些&#xff1a; 51 单片机的定时器既可以定时&#xff…

搭建docker私有化仓库Harbor

Docker私有仓库概述 Docker私有仓库介绍 Docker私有仓库是个人、组织或企业内部用于存储和管理Docker镜像的存储库。Docker默认会有一个公共的仓库Docker Hub,而与Docker Hub不同,私有仓库是受限访问的,只有授权用户才能够上传、下载和管理其中的镜像。这种私有仓库可以部…