OpenPCDet系列 | 5.4.3 DenseHead中的AxisAlignedTargetAssigner正负样本分配模块

news2024/11/23 12:03:23

文章目录

  • AxisAlignedTargetAssigner模块
  • assign_targets处理流程
    • 1. 提取有效gt信息
    • 2. 提取需要处理的类别信息
    • 3. 帧信息整合
    • 4. 批信息整合
  • assign_targets_single处理流程
    • 1. 构建每个anchor的正负样本label分配
    • 2. 构建每个anchor的正负样本编码信息bbox_targets分配
    • 3. 构建每个anchor的回归权重

AxisAlignedTargetAssigner模块

TargetAssigner处理,也就是正负样本的分配问题,算是整个检测算法中一个比较核心的问题,也比较重要。

在AnchorHeadSingle模块总,对分类、回归、方向预测三个特征矩阵预测完之后,随机就调用self.assign_targets函数来对在基类中生成的anchor进行正负样本的匹配,构建出box_cls_labels、box_reg_targets、reg_weights来进行损失函数的构建。而这里的调用self.assign_targets函数其实是调用基类的assign_targets函数,最后再跳转到AxisAlignedTargetAssigner模块的assign_targets函数中。调用关系如下所示:
在这里插入图片描述


assign_targets处理流程

1. 提取有效gt信息

对于当前一个batch的数据,gt信息是进行填充后再拼接在一起的,所以存在0填充的部分。那么,在进行后续处理的前提是,先对gt的填充信息进行去除。提取每个点云帧的有效gt信息以及有效类别信息,保留非零项。

cur_gt = gt_boxes[k]    # 提取第k个点云帧gt,然后提取非零信息,去除非0无效信息 [44, 7] -> [38, 7]
cnt = cur_gt.__len__() - 1      # 43
while cnt > 0 and cur_gt[cnt].sum() == 0:
    cnt -= 1
cur_gt = cur_gt[:cnt + 1]   # 提取当前第k点云帧的有效gt信息,保留非零项
cur_gt_classes = gt_classes[k][:cnt + 1].int()   # 提取当前第k点云帧有效gt类别

2. 提取需要处理的类别信息

由于当前的gt信息包含了3个类别: [‘Car’, ‘Pedestrian’, ‘Cyclist’],现在需要对着三个类别进行分别处理。也就是利用掩码矩阵,分别获取每个当前需要处理的类别,同时单独获取当前需要处理的gt信息,然后传入到assign_targets_single函数中进行处理,这个函数丶作用是针对某一个点云帧中的每一个类别anchors和gt信息,计算前景和背景的anchor类别,box编码以及回归的权重。

for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):    # 对每个类别及其配置anchor进行依次处理
    if cur_gt_classes.shape[0] > 1:
        mask = torch.from_numpy(self.class_names[cur_gt_classes.cpu() - 1] == anchor_class_name)    # 获取类别为'Car'的掩码矩阵:[True, True, ..., False]
    else:
        mask = torch.tensor([self.class_names[c - 1] == anchor_class_name
                             for c in cur_gt_classes], dtype=torch.bool)

    if self.use_multihead:  # False
        anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
        selected_classes = cur_gt_classes[mask]
    else:
        feature_map_size = anchors.shape[:3]   # zyx: (1, 248, 216)
        anchors = anchors.view(-1, anchors.shape[-1])   # (107136,7) 107136=1x248x216x1x2
        selected_classes = cur_gt_classes[mask]     # 被选择的类别 (14, )

    single_target = self.assign_targets_single(
        anchors,        # reshape后的anchor矩阵 (107136,7)
        cur_gt[mask],   # 根据当前类别的掩码矩阵选择当前处理的类别gt信息  (38, 7) -> (14, 7)
        gt_classes=selected_classes,    # 当前处理的类别信息 (14, )  [1,1,1, ..., 1,1]]
        matched_threshold=self.matched_thresholds[anchor_class_name],       # 当前处理类别的正样本阈值
        unmatched_threshold=self.unmatched_thresholds[anchor_class_name]    # 当前处理类别的负样本阈值
    )
    target_list.append(single_target)

也就是说,根据掩码矩阵来挑选出对应类别的anchor设置,以及对应类别的gt信息,然后使用assign_targets_single函数进行后续处理。在assign_targets_single函数中完成对当前类别的anchor进行类别赋值以及与其匹配的gt编码信息赋值,还设置了其回归的权重为1。由于这里设置了需要预测3个类别,所以这个函数对应每个点云帧场景会运行3次,依次处理好每个类别。

3. 帧信息整合

对某个点云帧场景的类别信息提取之后,列表信息如下所示:
在这里插入图片描述

对其进行合并并拼接起来:
在这里插入图片描述

先reshape,再合并,然后再reshape。最后得到如下的结果:
在这里插入图片描述

对于每个点云帧都进行信息的整合处理,然后将当前点云场景处理结果分别追加到对应列表中

bbox_targets.append(target_dict['box_reg_targets'])
cls_labels.append(target_dict['box_cls_labels'])
reg_weights.append(target_dict['reg_weights'])

4. 批信息整合

对于整个batch的点云帧信息都处理完之后,其数据结构如下所示:
在这里插入图片描述

现在分别对其进行堆叠层一个tensor矩阵处理,随后保存在键值value中:

# 将每个点云帧处理的结果进行stack堆叠
bbox_targets = torch.stack(bbox_targets, dim=0)    # (16, 321408, 7)
cls_labels = torch.stack(cls_labels, dim=0)        # (16, 321408)
reg_weights = torch.stack(reg_weights, dim=0)      # (16, 321408)

all_targets_dict = {
    'box_cls_labels': cls_labels,
    'box_reg_targets': bbox_targets,
    'reg_weights': reg_weights

}

返回的数据维度如下所示:
在这里插入图片描述

自此,完成了对每个点云帧的anchor正样本分配。最后,这个字典的信息会返回到AnchorHeadSingle函数中,保存在self.forward_ret_dict这个字典中,后续就会利用这个字典来进行损失的计算。
在这里插入图片描述

在这之后就是进行损失函数的计算:self.get_training_loss()


assign_targets_single处理流程

1. 构建每个anchor的正负样本label分配

首先,对于传进来的当前类别的gt信息以及当前类别的生成的anchor,可以进行一个iou3d的计算。也就是先计算anchor和gt之间的iou。

# 1.计算gt和anchors之间的overlap
anchor_by_gt_overlap = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) \
    if self.match_height else box_utils.boxes3d_nearest_bev_iou(anchors[:, 0:7], gt_boxes[:, 0:7])  # 计算anchor和gt之间的iou (107136, 14)

根据这个anchor和gt之间的iou矩阵,可以分别获得与每个anchor最匹配的gt索引以及数值,也可以获得与每个gt最匹配的anchor索引以及数值。其中有可能出现某个gt没有找到与之有任何重叠的anchor,那么最匹配的iou数值为0,此时将其赋值为-1.

# 找到每个anchor最匹配的gt的索引和iou
anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)    # (107136,)找到每个anchor最匹配的gt的索引
anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_anchors, device=anchors.device), anchor_to_gt_argmax]  # (107136,)找到每个anchor最匹配的gt的iou

# 提取最匹配的anchor,避免没有anchor满足索设定的阈值
# gt_to_anchor_argmax = torch.from_numpy(anchor_by_gt_overlap.cpu().numpy().argmax(axis=0)).cuda()
gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)     # (14,) 找到每个gt最匹配anchor的索引
gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, torch.arange(num_gt, device=anchors.device)]   # (14,)找到每个gt最匹配anchor的iou
empty_gt_mask = gt_to_anchor_max == 0   # 如果最匹配iou为0,表示某个gt没有与之匹配的anchor
gt_to_anchor_max[empty_gt_mask] = -1    # 没有与之匹配的anchor在iou值中设置为-1

接着,根据这一系列最大iou匹配的值,可以找到满足这个最大iou的每个anchor。具体来说,以gt为基础,逐个anchor对应,比如第一个gt的最大iou为0.9,则在所有anchor中找iou为0.9的anchor。对于这些anchor,在labels的对应位置中为其分配类别信息(此类别信息就是当前处理的类别信息),同时也记录为其分配的gt索引。

nchors_with_max_overlap = (anchor_by_gt_overlap == gt_to_anchor_max).nonzero()[:, 0]   # 找到满足最大iou的每个anchor
gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]   # 找到最大iou的gt索引
labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]    # 将gt的类别赋值到对应的anchor的label中 (107136,)
gt_ids[anchors_with_max_overlap] = gt_inds_force.int()          # 将gt的索引赋值到对应的anchor的gt_id中 (107136,)

以上的操作是处理最值以避免没有符合满足阈值设定的iou。接下来还会对满足正样本阈值的anchor进行label和gt的分配。

对于每个anchor与gt的最大iou值,如果超过设定的正样本阈值范围,比如这里的0.6,都会根据这个阈值掩码将符合的anchor挑选出来。然后在labels中对应符合阈值的anchor设置其类别,同时也设置其分配的gt索引。一般情况下,符合最值的anchor的iou匹配只能设置十来个,然后满足阈值的anchor的iou匹配只有百来个。

# 这里应该对labels和gt_ids的操作应该包含了上面的anchors_with_max_overlap
pos_inds = anchor_to_gt_max >= matched_threshold        # 找到最匹配的anchor中iou大于给定阈值的mask (107136,)
gt_inds_over_thresh = anchor_to_gt_argmax[pos_inds]     # 找到最匹配的anchor中iou大于给定阈值的gt的索引 (104,)
labels[pos_inds] = gt_classes[gt_inds_over_thresh]      # 将pos anchor对应gt的类别赋值到对应的anchor的label中 (107136,)
gt_ids[pos_inds] = gt_inds_over_thresh.int()            # 将pos anchor对应gt的索引赋值到对应的anchor的gt_id中 (107136,)

此时,由于已经设置了label的anchor就是正样本,就可以分别找到前景anchor和背景anchor的索引。

bg_inds = (anchor_to_gt_max < unmatched_threshold).nonzero()[:, 0]  # 找到背景anchor索引 (106874,)
fg_inds = (labels > 0).nonzero()[:, 0]      # 找到前景点的索引 (104,)

最后,将labels中的背景anchor类别设置为0.

labels[bg_inds] = 0     # 将背景点的label赋值为0

这时候,就完成了anchor的类别分配操作,一般前景anchor的数量还是比较少的。一个类别中只有百个这样的数量级。相比之下,anchor的数量级是十万。

2. 构建每个anchor的正负样本编码信息bbox_targets分配

对于负样本的anchor预测编码信息设置为0,只对正样本的anchor进行赋值处理。

首先,基于上述的操作已经获取到了最大值iou以及满足设定阈值的正样本anchor的索引,根据真个anchor索引可以获得与其iou最大的gt索引,那么根据gt的索引就可以获取对应的gt信息。同时,根据这个anchor索引也可以获取对应的生成的anchor信息。那么,根据anchor和其对应的gt信息,就可以进行所需预测的编码处理,所获得的编码存储在bbox_targets的前景点索引位置。这里的编码操作是通过self.box_coder.encode_torch实现的。代码如下所示:

# 2. 构建正样本anchor需要预测拟合的编码信息(负样本anchor全部设置为0)
bbox_targets = anchors.new_zeros((num_anchors, self.box_coder.code_size))   # (107136,7)
if len(gt_boxes) > 0 and anchors.shape[0] > 0:
    fg_gt_boxes = gt_boxes[anchor_to_gt_argmax[fg_inds], :]     # 提取前景对应的gt box信息 (104, 7)
    fg_anchors = anchors[fg_inds, :]    # 提取前景anchor (104, 7)
    bbox_targets[fg_inds, :] = self.box_coder.encode_torch(fg_gt_boxes, fg_anchors)    # 编码gt和前景anchor,并赋值到bbox_targets的对应位置

论文中的回归编码方式如下:
在这里插入图片描述

这部分的具体代码见ResidualCoder模块,

3. 构建每个anchor的回归权重

这里的回归权重只针对前景anchor,赋值为1.背景的anchor赋值为0.

# 3. 构建正负样本回归权重,其中背景anchor权重为0,前景anchor权重为1
reg_weights = anchors.new_zeros((num_anchors,))    # 回归权重 (107136,)
if self.norm_by_num_examples:   # False
    num_examples = (labels >= 0).sum()
    num_examples = num_examples if num_examples > 1.0 else 1.0
    reg_weights[labels > 0] = 1.0 / num_examples
else:
    reg_weights[labels > 0] = 1.0    # 将前景anchor的权重赋1

最后,将这构建的3个Tensor进行字典保存,返回到assign_targets函数中。

ret_dict = {
    'box_cls_labels': labels,           # 背景anchor的label是0,前景的anchor的label是当前处理的类别1  (107136,)
    'box_reg_targets': bbox_targets,    # 编码后待模型预测拟合的结果,背景anchor的编码信息也是0 (107136,7)
    'reg_weights': reg_weights,         # 背景anchor权重为0,前景anchor权重为1  (107136,)
}
return ret_dict

总结:本质上这个函数就是根据iou选择出前景anchor,然后对其进行类别赋值以及与其匹配的gt编码信息赋值,还设置了其回归的权重为1.


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/540627.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NMEA 2000总线连接器组成类别及功能

NMEA 2000总线介绍 连接的所有设备都能彼此通信 NMEA 2000总线系统是一种数字网络&#xff0c;可连接船上各种系统和设备&#xff0c;例如雷达、GPS、深度传感器、气象环境和船体控制系统。 NMEA 2000总线的工作方式 NMEA 2000总线通过总线连接器连接各个设备和系统&#xf…

HANTS时间序列滤波算法的MATLAB实现

本文介绍在MATLAB中&#xff0c;实现基于HANTS算法&#xff08;时间序列谐波分析法&#xff09;的长时间序列数据去噪、重建、填补的详细方法。 HANTS&#xff08;Harmonic Analysis of Time Series&#xff09;是一种用于时间序列分析和插值的算法。它基于谐波分析原理&#x…

linux 常用命令awk

AWK 是一种处理文本文件的语言&#xff0c;是一个强大的文本分析工具。之所以叫 AWK 是因为其取了三位创始人 Alfred Aho&#xff0c;Peter Weinberger, 和 Brian Kernighan 的 Family Name 的首字符。 AWK用法 awk 用法&#xff1a;awk pattern {action} files 1.RS, ORS, F…

CFS三层内网靶机渗透

目录 一、靶场框架 靶场搭建&#xff1a; 二、渗透过程 三、总结 靶场介绍&#xff1a; 三层内网靶场&#xff0c;共有三个网段&#xff0c;分别为75网段&#xff08;公网网段&#xff09;、22网段&#xff08;内网&#xff09;、33网段&#xff08;内网&#xff09; 靶…

ChatGPT+ “剪映 or 百度AIGC” 快速生成短视频

&#x1f34f;&#x1f350;&#x1f34a;&#x1f351;&#x1f352;&#x1f353;&#x1fad0;&#x1f951;&#x1f34b;&#x1f349;&#x1f95d; ChatGPT快速生成短视频 文章目录 &#x1f350;1.ChatGPT 剪映&#x1f34f;2.ChatGPT 百度AIGC平台&#x…

【ThinkPHP6系列学习-2】多应用模式配置

上一篇&#xff1a;【ThinkPHP6系列学习-1】下载并部署ThinkPHP6 这里写一写TP6下配置多应用。因为TP6和TP5有所差异&#xff0c;TP6默认是单应用模式&#xff08;单模块&#xff09;&#xff0c;而我们实际项目中往往是多应用的&#xff08;多个模块&#xff09;&#xff0c…

1.Buffer_Overflow-2.Stack_Overflow / 写入字符串

这道题虽然简单 但是却给我了另一个解题的看法 我们先进行运行 我们看看保护 发现只有NX保护 反汇编看看 发现有shellcode 但是我们没有办法执行shellcode 因为v5 不会等于后面的 这里我原本没有想法 后面进行看看他的汇编 这里其实就很清楚了 .text:00000000004011BB …

( 动态规划) 674. 最长连续递增序列 / 718. 最长重复子数组——【Leetcode每日一题】

题目一&#xff08;贪心&#xff09; ❓674. 最长连续递增序列 难度&#xff1a;简单 给定一个未经排序的整数数组&#xff0c;找到最长且 连续递增的子序列&#xff0c;并返回该序列的长度。 连续递增的子序列 可以由两个下标 l 和 r&#xff08;l < r&#xff09;确定…

vcruntime140.dll如何修复?这个修复方法很简单,适合电脑小白

今天打开photoshop软件工作的时候&#xff0c;突然间就打不开&#xff0c;电脑报错由于找不到vcruntime140.dll&#xff0c;无法继续执行此代码&#xff0c;然后我就把photoshop卸载了&#xff0c;再重新安装&#xff0c;依然还是报错。这个可怎么办&#xff1f;vcruntime140.d…

CentOS 安装MongoDB 6.0

一、安装依赖 yum install libcurl openssl xz-libs 二、下载安装包 安装包下载地址https://www.mongodb.com/try/download/community这里我选择的是 选择RedHat / CentOS 7.0平台的原因是我的操作系统使用的是CentOS 7.0的&#xff0c;需要下载与操作系统匹配的安装包 三、…

ChatGPT插件权限给Plus用户放开了

大家好&#xff0c;我是章北海mlpy ChatGPT插件权限给Plus用户放开了 我稍微测试了俩&#xff0c;感觉还行&#xff0c;后续我会对一些热门插件深入测测&#xff0c;敬请期待。 官方对插件的介绍如下&#xff1a; 1、插件由非由OpenAI控制的第三方应用程序提供动力。在安装…

Spring Cloud Alibaba-Sentinel熔断降级

Sentinel: 轻量级的流量控制、熔断降级Java库&#xff0c; 分布式系统的流量防卫兵。 文章目录 一、Sentinel 是什么&#xff1f;二、安装Sentinel控制台三、Sentinel 实战3.1、准备工作3.2、流控规则快速失败Warm Up匀速排队 3.3、热点key限流3.4、降级规则3.5、系统规则 四、…

自定义组件3-behaviors

1、behaviors是小程序中用于实现组件间代码共享的特性&#xff0c;类似于Vue.js中的mixins 2、behaviors的工作方式 每个behaviors可以包含一组属性、数据、生命周期函数和方法。组件引用它时&#xff0c;它的数据和属性和方法会被 合并到组件中。 每个组件可以引用多个behav…

Python采集知某专栏文章保存成pdf

前言 嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 环境使用: Python 3.8 Pycharm wkhtmltopdf 软件 文末名片获取 模块使用: requests >>> pip install requests 数据请求 parsel >>> pip install parsel 数据解析 re >>> 内置模块 不…

【年度最强超级Ai让你体验真正的人工智能】

年度最强超级Ai让你体验真正的人工智能&#xff08;破解版&#xff09; 登录就是永久会员&#xff08;先上链接&#xff0c;资源来源于网络&#xff0c;如有侵权&#xff0c;联系删除&#xff09; 网易邮箱&#xff1a; 23402001163163.com 我用蓝奏浏览器分享了[GPT-AI助手V1…

100 个 Go 错误以及如何避免:5~8

协议&#xff1a;CC BY-NC-SA 4.0 译者&#xff1a;飞龙 本文来自【OpenDocCN 饱和式翻译计划】&#xff0c;采用译后编辑&#xff08;MTPE&#xff09;流程来尽可能提升效率。 真相一旦入眼&#xff0c;你就再也无法视而不见。——《黑客帝国》 五、字符串 本章涵盖 理解GO中…

Oracle执行计划管理 - SPM

https://blog.51cto.com/lhrbest/3246884 目录 Oracle优化器辅助手段的发展 SPM需求背景 SPM重要构成 SQL计划基准捕获 如何创建SQL计划基准 如何查看SQL计划基准 SQL计划基准选择 执行计划的三个属性 如何选择SQL计划 SQL计划基准发展 SQL计划基准发展的三种选择 …

在外包干了三年,我废了……不吹不黑!

没错&#xff0c;我也干过外包&#xff0c;一干就是三年&#xff0c;三年后&#xff0c;我废了…… 虽说废的不是很彻底&#xff0c;但那三年我几乎是出差了三年、玩了三年、荒废了三年&#xff0c;那三年&#xff0c;我的技术能力几乎是零成长的。 说起这段三年的外包经历&a…

CTFShow-电子取证篇Writeup

CTFShow-电子取证篇Writeup 套的签到题&#xff1a;JiaJia-CP-1&#xff1a;JiaJia-CP-2&#xff1a;JiaJia-CP-3&#xff1a; CTFShow 平台&#xff1a;https://ctf.show/ 套的签到题&#xff1a; JiaJia-CP-1&#xff1a; 这是部分人熟知的刘佳佳同学的电脑&#xff0c;她…

JavaWeb-Servlet的学习

Servlet 简介 Servlet是JavaWeb最为核心的内容&#xff0c;它是Java提供的一门动态web资源开发技术。 使用Servlet就可以实现&#xff0c;根据不同的登录用户在页面上动态显示不同内容。 Servlet是JavaEE规范之一&#xff0c;其实就是一个接口&#xff0c;将来我们需要定义S…