关于PointHeadBox类的理解

news2024/12/23 1:57:54

forward函数

 def forward(self, batch_dict):
        """
        Args:
            batch_dict:
                batch_size:
                point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
                point_features_before_fusion: (N1 + N2 + N3 + ..., C)
                point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
                point_labels (optional): (N1 + N2 + N3 + ...)
                gt_boxes (optional): (B, M, 8)
        Returns:
            batch_dict:
                point_cls_scores: (N1 + N2 + N3 + ..., 1)
                point_part_offset: (N1 + N2 + N3 + ..., 3)
        """
        if self.model_cfg.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
            point_features = batch_dict['point_features_before_fusion']
        else:
            point_features = batch_dict['point_features']
            #通过全连接层128-->256-->256-->3生成类别信息
        point_cls_preds = self.cls_layers(point_features)  # (total_points, num_class)
        #通过全连接层128-->256-->256-->8生成回归框信息
        point_box_preds = self.box_layers(point_features)  # (total_points, box_code_size)

       #在预测的3个类别中求出最大可能的类别作为标签信息,并经过sigmod函数
        point_cls_preds_max, _ = point_cls_preds.max(dim=-1)
        batch_dict['point_cls_scores'] = torch.sigmoid(point_cls_preds_max)

        ret_dict = {'point_cls_preds': point_cls_preds,
                    'point_box_preds': point_box_preds}
        if self.training:
           #主要是生成每个点对应的真实的标签信息
           #以及真实框G相对于预测G_hat的框的参数偏移,每个点对应是1*8维向量
            targets_dict = self.assign_targets(batch_dict)
            ret_dict['point_cls_labels'] = targets_dict['point_cls_labels']
            ret_dict['point_box_labels'] = targets_dict['point_box_labels']

        if not self.training or self.predict_boxes_when_training:
           #求出每个点对应的预测的标签信息
           #以及P相对于预测的框G_hat的参数偏移,每个点对应是1*8维向量
            point_cls_preds, point_box_preds = self.generate_predicted_boxes(
                points=batch_dict['point_coords'][:, 1:4],
                point_cls_preds=point_cls_preds, point_box_preds=point_box_preds
            )
            batch_dict['batch_cls_preds'] = point_cls_preds
            batch_dict['batch_box_preds'] = point_box_preds
            batch_dict['batch_index'] = batch_dict['point_coords'][:, 0]
            batch_dict['cls_preds_normalized'] = False

        self.forward_ret_dict = ret_dict

        return batch_dict

注意:对于每一个point,point_box_preds是1×8维向量,8维分别表示[xt, yt, zt, dxt, dyt, dzt, cost, sint],[xt, yt, zt]为中心点偏移量,[dxt, dyt, dzt]为长宽高偏移量,[cost, sint]为角度偏移量。

在这里插入图片描述

forward函数得到了每个前景点对应的真实标签值以及标注框信息;(self.assign_targets--------->self.assign_stack_targets-----> self.box_coder.encode_torch调用了PointResidualCoder类中的encode_torch函数)

得到了从G_hat到G的1*8维参数


每个前景点对应的预测标签值以及预测框信息;(self.generate_predicted_boxes--------->self.box_coder.decode_torch调用了PointResidualCoder类中的decode_torch函数)

得到了从P到G_hat的1*8维参数

得到这两组参数后用于后续计算损失时计算的box损失,采用的是L1回归损失

 point_loss_box_src = F.smooth_l1_loss(
            point_box_preds[None, ...], point_box_labels[None, ...], weights=reg_weights[None, ...]
        )

边框回归(Bounding Box Regression)详解

PointResidualCoder

class PointResidualCoder(object):
    def __init__(self, code_size=8, use_mean_size=True, **kwargs):
        super().__init__()
        self.code_size = code_size
        self.use_mean_size = use_mean_size
        if self.use_mean_size:
            self.mean_size = torch.from_numpy(np.array(kwargs['mean_size'])).cuda().float()
            assert self.mean_size.min() > 0

    def encode_torch(self, gt_boxes, points, gt_classes=None):
        """
        Args:
            gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
            points: (N, 3) [x, y, z]
            gt_classes: (N) [1, num_classes]
        Returns:
            box_coding: (N, 8 + C)
        """
        gt_boxes[:, 3:6] = torch.clamp_min(gt_boxes[:, 3:6], min=1e-5)

        xg, yg, zg, dxg, dyg, dzg, rg, *cgs = torch.split(gt_boxes, 1, dim=-1)
        xa, ya, za = torch.split(points, 1, dim=-1)

        if self.use_mean_size:
            assert gt_classes.max() <= self.mean_size.shape[0]
            point_anchor_size = self.mean_size[gt_classes - 1]
            dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)
            diagonal = torch.sqrt(dxa ** 2 + dya ** 2)
            xt = (xg - xa) / diagonal
            yt = (yg - ya) / diagonal
            zt = (zg - za) / dza
            dxt = torch.log(dxg / dxa)
            dyt = torch.log(dyg / dya)
            dzt = torch.log(dzg / dza)
        else:
            xt = (xg - xa)
            yt = (yg - ya)
            zt = (zg - za)
            dxt = torch.log(dxg)
            dyt = torch.log(dyg)
            dzt = torch.log(dzg)

        cts = [g for g in cgs]
        return torch.cat([xt, yt, zt, dxt, dyt, dzt, torch.cos(rg), torch.sin(rg), *cts], dim=-1)

    def decode_torch(self, box_encodings, points, pred_classes=None):
        """
        Args:
            box_encodings: (N, 8 + C) [x, y, z, dx, dy, dz, cos, sin, ...]
            points: [x, y, z]
            pred_classes: (N) [1, num_classes]
        Returns:

        """
        xt, yt, zt, dxt, dyt, dzt, cost, sint, *cts = torch.split(box_encodings, 1, dim=-1)
        xa, ya, za = torch.split(points, 1, dim=-1)

        if self.use_mean_size:
            assert pred_classes.max() <= self.mean_size.shape[0]
            point_anchor_size = self.mean_size[pred_classes - 1]
            dxa, dya, dza = torch.split(point_anchor_size, 1, dim=-1)
            diagonal = torch.sqrt(dxa ** 2 + dya ** 2)
            xg = xt * diagonal + xa
            yg = yt * diagonal + ya
            zg = zt * dza + za

            dxg = torch.exp(dxt) * dxa
            dyg = torch.exp(dyt) * dya
            dzg = torch.exp(dzt) * dza
        else:
            xg = xt + xa
            yg = yt + ya
            zg = zt + za
            dxg, dyg, dzg = torch.split(torch.exp(box_encodings[..., 3:6]), 1, dim=-1)

        rg = torch.atan2(sint, cost)

        cgs = [t for t in cts]
        return torch.cat([xg, yg, zg, dxg, dyg, dzg, rg, *cgs], dim=-1)

decode_torch:如何通过point_box_preds的8维向量得到proposal的7维坐标?将每一个point原始xyz坐标加上坐标偏移量[xt, yt, zt]即可得到proposal中心点坐标,利用作者预设的point_anchor_size乘上长宽高偏移量[dxt, dyt, dzt]得到proposal长宽高,利用atan2函数计算角度heading。

思考:

在这里插入图片描述
代码这里的Pw和Py参数直接是用的diagonal

diagonal = torch.sqrt(dxa ** 2 + dya ** 2)

个人的理解是觉得这样可以同时优化生成的anchor大小并且可以调节中心坐标的偏移。

assign_targets

 def assign_targets(self, input_dict):
        """
        Args:
            input_dict:
                point_features: (N1 + N2 + N3 + ..., C)
                batch_size:
                point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
                gt_boxes (optional): (B, M, 8)
        Returns:
            point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
            point_part_labels: (N1 + N2 + N3 + ..., 3)
        """
        point_coords = input_dict['point_coords']
        gt_boxes = input_dict['gt_boxes']
        assert gt_boxes.shape.__len__() == 3, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
        assert point_coords.shape.__len__() in [2], 'points.shape=%s' % str(point_coords.shape)

        batch_size = gt_boxes.shape[0]
        extend_gt_boxes = box_utils.enlarge_box3d(
            gt_boxes.view(-1, gt_boxes.shape[-1]), extra_width=self.model_cfg.TARGET_CONFIG.GT_EXTRA_WIDTH
        ).view(batch_size, -1, gt_boxes.shape[-1])
        targets_dict = self.assign_stack_targets(
            points=point_coords, gt_boxes=gt_boxes, extend_gt_boxes=extend_gt_boxes,
            set_ignore_flag=True, use_ball_constraint=False,
            ret_part_labels=False, ret_box_labels=True
        )

        return targets_dict

extend_gt_boxes 主要是将groud truth boxex在长、宽、高方向上扩展

在这里插入图片描述

在这里插入图片描述

assign_stack_targets

#此函数传入的都是对应点的真实预测值和真实标注框
 def assign_stack_targets(self, points, gt_boxes, extend_gt_boxes=None,
                             ret_box_labels=False, ret_part_labels=False,
                             set_ignore_flag=True, use_ball_constraint=False, central_radius=2.0):
        """
        Args:
            points: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
            gt_boxes: (B, M, 8)
            extend_gt_boxes: [B, M, 8]
            ret_box_labels:
            ret_part_labels:
            set_ignore_flag:
            use_ball_constraint:
            central_radius:

        Returns:
            point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
            point_box_labels: (N1 + N2 + N3 + ..., code_size)

        """
        assert len(points.shape) == 2 and points.shape[1] == 4, 'points.shape=%s' % str(points.shape)
        assert len(gt_boxes.shape) == 3 and gt_boxes.shape[2] == 8, 'gt_boxes.shape=%s' % str(gt_boxes.shape)
        assert extend_gt_boxes is None or len(extend_gt_boxes.shape) == 3 and extend_gt_boxes.shape[2] == 8, \
            'extend_gt_boxes.shape=%s' % str(extend_gt_boxes.shape)
        assert set_ignore_flag != use_ball_constraint, 'Choose one only!'
        #将数据分批次处理
        batch_size = gt_boxes.shape[0]
        bs_idx = points[:, 0]
        point_cls_labels = points.new_zeros(points.shape[0]).long()
        point_box_labels = gt_boxes.new_zeros((points.shape[0], 8)) if ret_box_labels else None
        point_part_labels = gt_boxes.new_zeros((points.shape[0], 3)) if ret_part_labels else None
        #将数据分批次处理
        for k in range(batch_size):
            bs_mask = (bs_idx == k)
            #这里以*_single应该是中间缓存变量,作为每一批次处理的变量存储数据
            #points_single取出对应批次的点云的坐标信息
            points_single = points[bs_mask][:, 1:4]
            point_cls_labels_single = point_cls_labels.new_zeros(bs_mask.sum())
             #将每一个点云数据分配到真实标注框上
            box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(         
                points_single.unsqueeze(dim=0), gt_boxes[k:k + 1, :, 0:7].contiguous()
            ).long().squeeze(dim=0)
            
             #box_idxs_of_pts是每个点对应分配的标注框索引值,没有匹配的赋值为-1
            box_fg_flag = (box_idxs_of_pts >= 0) 
            #根据之前扩展的3D框计算被忽略的点
            if set_ignore_flag:
             #将每一个点云数据分配到扩展后的标注框上
                extend_box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
                    points_single.unsqueeze(dim=0), extend_gt_boxes[k:k+1, :, 0:7].contiguous()
                ).long().squeeze(dim=0)
                fg_flag = box_fg_flag
                #异或运算,未扩展前没有包括,扩展后包含到的框,即被忽略的框
                ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0)
                point_cls_labels_single[ignore_flag] = -1
            elif use_ball_constraint:
                box_centers = gt_boxes[k][box_idxs_of_pts][:, 0:3].clone()
                box_centers[:, 2] += gt_boxes[k][box_idxs_of_pts][:, 5] / 2
                ball_flag = ((box_centers - points_single).norm(dim=1) < central_radius)
                fg_flag = box_fg_flag & ball_flag
            else:
                raise NotImplementedError

           #记录前景点信息,可以理解为论文中所说的前景点分割
            gt_box_of_fg_points = gt_boxes[k][box_idxs_of_pts[fg_flag]]
             #最后一维代表的是标注框对应的类别信息,对应前景点的类别信息
            point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
            #记录一次批处理流程中所有点的类别信息
            point_cls_labels[bs_mask] = point_cls_labels_single

            if ret_box_labels and gt_box_of_fg_points.shape[0] > 0:
                point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
                #记录每一个前景点从G_hat到G的参数偏移,每个前景点最后输出是1*8维向量
                fg_point_box_labels = self.box_coder.encode_torch(
                    gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
                    gt_classes=gt_box_of_fg_points[:, -1].long()
                )
                point_box_labels_single[fg_flag] = fg_point_box_labels
                point_box_labels[bs_mask] = point_box_labels_single

            if ret_part_labels:
                point_part_labels_single = point_part_labels.new_zeros((bs_mask.sum(), 3))
                transformed_points = points_single[fg_flag] - gt_box_of_fg_points[:, 0:3]
                transformed_points = common_utils.rotate_points_along_z(
                    transformed_points.view(-1, 1, 3), -gt_box_of_fg_points[:, 6]
                ).view(-1, 3)
                offset = torch.tensor([0.5, 0.5, 0.5]).view(1, 3).type_as(transformed_points)
                point_part_labels_single[fg_flag] = (transformed_points / gt_box_of_fg_points[:, 3:6]) + offset
                point_part_labels[bs_mask] = point_part_labels_single

        targets_dict = {
            'point_cls_labels': point_cls_labels,
            'point_box_labels': point_box_labels,
            'point_part_labels': point_part_labels
        }
        return targets_dict

经典框架解读 | 论文+代码 | 3D Detection | OpenPCDet | PointRCNN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1061148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Folium笔记:HeatMap

在地图上生成热力图 0 举例 import folium from folium.plugins import HeatMap# 创建一个地图对象 m folium.Map(location(1.34084, 103.83637), zoom_start13)# 创建一个坐标点的数据集 data [(1.431656, 103.827896),(1.424789, 103.789902),(1.325781, 103.860446),(1.…

【算法训练-搜索算法 一】【DFS网格搜索框架】岛屿数量、岛屿的最大面积、岛屿的周长

废话不多说&#xff0c;喊一句号子鼓励自己&#xff1a;程序员永不失业&#xff0c;程序员走向架构&#xff01;本篇Blog的主题是【搜索算法】&#xff0c;使用【数组】这个基本的数据结构来实现&#xff0c;这个高频题的站点是&#xff1a;CodeTop&#xff0c;筛选条件为&…

多卡片效果悬停效果

效果展示 页面结构 从页面的结构上看&#xff0c;在默认状态下毛玻璃卡片是有层次感的效果叠加在一起&#xff0c;并且鼠标悬停在卡片区域后&#xff0c;卡片整齐排列。 CSS3 知识点 transform 属性的 rotate 值运用content 属性的 attr 值运用 实现页面整体布局 <div …

代码随想录算法训练营第五十七天 | 动态规划 part 15 | 392.判断子序列、115.不同的子序列

目录 392.判断子序列思路代码 115.不同的子序列思路代码 392.判断子序列 Leetcode 思路 dp[i][j] 表示以下标i-1为结尾的字符串s&#xff0c;和以下标j-1为结尾的字符串t&#xff0c;相同子序列的长度为dp[i][j]递推公式&#xff1a; 初始化&#xff1a;为0遍历顺序&#xff…

日常工作报告生成器微信小程序源码 支持日报,周报,月报,年终终结

相信大家上班都会有做工作报告的情况吧 那么这款小程序就是大家的福音了 只要输入你的工作内容或者岗位自动生成你的工作报告 支持报,周报,月报,年终终结 源码下载&#xff1a;https://download.csdn.net/download/m0_66047725/88391810 源码下载2&#xff1a;评论留言或私信…

JVM篇---第二篇

系列文章目录 文章目录 系列文章目录一、简述一下JVM的内存模型二、说说堆和栈的区别三、什么时候会触发FullGC一、简述一下JVM的内存模型 1.JVM内存模型简介 JVM定义了不同运行时数据区,他们是用来执行应用程序的。某些区域随着JVM启动及销毁,另外一 些区域的数据是线程性独…

自动化测试框架详解

一、什么是自动化测试框架 在了解什么是自动化测试框架之前&#xff0c;先了解一下什么叫框架&#xff1f;框架是整个或部分系统的可重用设计&#xff0c;表现为一组抽象构件及构件实例间交互的方法;另一种定义认为&#xff0c;框架是可被应用开发者定制的应用骨架。前者是从应…

知识图谱-Neo4j使用详解

neo4j应用场景 知识图谱欺诈检测实时推荐引擎反洗钱主数据管理供应链管理增强网络和IT运营管理能力数据谱系身份和访问管理材料清单 图数据库neo4j简介 关系查询&#xff1a;mysql和neo4j性能对比 neo4j的特性和优点&#xff1a; Neo4j-CQL简介 neo4j的Cypher语言是为处理图…

96.qt qml-http之XMLHttpRequest介绍详解使用

在QML中我们可以通过XMLHttpRequest 来实现http/https访问网络接口,接下来我们先来学习XMLHttpRequest类的常用部分、 由于QML的XMLHttpRequest少部分参数是没有的,所以本章来单独讲解下。下章我们来实现旋转请求按钮以及通用的JSON请求模板方法 1.XMLHttpRequest初步使用 …

计算机网络(六):应用层

参考引用 计算机网络微课堂-湖科大教书匠计算机网络&#xff08;第7版&#xff09;-谢希仁 1. 应用层概述 应用层是计算机网络体系结构的最顶层&#xff0c;是设计和建立计算机网络的最终目的&#xff0c;也是计算机网络中发展最快的部分 早期基于文本的应用 (电子邮件、远程登…

Boost程序库完全开发指南:1.1-C++基础知识点梳理

主要整理了N多年前&#xff08;2010年&#xff09;学习C的时候开始总结的知识点&#xff0c;好长时间不写C代码了&#xff0c;现在LLM量化和推理需要重新学习C编程&#xff0c;看来出来混迟早要还的。 1.shared_ptr 解析&#xff1a;shared_ptr是一种计数指针&#xff0c;当引…

【Linux】【网络】工具:httplib 库的安装与简单使用

文章目录 1. 下载 httplib 库2. 从 Win 传输文件到 Linux3. 解压缩 httplib 库1. struct Request 结构体源码展示2. struct Reponse 结构体源码展示3. httplib 库 Server 类4. httplib 库 Client 类5. 搭建简易 server 服务器6. 搭建简易 client 客户端 1. 下载 httplib 库 要求…

结构体和联合体

目录 一.结构体 1.1什么是结构体&#xff1f; 1.2结构的声明 1.3特殊声明 1.4结构的自引用 1.5 结构体变量的定义和初始化 1.6 结构体内存对齐 1.7修改默认对齐数 二.联合体 2.1联合类型的定义 2.2联合的特点 2.3联合大小的计算 一.结构体 1.1什么是结构体&#xf…

【milkv】max98357a驱动添加speaker

文章目录 一、电路1.1 duo音频接口1.2 I2S2连接 二、I2S2介绍2.1 参考cv182x的dts实现2.2 参考cv1835_fpga2.3 cv180x2.4 改动——保留i2s2.5 I2S小结 三、参考资料3.1 文章3.2 手册 四、驱动路径五、添加codec驱动——max98357a5.1 config5.2 dtsi5.3 dai_driver 六、查看plat…

VUE3照本宣科——内置指令与自定义指令及插槽

VUE3照本宣科——内置指令与自定义指令及插槽 前言一、内置指令1.v-text2.v-html3.v-show4.v-if5.v-else6.v-else-if7.v-for8.v-on9.v-bind10.v-model11.v-slot12.v-pre13.v-once14.v-memo15.v-cloak 二、自定义指令三、插槽1.v-slot2.useSlots3.defineSlots() 前言 &#x1f…

Chrome之解决DevTools: Failed to load data:No resource with given identifier found

问题 使用DevTools抓包时候, 有些跨域请求无法在加载出来, 提示 Failed to load data:No resource with given identifier found 解决办法 换其他浏览器 下断点 打开DevTools, 选择源代码/来源/Sources,找到事件监听器断点/Event Listener Breakpoints, 找到加载/Load下面…

C++树详解

树 树的定义 树&#xff08;Tree&#xff09;是n&#xff08;n≥0&#xff09;个结点的有限集。n0时称为空树。在任意一颗非空树中&#xff1a;①有且仅有一个特定的称为根&#xff08;Root&#xff09;的结点&#xff1b;②当n>1时&#xff0c;其余结点可分为m&#xff08…

十天学完基础数据结构-第七天(图(Graph))

图的基本概念 图是一种数据结构&#xff0c;用于表示对象之间的关系。它由两个基本组件构成&#xff1a; 顶点&#xff08;Vertex&#xff09;&#xff1a;也被称为节点&#xff0c;代表图中的对象或实体。 边&#xff08;Edge&#xff09;&#xff1a;连接两个顶点的线&…

【setxattr+userfaultfd】SECCON2020-kstack

这个题主要还是练习 userfaultfd 的利用。说实话&#xff0c;userfaultfd 的利用还是挺多的&#xff0c;虽然在新的内核版本已经做了相关保护。 老规矩&#xff0c;看下启动脚本 #!/bin/sh qemu-system-x86_64 \-m 128M \-kernel ./bzImage \-initrd ./rootfs.cpio \-append …

第三章、运输层

文章目录 3.1 概述和运输层服务3.1.1 运输层和网络层的关系3.1.2 因特网运输层概述 3.2 多路复用与多路分解3.3 无连接运输&#xff1a;UDP3.4 可靠数据传输原理3.4.1构造可靠数据传输协议rdt1.0rdt2.xrdt3.0 3.4.2 流水线可靠数据传输协议3.4.3 回退N步3.4.4选择重传 3.5 面向…