BEVFormer组件分析

news2024/10/6 12:29:19

BEVFormerEncoder中的get_reference_points


@staticmethod
    def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
        """Get the reference points used in SCA and TSA.
        Args:
            H, W: spatial shape of bev.
            Z: hight of pillar.
            D: sample D points uniformly from each pillar.
            device (obj:`device`): The device where
                reference_points should be.
        Returns:
            Tensor: reference points used in decoder, has \
                shape (bs, num_keys, num_levels, 2).
        """

        # reference points in 3D space, used in spatial cross-attention (SCA)
        if dim == '3d':
            zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
                                device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
            xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
                                device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
            ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
                                device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
            ref_3d = torch.stack((xs, ys, zs), -1)
            ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
            ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
            return ref_3d

        # reference points on 2D bev plane, used in temporal self-attention (TSA).
        elif dim == '2d':
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(
                    0.5, H - 0.5, H, dtype=dtype, device=device),
                torch.linspace(
                    0.5, W - 0.5, W, dtype=dtype, device=device)
            )
            ref_y = ref_y.reshape(-1)[None] / H
            ref_x = ref_x.reshape(-1)[None] / W
            ref_2d = torch.stack((ref_x, ref_y), -1)
            ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
            return ref_2d

根据上面的代码可以看出来,如果输入的是3d, 则是
按照:

  • X方向: 从0.5, 到W-0.5分成W份.
  • Y方向: 从0.5, 到H-0.5分成H份.
  • Z方向: 从0.5, 到Z-0.5, 分成 num_points_in_pillar份.
    其中num_points_in_pillar 默认给的是4.

配置文件里面给的其实也是4.
在这里插入图片描述

BEVFormerEncoder中的point_sampling

  # This function must use fp32!!!
    @force_fp32(apply_to=('reference_points', 'img_metas'))
    def point_sampling(self, reference_points, pc_range,  img_metas):
        lidar2img = []
        for img_meta in img_metas:
            lidar2img.append(img_meta['lidar2img'])
        lidar2img = np.asarray(lidar2img)
        lidar2img = reference_points.new_tensor(lidar2img)  # (B, N, 4, 4)
        reference_points = reference_points.clone()

        # 变换到点云的范围内. 这也是为何get_reference_points中会/H, /W, /Z, 先化到[0, 1]变成ratio.
        reference_points[..., 0:1] = reference_points[..., 0:1] * \
            (pc_range[3] - pc_range[0]) + pc_range[0]
        reference_points[..., 1:2] = reference_points[..., 1:2] * \
            (pc_range[4] - pc_range[1]) + pc_range[1]
        reference_points[..., 2:3] = reference_points[..., 2:3] * \
            (pc_range[5] - pc_range[2]) + pc_range[2]

        # 由(x, y, z) 变成(x, y, z, 1) 便于与4*4的参数矩阵相乘.
        reference_points = torch.cat(
            (reference_points, torch.ones_like(reference_points[..., :1])), -1)
        # 此时reference_points可以当成是点云的点了.

        reference_points = reference_points.permute(1, 0, 2, 3)
        # num_query等于H*W*Z. 等于grid_points的数量.
        D, B, num_query = reference_points.size()[:3]
        num_cam = lidar2img.size(1)

        # 要往每个相机上去投影. 因此先申请num_cam份.
        # reference_points的shape就变成了, (D, b, num_cam, num_query, 4, 1) 便于和4*4的矩阵做matmul.
        reference_points = reference_points.view(
            D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)

        # 相机参数由(b,num_cam, 4, 4) 变成(1, b, num_cam, 1, 4, 4) 再变成(D,b,num_cam,num_query,4,4)
        lidar2img = lidar2img.view(
            1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)

        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
                                            reference_points.to(torch.float32)).squeeze(-1)
        eps = 1e-5

        # 把每个相机后面的点mask掉. 因为相机后面的点投过来之后第三位是负的.
        bev_mask = (reference_points_cam[..., 2:3] > eps)
        # 再做齐次化. 得到像素坐标.
        reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
            reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)

        # 由像素坐标转成相对于图像的ratio..
        # NOTE 这里如果不同相机size不一样的话.要除以对应的相机的size
        reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
        reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]

        # 再把超出图像fov范围的点给去掉.
        bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
                    & (reference_points_cam[..., 1:2] < 1.0)
                    & (reference_points_cam[..., 0:1] < 1.0)
                    & (reference_points_cam[..., 0:1] > 0.0))
        if digit_version(TORCH_VERSION) >= digit_version('1.8'):
            bev_mask = torch.nan_to_num(bev_mask)
        else:
            bev_mask = bev_mask.new_tensor(
                np.nan_to_num(bev_mask.cpu().numpy()))

        # 由(D, b, num_cam, num_query, 2) 变成 (num_cam, b, num_query, D, 2)
        reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
        bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)

        # 至此. reference_points_cam代表的就是像素点相对于各个相机的ratio.
        # bev_mask就代表哪些点是有效的
        return reference_points_cam, bev_mask

SpatialCrossAttention

个人理解SpatialCrossAttention其实就是正常的Deformable Attention, 只不过原始Deformable Attention中的
refer points是由网络产生的,
而现在的refer points 是由 虚拟的grid points往图像上投影得到的. 在相机参数固定的情况下, 此时的refer points是固定的.

下面是 SpatialCrossAttention这个模块的forward函数的部分代码

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

问题: 给固定的这些refer points 的收益是多大? 文章好像并没有提. 这一块儿感觉不充分.

另外, 显然这样虚拟的grid points 是不合理的, 因为有些地方可能就没有点, 但是还是能够投影到图像上的. 这里用真值的点应该会更好,
比如用lidar的points. 但是BEVFormer paper里面没有对比加入lidar后的效果.

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/609073.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

让你的代码动起来:Python进度条神器tqdm详解及应用实例

各位Python高手&#xff0c;今天我要给大家介绍一个好用的库&#xff0c;它就是&#xff1a;tqdm tqdm在阿拉伯语中的意思是 "进展"&#xff0c;所以这个库也被称为 "快速进展条"。不得不说&#xff0c;这个名字真的很有创意&#xff01; 让我们想象一下&a…

蒙特卡洛积分——采样方法

蒙特卡洛积分 目的&#xff1a; 通过计算机进行采样近似求解复杂的积分理论基础&#xff1a; 大数定律&#xff0c;当 n n n足够大时&#xff0c; X X X的均值将收敛于它的期望&#xff08;不严谨表述&#xff09;一般形式&#xff1a; θ ∫ a b f ( x ) d x ∫ a b f ( x…

瑞云科技CTO赵志杰出席广州广告数字创意峰会并发表演讲

3月23日下午&#xff0c;广州广告数字创意峰会暨穗广协企业家大讲堂年度巡礼活动在广州图书馆圆满举行。本次峰会由广州市人民政府统筹&#xff0c;中共广州市委宣传部、广州市文化广电旅游局、中共广州市天河区委、广州市天河区人民政府主办。作为第六届“文创产业大会天河峰会…

go调试工具-delve

go调试工具-delve 简介 go debug工具&#xff0c;专门为go开发的调试工具&#xff0c;并且采用go语言开发&#xff0c;支持多平台。 官网&#xff1a;https://github.com/go-delve/delve 官网有详细的手册&#xff0c;学习起来很方便 快速开始 安装 我本地的go版本 官方…

python的基本知识与面试问题的汇总1

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下python的基本知识与面试问题的汇总&#xff0c;看完之后会对python巩固有很大的帮助哦。 Python中的多线程&#xff1a; 多线程是指在一个程序中同时运行多个线程以提高程序的执行效率。Python中的threading模块…

MongoDB基础实战:CRUD

1 缘起 后台项目使用的数据库是MongoDB&#xff0c; 在一次测试联调过程中&#xff0c;测试同事在测试数据的准确性时&#xff0c; 问我这些数据该怎么查&#xff0c;如何计算验证数据的结果&#xff0c; 我当时&#xff0c;对MongoDB数据操作不熟悉&#xff0c;请教了其他有经…

2. 两数相加解题思路

文章目录 题目解题思路 题目 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外&am…

C++11之异常处理

文章目录 一、异常处理的概念二、异常编写的步骤&#xff08;来自图论教育&#xff09;三、栈展开和异常捕获四、C11中noexcep关键字 一、异常处理的概念 异常是程序可能检测到的&#xff0c;运行时不正常的情况&#xff0c;如存储空间耗尽&#xff0c;数组越界&#xff0c;被…

PG提示could not determine data type of parameter $4

目录 场景&#xff1a; 现象&#xff1a; 版本&#xff1a; 分析&#xff1a; 解决方式&#xff1a; 场景&#xff1a; 今天遇到现场环境连接Postgre数据库&#xff0c;日志提示could not determine data type of parameter $4&#xff0c;通过日志复制出完整sql&#xff…

SpringCloudAlibaba:分布式事务之Seata学习

目录 一、分布式事务基础 &#xff08;一&#xff09;事务 &#xff08;二&#xff09;本地事务 &#xff08;三&#xff09;分布式事务 二、Seata概述 1.Seata 的架构包含: 2.其工作原理为: 3.如果需要在 Spring Boot 应用中使用 Seata 进行分布式事务管理,主要步骤为…

Android Jetpack Compose实现轮播图效果

Android Jetpack Compose实现轮播图效果 在最近思索如何使用Compose方式改进我的开源TMDB电影列表应用程序的主屏幕时&#xff0c;一个激动人心的概念浮现在我的脑海中——为什么不整合一个吸引人的轮播图来展示即将上映的电影呢&#xff1f;在本文中&#xff0c;我将分享我的开…

旧改快讯--星河操刀,龙华稳健工业园项目专规获批

龙华街道稳健工业园城市更新单元原列入《2019年深圳市龙华区城市更新单元计划第五批计划》&#xff0c;现已列入《2022年深圳市龙华区城市更新单元计划第三批计划》&#xff0c;现该更新单元规划已经深圳市城市规划委员会法定图则委员会2023年第16次会议审议并获原则通过&#…

python环境安装

测试电脑环境有无安装python&#xff1a; winR&#xff0c;输入cmd&#xff0c;打开窗口&#xff0c;输入pyhton&#xff0c;查看是否有版本号&#xff0c;没有则是没有安装python环境 找到python-3.7.0-amd64的安装包&#xff0c;直接双击启动。上面是快速安装&#xff0c;我…

【Linux驱动】字符设备驱动相关宏 / 函数介绍(module_init、register_chrdev)

驱动运行有两种方式&#xff1a; 方式一&#xff1a;直接编译到内核&#xff0c;Linux内核启动时自动运行驱动程序方式二&#xff1a;编译成模块&#xff0c;使用 insmod 命令加载驱动模块 我们在调试的时候&#xff0c;采用第二种方式是最合适的&#xff0c;每次修改驱动只需…

八大排序之图文详解

前言 在数据结构中&#xff0c;排序是非常重要的内容&#xff0c;也是未来面试和笔试的重点。 本文代码是Java 目录 前言 一、插入排序 &#xff08;一&#xff09;直接插入排序 &#xff08;二&#xff09;希尔排序 二、选择排序 &#xff08;一&#xff09;选择排序 …

【CSS3系列】第六章 · 2D和3D变换

写在前面 Hello大家好&#xff0c; 我是【麟-小白】&#xff0c;一位软件工程专业的学生&#xff0c;喜好计算机知识。希望大家能够一起学习进步呀&#xff01;本人是一名在读大学生&#xff0c;专业水平有限&#xff0c;如发现错误或不足之处&#xff0c;请多多指正&#xff0…

通义千问预体验,如何让 AI 模型应用“奔跑”在函数计算上?

立即体验基于函数计算部署通义千问预体验&#xff1a; https://developer.aliyun.com/topic/aigc_fc AIGC 浪潮已来&#xff0c;从文字生成到图片生成&#xff0c;AIGC 的创造力让人惊叹&#xff0c;更多人开始探索如何使用 AI 提高生产效率&#xff0c;激发更多创作潜能&…

android jetpack Room的基本使用(java)

数据库的基本使用 添加依赖 //roomdef room_version "2.5.0"implementation "androidx.room:room-runtime:$room_version"annotationProcessor "androidx.room:room-compiler:$room_version"创建表 Entity表示根据实体类创建数据表&#xff0c…

Linux基础篇 Ubuntu 22.04的环境安装-02

目录 一、资料的获取 二、安装虚拟机 三、安装Ubuntu过程 四、注意事项 一、资料的获取 1.通过官方网站下载 Ubuntu系统下载 | Ubuntuhttps://cn.ubuntu.com/download2.下载桌面板即可 3.选择下载的版本 二、安装虚拟机 1.创建新的虚拟机 2.选择自定义安装 3.硬件兼容性选…

Zinx框架学习 - 请求与路由模块实现

Zinx - V0.3 请求与路由模块实现 在zinxV0.2中链接只封装了套接字&#xff0c;而请求是封装了链接和用户传输的数据&#xff0c;后续通过请求来识别具体要实现什么功能&#xff0c;然后通过路由来完成对应的功能处理。conn链接的业务处理HandleFunc是固定写死的&#xff0c;接…