FoveaBox原理与代码解析

news2025/1/10 16:33:48

paper:FoveaBox: Beyond Anchor-based Object Detector

code:https://github.com/taokong/FoveaBox

背景

基于anchor的检测模型需要仔细设计anchor,常用方法之一是根据特定数据集的统计结果确定anchor的number、scale、ratio等,但这种针对特定数据集的设计并不总能适用于其它数据集,泛化性较差。另外训练阶段anchor-based的模型通常根据和GT的IoU来定义正负样本,这又引入了额外的计算和超参。

本文的创新点

受到人眼中心凹(fovea)区域的启发:视野中心区域的视觉灵敏度最高,本文提出了一种新的anchor-free的目标检测方法FoveaBox,FoveaBox联合预测对象中心区域可能存在的位置以及每个有效位置处的边界框。在FoveaBox中,每个目标对象通过中心区域的类别得分进行预测,同时预测bounding box,训练阶段不需要使用anchor或是IoU匹配来生成训练目标,训练目标是根据GT box直接生成的。

方法介绍

给定一个GT box \((x_{1},y_{1},x_{2},y_{2})\),首先将其映射到特征金字塔的目标层 \(P_{l}\)

其中 \(s_{l}\) 是下采样步长。定义输出特征图上对应GT box中心区域为正样本区域 \(R^{pos}\)

其中 \(\sigma\) 是收缩系数,文中 \(\sigma = 0.4\)。训练阶段,正样本区域内的每个像素位置都标为对应的目标类别标签,整个特征图上,除了正样本区域其它都是负样本区域。如下图右灰色区域所示

在标签分配中,除了按上述对正负样本区域进行了限制,还对FPN每层负责预测的目标大小即scale进行了限制。对于FPN的输出层 \(P_{3}-P_{7}\),每一层的basic scale \(r_{l}\) 为32至512。\(l\) 层的有效scale区间按下式计算得到

其中 \(\eta \) 是超参,文中 \(\eta =2\)。注意和之前一个目标只会由特征金字塔中的某一层负责预测的方法不同,FoveaBox中一个目标可能会由FPN的多层负责预测。将目标分配给多个相邻的FPN层有两个优点:(1)相邻的特征金字塔层通常具有相似语义表示能力,因此FoveaBox可以同时优化这些相邻层的特征。(2)FPN每一层的训练样本数量增大,使得训练过程更加稳定。

对于一个GT box \(G=(x_{1},y_{1},x_{2},y_{2})\),\(R_{pos}\) 区域中某一点 \((x,y)\) 的回归target即到四条边界的归一化的偏移按下式得到

FoveaBox的结构如下图所示,整体结构和anchor数量为1的RetinaNet是一样的,只不过在样本分配和定义上又区别。

代码解析

这里以mmdet中的实现为例,代码文件在https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/fovea_head.py,foveabox相对于retinanet的创新点就在于anchor-free以及对应的标签分配部分,这里的核心代码在函数_get_target_single()中,这个函数的作用就是就算FPN输出层中的单层分类和回归的target,完整代码如下

def _get_target_single(self,
                   gt_bboxes_raw,
                   # (2,4), tensor([[52.5, 46.8, 235.7, 274.4], [101.7, 29.6, 221.7, 175.8]], device='cuda:0')
                   gt_labels_raw,  # (2), tensor([12, 14], device='cuda:0')
                   featmap_size_list=None,
                   point_list=None):
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
                      (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))  # torch.Size([2])
label_list = []
bbox_target_list = []
# for each pyramid, find the cls and box target
# self.base_edge_list=[16, 32, 64, 128, 256]
# self.scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)), 注意收尾本来分别为16和1024,这里改为了1和2048
# self.strides=[8, 16, 32, 64, 128]
for base_len, (lower_bound, upper_bound), stride, featmap_size, \
    points in zip(self.base_edge_list, self.scale_ranges,
                  self.strides, featmap_size_list, point_list):
    # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
    points = points.view(*featmap_size, 2)  # (1444,2) -> (38,38,2)
    x, y = points[..., 0], points[..., 1]  # (38,38),(38,38)
    labels = gt_labels_raw.new_zeros(featmap_size) + self.num_classes  # (38,38), 值全为self.num_classes
    bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
                                     4) + 1  # (38,38,4),值全为1
    # scale assignment
    hit_indices = ((gt_areas >= lower_bound) &
                   (gt_areas <= upper_bound)).nonzero().flatten()  # torch.Size([1]), tensor([1], device='cuda:0')
    if len(hit_indices) == 0:
        label_list.append(labels)
        bbox_target_list.append(torch.log(bbox_targets))
        continue
    _, hit_index_order = torch.sort(-gt_areas[hit_indices])
    hit_indices = hit_indices[hit_index_order]  # 按面积从大到小排列
    gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
    gt_labels = gt_labels_raw[hit_indices]
    half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
    half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
    # valid fovea area: left, right, top, down
    pos_left = torch.ceil(
        gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long(). \
        clamp(0, featmap_size[1] - 1)
    pos_right = torch.floor(
        gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long(). \
        clamp(0, featmap_size[1] - 1)
    pos_top = torch.ceil(
        gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long(). \
        clamp(0, featmap_size[0] - 1)
    pos_down = torch.floor(
        gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long(). \
        clamp(0, featmap_size[0] - 1)
    for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
            zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
                gt_bboxes_raw[hit_indices, :]):
        labels[py1:py2 + 1, px1:px2 + 1] = label
        bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
            (x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
        bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
            (y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
        bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
            (gt_x2 - x[py1:py2 + 1, px1:px2 + 1]) / base_len
        bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
            (gt_y2 - y[py1:py2 + 1, px1:px2 + 1]) / base_len
    bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)  # 文中有这个限制吗?
    label_list.append(labels)
    bbox_target_list.append(torch.log(bbox_targets))
return label_list, bbox_target_list

下面是根据FPN某一层对应的尺度限制取出该层负责预测的GT box的index,即上面的式(3),代码如下

hit_indices = ((gt_areas >= lower_bound) &
               (gt_areas <= upper_bound)).nonzero().flatten() 

下面是按式(2)计算 \(R^{pos}\) 区域的坐标,self.sigma是收缩系数 \(\sigma\),式(3)中是以gt box的中心坐标为基准计算的,而下面的实现是以gt box的左上角坐标为基准计算的。

# valid fovea area: left, right, top, down
pos_left = torch.ceil(
    gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long(). \
    clamp(0, featmap_size[1] - 1)
pos_right = torch.floor(
    gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long(). \
    clamp(0, featmap_size[1] - 1)
pos_top = torch.ceil(
    gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long(). \
    clamp(0, featmap_size[0] - 1)
pos_down = torch.floor(
    gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long(). \
    clamp(0, featmap_size[0] - 1)

下面是按式(4)计算回归target,其中base_len即这一层对应的basic scale \(r_{l}\)。

for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
        zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
            gt_bboxes_raw[hit_indices, :]):
    labels[py1:py2 + 1, px1:px2 + 1] = label
    bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
        (x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
    bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
        (y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
    bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
        (gt_x2 - x[py1:py2 + 1, px1:px2 + 1]) / base_len
    bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
        (gt_y2 - y[py1:py2 + 1, px1:px2 + 1]) / base_len

这里有一些疑问,一是下面这行对回归target进行大小的限制论文中好像没有提到

bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)

二是mmdet中对FPN每一层的basic scale \(r_{l}\) 以及负责预测目标的valid scale range和论文中有些差异,如下

其中base_edge_list就是每一层的 \(r_{l}\),如果按照文中计算方式,实际的valid scale range应该如下

如果以设定的scale_ranges为准,则实际的 \(r_{l}\) 应该是[32, 64, 128, 256, 512],并且第一个值由16改为了1,最后一个值由1024改为2048。

实验结果

Comparision with SOTA

下面是FoveaBox和当时的一些SOTA方法的对比,可以看出FoveaBox取得了最优的精度,而且好于当时刚刚提出的其它anchor-free方法比如CornerNet和ExtremeNet。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/337688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

elasticsearch8.3.2搭建部署

Elasticsearch8.3.2搭建部署详细步骤 0.过往文章 ES-6文章&#xff1a; Elasticsearch6.6.0部署、原理和使用介绍: https://blog.csdn.net/wt334502157/article/details/119515730 ES-7文章&#xff1a; Elasticsearch7.6.1部署、原理和使用介绍: https://blog.csdn.net/wt…

堆排序

章节目录&#xff1a;一、相关概述1.1 基本介绍1.2 排序思想二、基本应用2.1 步骤说明2.2 代码示例三、结束语一、相关概述 1.1 基本介绍 堆排序是利用堆这种数据结构而设计的一种排序算法&#xff0c;堆排序是一种选择排序。它的最坏最好平均时间复杂度均为 O(nlogn)&#x…

(深度学习快速入门)第五章第二节:GAN变体

文章目录一&#xff1a;CycleGAN&#xff08;1&#xff09;概述&#xff08;2&#xff09;双判别器&#xff08;3&#xff09;损失函数二&#xff1a;StyleGAN&#xff08;1&#xff09;解耦表征学习&#xff08;2&#xff09;概述三&#xff1a;DCGAN一&#xff1a;CycleGAN …

4.5.8 Set接口与HashSet

文章目录1.概述2.Set集合的特点3.常用方法4.HashSet4.1 概述4.2 练习: Set相关测试一4.3 练习: Set相关测试二1.概述 Set是一个不包含重复数据的CollectionSet集合中的数据是无序的(因为Set集合没有下标)Set集合中的元素不可以重复 – 常用来给数据去重 2.Set集合的特点 数据…

排序算法学习

文章目录前言一、直接插入排序算法二、折半插入排序算法三、2路插入排序算法四、快速排序算法学习前言 算法是道路生涯的一个巨大阻碍。今日前来解决这其中之一&#xff1a;有关的排序算法&#xff0c;进行实现以及性能分析。 一、直接插入排序算法 插入排序算法实现主要思想…

Kubernetes_从Linux的cgroup配置到Kubernetes中的cgroup配置

系列文章目录 文章目录系列文章目录前言一、Linux层面的cgroup二、Kubernetes层面的cgroup driver2.1 kubelet和docker的Cgroup Driver不同导致kubelet开启失败2.1.1 命令2.1.2 演示总结前言 一、Linux层面的cgroup cgroup是控制组&#xff0c;用来控制进程对资源的分配&…

Cesium-数字仿真-你总要了解

Cesium&#xff08;专注于时空数据的实时可视化) cesium是一款三维地球开源框架&#xff08;可以多平台、跨平台使用&#xff09;cesium隶属于美国AGI公司&#xff08;Analytical Graphics Incorporation&#xff09;&#xff0c;美国通用公司宇航部的工程师创始开源 周边产…

微信小程序的优化方案之主包与分包的研究

什么是分包&#xff1f; 某些情况下&#xff0c;开发者需要将小程序划分成不同的子包&#xff0c;在构建时打包成不同的分包&#xff0c;用户在使用时按需进行加载。 在构建小程序分包项目时&#xff0c;构建会输出一个或多个分包。每个使用分包小程序必定含有一个主包。所谓的…

错误代码0xc0000001要怎么解决?如何修复错误

出现错误代码0xc0000001这个要怎么解决&#xff1f;其实这个的蓝屏问题还是非常的简单的&#xff0c;有多种方法可以实现 解决方法一 1、首先使用电脑系统自带的修复功能&#xff0c;首先长按开机键强制电脑关机。 注&#xff1a;如果有重要的资料请先提前备份好&#xff0c;…

【C++】C++11 ~ 包装器解析

&#x1f308;欢迎来到C专栏~~包装器解析 (꒪ꇴ꒪(꒪ꇴ꒪ )&#x1f423;,我是Scort目前状态&#xff1a;大三非科班啃C中&#x1f30d;博客主页&#xff1a;张小姐的猫~江湖背景快上车&#x1f698;&#xff0c;握好方向盘跟我有一起打天下嘞&#xff01;送给自己的一句鸡汤&a…

Java 内存结构解密

程序计数器 物理上被称为寄存器&#xff0c;存取速度很快。 作用 记住下一条jvm指令的执行地址。 特点 线程私有&#xff0c;和线程一块出生。 不存在内存溢出。 虚拟机栈 每个线程运行时所需要的内存&#xff0c;称为虚拟机栈。 每个栈由多个栈帧组成&#xff0c;…

C/C++ 中的宏 (macros) 与宏展开的可视化显示

C/C 中的宏 (macros) 与宏展开的可视化显示1. Replacing text macros (替换文本宏) https://en.cppreference.com/w/cpp/preprocessor/replace https://www.codecademy.com/resources/docs/cpp/macros A macro is a label defined in the source code that is replaced by it…

dll修复工具哪个比较好?修复工具介绍

DLL&#xff08;动态链接库&#xff09;是Windows操作系统中非常重要的一部分&#xff0c;它们存储了各种软件应用程序所需的公共代码和数据。然而&#xff0c;随着时间的推移&#xff0c;电脑上的DLL文件可能会因为各种原因而损坏或丢失&#xff0c;导致系统出现错误。因此&am…

PyTorch自定义损失函数实现

在机器学习中&#xff0c;损失函数是衡量预测输出与实际输出之间差异的关键组成部分。 它在模型训练中起着至关重要的作用&#xff0c;因为它通过指示模型应该改进的方向来指导优化过程。 损失函数的选择取决于具体的任务和数据类型。 在本文中&#xff0c;我们将以用于手写数字…

VHDL语言基础-时序逻辑电路-概述

目录 时序逻辑电路-概述: 时序逻辑电路: 时序逻辑电路——有记忆功能: 时序电路的分类: 按照触发器的动作特点: 按照输出信号的特点: 同步时序逻辑电路: 异步时序逻辑电路: 时序逻辑电路-概述: 数字电路按其完成逻辑功能的不同特点&#xff0c;划分为组合逻辑电路和时序…

福利篇1——嵌入式软件行业与公司汇总

前言 汇总嵌入式软件行业与公司,供参考。 文章目录 前言一、嵌入式软件行业和公司汇总1、芯片行业代表性公司2、人工智能代表性公司1)智能驾驶方向代表性公司2)机器人方向代表性公司3、消费电子领域代表性公司4、传统电子电器领域代表性公司5、国企和军工领域代表性公司6、网…

嵌入式系统那些事——aarch64 backtrace嵌入式汇编实现

0 背景 在aarch64嵌入式应用开发中&#xff0c;经常会遇到段错误(segmentation fault)&#xff0c;但是通常情况下系统报错后直接退出&#xff0c;没有异常调用打印信息&#xff0c;定位出错原因十分困难。经确认&#xff0c;该问题是由于没有设置捕获段错误&#xff0c;并调用…

推荐3dMax三维设计十大插件

3dMax是一款功能非常强大的三维设计软件&#xff0c;但无论它的功能多么强大&#xff0c;也不可能包含所有三维方面的功能&#xff0c;这时候&#xff0c;第三方插件可以很好的弥补和增强3dMax的基本功能&#xff0c;下面就给大家介绍十款非常不错的3dMax插件。 森林包&#xf…

Unsupervised Question Answering 简单综述

Unsupervised Question Answering by Cloze Translation, ACL 2019 随机从文本中抽取noun phrases或者named entity作为答案将答案部分mask掉&#xff0c;生成cloze question利用无监督翻译&#xff0c;将cloze question转化为natural question 缺点&#xff1a; 直接利用原句…

Android 进阶——Framework核心 之Binder Native成员类详解(二)

文章大纲引言一、Native 家族核心成员关系图二、Native 家族核心成员源码概述1、IInterface1.1、DECLARE_META_INTERFACE 宏1.2、IMPLEMENT_META_INTERFACE(INTERFACE, NAME) 宏1.3、sp< IInterface > BnInterface< INTERFACE >::queryLocalInterface(const String…