【BEV 视图变换】Ray-based(2): 代码复现+画图解释 基于深度估计、bev_pool(代码一键运行)

news2024/9/23 12:28:54

paper:Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D
code:https://github.com/nv-tlabs/lift-splat-shoot

一、完整复现代码(可一键运行)和效果图

在这里插入图片描述

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
import numpy as np

# 根据世界坐标范围和一个像素代表的世界坐标距离来计算bev_size
# dx:[0.5,0.5,20]代表单位长度,bx是[-49.75,49.75,0]代表起始网格点的中心,nx[200,200,1] 代表网格数目
xbound = [-50.0, 50.0, 0.5]  # 前后100米,1个pixel=0.5米 -> x方向: 200 pixel
ybound = [-50.0, 50.0, 0.5]  # 左右100米,1个pixel=0.5米 -> y方向: 200 pixel
zbound = [-10.0, 10.0, 20.0]  # 上下20米, 1个pixel=20米  -> z方向: 1   pixel
dbound = [4.0, 45.0, 1.0]  # 深度4~45米, 1个pixel=1米 -> d方向: 41  pixel
D_ = int((dbound[1]-dbound[0])/dbound[2])

def gen_dx_bx(xbound, ybound, zbound):
    dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
    bx = torch.Tensor([row[0] + row[2]/2.0 for row in [xbound, ybound, zbound]])
    nx = torch.LongTensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]])
    dx = nn.Parameter(dx, requires_grad=False)
    bx = nn.Parameter(bx, requires_grad=False)
    nx = nn.Parameter(nx, requires_grad=False)
    return dx, bx, nx

batch_size = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 模型输入尺寸及下采样倍数
in_H = 128
in_W = 352
scale_downsample = 16
# 模型输出尺寸
feat_W16 = in_W // scale_downsample
feat_H16 = in_H // scale_downsample
semantic_channels = 64

# 相机参数(两个相机)
num_cams = 2
rots=torch.Tensor([[[[ 8.2076e-01, -3.4144e-04,  5.7128e-01],[-5.7127e-01,  3.2195e-03,  8.2075e-01],[-2.1195e-03, -9.9999e-01,  2.4474e-03]],
         [[-9.3478e-01,  0, 0],[ 3.5507e-01,  0, -9.3477e-01],[-1.0805e-02, -9.9981e-01, 0]]]])
intrins = torch.Tensor([[[[1.2726e+03, 0.0, 0],[0.0000e+00, 1.2726e+03, 4.7975e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]],
         [[1.2595e+03, 0.0000e+00, 8.0725e+02], [0.0000e+00, 1.2595e+03, 5.0120e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]]]])
post_rots = torch.Tensor([[[[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]],
         [[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]]]])
post_trans =torch.Tensor([[[  0.],[  0.]], [[0.], [0.]], [[  0.],[  0.]]])
trans = torch.Tensor([[[ 1.5239,  0.4946,  1.5093], [ 1.0149, -0.4806,  1.5624]]])

def create_uvd_frustum():

    # 41米深度范围,值在[4,45]
    # 扩展至41x22x8
    distance = torch.arange(*dbound, dtype=torch.float).view(-1, 1, 1).expand(-1, feat_H16, feat_W16)
    D, _, _ = distance.shape
    # 22格,值在[0,128]
    # 再扩展至[41,8,22]
    x_stride = torch.linspace(0, in_W - 1, feat_W16, dtype=torch.float).view(1, 1, feat_W16).expand(D, feat_H16, feat_W16)
    # 8格,值在[0,352]
    # 再扩展至[41,8,22]
    y_stride = torch.linspace(0, in_H - 1, feat_H16, dtype=torch.float).view(1, feat_H16, 1).expand(D, feat_H16, feat_W16)
    # 创建视锥: [41,8,22,3]
    frustum = torch.stack((x_stride, y_stride, distance), -1)
    # 不计算梯度,不需要学习
    return nn.Parameter(frustum, requires_grad=False)

def plot_uvd_frustum(frustum): # 41 8 22 3
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Convert frustum tensor to numpy array for visualization
    frustum_np = frustum.numpy()

    # Extract x, y, d coordinates
    x = frustum_np[..., 0].flatten()
    y = frustum_np[..., 1].flatten()
    d = frustum_np[..., 2].flatten()

    # Plot the points in 3D space
    ax.scatter(x, y, d, c=d, cmap='viridis', marker='o')
    ax.set_xlabel('u')
    ax.set_ylabel('v')
    ax.set_zlabel('d')
    plt.show()
    path = f'uvd_frustum.png'
    plt.savefig(path)

def get_geometry_feat(frustum,rots, trans, intrins, post_rots, post_trans):
    B, N, _ = trans.shape
    # 视锥逆数据增强
    points = frustum - post_trans.view(B, N, 1, 1, 1, 3)
    # 加上B,N(6 cams)维度
    points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))
    #根据相机内外参将视锥点云从相机坐标映射到世界坐标
    points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],points[:, :, :, :, :, 2:3]), 5)
    combine = rots.matmul(torch.inverse(intrins))
    points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
    points += trans.view(B, N, 1, 1, 1, 3)
    return points

def plot_XYZ_frustum(frustum,path):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # Convert frustum tensor to numpy array for visualization
    for i in range(len(frustum)):
        frustum_np = frustum[i].numpy()
        # Extract x, y, d coordinates
        x = frustum_np[..., 0].flatten()
        y = frustum_np[..., 1].flatten()
        d = frustum_np[..., 2].flatten()
        # Plot the points in 3D space
        ax.scatter(x, y, d, c=d, cmap='viridis', marker='o')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.show()
    plt.savefig(path)

def cumsum_trick(cam_feat, geom_feat, ranks):
    # 最后一个维度累计,前缀和
    cam_feat = cam_feat.cumsum(0)
    # 过滤
    # [42162,64]->[7268,64] [42162,4]->[7268,4]
    # 将rank错位比较,找到rank中 == voxel_id == 发生变化的位置,记为kept
    kept = torch.ones(cam_feat.shape[0], device=cam_feat.device, dtype=torch.bool)
    kept[:-1] = (ranks[1:] != ranks[:-1])
    # 利用kept筛选得到x, 错位相减,从而实现将落在相同voxel特征求和
    cam_feat, geom_feat = cam_feat[kept], geom_feat[kept]
    cam_feat = torch.cat((cam_feat[:1], cam_feat[1:] - cam_feat[:-1])) # 错位相减得到的特征和
    return cam_feat, geom_feat


def plot_bev(bev, name = f'bev'):
    # ---- tensor -> array ----#
    array1 = bev.squeeze(0).cpu().detach().numpy()
    # ---- array -> mat ----#
    array1 = array1 * 255
    mat = np.uint8(array1)
    mat = mat.transpose(1, 2, 0)
    # ---- vis ----#
    cv2.imshow(name, mat)
    cv2.waitKey(0)



if __name__ == "__main__":
    # 1.创建三维tensor(2d image + depth)
    uvd_frustum = create_uvd_frustum()
    plot_uvd_frustum(uvd_frustum)

    # 2.视锥化(使用相机内外参,将三维tensor转到EGO坐标系下)
    XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)
    plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')

    # 3.体素化
    dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
    geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()
    plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')

    # 4.bev_pool
    # 4.1. cam_feats,geom_feats 展平
    cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)
    B, N, D, H, W, C = cam_feats.shape
    L__ = B * N * D * H * W
    cam_feats = cam_feats.reshape(L__, C)

    geom_feats = geom_feats.view(L__, 3)

    # 4.2.geom_feat增加batch维度
    batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])
    geom_feats = torch.cat((geom_feats, batch_index), 1)

    # 4.3.filter by (X<200,Y<200,Z<1)
    kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])
    cam_feats = cam_feats[kept]
    geom_feats = geom_feats[kept]

    # 4.4.voxel index 位置编码,排序
    ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B)  # X
             + geom_feats[:, 1] * (nx[2] * B)  # Y
             + geom_feats[:, 2] * B  # Z
             + geom_feats[:, 3])  # batch_index
    sorts = ranks.argsort()
    cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]

    # 4.5. sum
    cam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)

    # 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]
    final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)
    final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats

    # 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图
    final = torch.cat(final.unbind(dim=2), 1)

    # 5.bev_encoder
    bev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)
    bev = bev_encoder(final)
    plot_bev(bev, name = f'bev')

二、逐步代码讲解+图解

完整流程:
1.创建uv coord + depth estimation (2d image + depth)
2.视锥化(uv coord -> world coord) (根据相机内外参,构建4x3的投影矩阵)
3.体素化(world coord -> voxel coord) (会有到世界范围划分及各自维度的刻度)
4.bev_pool(voxel coord -> bev coord)(去掉Z轴)

1.创建uv coord + depth estimation (2d image + depth)

uvd_frustum = create_uvd_frustum()
plot_uvd_frustum(uvd_frustum)

在这里插入图片描述
注意
1.坐标范围,u,v范围代表模型输入尺寸(352,128),d范围为(4,45)。
2.u轴有22个柱子(pillar),22=352//16;v轴有8个柱子(pillar),8=128//16;d轴有41个刻度,41=(45-4)//1

2.视锥化(uv coord -> world coord) (根据相机内外参,构建4x3的投影矩阵)

XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)
plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')

在这里插入图片描述
我这里为了看起来更直观点,选了两个相机,实际在使用过程中,可以灵活使用1个,2个,4个,6个相机。

3.体素化(world coord -> voxel coord) (会有到世界范围划分及各自维度的刻度)

dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()
plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')

在这里插入图片描述
为什么上面和下面的形状不一样呢?因为1.相机内外参数的影响 2.因为(旋转,平移)数据增强的影响
注意观察,此时的XYZ轴的范围已经落在(200,200,1)的bev尺寸范围里了!

4.bev_pool(voxel coord -> bev coord)(去掉Z轴)

  • 4.1. cam_feats,geom_feats 展平
cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)
B, N, D, H, W, C = cam_feats.shape
L__ = B * N * D * H * W
cam_feats = cam_feats.reshape(L__, C)

geom_feats = geom_feats.view(L__, 3)
  • 4.2.geom_feat增加batch维度
batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_index), 1)
  • 4.3.filter by (X<200,Y<200,Z<1)
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])
cam_feats = cam_feats[kept]
geom_feats = geom_feats[kept]

在这里插入图片描述

  • 4.4.voxel index 位置编码,排序
ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B)  # X
         + geom_feats[:, 1] * (nx[2] * B)  # Y
         + geom_feats[:, 2] * B  # Z
         + geom_feats[:, 3])  # batch_index
sorts = ranks.argsort()
cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]

可以参考我画的示意图
在这里插入图片描述

  • 4.5. sum
cam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)
  • 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]
final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats
  • 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图
final = torch.cat(final.unbind(dim=2), 1)

5.bev_encoder

bev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)
bev = bev_encoder(final)
plot_bev(bev, name = f'bev')

在这里插入图片描述
bev尺寸为200x200

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2157525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Go】Go语言切片(Slice)深度剖析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

知乎:从零开始做自动驾驶定位; 注释详解(二)

这个个系统整体分为: 数据预处理 前端里程计 后端优化 回环检测 显示模块。首先来看一下数据预处理节点做的所有事情&#xff1a; 数据预处理节点 根据知乎文章以及代码我们知道: 节点功能输入输出数据预处理1.接收各传感器信息2.传感器数据时间同步 3.点云运动畸变补偿 4.传…

免杀对抗—python混淆算法反序列化shellcode

一、前言 内网已经学的七七八八了(主要是实验环境太麻烦了&#xff0c;累了)&#xff0c;今天就开启新的篇章——免杀。免杀我们主要是对生成的shellcode做免杀&#xff0c;而不是对生成的exe做免杀。为啥呢&#xff0c;你可以这样理解&#xff0c;exe已经是成品了&#xff0c…

Vue 内存泄漏分析:如何避免开发过程中导致的内存泄漏问题

一. 引言 Vue 作为一款流行的前端框架&#xff0c;已经在许多项目中得到广泛应用。然而&#xff0c;随着我们在 Vue 中构建更大规模的应用程序&#xff0c;我们可能会遇到一个严重的问题&#xff0c;那就是内存泄漏。内存泄漏是指应用程序在使用内存资源时未正确释放&#xff…

昇思MindSpore进阶教程-模型模块自定义

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 基础用法示例 神经网络模型由各种层(Lay…

【AI实战攻略】保姆级教程:用AI打造治愈动画vlog,轻松打造爆款,快速涨粉!

在当今这个快节奏的社会中&#xff0c;你是否也曾在某个雨夜&#xff0c;沉浸于那些温馨而治愈的动画短视频中&#xff0c;找到片刻的宁静与放松&#xff1f; 窗外大雨滂沱&#xff0c;而你&#xff0c;刚结束一天的忙碌&#xff0c;沐浴在温暖的热水中&#xff0c;随后裹上柔…

Integer 源码记录

Integer 公共方法结构 注意&#xff1a; 通过构造函数创建一个Integer对象&#xff0c;每次都会返回一个新的对象&#xff0c;如果使用 进行对象的比较&#xff0c;那么结果是false。 public Integer(int value) {this.value value;}与之对应的是&#xff0c;valueOf 方法…

java -----泛型

泛型的理解和好处 泛型是在JDK5之后引入的一个新特性&#xff0c;可以在编译阶段约束操作的数据类型&#xff0c;并进行检查。 泛型的格式为 <数据类型> import java.util.ArrayList;SuppressWarnings({"all"}) public class Generic02 {public static void…

WGS1984快速度确定平面坐标系UTM分带(快速套表、公式计算、软件范围判定)

之前我们介绍了坐标系3带6带快速确定带号及中央经线&#xff08;快速套表、公式计算、软件范围判定&#xff09;就&#xff0c;讲的是CGCS2000 高斯克吕格的投影坐标系。 那还有我们经常用的WGS1984的平面坐标系一般用什么投影呢? 对于全球全国的比如在线地图使用&#xff1a…

探索GraphRAG:用yfiles-jupyter-graphs将知识库可视化!

yfiles-jupyter-graphs 可视化 GraphRAG 结构 前言 前面我们通过 GraphRag 命令生成了知识库文件 parquet&#xff0c;这节我们看一下如何使用 yfiles-jupyter-graphs 添加 parquet 文件的交互式图形可视化以及如何可视化 graphrag 查询的结果。 yfiles-jupyter-graphs 是一…

前端-js例子:收钱转账

支付宝转账 在这里用到周期定时器setInterval(function,time)&#xff0c;设置达到目标钱数时停止定时器。 点击转账按钮时&#xff0c;开始函数显示。 同时要确定输入框里输入的是数字。&#xff08;有一定容错&#xff09; window.onloadfunction(){var btn document.que…

vue3 + ts + pnpm:nprogress / 页面顶部进度条

一、简介 nprogress 是一个轻量级的进度条库&#xff0c;它适用于在网页上添加顶部进度条&#xff0c;用于指示页面加载进度或任何长时间的运行过程。这个库非常流行&#xff0c;因为它易于使用且视觉效果很好。 二、安装 pnpm add nprogress 三、在使用的页面引入 / src/v…

MySQL连接查询解析与性能优化成本

文章目录 一、连接查询1.连接查询基础1. INNER JOIN内连接2. LEFT JOIN (或 LEFT OUTER JOIN)左外连接3. RIGHT JOIN (或 RIGHT OUTER JOIN)右外连接4. FULL OUTER JOIN 2.连接查询的两种过滤条件3.连接的原理 二、性能优化成本1.基于成本的优化2.调节成本常数(1)mysql.server_…

ECharts基础使用方法 ---vue

1.安装依赖文件 仔细看项目" README.md " 描述&#xff0c;确定用什么安装 npm npm install echarts --save //官网推荐使用 pnpm pnpm install echarts --save 其他也是 在项目根目录&#xff0c;打开当前目录命令控制栏&#xff0c;输入以上命令并运行 安装成功后…

动动手指探索世界,旅游APP如何定制开发?

旅游APP的出现为旅行带来了许多便利。随着移动互联网的发展&#xff0c;旅游行业也在不断寻求创新与变革。旅游APP为游客提供了更加便捷的旅行体验&#xff0c;通过旅游APP&#xff0c;用户可以了解旅游信息、旅游服务、在线咨询等&#xff0c;实现在线一站式解决旅行需求的目标…

Github 2024-09-23 开源项目周报 Top15

根据Github Trendings的统计,本周(2024-09-23统计)共有15个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目6C++项目3C项目3HTML项目2PowerShell项目1TypeScript项目1JavaScript项目1Blade项目1PHP项目1Bootstrap 5: Web上开发响应式、移动优…

【文心智能体】 旅游手绘手帐 开发分享 零代码 手绘风景 记录行程和心情 旅游攻略

旅游手绘手帐&#xff0c;点击文心智能体平台AgentBuilder | 想象即现实 (baidu.com) 目录 背景 创作灵感 开发历程 一、基础配置 二、高级配置 三、引导示例&#xff08;提示词&#xff09; 期待优化 背景 这个智能体是一个零代码智能体&#xff08;文心智能体平台现…

MySQL篇(管理工具)

目录 一、系统数据库 二、常用工具 1. mysql 2. mysqladmin 3. mysqlbinlog 4. mysqlshow 5. mysqldump 6. mysqlimport/source 6.1 mysqlimport 6.2 source 一、系统数据库 MySQL数据库安装完成后&#xff0c;自带了一下四个数据库&#xff0c;具体作用如下&#xf…

JDBC和一下重要的jar包,分层结构

系列文章目录 JDBC和方便使用的jar包 目录 系列文章目录 文章目录 一、JDBC 1.步骤 2.SQL注入 3.SQL注入解决&#xff08;PreparedStatement&#xff09; 4.批处理和事务控制 5.连接池 Druid连接池&#xff08;德鲁伊&#xff09; 6.封装为工具类 7.ThreadLocal 、小秘书 二、…

大语言模型(LLM)入门学习路线图

Github项目上有一个大语言模型学习路线笔记&#xff0c;它全面涵盖了大语言模型的所需的基础知识学习&#xff0c;LLM前沿算法和架构&#xff0c;以及如何将大语言模型进行工程化实践。这份资料是初学者或有一定基础的开发/算法人员入门活深入大型语言模型学习的优秀参考。这份…