RTMDet原理与代码解析

news2024/11/15 10:58:23

paper:RTMDet: An Empirical Study of Designing Real-Time Object Detectors

official implementation:https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet

本文的创新点

Backbone and Neck

  1. 在backbone的basic building block中采用large-kernel depth-wise convolution,提高了模型捕获全局上下文的能力。
  2. 直接添加depth-wise卷积会增加模型深度,减慢推理速度。因此通过减小building block的个数来减小模型深度,同时通过增加模型宽度来补偿模型容量。
  3. 作者还观察到,在neck部分设置更多参数,使其容量与backbone兼容,可以实现更好的速度-精度的平衡。

论文中提到的增加模型宽度,以及通过增加neck部分的expand ratio来增加neck部分的参数在代码中都没有看到。

block的结构设计中采用大核深度可分离卷积,而没有采用重参数化方法的原因:其它YOLO系列中结构重参数化技术被广泛使用,但有一些副作用,比如训练速度变慢,增加训练占用显存,以及在量化到较低的比特后增加了误差间隙,这需要通过重参数化优化器或量化感知训练来补偿。

Head

  1. 不同尺度的head之间共享卷积参数,但BN层独立。

Label Assignment and Loss

提出在计算matching cost时使用soft label来扩大高质量匹配和低质量配置之间的差异,从而稳定训练加速收敛。基于SimOTA进行的改动。

  1. 分类损失引入soft label,就是GFL
  2. 回归损失添加log,增大了低质量匹配的cost,增大了高质量匹配和低质量匹配之间的差异。
  3. center损失采用soft center region cost。

Data Augmentation

cross-sample数据增强比如MixUp和CutMix有两个缺点:(1)每个iteration需要load多张图片,增加了data loading cost减慢了训练速度。(2)生成的样本带有噪声有可能不属于真实数据的实际分布,影响模型的训练。

  1. 引入caching mechanism改进MixUp和CutMix。
  2. 对于第二点,YOLOX通过使用两阶段的训练策略,第一阶段使用强数据增强,第二阶段使用弱数据增强,由于第一阶段的强数据增强包括随机旋转和剪切,导致输入和转换后的box之间有错位,YOLOX在第二阶段增加L1损失来微调回归分支
    为了解耦数据增强和损失函数的使用,本文去除了这些数据增强,在280个epoch的第一个阶段将混合图片的数量增加到8个来补偿数据增强的强度。在最后20个epoch中,切换到Large Scale Jittering,从而在一个与真实数据分布更更一致的doman中对模型进行微调。

Training Strategy

  1. 为了稳定训练优化器采用AdamW,这个在卷积目标检测模型中很少使用,但在vision transformer中是default。

方法介绍 & 代码解析

骨干网络部分用depth-wise卷积增加了网络深度,因此作者减少了第2个和第3个stage中block的数量,如下表所示,block数量由9减到6延迟降低了20%,但AP也降低了0.5,为了弥补精度的所示,作者在每个stage的最后添加了一个channel attention,实现了更好的精度-速度的权衡。

以RTMDet-s为例,deepen_factor=0.33, widen_factor=0.5,因此每个stage的block数量变成了1-2-2-1。stage2的结构如下

Head部分共享卷积参数,但BN独立。官方实现中在定义head时是分开的,最后将每个head的卷积就赋值为相同。 

if self.share_conv:
    for n in range(len(self.prior_generator.strides)):
        for i in range(self.stacked_convs):
            self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
            self.reg_convs[n][i].conv = self.reg_convs[0][i].conv

# print(self.cls_convs[n][0])  # 共享conv,但BN独立
ConvModule(
  (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): SyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activate): SiLU(inplace=True)
)

标签分配的实现在dynamic_soft_label_assigner.py中,基于SimOTA,只不过对cost的计算进行了改进,具体如下,其值为分类cost、回归cost、区域cost的加权和

cost_matrix = soft_cls_cost + iou_cost + soft_center_prior

 其中分类cost采用了Generalized focal loss中的quality focal loss

pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes)
soft_label = gt_onehot_label * pairwise_ious[..., None]
scale_factor = soft_label - valid_pred_scores.sigmoid()
soft_cls_cost = F.binary_cross_entropy_with_logits(
    valid_pred_scores, soft_label,
    reduction='none') * scale_factor.abs().pow(2.0)
soft_cls_cost = soft_cls_cost.sum(dim=-1)

其中 \(Y_{soft}\) 是预测框与gt框之间的IoU,作为soft label取来原始的标签1。

当用IoU以及相关变体作为回归损失时,最佳匹配和最差匹配之间的差值小于1,这使得区分高质量匹配和低质量匹配变得困难,因此作者使用IoU的对数作为回归的代价,这增大了低质量匹配即IoU较小匹配的cost

iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight

至于区域代价 \(C_{region}\),和FCOS、YOLOX等采用的fixed center prior方法不同,本文采用了一种soft center region cost来稳定dynamic cost的匹配

distance = (valid_prior[:, None, :2] - gt_center[None, :, :]
            ).pow(2).sum(-1).sqrt() / strides[:, None]
soft_center_prior = torch.pow(10, distance - self.soft_center_radius)

数据增强部分,Mosaic和MixUp中引入缓存机制。这一部分实现在mmdet/datasets/transforms/transforms.py中。

在mmdet中要使用Mosaic,需要同时使用MultiImageMixDataset。原本results字典中保存的是一张图的相关信息包括img、gt_bboxes、gt_labels等,在MultiImageMixDataset类中调用Mosaic类中的get_indexes方法,随机再挑出其它三张图的索引。然后将这3张图的信息放到列表中作为key 'mix_results'的value加到原始的results中,这样results就包含了4张图的信息。

而在CachedMosaic中,是维护了一个缓存列表self.results_cache,max_cached_images指定最大缓存数量,默认为40,作者指出10张缓存就可以满足随机的要求了。

def __init__(self,
             *args,
             max_cached_images: int = 40,
             random_pop: bool = True,
             **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.results_cache = []
    self.random_pop = random_pop
    assert max_cached_images >= 4, 'The length of cache must >= 4, ' \
                                   f'but got {max_cached_images}.'
    self.max_cached_images = max_cached_images

在原始mosaic中,每个iteration需要从整个训练集中随机挑选3张与当前张进行combine,而在cachedmosaic中是从cache中挑选3张,因此挑选索引的原始大小不同,如下

# mosaic
def get_indexes(self, dataset: BaseDataset) -> int:
    """Call function to collect indexes.

    Args:
        dataset (:obj:`MultiImageMixDataset`): The dataset.

    Returns:
        list: indexes.
    """

    indexes = [random.randint(0, len(dataset)) for _ in range(3)]
    return indexes

# cachedmosaic
def get_indexes(self, cache: list) -> list:
    """Call function to collect indexes.

    Args:
        cache (list): The results cache.

    Returns:
        list: indexes.
    """

    indexes = [random.randint(0, len(cache) - 1) for _ in range(3)]
    return indexes

首先进行append和pop更新缓存列表

# cache and pop images
self.results_cache.append(copy.deepcopy(results))
if len(self.results_cache) > self.max_cached_images:
    if self.random_pop:
        index = random.randint(0, len(self.results_cache) - 1)
    else:
        index = 0
    self.results_cache.pop(index)

然后根据get_indexes方法得到的索引从缓存列表中得到mix_results,其中包含3张图片的信息用于与当前图片进行组合,当前图片保存在results中。

indices = self.get_indexes(self.results_cache)
mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]

而在原始的mosaic中,results中除了当前图片还包含从整个训练集中挑选的3张图片,即mix_results包含在results中传进函数的。

assert 'mix_results' in results
results_patch = copy.deepcopy(results['mix_results'][i - 1])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1248109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark SQL 时间格式处理

初始化Spark Sql package pbcp_2023.clear_dataimport org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{current_date, current_timestamp}object twe_2 {def main(args: Array[String]): Unit {val con new …

js获取时间日期

目录 Date 对象 1. 获取当前时间 2. 获取特定日期时间 Date 对象的方法 1. 获取各种日期时间组件 2. 获取星期几 3. 获取时间戳 格式化日期时间 1. 使用 toLocaleString() 方法 2. 使用第三方库 UNIX 时间戳 内部表示 时区 Date 对象 JavaScript中内置的 Date 对象…

获取DOM元素和类型判断

一、获取dom元素 <div id"one" class"one">我是第一个div</div> <div>我是第二个div</div> <div name"username">我是第三个div</div> <input type"text" name"username"> 1.qu…

【完美世界】叶倾仙强势登场,孔雀神主VS护道人,石昊重逢清漪

Hello,小伙伴们&#xff0c;我是拾荒君。 《完美世界》国漫第138集已更新。在这一集中&#xff0c;天人族的行为让人大跌眼镜&#xff0c;他们不仅没有对石昊感恩戴德&#xff0c;反而一心想要夺取他身上的所有秘法宝术。天人族的护道人登场&#xff0c;试图以强大的实力压制石…

Theta*: Any-Angle Path Planning on Grids 原文翻译

Theta*: Any-Angle Path Planning on Grids 文章目录 Theta*: Any-Angle Path Planning on Grids翻译摘要1.Introduction2. Path-Planning Problem and Notation3. Existing Terrain Discretizations4.Existing Path-Planning Algorithms4.1 A* on GridsA* with Post-Smoothed …

2023年【R1快开门式压力容器操作】考试资料及R1快开门式压力容器操作复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 R1快开门式压力容器操作考试资料参考答案及R1快开门式压力容器操作考试试题解析是安全生产模拟考试一点通题库老师及R1快开门式压力容器操作操作证已考过的学员汇总&#xff0c;相对有效帮助R1快开门式压力容器操作复…

python获取json所有节点和子节点

使用python获取json的所有父结点和子节点 并使用父节点加下划线命名子节点 先展示一段json代码 {"level1": {"level2": {"level3": [{"level4": "4value"},{"level4_2": "4_2value"}]},"level2_…

毅速丨3D打印随形水路为何受到模具制造追捧

在模具制造行业中&#xff0c;随形水路镶件正逐渐成为一种革命性的技术&#xff0c;其提高冷却效率、优化产品设计、降低成本等优点&#xff0c;为模具制造带来了巨大的创新价值。 随形水路是一种根据产品形状定制的冷却水路&#xff0c;其镶件可以均匀地分布在模具的表面或内部…

迪科DTC-F81收费机DTC-F82

迪科DTC-F81收费机是一款挂式收费机&#xff0c;广泛应用的学校食堂刷卡消费&#xff0c;DTC-F82收费机是台式消费机&#xff0c;常用在学校超市&#xff0c;放在桌子上使用的&#xff0c;这2款消费机是迪科畅销机型&#xff0c;如下图 机器质量可靠稳定&#xff0c;不少用户使…

vivado产生报告阅读分析19-设计收敛报告

Challenging Timing Paths “ Challenging Timing Paths ” &#xff08; 时序收敛困难的路径 &#xff09; 部分列出了“ Assessment Details ” &#xff08; 评估详情 &#xff09; 部分中未能通过检查的时序路径的关键属性。默认情况下&#xff0c; 该命令会对每个时钟组中…

2024北京林业大学计算机考研分析

24计算机考研|上岸指南 北京林业大学 特色优势 Characteristics & Advantages&#xff1a;信息学院创建于2001年&#xff0c;是一个年轻而有朝气的学院。学院秉承“结构、特色、质量、创新”的八字方针&#xff0c;坚持以“质量提升、行业融合”为核心的内涵式发展战略&am…

在Linux上搭建JavaWeb项目运行环境

文章目录 安装JDK安装Tomcat安装数据库 安装JDK 安装Oracle官方的JDK比较麻烦&#xff0c;我们在此处选择安装开源社区维护的openjdk。他们俩的差别不大且兼容。 安装Tomcat 我们把本地下载好的 tomcat.zip 包拖到Linux页面上&#xff0c;让Linux也有一个zip包&#xff0c;再…

运动鞋品牌识别

一、前期工作 1. 设置GPU from tensorflow import keras from tensorflow.keras import layers,models import os, PIL, pathlib import matplotlib.pyplot as plt import tensorflow as tfgpus tf.config.list_physical_devices("GPU")if gpus:gpu0 …

网络安全工程师究竟是什么?怎么入门?

首先啊骚年们我们必须先了解网络安全这个行业究竟是干啥的。 是打ctf的&#xff1f;一个个都像韩商言吴白那么帅刷刷敲几个代码就能轻易夺旗&#xff1f; 还是像十大黑客之一的米特尼克一样闯入了“北美空中防务指挥系统”的计算机主机内&#xff0c;还在被通缉逃跑期间控制了…

【多线程】Thread类的使用

目录 1.概述 2.Thread的常见构造方法 3.Thread的几个常见属性 4.启动一个线程-start() 5.中断一个线程 5.1通过共享的标记来进行沟通 5.2 调用 interrupt() 方法来通知 6.等待一个进程 7.获取当前线程引用 8.线程的状态 8.1所有状态 8.2线程状态和转移的意义 1.概述 …

基于java技术的社区交易二手平台

基于java技术的社区交易二手平台的设计与实现 &#xff08;一&#xff09;开发背景 随着因特网的日益普及与发展&#xff0c;更多的人们开始通过因特网来寻求便利。但是&#xff0c;许多人都觉得网上商店里的东西不贵。所以&#xff0c;有些顾客宁愿去那些用二次定价建立起来的…

Relabel与Metic Relabel

Prometheus支持多种方式的自动发现目标&#xff08;targets&#xff09;&#xff0c;以下是一些常见的自动发现方式&#xff1a; 静态配置&#xff1a;您可以在Prometheus配置文件中直接列出要监测的目标。这种方式适用于目标相对稳定的情况下&#xff0c;例如固定的服务器或设…

【C++】泛型编程 ⑮ ( 类模板示例 - 数组类模板 | 自定义类中持有指针成员变量 )

文章目录 一、支持 数组类模板 存储的 自定义类1、可拷贝和可打印的自定义类2、改进方向3、改进方向 - 构造函数4、改进方向 - 析构函数5、改进方向 - 重载左移运算符6、改进方向 - 重载拷贝构造函数 和 等号运算符 二、代码示例1、Array.h 头文件2、Array.cpp 代码文件3、Test…

网络安全—自学

1.网络安全是什么 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高&#xff1b; 二、则是发展相对成熟…

路径规划之Best-First Search算法

系列文章目录 路径规划之Dijkstra算法 路径规划之Best-First Search算法 路径规划之Best-First Search算法 系列文章目录前言一、Best-First Search算法1.1 起源1.2 过程 三、简单使用 前言 Best-First Search算法和Dijkstra算法类似&#xff0c;都属于BFS的扩展或改进 一、…