paper:Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D
code:https://github.com/nv-tlabs/lift-splat-shoot
一、完整复现代码(可一键运行)和效果图
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
import numpy as np
# 根据世界坐标范围和一个像素代表的世界坐标距离来计算bev_size
# dx:[0.5,0.5,20]代表单位长度,bx是[-49.75,49.75,0]代表起始网格点的中心,nx[200,200,1] 代表网格数目
xbound = [-50.0, 50.0, 0.5] # 前后100米,1个pixel=0.5米 -> x方向: 200 pixel
ybound = [-50.0, 50.0, 0.5] # 左右100米,1个pixel=0.5米 -> y方向: 200 pixel
zbound = [-10.0, 10.0, 20.0] # 上下20米, 1个pixel=20米 -> z方向: 1 pixel
dbound = [4.0, 45.0, 1.0] # 深度4~45米, 1个pixel=1米 -> d方向: 41 pixel
D_ = int((dbound[1]-dbound[0])/dbound[2])
def gen_dx_bx(xbound, ybound, zbound):
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
bx = torch.Tensor([row[0] + row[2]/2.0 for row in [xbound, ybound, zbound]])
nx = torch.LongTensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]])
dx = nn.Parameter(dx, requires_grad=False)
bx = nn.Parameter(bx, requires_grad=False)
nx = nn.Parameter(nx, requires_grad=False)
return dx, bx, nx
batch_size = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 模型输入尺寸及下采样倍数
in_H = 128
in_W = 352
scale_downsample = 16
# 模型输出尺寸
feat_W16 = in_W // scale_downsample
feat_H16 = in_H // scale_downsample
semantic_channels = 64
# 相机参数(两个相机)
num_cams = 2
rots=torch.Tensor([[[[ 8.2076e-01, -3.4144e-04, 5.7128e-01],[-5.7127e-01, 3.2195e-03, 8.2075e-01],[-2.1195e-03, -9.9999e-01, 2.4474e-03]],
[[-9.3478e-01, 0, 0],[ 3.5507e-01, 0, -9.3477e-01],[-1.0805e-02, -9.9981e-01, 0]]]])
intrins = torch.Tensor([[[[1.2726e+03, 0.0, 0],[0.0000e+00, 1.2726e+03, 4.7975e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]],
[[1.2595e+03, 0.0000e+00, 8.0725e+02], [0.0000e+00, 1.2595e+03, 5.0120e+02],[0.0000e+00, 0.0000e+00, 1.0000e+00]]]])
post_rots = torch.Tensor([[[[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]],
[[0.2200, 0.0000, 0.0000],[0.0000, 0.2200, 0.0000],[0.0000, 0.0000, 1.0000]]]])
post_trans =torch.Tensor([[[ 0.],[ 0.]], [[0.], [0.]], [[ 0.],[ 0.]]])
trans = torch.Tensor([[[ 1.5239, 0.4946, 1.5093], [ 1.0149, -0.4806, 1.5624]]])
def create_uvd_frustum():
# 41米深度范围,值在[4,45]
# 扩展至41x22x8
distance = torch.arange(*dbound, dtype=torch.float).view(-1, 1, 1).expand(-1, feat_H16, feat_W16)
D, _, _ = distance.shape
# 22格,值在[0,128]
# 再扩展至[41,8,22]
x_stride = torch.linspace(0, in_W - 1, feat_W16, dtype=torch.float).view(1, 1, feat_W16).expand(D, feat_H16, feat_W16)
# 8格,值在[0,352]
# 再扩展至[41,8,22]
y_stride = torch.linspace(0, in_H - 1, feat_H16, dtype=torch.float).view(1, feat_H16, 1).expand(D, feat_H16, feat_W16)
# 创建视锥: [41,8,22,3]
frustum = torch.stack((x_stride, y_stride, distance), -1)
# 不计算梯度,不需要学习
return nn.Parameter(frustum, requires_grad=False)
def plot_uvd_frustum(frustum): # 41 8 22 3
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Convert frustum tensor to numpy array for visualization
frustum_np = frustum.numpy()
# Extract x, y, d coordinates
x = frustum_np[..., 0].flatten()
y = frustum_np[..., 1].flatten()
d = frustum_np[..., 2].flatten()
# Plot the points in 3D space
ax.scatter(x, y, d, c=d, cmap='viridis', marker='o')
ax.set_xlabel('u')
ax.set_ylabel('v')
ax.set_zlabel('d')
plt.show()
path = f'uvd_frustum.png'
plt.savefig(path)
def get_geometry_feat(frustum,rots, trans, intrins, post_rots, post_trans):
B, N, _ = trans.shape
# 视锥逆数据增强
points = frustum - post_trans.view(B, N, 1, 1, 1, 3)
# 加上B,N(6 cams)维度
points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))
#根据相机内外参将视锥点云从相机坐标映射到世界坐标
points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],points[:, :, :, :, :, 2:3]), 5)
combine = rots.matmul(torch.inverse(intrins))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += trans.view(B, N, 1, 1, 1, 3)
return points
def plot_XYZ_frustum(frustum,path):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Convert frustum tensor to numpy array for visualization
for i in range(len(frustum)):
frustum_np = frustum[i].numpy()
# Extract x, y, d coordinates
x = frustum_np[..., 0].flatten()
y = frustum_np[..., 1].flatten()
d = frustum_np[..., 2].flatten()
# Plot the points in 3D space
ax.scatter(x, y, d, c=d, cmap='viridis', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
plt.savefig(path)
def cumsum_trick(cam_feat, geom_feat, ranks):
# 最后一个维度累计,前缀和
cam_feat = cam_feat.cumsum(0)
# 过滤
# [42162,64]->[7268,64] [42162,4]->[7268,4]
# 将rank错位比较,找到rank中 == voxel_id == 发生变化的位置,记为kept
kept = torch.ones(cam_feat.shape[0], device=cam_feat.device, dtype=torch.bool)
kept[:-1] = (ranks[1:] != ranks[:-1])
# 利用kept筛选得到x, 错位相减,从而实现将落在相同voxel特征求和
cam_feat, geom_feat = cam_feat[kept], geom_feat[kept]
cam_feat = torch.cat((cam_feat[:1], cam_feat[1:] - cam_feat[:-1])) # 错位相减得到的特征和
return cam_feat, geom_feat
def plot_bev(bev, name = f'bev'):
# ---- tensor -> array ----#
array1 = bev.squeeze(0).cpu().detach().numpy()
# ---- array -> mat ----#
array1 = array1 * 255
mat = np.uint8(array1)
mat = mat.transpose(1, 2, 0)
# ---- vis ----#
cv2.imshow(name, mat)
cv2.waitKey(0)
if __name__ == "__main__":
# 1.创建三维tensor(2d image + depth)
uvd_frustum = create_uvd_frustum()
plot_uvd_frustum(uvd_frustum)
# 2.视锥化(使用相机内外参,将三维tensor转到EGO坐标系下)
XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)
plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')
# 3.体素化
dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()
plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')
# 4.bev_pool
# 4.1. cam_feats,geom_feats 展平
cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)
B, N, D, H, W, C = cam_feats.shape
L__ = B * N * D * H * W
cam_feats = cam_feats.reshape(L__, C)
geom_feats = geom_feats.view(L__, 3)
# 4.2.geom_feat增加batch维度
batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_index), 1)
# 4.3.filter by (X<200,Y<200,Z<1)
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])
cam_feats = cam_feats[kept]
geom_feats = geom_feats[kept]
# 4.4.voxel index 位置编码,排序
ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B) # X
+ geom_feats[:, 1] * (nx[2] * B) # Y
+ geom_feats[:, 2] * B # Z
+ geom_feats[:, 3]) # batch_index
sorts = ranks.argsort()
cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]
# 4.5. sum
cam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)
# 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]
final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats
# 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图
final = torch.cat(final.unbind(dim=2), 1)
# 5.bev_encoder
bev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)
bev = bev_encoder(final)
plot_bev(bev, name = f'bev')
二、逐步代码讲解+图解
完整流程:
1.创建uv coord
+ depth estimation (2d image + depth)
2.视锥化(uv coord -> world coord
) (根据相机内外参,构建4x3的投影矩阵)
3.体素化(world coord -> voxel coord
) (会有到世界范围划分及各自维度的刻度)
4.bev_pool(voxel coord -> bev coord
)(去掉Z轴)
1.创建uv coord + depth estimation (2d image + depth)
uvd_frustum = create_uvd_frustum()
plot_uvd_frustum(uvd_frustum)
注意
1.坐标范围,u,v范围代表模型输入尺寸(352,128),d范围为(4,45)。
2.u轴有22个柱子(pillar),22=352//16;v轴有8个柱子(pillar),8=128//16;d轴有41个刻度,41=(45-4)//1
2.视锥化(uv coord -> world coord) (根据相机内外参,构建4x3的投影矩阵)
XYZ_frustum = get_geometry_feat(uvd_frustum,rots, trans, intrins, post_rots, post_trans)
plot_XYZ_frustum(XYZ_frustum[0],path = f'EGO_XYZ_frustum.png')
我这里为了看起来更直观点,选了两个相机,实际在使用过程中,可以灵活使用1个,2个,4个,6个相机。
3.体素化(world coord -> voxel coord) (会有到世界范围划分及各自维度的刻度)
dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
geom_feats = ((XYZ_frustum - (bx - dx / 2.)) / dx).long()
plot_XYZ_frustum(geom_feats[0], path = f'voxel.png')
为什么上面和下面的形状不一样呢?因为1.相机内外参数的影响 2.因为(旋转,平移)数据增强的影响
注意观察,此时的XYZ轴的范围已经落在(200,200,1)的bev尺寸范围里了!
4.bev_pool(voxel coord -> bev coord)(去掉Z轴)
- 4.1. cam_feats,geom_feats 展平
cam_feats = torch.rand(batch_size, num_cams, D_, feat_H16, feat_W16, semantic_channels)
B, N, D, H, W, C = cam_feats.shape
L__ = B * N * D * H * W
cam_feats = cam_feats.reshape(L__, C)
geom_feats = geom_feats.view(L__, 3)
- 4.2.geom_feat增加batch维度
batch_index = torch.cat([torch.full([L__ // B, 1], ix, device=cam_feats.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_index), 1)
- 4.3.filter by (X<200,Y<200,Z<1)
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < nx[2])
cam_feats = cam_feats[kept]
geom_feats = geom_feats[kept]
- 4.4.voxel index 位置编码,排序
ranks = (geom_feats[:, 0] * (nx[1] * nx[2] * B) # X
+ geom_feats[:, 1] * (nx[2] * B) # Y
+ geom_feats[:, 2] * B # Z
+ geom_feats[:, 3]) # batch_index
sorts = ranks.argsort()
cam_feats, geom_feats, ranks = cam_feats[sorts], geom_feats[sorts], ranks[sorts]
可以参考我画的示意图
- 4.5. sum
cam_feats, geom_feats = cumsum_trick(cam_feats, geom_feats, ranks)
- 4.6.根据视锥获取相应的cam_feat, final:[1,64,1,200,200]
final = torch.zeros((B, C, nx[2], nx[0], nx[1]), device=cam_feats.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 0], geom_feats[:, 1]] = cam_feats
- 4.7.去掉Z维度, dim_Z维度属于dim=2, 生成bev图
final = torch.cat(final.unbind(dim=2), 1)
5.bev_encoder
bev_encoder = nn.Conv2d(semantic_channels, 1, kernel_size=1, stride=1, padding=0,bias=False)
bev = bev_encoder(final)
plot_bev(bev, name = f'bev')
bev尺寸为200x200