【三维深度补全模型】PENet

news2025/1/4 10:26:07

 【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。   

参考书籍:《人工智能点云处理及深度学习算法》

 本文为专栏《Python三维点云实战宝典》系列文章,专栏介绍地址“【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客”。配套书籍《人工智能点云处理及深度学习算法》提供更加全面和系统的解析。

        PENet是一个由浙江大学和上海华为于ICRA 2021发布的深度补全模型(Sparse-Depth-Completion),即通过RGB图像和雷达稀疏点云来获取更加稠密的点云。论文题目和地址分别为《PENet: Towards Precise and Efficient Image Guided Depth Completion》和“https://arxiv.org/abs/2103.00783”。该模型采用了coarse-refine结构,即粗补全和精补全(精度微调)相结合,并且模型在粗补全阶段对不同尺度图像、稀疏点云和几何特征进行充分融合以提高模型深度补全精度。另一方面,模型对CSPN++网络卷积操作进行优化以提高模型运行速度。PENet提出时在KITTI深度补全数据集上取得了最好成绩,目前排名仍然靠前。下图是其在paperwithcode官网上的排名情况,地址为“https://paperswithcode.com/sota/depth-completion-on-kitti-depth-completion”。

图 PENet排名情况

1 PENet模型结构

        PENet模型总体结构如下图所示,采用了coarse-refine结构。其深度粗补全网络称为ENet,采用两条主干网络进行深度补全特征提取。两条主干网络均融合了雷达所采集的稀疏点云,区别在于第一条主干网络融合了RGB色彩信息,而第二条网络融合了第一条网络预测的深度结果。主干网络采用了类似UNet结构的编码-解码结构,实现对不同尺度特征进行融合。因此,ENet对特征类型和特征空间都进行了充分融合,以获取更加丰富的深度特征。

图 PENet模型结构

        PENet深度粗补全结果是两条主干分支网络预测结果的融合,即图中Fused Depth。由于点深度信息与其邻近点密切关联,作者采用DA CSPN++网络对粗补全结果进行微调,进一步提高模型预测精度。

2 输入数据

2.1 KITTI数据集下载

        PENet模型官方程序地址为“https://github.com/JUGGHM/PENet_ICRA2021”,本节将结合该程序进行详细介绍。程序中模型输入数据集为KITTI补全数据集,需要分别下载KITTI原始数据和补全数据。

        KITTI原始数据集如下图所示,包含City、Residential、Road、Campus、Person 和Calibration6个类别,下载地址为“https://www.cvlibs.net/datasets/kitti/raw_data.php?type=city”。如需进行完整训练和测试验证,程序需要下载这6个类别下全部数据,共包括138个可用数据。如果仅进行程序学习或验证测试,那么我们下载部分数据即可,例如City类别下的2011_09_26_drive_0001、2011_09_26_drive_0002、2011_09_26_drive_0005和2011_09_26_drive_0009。下载数据解压得到以日期命名的文件夹,如2011_09_26。

图 KITTI原始数据集下载

        KITTI深度补全数据集主要包含稠密点云深度,以提供稀疏点云补全的真实标签,下载地址为“https://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion”。数据集下载内容包括下图所示部分,即annotated depth maps data set (14 GB)、raw LiDaR scans data set (5 GB)、manually selected validation and test data sets (2 GB)、development kit (48 K)。解压后文件下包含了训练和验证样本数据。

图 KITII深度补全数据集下载

2.2 数据集预处理

        PENet输入数据集目录如下所示,需将上述所下载文件整理成该目录结构形式。我们如果仅下载部分原始数据(2011_09_26_drive_0001、2011_09_26_drive_0002、2011_09_26_drive_0005和2011_09_26_drive_0009),那么需要对深度补全数据集进行相应设置。data_depth_annotated和data_depth_velodyne文件夹下训练train文件夹仅保留2011_09_26_drive_0001_sync和2011_09_26_drive_0009_sync,删除其它文件夹。data_depth_annotated和data_depth_velodyne文件夹下验证val文件夹仅保留2011_09_26_drive_0002_sync和2011_09_26_drive_0005_sync,删除其它文件夹。

├── kitti_depth

|   ├── depth

|   |   ├──data_depth_annotated

|   |   |  ├── train

|   |   |  ├── val

|   |   ├── data_depth_velodyne

|   |   |  ├── train

|   |   |  ├── val

|   |   ├── data_depth_selection

|   |   |  ├── test_depth_completion_anonymous

|   |   |  |── test_depth_prediction_anonymous

|   |   |  ├── val_selection_cropped

├── kitti_raw

|   ├── 2011_09_26

|   ├── 2011_09_28

|   ├── 2011_09_29

|   ├── 2011_09_30

|   ├── 2011_10_03

        完整训练集包括138个文件夹,而这里仅使用如上两个文件夹数据进行模型解析。完整验证集包括1000个样本。

        模型输入数据由rgb、d、gt、g、position和K等6部分组成。

        (1)rgb

        rgb数据来自于KITTI的2号和3号彩色相机,即彩色图像数据。训练集和验证集图片路径分别为“kitti_raw/*/*_sync/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/image/*.png”。原始图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。图片像素深度为8bit,像素取值范围0~255。

        (2)d

        d为激光雷达所采集的稀疏点云深度数据,以16位png图片存储,取值范围0~65535。取值除以256可得到深度值,且取值为零的点表示无效点,即未采集到深度数据。训练集和验证集路径分别为“data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/velodyne_raw/*.png”。深度图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。

        (3)gt

        gt为稠密点云深度的真实标签数据,以16位png图片存储,取值范围0~65535。取值除以256可得到深度值,且取值为零的点表示无效点,即未采集到深度数据。训练集和验证集路径分别为“data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png”和“data_depth_selection/val_selection_cropped/groundtruth_depth/*.png”。深度图片维度为3x375x1242,经过固定裁剪和随机裁剪后维度为3x320x1216。

        (4)g

        g为rgb彩色图像数据转换后的灰度图像数据,维度为1x320x1216。

        (5)position

        positon是图片像素坐标经过归一化后取值,归一化范围为-1~1。由于像素横纵坐标分别进行归一化处理,因而position维度为2x352x1216,并经过随机裁剪后维度为2x320x1216。

xx_channel = xx_channel.astype('float32') / (self.y_dim - 1)#除以最大值,0~1
yy_channel = yy_channel.astype('float32') / (self.x_dim - 1)#除以最大值,0~1
xx_channel = xx_channel*2 - 1#变换到-1~1
yy_channel = yy_channel*2 - 1#变换到-1~1
ret = np.concatenate([xx_channel, yy_channel], axis=-1)#拼接

        (6)K

        K为3x3维度相机内参矩阵,包含了x、y方向上焦距和光心偏移信息,用于像素坐标和相机坐标系间坐标变换。除直接从calib_cam_to_cam.txt标定文件中读取原始内参矩阵之外,此时K矩阵还需要根据图像裁剪情况对光心偏移进行调整。

def load_calib():
    """
    Temporarily hardcoding the calibration matrix using calib file from 2011_09_26
    """
    calib = open("dataloaders/calib_cam_to_cam.txt", "r")
    lines = calib.readlines()
    P_rect_line = lines[25]
    Proj_str = P_rect_line.split(":")[1].split(" ")[1:]
    Proj = np.reshape(np.array([float(p) for p in Proj_str]), (3, 4)).astype(np.float32)
    K = Proj[:3, :3]  # camera matrix
    # note: we will take the center crop of the images during augmentation
    # that changes the optical centers, but not focal lengths
    # K[0, 2] = K[0, 2] - 13  # from width = 1242 to 1216, with a 13-pixel cut on both sides
    # K[1, 2] = K[1, 2] - 11.5  # from width = 375 to 352, with a 11.5-pixel cut on both sides
    K[0, 2] = K[0, 2] - 13;
    K[1, 2] = K[1, 2] - 11.5;
    return K

3 ENet主干网络

        ENet主干网络包含两条分支,其中一条支路是图像rgb和稀疏深度d融合对稠密深度的预测,另一条支路是预测结果进一步与稀疏深度d融合并对稠密深度进行再次预测。两条支路预测结果融合得到ENet对稠密深度最终预测结果。

3.1 ENet主干支路一

        程序首先通过平均值池化对position(2x320x1216)进行下采样,采样倍数分别为2、4、8、16、32,从而得到6种不同尺度分辨率的像素坐标(vnorm_sx和unorm_sx)。同样地,激光雷达稀疏深度d也采用最大值池化得到相应分辨率下的深度图d_sx。像素坐标与相机坐标系的对应关系可通过如下公式进行计算,那么根据像素坐标和深度坐标可计算得到目标在相机坐标系下的空间坐标(x,y,z)。

        程序相应函数为GeometryFeature,具体计算过程如下所示。由于position坐标已归一化到-1~1,因此需要结合图片尺寸恢复出像素坐标绝对值,然后使用内参和距离参数得到相机坐标,并称该坐标为几何特征。6种分辨率下的像素坐标和稀疏深度分别进行计算,从而得到6种不同分辨率的几何特征geo_sx。

x = z*(0.5*h*(vnorm+1)-ch)/fh
y = z*(0.5*w*(unorm+1)-cw)/fw
return torch.cat((x, y, z),1)

        第一条主干支路输入的图像rgb和深度d拼接并经过卷积Conv2d(4, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)运算后得到32x320x1216维特征rgb_feature。程序将rgb特征和几何特征按照如下图所示过程逐步进行特征融合与特征提取(特征编码),进而得到不同尺度融合特征rgb_featurex。

Concate(rgb_feature_x, geo_sx)->Conv(s=2) 1
Concate(1, geo_sx) Conv1 2
Downsample(rgb_feaure) 3
Add(rgb_feature_sx, 3) rgb_feature_x+1
Concate(rgb_feature_x+1, geo_sx+1)->Conv(s=1) 4
Concate(4, geo_sx+1) Conv1  5
Add(rgb_feature_sx+1, 5) rgb_feature_x+2

图 图像特征与几何特征融合

        rgb特征与几何特征融合过程如下:

  1. rgb输入特征rgb_feature_x维度为C1xH1xW1,两种尺度几何特征geo_sx和geo_sx+1维度分别为3xH1xW1和3xH2xW2,且H2=H1/2、W2=W1/2。
  2. rgb_feature_x与geo_sx进行拼接后经过卷积Conv(C1+3, 2*C1, 3, 2)得到2*C1xH2xW2维度特征。
  3. (2)中特征进一步与geo_sx+1拼接并经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  4. rgb_feature_x与geo_sx进行拼接后经过卷积Conv(C1+3, 2*C1, 3, 2)下采样得到2*C1xH2xW2维度特征。
  5. (3)和(4)中特征进行求和得到融合后rgb特征rgb_feature_x+1,维度为2*C1xH2xW2。
  6. rgb_feature_x+1与geo_sx+1进行拼接后经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  7. (6)中特征进一步与geo_sx+1拼接并经过卷积Conv(2*C1+3, 2*C1, 3, 1)得到2*C1xH2xW2维度特征。
  8. (7)中特征和rgb_feature_x+1进行求和得到新的融合后rgb特征rgb_feature_x+2,维度为2*C1xH2xW2。

        从上述步骤可以看到,rgb特征与几何特征进行多次融合,以获取更加充分的几何特征信息。融合后rgb特征rgb_feature10、rgb_feature8、rgb_feature6、rgb_feature4、rgb_feature2、rgb_feature的维度分别为1024x10x38、512x20x76、256x40x152、128x80x304、64x160x608、32x320x1216。

        rgb特征解码阶段从最小尺度rgb特征逐步通过逆卷积上采样与特征融合得到解码后的不同尺度rgb特征,分别为rgb_feature8_plus(512x20x76)、rgb_feature6_plus(256x40x152)、rgb_feature4_plus(128x80x304)、rgb_feature2_plus(64x160x608)、rgb_feature0_plus(32x320x1216)。rgb_feature0_plus经过逆卷积ConvTranspose2d(32, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)得到2x320x1216维度预测结果,这两个维度分别为第一条主干支路深度预测结果(rgb_depth,1x320x1216)及其置信度(rgb_conf,1x320x1216)。

        以上为第一条主干分支的模型处理过程,关键程序解析如下所示。

vnorm_s2 = self.pooling(vnorm)#每两个点平均池化,160x608
vnorm_s3 = self.pooling(vnorm_s2)#每两个点平均池化,80x304
vnorm_s4 = self.pooling(vnorm_s3)#每两个点平均池化,40x152
vnorm_s5 = self.pooling(vnorm_s4)#每两个点平均池化,20x76
vnorm_s6 = self.pooling(vnorm_s5)#每两个点平均池化,10x38
unorm_s2 = self.pooling(unorm)#每两个点平均池化,160x608
unorm_s3 = self.pooling(unorm_s2)#每两个点平均池化,80x304
unorm_s4 = self.pooling(unorm_s3)#每两个点平均池化,40x152
unorm_s5 = self.pooling(unorm_s4)#每两个点平均池化,20x76
unorm_s6 = self.pooling(unorm_s5)#每两个点平均池化,10x38
#不同尺度深度图
valid_mask = torch.where(d>0, torch.full_like(d, 1.0), torch.full_like(d, 0.0))#深度大于0的点为有效点
d_s2, vm_s2 = self.sparsepooling(d, valid_mask)#深度最大值池化,160x608
d_s3, vm_s3 = self.sparsepooling(d_s2, vm_s2)#深度最大值池化,80x304
d_s4, vm_s4 = self.sparsepooling(d_s3, vm_s3)#深度最大值池化,40x152
d_s5, vm_s5 = self.sparsepooling(d_s4, vm_s4)#深度最大值池化,20x76
d_s6, vm_s6 = self.sparsepooling(d_s5, vm_s5)#深度最大值池化,10x38
geo_s1 = self.geofeature(d, vnorm, unorm, 352, 1216, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x320x1216
geo_s2 = self.geofeature(d_s2, vnorm_s2, unorm_s2, 352 / 2, 1216 / 2, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x160x608
geo_s3 = self.geofeature(d_s3, vnorm_s3, unorm_s3, 352 / 4, 1216 / 4, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x80x304
geo_s4 = self.geofeature(d_s4, vnorm_s4, unorm_s4, 352 / 8, 1216 / 8, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x40x152
geo_s5 = self.geofeature(d_s5, vnorm_s5, unorm_s5, 352 / 16, 1216 / 16, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x20x76
geo_s6 = self.geofeature(d_s6, vnorm_s6, unorm_s6, 352 / 32, 1216 / 32, c352, c1216, f352, f1216)#像素坐标到相机坐标,x,y,z,3x10x38
rgb_feature = self.rgb_conv_init(torch.cat((rgb, d), dim=1))#rgbd特征提取,4x320x1216 -> 32x320x1216
rgb_feature1 = self.rgb_encoder_layer1(rgb_feature, geo_s1, geo_s2) #64x160x608,不同尺度rgb与坐标特征融合
rgb_feature2 = self.rgb_encoder_layer2(rgb_feature1, geo_s2, geo_s2) #64x160x608,不同尺度rgb与坐标特征融合
rgb_feature3 = self.rgb_encoder_layer3(rgb_feature2, geo_s2, geo_s3) #128x80x304,不同尺度rgb与坐标特征融合
rgb_feature4 = self.rgb_encoder_layer4(rgb_feature3, geo_s3, geo_s3) #128x80x304,不同尺度rgb与坐标特征融合
rgb_feature5 = self.rgb_encoder_layer5(rgb_feature4, geo_s3, geo_s4) #256x40x152,不同尺度rgb与坐标特征融合
rgb_feature6 = self.rgb_encoder_layer6(rgb_feature5, geo_s4, geo_s4) #256x40x152,不同尺度rgb与坐标特征融合
rgb_feature7 = self.rgb_encoder_layer7(rgb_feature6, geo_s4, geo_s5) #512x20x76,不同尺度rgb与坐标特征融合
rgb_feature8 = self.rgb_encoder_layer8(rgb_feature7, geo_s5, geo_s5) #512x20x76,不同尺度rgb与坐标特征融合
rgb_feature9 = self.rgb_encoder_layer9(rgb_feature8, geo_s5, geo_s6) #1024x10x38,不同尺度rgb与坐标特征融合
rgb_feature10 = self.rgb_encoder_layer10(rgb_feature9, geo_s6, geo_s6) #1024x10x38,不同尺度rgb与坐标特征融合
rgb_feature_decoder8 = self.rgb_decoder_layer8(rgb_feature10)#逆卷积上采样,512x20x76
rgb_feature8_plus = rgb_feature_decoder8 + rgb_feature8#特征融合,512x20x76
rgb_feature_decoder6 = self.rgb_decoder_layer6(rgb_feature8_plus)#逆卷积上采样,256x40x152
rgb_feature6_plus = rgb_feature_decoder6 + rgb_feature6#特征融合,256x40x152
rgb_feature_decoder4 = self.rgb_decoder_layer4(rgb_feature6_plus)#逆卷积上采样,128x80x304
rgb_feature4_plus = rgb_feature_decoder4 + rgb_feature4#特征融合,128x80x304
rgb_feature_decoder2 = self.rgb_decoder_layer2(rgb_feature4_plus)#逆卷积上采样,64x160x608
rgb_feature2_plus = rgb_feature_decoder2 + rgb_feature2#特征融合,64x160x608
rgb_feature_decoder0 = self.rgb_decoder_layer0(rgb_feature2_plus)#逆卷积上采样,32x320x1216
rgb_feature0_plus = rgb_feature_decoder0 + rgb_feature#特征融合,32x320x1216
rgb_output = self.rgb_decoder_output(rgb_feature0_plus)#深度和置信度预测,2x320x1216
rgb_depth = rgb_output[:, 0:1, :, :]#1x320x1216
rgb_conf = rgb_output[:, 1:2, :, :]#1x320x1216

3.2 ENet主干支路二

        ENet第二条主干支路输入为稀疏深度d和支路一预测深度rgb_depth,二者拼接并经过卷积Conv2d(2, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)得到32x320x1216维度融合特征。模型将该特征定义为稀疏特征,即sparsed_feature。该支路仍然采用特征编码-解码的结构进行特征提取。

        与支路一操作类似,不同尺度下稀疏特征也与几何特征进行两次融合。除此之外,稀疏特征sparsed_featurex还与相同尺度的rgb特征rgb_featurex_plus进行拼接融合。稀疏特征、几何特征和rgb特征相互融合,完成特征编码,主要输出为sparsed_feature10(1024x10x38)、sparsed_feature8(512x20x76)、sparsed_feature6(256x40x152)、sparsed_feature4(128x80x304)、sparsed_feature2(64x160x608)。

        稀疏特征解码阶段从最小尺度稀疏特征逐步通过逆卷积上采样与特征融合得到解码后的不同尺度稀疏特征,分别为decoder_feature1(512x20x76)、decoder_feature2(256x40x152)、decoder_feature3(128x80x304)、decoder_feature4(64x160x608)、decoder_feature5(32x320x1216)。rgb_feature0_plus经过逆卷积ConvTranspose2d(32, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)得到2x320x1216维度预测结果,这两个维度分别为第二条主干支路深度预测结果(d_depth,1x320x1216)及其置信度(d_conf,1x320x1216)。

        以上为第二条主干分支的模型处理过程,关键程序解析如下所示。

sparsed_feature = self.depth_conv_init(torch.cat((d, rgb_depth), dim=1))#雷达特征与RGB预测深度融合后提取特征,32x320x1216
sparsed_feature1 = self.depth_layer1(sparsed_feature, geo_s1, geo_s2)#类似深度信息与几何坐标信息融合,64x160x608
sparsed_feature2 = self.depth_layer2(sparsed_feature1, geo_s2, geo_s2) #64x160x608
sparsed_feature2_plus = torch.cat([rgb_feature2_plus, sparsed_feature2], 1)#128x160x608
sparsed_feature3 = self.depth_layer3(sparsed_feature2_plus, geo_s2, geo_s3) #128x80x304
sparsed_feature4 = self.depth_layer4(sparsed_feature3, geo_s3, geo_s3)#128x80x304
sparsed_feature4_plus = torch.cat([rgb_feature4_plus, sparsed_feature4], 1)#256x80x304
sparsed_feature5 = self.depth_layer5(sparsed_feature4_plus, geo_s3, geo_s4) #256x40x152
sparsed_feature6 = self.depth_layer6(sparsed_feature5, geo_s4, geo_s4) #256x40x152
sparsed_feature6_plus = torch.cat([rgb_feature6_plus, sparsed_feature6], 1)#512x40x152
sparsed_feature7 = self.depth_layer7(sparsed_feature6_plus, geo_s4, geo_s5) #512x20x76
sparsed_feature8 = self.depth_layer8(sparsed_feature7, geo_s5, geo_s5) #512x20x76
sparsed_feature8_plus = torch.cat([rgb_feature8_plus, sparsed_feature8], 1)#1024x20x76
sparsed_feature9 = self.depth_layer9(sparsed_feature8_plus, geo_s5, geo_s6) #1024x10x38
sparsed_feature10 = self.depth_layer10(sparsed_feature9, geo_s6, geo_s6) #1024x10x38
fusion1 = rgb_feature10 + sparsed_feature10#1024x10x38
decoder_feature1 = self.decoder_layer1(fusion1)#逆卷积上采样,512x20x76
fusion2 = sparsed_feature8 + decoder_feature1#特征融合,512x20x76
decoder_feature2 = self.decoder_layer2(fusion2)#逆卷积上采样,256x40x152
fusion3 = sparsed_feature6 + decoder_feature2#特征融合,256x40x152
decoder_feature3 = self.decoder_layer3(fusion3)#逆卷积上采样,128x80x304
fusion4 = sparsed_feature4 + decoder_feature3#特征融合,128x80x304
decoder_feature4 = self.decoder_layer4(fusion4)#逆卷积上采样,64x160x608
fusion5 = sparsed_feature2 + decoder_feature4#特征融合,64x160x608
decoder_feature5 = self.decoder_layer5(fusion5)#逆卷积上采样,32x320x1216
depth_output = self.decoder_layer6(decoder_feature5)#卷积,2x320x1216
d_depth, d_conf = torch.chunk(depth_output, 2, dim=1)#1x320x1216,1x320x1216

3.3 分支融合

        ENet两条支路均预测了深度及其置信度,其中第一条支路预测结果为深度(rgb_depth,1x320x1216)及其置信度(rgb_conf,1x320x1216);第二条支路预测结果为深度(d_depth,1x320x1216)及其置信度(d_conf,1x320x1216)。融合时最终预测深度来源于两条支路预测深度的加权求和,权重由置信度经过softmax得到,即置信度概率越大,权重占比越大。关键程序解析如下所示。

rgb_conf, d_conf = torch.chunk(self.softmax(torch.cat((rgb_conf, d_conf), dim=1)), 2, dim=1)#将两条支路的置信度转换为权重
output = rgb_conf*rgb_depth + d_conf*d_depth#深度预测结果,1x320x1216

        模型返回值为rgb_depth、d_depth、output(融合预测深度)。

3.4 ENet损失函数

        ENet训练损失包含rgb深度损失、稀疏深度损失和融合深度损失,对应预测结果为rgb_depth、d_depth、output。其损失函数均为MaskedMSELoss,衡量预测深度与真实深度标签gt之间的偏差。

        训练前两个迭代周期中,rgb深度损失和稀疏深度损失的权重为0.2,并在第3~4个周期内降为0.05。从第5个训练周期开始,ENet训练损失函数仅包括融合深度损失depth_loss。

        ENet训练损失关键程序解析如下所示。

st1_pred, st2_pred, pred = model(batch_data)#rgb_depth、d_depth、output(融合预测深度)
round1, round2, round3 = 1, 3, None
if(actual_epoch <= round1):
    w_st1, w_st2 = 0.2, 0.2
elif(actual_epoch <= round2):
    w_st1, w_st2 = 0.05, 0.05

else:
    w_st1, w_st2 = 0, 0
depth_loss = depth_criterion(pred, gt)#MaskedMSELoss()
st1_loss = depth_criterion(st1_pred, gt)#MaskedMSELoss()
st2_loss = depth_criterion(st2_pred, gt)#MaskedMSELoss()
loss = (1 - w_st1 - w_st2) * depth_loss + w_st1 * st1_loss + w_st2 * st2_loss

4 DA CSPN++

        DA (dilated and acceleratedm,膨胀加速)CSPN++网络是对ENet预测结果进行微调以获取更加准确的深度信息。其输入包括ENet所提取特征(feature_s1 64x320x1216,feature_s2 128x160x608)与深度预测结果(coarse_depth,1x320x1216),其中特征feature_s1 和feature_s2是rgb深度特征和融合特征的融合。根据膨胀比例,模型设置相应尺度的输入特征。卷积膨胀的作用是为了使卷积核覆盖范围更大,从而使卷积视野范围更广。从另外一个角度上来说,卷积膨胀相当于在下采样的特征图上进行普通卷积操作,该模型的后续操作便是采用这种方法。因此,假设膨胀系数为2,那么所需特征图尺寸为160x608。模型输入特征包含两部分,一部分为原始尺度特征feature_s1,即rgb_feature0_plus和 decoder_feature5拼接融合,维度为64x320x1216;另一部分为用于膨胀操作的特征feature_s2,即rgb_feature2_plus和 decoder_feature4拼接融合,维度为128x160x608。

#ENet输出
torch.cat((rgb_feature0_plus, decoder_feature5), 1), torch.cat((rgb_feature2_plus, decoder_feature4),1), output
feature_s1, feature_s2, coarse_depth = self.backbone(input)#由ENet得到的特征与预测深度,64x320x1216,128x160x608,1x320x1216
depth = coarse_depth#1x320x1216

        CSPN++网络核心思想是采用模型来自主学习卷积核权重,而不是使用卷积直接对输入进行操作,这一点类似于transformer的QK操作。DA CSPN++用于学习卷积核权重的输入特征为feature_s2(128x160x608)。CSPN++网络的另一个特点为采用多种尺度卷积核进行特征提取,参考程序使用了尺寸为3、5、7的卷积核。每个卷积核所提取特征采用加权求和的方法进行融合,其中置信度权重网路的输入也为feature_s2。

        作者对CSPN++网络进行了加速设计,将卷积操作转换为矩阵乘法,从而实现并行计算。例如,3x3卷积核在HxW维度特征图进行滑动操作可转换为9xHxW维度卷积和9xHxW特征图的矩阵乘法。

图 CSPN++加速

        模型对原始输入特征和膨胀特征均会进行CSPN++微调,主要包括卷积核参数及其权重学习、DA CSPN++结果微调、feature_s1 CSPN++结果微调、特征加权求和融合等步骤。

4.1 卷积核参数及其权重学习

        feature_s2(128x160x608)经过卷积 Conv2d(128, 3)和softmax后得到3x160x608维度卷积核权重,对应kernel_conf3_s2(1x160x608)、kernel_conf5_s2(1x160x608)、kernel_conf7_s2(1x160x608)。另一方面,feature_s2经过卷积 Conv2d和padding操作得到卷积核参数,分别为guide3_s2(9x162x610)、guide5_s2(25x164x612)、guide7_s2(49x166x614)。

        feature_s1(64x320x1216)经过卷积 Conv2d(64, 3)和softmax后得到3x320x1216维度卷积核权重,对应kernel_conf3(1x320x1216)、kernel_conf5(1x320x1216)、kernel_conf7(1x320x1216)。另一方面,feature_s1经过卷积 Conv2d和padding操作得到卷积核参数,分别为guide3(9x322x1218)、guide5(25x324x1220)、guide7(49x326x1222)。

kernel_conf_s2 = self.kernel_conf_layer_s2(feature_s2)#128x320x1216 -> 3x160x608
kernel_conf_s2 = self.softmax(kernel_conf_s2)#转换为权重,3x160x608,通道维度和为1,即不同卷积核特征的置信度权重
kernel_conf3_s2 = kernel_conf_s2[:, 0:1, :, :]#1x160x608,3x3卷积核特征权重
kernel_conf5_s2 = kernel_conf_s2[:, 1:2, :, :]#1x160x608,5x5卷积核特征权重
kernel_conf7_s2 = kernel_conf_s2[:, 2:3, :, :]#1x160x608,7x7卷积核特征权重
kernel_conf = self.kernel_conf_layer(feature_s1)#64x320x1216 -> 3x320x1216
kernel_conf = self.softmax(kernel_conf)#转换为权重,3x320x1216,通道维度和为1,即不同卷积核特征的置信度权重
kernel_conf3 = kernel_conf[:, 0:1, :, :]#1x320x1216,3x3卷积核特征权重
kernel_conf5 = kernel_conf[:, 1:2, :, :]#1x320x1216,5x5卷积核特征权重
kernel_conf7 = kernel_conf[:, 2:3, :, :]#1x320x1216,7x7卷积核特征权重
guide3_s2 = self.iter_guide_layer3_s2(feature_s2)#学习3x3卷积CSPN,9x162x610
guide5_s2 = self.iter_guide_layer5_s2(feature_s2)#学习5x5卷积CSPN,25x164x612
guide7_s2 = self.iter_guide_layer7_s2(feature_s2)#学习7x7卷积CSPN,49x166x614
guide3 = self.iter_guide_layer3(feature_s1)#学习3x3卷积CSPN,9x322x1218
guide5 = self.iter_guide_layer5(feature_s1)#学习3x3卷积CSPN,25x324x1220
guide7 = self.iter_guide_layer7(feature_s1)#学习3x3卷积CSPN,49x326x1222

4.2 DA CSPN++结果微调

        ENet预测深度(1x320x1216)下采样成4张子深度图(1x160x608),由于子深度图可构成完整原始深度图,因而这种下采样不会带来信息丢失。DA CSPN++对这4种特征图(depth_s2_00、depth_s2_01、depth_s2_10、depth_s2_11)分别进行深度微调结果预测。每个子深度图进行6次连续CSPN++操作以利用更深层次特征来预测新的微调深度,并且每次进行CSPN++操作时都会与ENet子深度图和激光雷达稀疏深度图d_s2进行融合。子深度图预测结果为3种卷积核提取特征的加权求和。DA CSPN++预测深度(depth_s2_00、depth_s2_01、depth_s2_10、depth_s2_11)重新拼接成原始尺寸,即depth_s2(1x320x1216)。

d_s2, valid_mask_s2 = self.downsample(d, valid_mask)#原始雷达深度最大值池化,1x160x608
mask_s2 = self.mask_layer_s2(feature_s2)#128x320x1216 -> 1x160x608
mask_s2 = torch.sigmoid(mask_s2)#转化为权重, 1x160x608,即DA CSPN++输出的权重
mask_s2 = mask_s2*valid_mask_s2#深度mask与预测mask相乘,1x160x608
feature_12 = torch.cat((feature_s1, self.upsample(self.dimhalf_s2(feature_s2))), 1)#128x320x1216,两种输入特征融合
att_map_12 = self.softmax(self.att_12(feature_12))#2x320x1216,用于ENet预测深度和DA CSPN++微调深度融合
depth_s2 = depth#1x320x1216
depth_s2_00 = depth_s2[:, :, 0::2, 0::2]#深度图拆分,1x160x608
depth_s2_01 = depth_s2[:, :, 0::2, 1::2]#深度图拆分,1x160x608
depth_s2_10 = depth_s2[:, :, 1::2, 0::2]#深度图拆分,1x160x608
depth_s2_11 = depth_s2[:, :, 1::2, 1::2]#深度图拆分,1x160x608
depth3_s2_00 = self.CSPN3(guide3_s2, depth3_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth3_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth3_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth5_s2_00 = self.CSPN5(guide5_s2, depth5_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth5_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth5_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth7_s2_00 = self.CSPN7(guide7_s2, depth7_s2_00, depth_s2_00_h0)#1x160x608,CSPN特征提取
depth7_s2_00 = mask_s2*d_s2 + (1-mask_s2)*depth7_s2_00#1x160x608,与原始输入稀疏特征加权求和融合
depth_s2_00 = kernel_conf3_s2*depth3_s2_00 + kernel_conf5_s2*depth5_s2_00 + kernel_conf7_s2*depth7_s2_00#不同卷积核特征加权求和融合,1x160x608
depth_s2[:, :, 0::2, 0::2] = depth_s2_00#将深度重新拼接成原始尺度,1x320x1216
refined_depth_s2 = depth*att_map_12[:, 0:1, :, :] + depth_s2*att_map_12[:, 1:2, :, :]#与ENet深度加权求和融合,1x320x1216

4.3 feature_s1 CSPN++结果微调

        模型再次使用feature_s1学习的三种卷积核参数对DA CSPN++的预测深度结果depth_s2进行结果微调。模型此时同样采用连续6次CSPN++操作,,并且每次进行CSPN++操作时都会与depth_s2和激光雷达稀疏深度图d进行融合。三种尺寸卷积核对应的CSPN预测结果(depth3、depth5、depth7)加权求和即可得到模型最终微调后的预测深度refined_depth(1x320x1216)。

mask = self.mask_layer(feature_s1)#64x320x1216 -> 1x320x1216
mask = torch.sigmoid(mask)#转化为权重,1x320x1216,非膨胀CSPN++卷积输出的权重
mask = mask*valid_mask#深度mask与预测mask相乘,1x320x1216
for i in range(6):
    depth3 = self.CSPN3(guide3, depth3, depth)#采用CSPN再次进行深度微调,1x320x1216
    depth3 = mask*d + (1-mask)*depth3#与原始输入稀疏特征加权求和融合,1x320x1216
    depth5 = self.CSPN5(guide5, depth5, depth)#采用CSPN再次进行深度微调,1x320x1216
    depth5 = mask*d + (1-mask)*depth5#与原始输入稀疏特征加权求和融合,1x320x1216
    depth7 = self.CSPN7(guide7, depth7, depth)#采用CSPN再次进行深度微调,1x320x1216
    depth7 = mask*d + (1-mask)*depth7#与原始输入稀疏特征加权求和融合,1x320x1216
refined_depth = kernel_conf3*depth3 + kernel_conf5*depth5 + kernel_conf7*depth7#深度加权求和融合,1x320x1216

4.4 损失函数

        DA CSPN++阶段训练损失函数仅由depth_loss组成,即CsPN++深度预测结果与真实标签之间偏差,损失函数类型为MaskedMSELoss。

5 模型训练

        PENet模型训练包括三个步骤,分别是ENet训练、DA CSPN++训练和PENet训练,分别对应下图中I、II、III。

图PENet训练示意图

        ENet训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 6 -n e”,CUDA_VISIBLE_DEVICES="0,1"部分可以根据实际情况设置GPU序号。作者提供的Enet预训练模型下载地址为“https://drive.google.com/file/d/1TRVmduAnrqDagEGKqbpYcKCT307HVQp1/view?usp=sharing”。

        DA-CSPN++训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 6 -f -n pe --resume [enet-checkpoint-path]”。“-f”表示训练DA-CSPN++网络时ENet主干网络是固定的,即不通过梯度传播更新参数。

        当ENet和DA-CSPN++分别训练完成后,模型再次进行整体训练,训练命令为“CUDA_VISIBLE_DEVICES="0,1" python main.py -b 10 -n pe -he 160 -w 576 --resume [penet-checkpoint-path]”。该训练程序与DA CSPN++训练的区别在于不再使用-f参数,即ENet主干网络也需要进行训练更新。预训练模型可以使用DA-CSPN++训练得到的模型,也可直接使用作者提供的PENet预训练模型,下载地址为“https://drive.google.com/file/d/1RDdKlKJcas-G5OA49x8OoqcUDiYYZgeM/view?usp=sharing”。

6 【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客

【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。 
 

更多python与C++技巧、三维算法、深度学习算法总结、大模型请关注我的博客,欢迎讨论与交流:https://blog.csdn.net/suiyingy,或”乐乐感知学堂“公众号。Python三维领域专业书籍推荐:《人工智能点云处理及深度学习算法》。

 本文为专栏《Python三维点云实战宝典》系列文章,专栏介绍地址“【python三维深度学习】python三维点云从基础到深度学习_python3d点云从基础到深度学习-CSDN博客”。配套书籍《人工智能点云处理及深度学习算法》提供更加全面和系统的解析。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2069359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java结合Ai

Spring AI Spring AI提供的API支持跨人工智能提供商的 聊天,文本到图像,和嵌入模型等,同时支持同步和流API选项; 介绍 Spring AI 是 AI 工程的应用框架。其目标是将Spring生态系统的设计原则(如可移植性和模块化设计)应用于AI领域,并促进使用POJO作为应用程序的构建块…

大数据-100 Spark 集群 Spark Streaming DStream转换 黑名单过滤的三种实现方式

喜大普奔&#xff01;破百了&#xff01; 点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&a…

【精选】基于django柚子校园影院(咨询+解答+辅导)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

[240824] 微软更新导致部分 Linux 用户无法启动系统,谁之过?| Chrome 稳定版更新(128.0.6613.84)

目录 微软更新导致部分 Linux 用户无法启动系统&#xff0c;谁之过&#xff1f;Chrome 稳定版更新 (128.0.6613.84) 微软更新导致部分 Linux 用户无法启动系统&#xff0c;谁之过&#xff1f; 最近&#xff0c;微软推送的一项 Windows 更新导致部分 Linux 用户无法启动系统&am…

基于Springboot + vue + mysql 藏区特产销售平台 设计实现

目录 &#x1f4da; 前言 &#x1f4d1;摘要 1.1 研究背景 &#x1f4d1;操作流程 &#x1f4da; 系统架构设计 &#x1f4da; 数据库设计 &#x1f4ac; E-R表 系统功能模块 系统首页 特产信息 ​编辑 个人中心 购物车 用户注册 管理员功能模块 管理员登录 管…

Stable diffusion模型如何区分?通俗易懂,入门必看!

在Stable Diffusion的基础学习中&#xff0c;很多小伙伴们可能看到繁杂的大模型就蒙圈了&#xff0c;那么多的模型后缀&#xff0c;究竟代表什么呢&#xff1f;如何区分呢&#xff1f;今天就带大家来学习一下&#xff5e; 不同后缀模型介绍 在Stable diffusion中&#xff0c;…

【Tomact源码解析】——组件介绍

目录 一、简介 二、组件和体系架构简介 三、组件详情 Server Service Connector Engine ​编辑Host Context Wrapper 四、容器详情 生命周期机制 监听器机制 管道机制 五、补充内容 一、简介 Tomcat 服务器是一个免费的开放源代码的 Web 应用服务器,属于…

支持在线编辑的文件管理系统MxsDoc

DocSys是一个基于Web的文件管理系统&#xff08;全平台支持:Linux&#xff0c;Windows&#xff0c;Mac&#xff09;&#xff0c;它提供了丰富的功能和特性&#xff0c;以满足不同用户在不同场景下的需求。 开源地址&#xff1a;DocSys: MxsDoc是基于Web的文件管理系统&#xff…

校友林小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;树木管理管理&#xff0c;所属科管理&#xff0c;树木领取管理&#xff0c;树跟踪状态管理&#xff0c;用户信息统计管理&#xff0c;树木捐款管理&#xff0c;留言板管理 微信端…

【芯片往事】陈大同-展讯和TD

前言&#xff1a;几年前&#xff08;2012&#xff09;&#xff0c;应邀为校友刊物《水木清华》写了一年创业专栏&#xff0c;其中有几期回忆了当年先后创办硅谷豪威科技&#xff08;OmniVision&#xff09;和上海展讯通信&#xff08;SpreadTrum&#xff09;的经历&#xff0c;…

ZMQ发布订阅模型

案例一 发布者Publisher(server) // server.cpp #include <zmq.hpp> #include <string> #include <iostream> #include <chrono> #include <thread> using namespace std; using namespace zmq; int main() {context_t context(1);socket_t so…

维纳滤波(Wiener Filtering)

维纳滤波&#xff08;Wiener Filtering&#xff09; 引言 维纳滤波&#xff08;Wiener Filtering&#xff09;是一种最优线性滤波方法&#xff0c;广泛应用于信号处理、图像处理和通信系统中。它旨在从含噪声的信号中恢复原始信号&#xff0c;最小化均方误差&#xff08;MSE&…

谷粒商城实战笔记-251-商城业务-消息队列-Exchange类型

文章目录 一&#xff0c;Exchange二&#xff0c;Exchange的四种类型1&#xff0c;direct2&#xff0c;fanout3&#xff0c;topic 三&#xff0c;实操1&#xff0c;创建一个exchange2&#xff0c;创建一个queue3&#xff0c;将queue绑定到exchange 一&#xff0c;Exchange AMQP …

二叉树的链式存储(代码实现)

二叉树的链式存储 用链表实现&#xff0c;基于完全二叉树规律来构建树&#xff0c;按照完全二叉树的编号方法&#xff0c;从上到下&#xff0c;从左到右。一共n个节点。 第i个节点&#xff1a; 左子节点编号&#xff1a;2*i &#xff08;2*i<n&#xff09; 右子节点编号&…

【C++题解】1146. 求S的值

欢迎关注本专栏《C从零基础到信奥赛入门级&#xff08;CSP-J&#xff09;》 问题&#xff1a;1146. 求S的值 类型&#xff1a;递归基础、函数 题目描述&#xff1a; 求 S12471116…的值刚好大于等于 5000 时 S 的值。 输入&#xff1a; 无。 输出&#xff1a; 一行&…

写作手三天速成攻略【数学建模国赛赛前必看内容】

第一天&#xff1a;准备论文模板&#xff0c;学习各类基础画图技巧 1、论文模板 对于写作手&#xff0c;除了内容的连贯性&#xff0c;排版是非常重要的&#xff0c;可以说有一个好的排版&#xff0c;只要论文是完整的&#xff0c;有结果的&#xff0c;基本上保底有省奖&#…

CSP-CCF 201412-2 Z字形扫描

目录 一、问题描述 二、解答 三、总结 一、问题描述 在图像编码的算法中&#xff0c;需要将一个给定的方形矩阵进行Z字形扫描(Zigzag Scan)。给定一个nn的矩阵&#xff0c;Z字形扫描的过程如下图所示&#xff1a; 对于下面的44的矩阵&#xff0c;   1 5 3 9   3 7 5 6  …

玩客云刷机armbian后docker启动不起来,提示bpf_prog_query(BPF_CGROUP_DEVICE) failed

/ ___| ( _ )/ |___ \ \___ \ / _ \| | __) |___) | (_) | |/ __/ |____/ \___/|_|_____|Welcome to Armbian 20.12 Bullseye with Linux 5.10.61-aml-s812Linux aml-s812 5.10.61-aml-s812 #20.12 SMP Thu Sep 2 20:11:09 CST 2021 armv7l GNU/Linux 玩客云刷机armbian后dock…

工业气膜仓储:高效、灵活的仓储解决方案—轻空间

在现代工业生产中&#xff0c;仓储设施的选择至关重要。作为一种新型的仓储解决方案&#xff0c;工业气膜仓储凭借其高效、灵活、经济的优势&#xff0c;正在逐渐取代传统建筑仓库&#xff0c;成为各类企业的理想选择。 一、快速搭建&#xff0c;满足多种需求 工业气膜仓储采用…

24年浙江事业单位考试报名流程保姆级教程

2024年浙江事业单位考试报名马上就要开始了&#xff0c;有想要参加考试报名的同学可以提前了解一下报名流程&#xff0c;以及报名照要求。 一、考试时间安排&#xff1a; 报名时间&#xff1a;8月27日9:00 9月2日16:00 资格审核时间&#xff1a;8月27日—9月3日 网上缴费时…