DDP示例

news2024/9/21 22:31:13

https://zhuanlan.zhihu.com/p/602305591
https://zhuanlan.zhihu.com/p/178402798

在这里插入图片描述
在这里插入图片描述

关于模型保存与加载 : 其实分为保存 有module和无module2种 ; (上面知乎这篇文章说带时带module)
在这里插入图片描述

关于2种带与不带的说明:
https://blog.csdn.net/hustwayne/article/details/120324639

在project中, 是不带module的, 然后加载预训练权重,会remove一些key; 后期改为mmcv中的load_checkpoint自适应匹配kye-value;

在这里插入图片描述

老模型main.py DDP示例

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""
import warnings
warnings.filterwarnings("error", "MAGMA*")
from fire import Fire
import argparse
import torch
import src
import os
"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
import numpy as np
from time import time
from torch import nn
from src.models_goe_1129_nornn_2d_2_ori import compile_model
# from src.models_goe_1129_nornn_2d_2_zj import compile_model
from tensorboardX import SummaryWriter
from src.data_tfmap_newcxy_nextmask2 import compile_data  # 当前帧拼接帧都加超界点
# from src.data_tfmap_newcxy_ori import compile_data  #  不加超界点
#from src.data_tfmap import compile_data
from src.tools import SimpleLoss, RegLoss, SegLoss, SegLoss, BCEFocalLoss, get_batch_iou, get_val_info, denormalize_img, SimpleLoss
import sys
import cv2
from collections import OrderedDict

from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.rendering.neuconw_helper import NeuconWHelper
import open3d as o3d
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
os.environ['LOCAL_RANK'] = "0,1"
torch.set_num_threads(8)


# os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# os.environ['RANK'] = "0"
# os.environ['WORLD_SIZE'] = "1"
# os.environ['MASTER_ADDR'] = "localhost"
# os.environ['MASTER_PORT'] = "12345"
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

"动静态分离里, 构造sample时rays要加一个type的维度"

import argparse


def project_from_lidar_2_cam(img, points, rots, trans, intrins, post_rots, post_trans):
    color_arr = np.zeros((points.shape[0], 3))
    # ego_to_cam
    points -= trans

    points = torch.inverse(rots.view(1, 3, 3)).matmul(points.unsqueeze(-1)).squeeze(-1)
    depths = points[..., 2:]
    points = torch.cat((points[..., :2] / depths, torch.ones_like(depths)), -1)

    # cam_to_img
    points = intrins.view(1, 3, 3).matmul(points.unsqueeze(-1)).squeeze(-1)
    points = post_rots.view(1, 3, 3).matmul(points.unsqueeze(-1)).squeeze(-1)
    points = points + post_trans.view(1, 3)
    # points = points.view(B, N, Z, Y, X, 3)[..., :2]
    points = points.view(-1, 3).int().numpy()

    # imshow
    # pts = points[0,0,2,...].reshape(-1, 2).cpu().numpy()
    # image = np.zeros((128, 352, 3), dtype=np.uint8)
    # for i in range(pts.shape[0]):
    #     cv2.circle(image, (int(pts[i, 0]), int(pts[i, 1])), 1, (255, 255, 255), 2)
    # cv2.imshow("local_map", image)
    # cv2.waitKey(-1)

    # normalize_coord
    img = np.array(img)
    # for i in range(points.shape[0]):
    #     cv2.circle(img, (points[i,0], points[i,1]), 1, tuple(color_arr[i].tolist()), -1)
    return img

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)
    print(config)

    # args.local_rank = 2
    print("sssss",args.local_rank)
    # 新增3:DDP backend初始化
#       a.根据local_rank来设定当前使用哪块GPU
#       b.初始化DDP,使用默认backend(nccl)就行。如果是CPU模型运行,需要选择其他后端。
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')
    
    version = "0"
    #dataroot = "/defaultShare/aishare/share"
    dataroot = "/data/zjj/data/aishare/share"
    nepochs=10000
 
    final_dim=(128, 352)
    max_grad_norm=5.0
    #max_grad_norm=2.0
    pos_weight=2.13

    logdir=f'/mnt/sdb/xzq/occ_project/occ_nerf_st/log/{args.exp_name}'
   
    xbound=[0.0, 102., 0.85]
    ybound=[-10.0, 10.0, 0.5]
    zbound=[-2.0, 4.0, 1]
    dbound=[3.0, 103.0, 2.]

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-2.0, 4.0, 1]
    # dbound=[3.0, 103.0, 2.]

    bsz=4
    seq_len=5 #5
    nworkers=1 #2
    lr=1e-4
    # weight_decay=1e-7
    weight_decay = 0
    sample_num = 1024
    datatype = "single"    #multi   single
        
    torch.backends.cudnn.benchmark = True
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }
 
    ### bevgnd
    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }
    
    train_sampler, val_sampler, trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
                                          grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
                                          parser_name='segmentationdata', datatype=datatype)
    print("train lengths: ", len(trainloader))
    # print("val lengths: ", len(valloader))
    # device = torch.device('cpu') if gpuid < 0 else torch.device(f'cuda:{gpuid}')
    writer = SummaryWriter(logdir=logdir)
    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, writer=writer)
    counter = 0
 
    if 0:
        print('==> loading existing model')
        model_info = torch.load('/data/zjj/project/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231113_nornn_120_21_6_b2_lall_sample1024_v1/checkpts/model_30000.pt')
        # model_info = torch.load('/zhangjingjuan/NeRF/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231114_nornn_v2/checkpts/model_50000.pt')
        #model_info = torch.load('/data/zjj/bev_osr_distort_multi_addtime_nornn_align_h5_nerf_multi2/checkpoints/models_20231120_nornn_v1/checkpts/model_18000.pt')

        counter = 0

        new_state_dict = OrderedDict()
        for k, v in model_info.items():
            if 'semantic_net' in k:
                continue
            # if 'SEnet' in k or 'voxels' in k or 'bevencode.downchannel' in k or 'bevencode.up3' in k or 'bevencode.conv1_block' in k:
            #    continue
            # if 'voxels' in k:
            #     continue
            # if 'color_net' in k:
            #     continue
    
            
            if "neuconw_helper" in k:
                name = k[22:]
            elif "module." in k:
                name = k[7:]  # remove "module."
                #print(k)
            else:
                name = k
            
            '''
            if "module." in k:
                name = k[7:]  # remove "module."
            else:
                name = k
            '''
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict, strict=False)
        model.dx.data = torch.tensor([0.85, 0.5, 1.0]).to(device)
        # model.dx.data = torch.tensor([0.5, 0.5, 0.5]).to(device)
        # model.nx.data = torch.tensor([204, 40, 12]).to(device)
        # model.bx.data = torch.tensor([0.25, -9.75, -1.75]).to(device)
    # 封装之前要把模型移到对应的gpu
    model.to(device)

    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, writer)
		#  DDP封装
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                            output_device=args.local_rank,find_unused_parameters=True)

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    # opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)


    loss_fn = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_ll = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_sl = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_zc = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_ar = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_rs = SegLoss(pos_weight).cuda(args.local_rank)
    loss_fn_cl = SimpleLoss(pos_weight).cuda(args.local_rank)
    loss_fn_lf_pred = SimpleLoss(pos_weight).cuda(args.local_rank)
    loss_fn_lf_norm = RegLoss(0).cuda(args.local_rank)
    # loss_fn_patch = SimpleLoss(pos_weight).cuda(args.local_rank)
    
    val_step = 1000
    t1 = time()
    t2 = time()
    model.train()
    scaler = torch.cuda.amp.GradScaler()

    train_bev = False # False
    train_occ = True
    for epoch in range(nepochs):
        np.random.seed()
        train_sampler.set_epoch(epoch)
        start = time()
        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d) in enumerate(trainloader):
            t0 = time()
            t =  t0 - t1
            tt = t0 - t2
            t1 = time()

            # print("img_path = ", img_paths[-1][0])
            if 1:
                seg_preds1, seg_preds2, lf_preds, _, _ , loss_osr = model(imgs.to(device), rots.to(device), trans.to(device), intrins.to(device), dist_coeffss.to(device), post_rots.to(device), 
                        post_trans.to(device), cam_pos_embeddings.to(device), fork_scales_gt.to(device),fork_offsets_gt.to(device),fork_oris_gt.to(device), rays.to(device), theta_mat_2d.to(device), counter, 'train')

                if train_bev:
                    lf_pred = lf_preds[:, :, :1].contiguous()
                    lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()
                    # lf_kappa = lf_preds[:, :, (1+4):(1+4+2)].contiguous()

                    lf_out = lf_pred.sigmoid()
                    out = seg_preds1.sigmoid()
                    out1 = seg_preds2.sigmoid()

                    binimgs = binimgs.to(device)
                    seg_preds_0 = seg_preds1[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs0 = binimgs[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_1 = seg_preds1[:, :, 1] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs1 = binimgs[:, :, 1] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_2 = seg_preds1[:, :, 2] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs2 = binimgs[:, :, 2] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_3 = seg_preds2[:, :, 0] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs3 = binimgs[:, :, 3] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_4 = seg_preds1[:, :, 3] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs4 = binimgs[:, :, 4] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    seg_preds_5 = seg_preds1[:, :, 4] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])
                    binimgs5 = binimgs[:, :, 5] * mask_gt[:, :, 0] + (-1) * (1 - mask_gt[:, :, 0])

                    loss_ll = loss_fn_ll(seg_preds1[:, :, 0].contiguous(), binimgs[:, :, 0].contiguous()) + loss_fn_ll(
                        seg_preds_0.contiguous(), binimgs0.contiguous())
                    loss_sl = loss_fn_sl(seg_preds1[:, :, 1].contiguous(), binimgs[:, :, 1].contiguous()) + loss_fn_sl(
                        seg_preds_1.contiguous(), binimgs1.contiguous())
                    loss_zc = loss_fn_zc(seg_preds1[:, :, 2].contiguous(), binimgs[:, :, 2].contiguous()) + loss_fn_zc(
                        seg_preds_2.contiguous(), binimgs2.contiguous())
                    loss_ar = loss_fn_ar(seg_preds2[:, :, 0].contiguous(), binimgs[:, :, 3].contiguous()) + loss_fn_ar(
                        seg_preds_3.contiguous(), binimgs3.contiguous())
                    loss_rs = loss_fn_rs(seg_preds1[:, :, 3].contiguous(), binimgs[:, :, 4].contiguous()) + loss_fn_rs(
                        seg_preds_4.contiguous(), binimgs4.contiguous())
                    loss_cl = loss_fn_cl(seg_preds1[:, :, 4].contiguous(), binimgs[:, :, 5].contiguous()) + loss_fn_cl(
                        seg_preds_5.contiguous(), binimgs5.contiguous())
            
                    # lf_norm_gt0 = torch.unsqueeze(torch.sum(lf_norm_gt, 2), 2)
                    norm_mask = (lf_norm_gt > -500)
                    # norm_mask = ((lf_label_gt>-0.5)).repeat(1, 1, 4, 1, 1)

                    scale_lf = 5.
                    loss_lf = loss_fn_lf_pred(lf_pred, lf_label_gt.to(device)) + loss_fn_lf_norm(lf_norm[norm_mask], scale_lf*lf_norm_gt[norm_mask].to(device))
                    # loss_ilf = loss_fn_lf_pred(lf_ipred, lf_label_gt.to(device)) + loss_fn_lf_norm(scale_lf*lf_inorm[norm_mask], scale_lf*lf_norm_gt[norm_mask].to(device))
                    # loss_lf_crop = loss_fn_patch(lf_crop_preds, fork_patch_gt.to(device))
                    # print('lf_loss = ', loss_lf)
                    loss_gnd = loss_lf + loss_ll + loss_sl + loss_zc + loss_ar + loss_rs + loss_cl# + loss_ilf
                    # loss = loss_ll + loss_sl + loss_zc + loss_ar + loss_rs + loss_cl

                if train_occ:
                    # loss = loss_gnd + loss_osr
                    loss = loss_osr
                    #loss = loss_gnd
                    opt.zero_grad()
                    # scaler.scale(loss).backward()
                    loss.backward()
                    clip_debug = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    opt.step()
            # except:
                # continue
            # scaler.step(opt)
            # scaler.update()
            t2 = time()
            writer.add_scalar('train/clip_debug', clip_debug.item(), counter)
            if counter % 10 == 0 and args.local_rank==0:
                print(counter, loss.item(),  time() - start)

            if train_bev:
                if counter % 10 == 0 and args.local_rank==0:
                    # print(loss_lf.item(), loss_ll.item(), loss_sl.item(), loss_zc.item(), loss_ar.item(), loss_rs.item(), loss_cl.item())
                    # print(counter, loss.item(), loss_gnd.item(), loss_osr.item(), time() - start)
                    # print(counter, loss.item(), time() - start)
                    writer.add_scalar('train/loss', loss, counter)
                    writer.add_scalar('train/loss_ll', loss_ll, counter)
                    writer.add_scalar('train/loss_sl', loss_sl, counter)
                    writer.add_scalar('train/loss_zc', loss_zc, counter)
                    writer.add_scalar('train/loss_ar', loss_ar, counter)
                    writer.add_scalar('train/loss_rs', loss_rs, counter)
                    writer.add_scalar('train/loss_cl', loss_cl, counter)
                    writer.add_scalar('train/loss_lf', loss_lf, counter)
                    # writer.add_scalar('train/loss_lf_crop', loss_lf_crop, counter)
                    writer.add_scalar('train/loss_gnd', loss_gnd, counter)
                    writer.add_scalar('train/loss_osr', loss_osr, counter)
                    writer.add_scalar('train/clip_debug', clip_debug.item(), counter)

                if counter % 50 == 0 and args.local_rank==0:
                    _, _, iou_ll = get_batch_iou(seg_preds1[:, :, 0].contiguous(), binimgs[:, :, 0].contiguous())
                    _, _, iou_sl = get_batch_iou(seg_preds1[:, :, 1].contiguous(), binimgs[:, :, 1].contiguous())
                    _, _, iou_zc = get_batch_iou(seg_preds1[:, :, 2].contiguous(), binimgs[:, :, 2].contiguous())
                    _, _, iou_ar = get_batch_iou(seg_preds2[:, :, 0].contiguous(), binimgs[:, :, 3].contiguous())
                    _, _, iou_rs = get_batch_iou(seg_preds1[:, :, 3].contiguous(), binimgs[:, :, 4].contiguous())
                    _, _, iou_cl = get_batch_iou(seg_preds1[:, :, 4].contiguous(), binimgs[:, :, 5].contiguous())
                    writer.add_scalar('train/iou_ll', iou_ll, counter)
                    writer.add_scalar('train/iou_sl', iou_sl, counter)
                    writer.add_scalar('train/iou_zc', iou_zc, counter)
                    writer.add_scalar('train/iou_ar', iou_ar, counter)
                    writer.add_scalar('train/iou_rs', iou_rs, counter)
                    writer.add_scalar('train/iou_cl', iou_cl, counter)
                    writer.add_scalar('train/epoch', epoch, counter)
                    writer.add_scalar('train/step_time', t, counter)
                    writer.add_scalar('train/data_time', tt, counter)

                if counter % 200 == 0 and args.local_rank==0:
                    fH = final_dim[0]
                    fW = final_dim[1]
                    image0 =np.array(denormalize_img(imgs[0, 0]))
                    image1 =np.array(denormalize_img(imgs[0, 1]))
                    # image2 =np.array(denormalize_img(imgs[0, 2]))
                    # image3 =np.array(denormalize_img(imgs[0, 3]))
                    writer.add_image('train/image/00', image0, global_step=counter, dataformats='HWC')
                    writer.add_image('train/image/01', image1, global_step=counter, dataformats='HWC')
                    # writer.add_image('train/image/02', image2, global_step=counter, dataformats='HWC')
                    # writer.add_image('train/image/03', image3, global_step=counter, dataformats='HWC')
                    writer.add_image('train/binimg/0', (binimgs[0, 1, 0:1]+1.)/2.01, global_step=counter)

                    writer.add_image('train/binimg/1', (binimgs[0, 1, 1:2]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/2', (binimgs[0, 1, 2:3]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/3', (binimgs[0, 1, 3:4]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/4', (binimgs[0, 1, 4:5]+1.)/2.01, global_step=counter)
                    writer.add_image('train/binimg/5', (binimgs[0, 1, 5:6]+1.)/2.01, global_step=counter)
                    writer.add_image('train/out/0', out[0, 1, 0:1], global_step=counter)
                    writer.add_image('train/out/1', out[0, 1, 1:2], global_step=counter)
                    writer.add_image('train/out/2', out[0, 1, 2:3], global_step=counter)
                    writer.add_image('train/out/3', out1[0, 1, 0:1], global_step=counter)
                    writer.add_image('train/out/4', out[0, 1, 3:4], global_step=counter)
                    writer.add_image('train/out/5', out[0, 1, 4:5], global_step=counter)

                    writer.add_image('train/lf_label_gt/0', (lf_label_gt[0, 1]+1.)/2.01, global_step=counter)
                    writer.add_image('train/lf_out/0', lf_out[0, 1], global_step=counter)
                    # writer.add_image('train/fork_patch/0', (fork_patch_gt[0, 1, 0:1]+1.)/2.01, global_step=counter)
                    # writer.add_image('train/fork_patch/1', (fork_patch_gt[0, 1, 1:2]+1.)/2.01, global_step=counter)
                    # writer.add_image('train/lf_crop_out/0', lf_crop_out[0, 1, 0:1], global_step=counter)
                    # writer.add_image('train/lf_crop_out/1', lf_crop_out[0, 1, 1:2], global_step=counter)

                    seg_ll_data = binimgs[0, 1, 0].cpu().detach().numpy()
                    seg_cl_data = binimgs[0, 1, 5].cpu().detach().numpy()

                    lf_label_data_gt = lf_label_gt[0, 1, 0].numpy()
                    lf_norm_data_gt = lf_norm_gt[0, 1].numpy()

                    lf_norm_show = np.zeros((480, 160, 3), dtype=np.uint8)
                    ys, xs = np.where(seg_ll_data > 0.5)
                    lf_norm_show[ys, xs, :] = 255

                    ys, xs = np.where(lf_label_data_gt> -0.5)
                    lf_norm_show[ys, xs, :] = 128

                    labels = np.logical_or(seg_ll_data[ys, xs] > 0.5, seg_cl_data[ys, xs] > 0.5)
                    ys = ys[labels]
                    xs = xs[labels]
                    scale = 1.7

                    if ys.shape[0] > 0:
                        for mm in range(0, ys.shape[0], 10):
                            y = ys[mm]
                            x = xs[mm]
                            norm0 = lf_norm_data_gt[0:2, y, x]
                            if norm0[0] == -999.:
                                continue
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm0[0]*50)), y + int(round(scale * (norm0[1]+1)*50))), (0, 0, 255))
                            norm1 = lf_norm_data_gt[2:4, y, x]
                            if norm1[0] == -999.:
                                continue
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm1[0]*50)), y + int(round(scale * (norm1[1]+1)*50))), (255, 0, 0))
                    writer.add_image('train/lf_norm_gt/0',  lf_norm_show, global_step=counter, dataformats='HWC')

                    lf_norm_data = lf_norm[0, 1].detach().cpu().numpy()
                    ys, xs = np.where(np.logical_or(seg_ll_data > 0.5, seg_cl_data > 0.5))
                    lf_norm_show = np.zeros((480, 160, 3), dtype=np.uint8)
                    if ys.shape[0] > 0:
                        for mm in range(0, ys.shape[0], 10):
                            y = ys[mm]
                            x = xs[mm]
                            norm0 = lf_norm_data[0:2, y, x]/scale_lf
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm0[0]*50)), y+int(round(scale * (norm0[1]+1)*50))), (0, 0, 255))
                            norm1 = lf_norm_data[2:4, y, x]/scale_lf
                            cv2.line(lf_norm_show, (x, y), (x+int(round(norm1[0]*50)), y+int(round(scale * (norm1[1]+1)*50))), (255, 0, 0))
                    writer.add_image('train/lf_norm/0',  lf_norm_show, global_step=counter, dataformats='HWC')

            if counter % (1*val_step) == 0 and args.local_rank==0:
                model.eval()
                #mname = os.path.join(logdir, "model{}.pt".format(0))
                #mname = os.path.join(logdir, "model{}.pt".format(counter))#counter))
                #print('saving', mname)
                #torch.save(model.state_dict(), mname)

                checkpt_dir = f"{config.TRAINER.SAVE_DIR}/{args.exp_name}/checkpts/"
                os.makedirs(checkpt_dir, exist_ok=True)
                mname = os.path.join(checkpt_dir, f"model_{counter}.pt")
                torch.save(model.state_dict(), mname)

				
    

            counter += 1


  
if __name__ == '__main__':
    main()

train.sh

PORT=${PORT:-29512}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --master_addr=$MASTER_ADDR \
    --master_port=$PORT \
    --nproc_per_node=2 \  # 对应gpu数量
    main_multii_conv2d.py \
    --cfg_path /mnt/sdb/xzq/occ_project/occ_nerf_st/src/config/train_tongfan_ngp.yaml \
    --num_epochs 50 \
    --num_gpus 2 \
    --num_nodes 1 \
    --batch_size 2048 \
    --test_batch_size 512 \
    --num_workers 2 \
    --exp_name models_20231207_nornn_2d_2_ori_theatmatvalid__st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_bevgrid_conf_adjustnearfar2

Note :

  1. 貌似 单机多卡不需要通讯address, port
  2. 多机多卡才需要
# 单机多卡示例
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py

老模型推理原始脚本 - remove key

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
import torch
import numpy as np
from torch import nn
from collections import OrderedDict
from src.models_goe_1129_nornn_2d_2_ori import compile_model
# from src.models_goe_1129_nornn_2d_2_ori_flash import compile_model
from tensorboardX import SummaryWriter
# from src.data_tfmap_newcxy_ori import compile_data
from src.data_tfmap_newcxy_nextmask2 import compile_data
from src.tools import SimpleLoss, RegLoss, SegLoss, BCEFocalLoss, get_batch_iou, get_val_info, denormalize_img
import sys
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ['RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MASTER_ADDR'] = "localhost"
os.environ['MASTER_PORT'] = "12332"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import argparse
import open3d as o3d
import json
from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.utils.visualization import extract_mesh, extract_mesh2, extract_alpha
from src.rendering.neuconw_helper import NeuconWHelper

pi = 3.1415926

def convert_rollyawpitch_to_rot(roll, yaw, pitch):
    roll *= pi/180.
    yaw *= pi/180.
    pitch *= pi/180.
    Rr = np.array([[0.0, -1.0, 0.0],
                   [0.0, 0.0, -1.0],
                   [1.0, 0.0, 0.0]], dtype=np.float32)
    Rx = np.array([[1.0, 0.0, 0.0],
                   [0.0, np.cos(roll), np.sin(roll)],
                   [0.0, -np.sin(roll), np.cos(roll)]], dtype=np.float32)
    Ry = np.array([[np.cos(pitch), 0.0, -np.sin(pitch)],
                   [0.0, 1.0, 0.0],
                   [np.sin(pitch), 0.0, np.cos(pitch)]], dtype=np.float32)
    Rz = np.array([[np.cos(yaw), np.sin(yaw), 0.0],
                   [-np.sin(yaw), np.cos(yaw), 0.0],
                   [0.0, 0.0, 1.0]], dtype=np.float32)
    R = np.matrix(Rr) * np.matrix(Rx) * np.matrix(Ry) * np.matrix(Rz)
    return R

def get_view_control(vis, idx):
    view_control = vis.get_view_control()
    if idx == 0:
        ### cam view
        # view_control.set_front([-1, 0, 0])
        # view_control.set_lookat([8, 0, 2])
        # view_control.set_up([0, 0, 1])
        # view_control.set_zoom(0.025)
        # view_control.rotate(0, 2100 / 40)

        ### bev observe object depth
        view_control.set_front([-1, 0, 1])
        view_control.set_lookat([30, 0, 0])
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.3)
        view_control.rotate(0, 2100 / 20)

    elif idx == 1:
        view_control.set_front([-1, 0, 0])
        view_control.set_lookat([8, 0, 0])
        # view_control.set_lookat([8, 0, 2])  ### look down
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.025)
        view_control.rotate(0, 2100 / 40)
    return view_control

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)

    args.local_rank = 1
    print("sssss",args.local_rank)
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')
    

    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_ori_st_v0_1bag_bsz4_rays600_data_tfmap_newcxy_ori_theta_matiszero"
    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231201_nornn_2d_2_ori_st_v0_1bag_bsz4_rays800_data_tfmap_newcxy_ori_theta_iszero_z6" # 单包, retrain 2d
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/models_20231204_nornn_2d_2_ori_st_v0_10bag_bsz4_rays1024_data_tfmap_newcxy_ori_theta_iszero_z6_adjustnearfar" # 10包, retrain 2d
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/models_20231204_nornn_2d_2_ori_flash_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_ori_theta_iszero_z6_adjustnearfar_2"
    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231205_nornn_2d_2_ori_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_theta_iszero_bevgrid_conf_adjustnearfar2"
    model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231207_nornn_2d_2_ori_st_v0_1bag_bsz4_rays1024_data_tfmap_newcxy_nextmask2_bevgrid_conf_adjustnearfar2"
    
    model_name = "model_32000.pt"
    ckpt_path = model_path + "/checkpts/" + model_name
    to_result_path = "result/" + model_path.split('/')[-1] + '/' + model_name.split('.')[0]
    viz_train = False
    viz_gnd = False
    viz_osr = True

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-3.0, 5.0, 0.5]
    # dbound=[3.0, 103.0, 2.]

    # xbound=[0.0, 96., 0.5]
    # ybound=[-12.0, 12.0, 0.5]
    # zbound=[-2.0, 4.0, 1]
    # dbound=[3.0, 103.0, 2.]
    xbound=[0.0, 102., 0.85]
    ybound=[-10.0, 10.0, 0.5]
    zbound=[-2.0, 4.0, 1]
    dbound=[3.0, 103.0, 2.]


    bsz=1
    seq_len=5
    nworkers=1
    sample_num = 3200
    datatype = "single"    #multi   single

    version = "0"
    dataroot = "/data/zjj/data/aishare/share"
    # dataroot = "/run/user/1000/gvfs/sftp:host=192.168.1.40%20-p%2022/mnt/inspurfs/share-directory/defaultShare/aishare/share"


    torch.backends.cudnn.benchmark = True
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }

    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }

    # data_aug_conf = {
    #             'resize_lim': [(0.125, 0.125), (0.25, 0.25)],
    #             'final_dim': (128, 352),
    #             'rot_lim': (0, 0),
    #             'rand_flip': False,
    #             'bot_pct_lim': [(0.0, 0.051), (0.2, 0.2)],
    #             'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
    #             'Ncams': 2,
    #     }
    
    train_sampler, val_sampler,trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
					  grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
					  parser_name='segmentation1data', datatype=datatype)
    loader = trainloader if viz_train else valloader

    writer = SummaryWriter(logdir=None)
    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, writer=writer,phase='validation')
    checkpoint = torch.load(ckpt_path)
    new_state_dict = OrderedDict()
    for k, v in checkpoint.items():

        if "neuconw_helper" in k:
            name = k[22:]  # remove "neuconw_helper.module."
            # name = k[15:]  # remove "neuconw_helper."
            print(k, name)
            continue
        elif "module." in k:
            name = k[7:]  # remove "module."
            print(k)
        else:
            name = k
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict, True)
    model.to(device)
    num_gpus = torch.cuda.device_count()
    # if num_gpus > 1:
    #     model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
    #                                                          output_device=args.local_rank,find_unused_parameters=True)
    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, None)

    ww = 160
    hh = 480
    model.eval()
    fps = 30
    flourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    width = int(3715*300./1110)
    n_view = 2
    roi_num = 2
    osr_hh = int((width + ww * 6)/1853/2*1025)
    if viz_gnd:
        if viz_osr:
            out_shape = (width + ww * 6, hh + osr_hh)
        else:
            out_shape = (width + ww * 6, hh)
    else:
        if viz_osr:
            out_shape = (width + ww * 6, 1080)
        else:
            out_shape = (0, 0)

    colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
    # vis = o3d.visualization.Visualizer()
    # vis.create_window(window_name='bev')
    cur_sce_name = None
    
    count = 0
    with torch.no_grad():
        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d, img_paths, sce_name) in enumerate(loader):
        # for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimgs, lf_label_gt, lf_norm_gt, fork_scales_gt, fork_offsets_gt, fork_oris_gt, rays, theta_mat_2d, theta_mat_3d,  sce_id_ind, idx, img_paths, sce_name) in enumerate(loader):
            if count==0:
                count += 1
                continue
            if sce_name[0] != cur_sce_name:
                sname = '_'.join(sce_name[0].split('/')[-6:-3])
                # output_path = model_path + "/result/" + model_name.split('.')[0] + "/" + sname + '_roi3'
                output_path = to_result_path + "/" + sname
                os.makedirs(output_path, exist_ok=True)
                to_video_path = output_path + "/demo_" + sname + "_train.mp4"
                print(to_video_path)
                to_occ_gt_dir = output_path + '/occ_gts/'
                to_mesh_dir = output_path + '/meshes/'
                to_occ_pred_dir = output_path + '/occ_preds/'
                to_img_dir = output_path + '/img_result/'
                # if cur_sce_name is not None:
                #     videoWriter.release()
                # videoWriter = cv2.VideoWriter(to_video_path, flourcc, fps, out_shape)
                os.makedirs(to_occ_gt_dir, exist_ok=True)
                os.makedirs(to_occ_pred_dir, exist_ok=True)
                os.makedirs(to_mesh_dir, exist_ok=True)
                os.makedirs(to_img_dir, exist_ok=True)
                cur_sce_name = sce_name[0]

            voxel_map_data = model(imgs.to(device), rots.to(device), trans.to(device), 
                                    intrins.to(device), dist_coeffss.to(device), post_rots.to(device), 
                                    post_trans.to(device), cam_pos_embeddings.to(device), fork_scales_gt.to(device),fork_offsets_gt.to(device),fork_oris_gt.to(device), 
                                    rays.to(device), theta_mat_2d.to(device), 0, 'validation')
            
            # voxel_map_data  =model(imgs.to(device),
            #                     rots.to(device),
            #                     trans.to(device),
            #                     intrins.to(device),
            #                     dist_coeffss.to(device),
            #                     post_rots.to(device),
            #                     post_trans.to(device),
            #                     cam_pos_embeddings.to(device),
            #                     fork_scales_gt.to(device),
            #                     fork_offsets_gt.to(device),
            #                     fork_oris_gt.to(device),
            #                     rays.to(device),
            #                     theta_mat_2d.to(device),
            #                     0,
            #                     'validation'
            #                     )

            output_img_merge = np.zeros((out_shape[1], out_shape[0], 3), dtype=np.uint8)
            if viz_gnd:
                print('viz_gnd')
                # norm_mask = (lf_norm_gt > -500)
                binimgs = binimgs.cpu().numpy()
                lf_pred = lf_preds[:, :, :1].contiguous()
                lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()

                seg_out = seg_preds.sigmoid()
                seg_out = seg_out.cpu().numpy()

                lf_out = lf_pred.sigmoid().cpu().numpy()
                lf_norm = lf_norm.cpu().numpy()

                H, W = 944, 1824
                fH, fW = data_aug_conf['final_dim']
                crop0 = []
                crop1 = []
                for cam_idx in range(2):
                    resize = np.mean(data_aug_conf['resize_lim'][cam_idx])
                    resize_dims = (int(fW / resize), int(fH / resize))
                    newfW, newfH = resize_dims
                    # print(newfW, newfH)
                    crop_h = int((1 - np.mean(data_aug_conf['bot_pct_lim'][cam_idx])) * H) - newfH
                    crop_w = int(max(0, W - newfW) / 2)
                    if cam_idx == 0:
                        crop0 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)
                    else:
                        crop1 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)

                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                print('imgname = ', img_paths[-si][0])
                img_org = cv2.imread(img_paths[si][0])

                imgpath = img_paths[si][0][: img_paths[si][0].rfind('org/')-1]
                param_path = imgpath + '/gen/param_infos.json'
                param_infos = {}
                with open(param_path, 'r') as ff :
                    param_infos = json.load(ff)
                yaw = param_infos['yaw']
                pitch = param_infos['pitch']
                if pitch == 0.789806:
                    pitch = -pitch
                roll = param_infos['roll']
                tran = np.array(param_infos['xyz'])

                H, W = param_infos['imgH_ori'], param_infos['imgW_ori']
                ori_K       = np.array(param_infos['ori_K'],dtype=np.float64).reshape(3,3)
                dist_coeffs = np.array(param_infos['dist_coeffs']).astype(np.float64)

                # cam2car_matrix
                rot = convert_rollyawpitch_to_rot(roll, yaw, pitch).I
                cam2car = np.eye(4, dtype= np.float64)
                cam2car[:3, :3] = rot
                cam2car[:3, 3] = tran.T

                norm = lf_norm[0, 4]
                fork = lf_out[0, 4]
                img_res = np.ones((480, 160, 3), dtype=np.uint8)
                colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(0, 255, 255)]
                for class_id in range(6):
                    result = seg_out[0][si][class_id]
                    if class_id == 5:
                        img_res[result> 0.4] = np.array(colors[class_id])
                    else:
                        img_res[result> 0.4] = np.array(colors[class_id])

                    ys, xs = np.where(result > 0.4)
                    pt = np.array([ys*0.2125, 0.125*xs-10, np.zeros(ys.shape), np.ones(ys.shape)])
                    if pt.shape[1] == 0:
                        continue
                    car2cam = np.matrix(cam2car).I.dot(pt)[:3, :]

                    rvec, tvec = np.array([0,0,0], dtype=np.float32), np.array([0,0,0], dtype=np.float32)
                    cam2img, _ = cv2.projectPoints(np.array(car2cam.T), rvec, tvec, ori_K, dist_coeffs)

                    for ii in range(cam2img.shape[0]):
                        ptx = round(cam2img[ii,0,0])
                        pty = round(cam2img[ii,0,1])
                        cv2.circle(img_org, (ptx, pty), 3, colors[class_id], -1)


                    # gt = binimgs[0][si][class_id]
                    # img_res[gt< -0.5] = np.array((128,128,128))
                img_res = cv2.flip(cv2.flip(img_res, 0), 1)

                img_gt = np.ones((480, 160, 3), dtype=np.uint8)
                for class_id in range(6):
                    result = binimgs[0][si][class_id]
                    img_gt[result> 0.5] = np.array(colors[class_id])
                    img_gt[result< -0.5] = np.array((128,128,128))


                img_gt = cv2.flip(cv2.flip(img_gt, 0), 1)

                cv2.rectangle(img_org, (int(crop0[0]), int(crop0[1])), (int(crop0[2]), int(crop0[3])), (0,255,255), 2)
                cv2.rectangle(img_org, (int(crop1[0]), int(crop1[1])), (int(crop1[2]), int(crop1[3])), (0,255,0), 2)
                img_org = cv2.resize(img_org, (width, hh))
                img_org_show = np.zeros((hh, width+ww*6, 3), dtype=np.uint8)*255
                img_org_show[:, ww*6:] = img_org

                outs = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)

                ys, xs = np.where(lf_label_gt[0, si, 0] > -0.5)
                ys1, xs1 = np.where(lf_label_gt[0, si, 0] > 0.5)
                ys2, xs2 = np.where(lf_out[0, si, 0] > 0.5)


                gts[si][binimgs[0, si, 0] > 0.5] = np.array(colors[0])
                outs[si][seg_out[0, si, 0] > 0.5] = np.array(colors[0])

                gts[si][binimgs[0, si, 4] > 0.6] = np.array(colors[4])
                outs[si][seg_out[0, si, 4] > 0.6] = np.array(colors[4])

                gts[si][binimgs[0, si, 5] > 0.6] = np.array(colors[5])
                outs[si][seg_out[0, si, 5] > 0.6] = np.array(colors[5])

                valid_mask = np.sum(gts[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                gts1[si][ys1, xs1, :] = 255

                mask = torch.squeeze(lf_norm_gt[:,si,0])
                # gts2[si][mask < -500] = (128, 128, 128)
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        # for mm in range(0, 800, 100):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm_gt[0, si, 0:2, y, x].numpy()
                        if norm[0] == -999.:
                            continue
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm_gt[0, si, 2:4, y, x].numpy()
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)
                        # print (norm)
                        # cv2.circle(gts2[si], (x, y), 3, (0, 255, 255))


                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > 0.5, seg_out[0][si][5] > 0.5))
                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > -0.5, seg_out[0][si][5] > -0.5))
                valid_mask = np.sum(outs[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                outs1[si][ys2, xs2, :] = 255
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm[0, si, 0:2, y, x] / 5.
                        # print (norm)
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm[0, si, 2:4, y, x] / 5.
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)

                # gts2[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)
                # gts1[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)

                img_org_show[:, :ww] = img_res
                img_org_show[:, ww:ww*2] = img_gt
                img_org_show[:, ww*2:ww*3] = cv2.flip(cv2.flip(outs2[si], 0), 1)
                img_org_show[:, ww*3:ww*4] = cv2.flip(cv2.flip(gts2[si], 0), 1)
                img_org_show[:, ww*4:ww*5] = cv2.flip(cv2.flip(outs1[si], 0), 1)
                img_org_show[:, ww*5:ww*6] = cv2.flip(cv2.flip(gts1[si], 0), 1)

                cv2.putText(img_org_show, "NAME:" + imgname + 'seq_id: '+ str(si), (700+320, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                # print(idxs)

                output_img_merge[:img_org_show.shape[0], :] = img_org_show


            if viz_osr:
                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                # print('imgname = ', img_paths[-si][0])
                output_img = np.zeros((1025, 1853*2, 3), dtype=np.uint8)
                to_occ_gt_path = to_occ_gt_dir + imgname.replace('.jpg', '.ply')
                to_occ_pred_path = to_occ_pred_dir + imgname.replace('.jpg', '.ply')
                to_mesh_path = to_mesh_dir + imgname.replace('.jpg', '.ply')
                to_img_path = to_img_dir + imgname
                to_bin_path = to_img_dir + imgname.replace('.jpg', '.bin')
                idx = rays[0, si, :, 15] < 1

                pts_gt = rays[0, si, idx, 0:3] + rays[0, si, idx, 3:6]*rays[0, si, idx, 9:10]  # gt_pts
                semantic_gt = rays[0, si, idx, 8].view(-1,1)

                # pts = rays_all[si][0, :, :3] + rays_all[si][0, :, 3:6] * rays_all[si][0, :, 9:10]
                # semantic_gt = rays_all[si][0, :, 9:10]
                # np.save(to_occ_gt_path, np.concatenate([pts, semantic_gt], axis=1))

                pcd_gt = o3d.geometry.PointCloud()
                pcd_gt.points = o3d.utility.Vector3dVector(pts_gt.numpy())
                pcd_gt.paint_uniform_color([0, 1, 0])  # 绿色
                o3d.io.write_point_cloud(to_occ_gt_path, pcd_gt)

                voxel_map = {
                    "origin": (model.bx - model.dx / 2).to(device),
                    "size": (model.dx * (model.nx - 1)).to(device),
                    "dx": model.dx.to(device),
                    # "origin": (model_bx - model_dx / 2).to(device),
                    # "size": (model_dx * (model_nx - 1)).to(device),
                    # "dx": model_dx.to(device),
                    "data": voxel_map_data[0][si:si + 1, ...],
                    "all_rays": rays[0, si:si + 1, :, :].view(-1, rays.shape[-1]).to(device),
                    "rots": rots[0, si * roi_num:si * roi_num + 1, ...],
                    "trans": trans[0, si * roi_num:si * roi_num + 1, ...],
                    "intrins": intrins[0, si * roi_num:si * roi_num + 1, ...],
                    "post_rots": post_rots[0, si * roi_num:si * roi_num + 1, ...],
                    "post_trans": post_trans[0, si * roi_num:si * roi_num + 1, ...],
                    # "valid_mask": valid_mask_coo[si:si + 1, ...]
                }
                if 1:
                    all_rays = rays[0,si,idx,:].view(-1,rays.shape[-1]).to(device)                     # 确定渲染的是第几帧的rays
                    sample = {
                        "rays": torch.cat(
                            (all_rays[:, :8], all_rays[:, 9:11],all_rays[:, 15:17]), dim=-1
                        ),
                        "ts": all_rays[:,17],       # delta_t
                        # "ts": torch.ones_like(all_rays[:, -1]).long()*0.,
                        "rgbs": all_rays[:, -3:],     # 索引错的,但是不影响--rgb loss没用上
                        "semantics": all_rays[:, 8],
                    }
                    # pts_generate, depth_loss = neuconw_helper.generate_depth(sample, voxel_map, 0, args.local_rank)  # 由渲染的depth得到预测点
                    # print(">>>>>>>>>>>>>>depth_loss:",depth_loss.mean())
                    # if depth_loss.mean() > 0.2 : print('--imgname--', imgname)
                    # # depth_loss_mean_list.append(depth_loss.mean().detach().cpu().numpy())
                    # # count_list.append(count)

                    # pts_pred = o3d.geometry.PointCloud()
                    # pts_pred.points = o3d.utility.Vector3dVector(np.array(pts_generate.detach().cpu().numpy()))
                    # pts_pred.paint_uniform_color([0, 0, 1])

                    # idx_high_loss = np.where(depth_loss.cpu().numpy()>1.25)  #>0.5
                    # idx_mid_loss = np.where((depth_loss.cpu().numpy()>0.2)*(depth_loss.cpu().numpy()<=1.25))  #0.2~0.5
                    # idx_low_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2
                    # # idx_lower_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2

                    # np.asarray(pts_pred.colors)[idx_high_loss, :] = [1, 0, 0]
                    # np.asarray(pts_pred.colors)[idx_mid_loss, :] = [1, 1, 0]
                    # np.asarray(pts_pred.colors)[idx_low_loss, :] = [0, 1, 0]

                    # # o3d.io.write_point_cloud(
                    # #     f"/home/algo/1/1/debug_pts_gen_car_" + imgname.split('.jpg')[0] + ".ply", pts_pred)
                    # o3d.io.write_point_cloud(os.path.join(to_occ_pred_dir + imgname.replace('.jpg', '_pred.ply')), pts_pred)

                if 1:
                    out_info = extract_alpha(
                        voxel_map, dim=512,  # np.int(np.round(self.scene_config["radius"]/(3**(1/3))/0.1))
                        # chunk=16384,
                        chunk=8192,
                        with_color=False,
                        embedding_a=neuconw_helper.embedding_a((torch.ones(1).cuda() * 1).long()),
                        renderer=neuconw_helper.renderer
                    )

                    # mesh, out_info = extract_mesh2(voxel_map, renderer=neuconw_helper.renderer)
                    np.save(to_occ_pred_path, out_info)

                    # mesh.export(to_mesh_path)
                    # mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(
                    # mesh.vertices.copy()),
                    # triangles=o3d.utility.Vector3iVector(
                    #     mesh.faces.copy()))
                    # mesh.compute_vertex_normals()

                    # for idx_v in range(n_view):
                    #     if idx_v == 0:
                    #         vis.add_geometry(mesh, True)
                    #         vis.add_geometry(pcd_gt, True)
                    #     else:
                    #         vis.add_geometry(mesh, True)

                    #     view_control = get_view_control(vis, idx_v)
                    #     vis.poll_events()
                    #     vis.update_renderer()
                    #     # vis.run()
                    #     mesh_capture_img = vis.capture_screen_float_buffer(True)
                    #     vis.clear_geometries()
                    #     mesh_capture_img = np.array(np.asarray(mesh_capture_img)[..., ::-1] * 255, dtype=np.uint8)
                    #     output_img[:, mesh_capture_img.shape[1] * idx_v:mesh_capture_img.shape[1] * (idx_v + 1),:] = mesh_capture_img
                    #     output_img_resize = cv2.resize(output_img, (out_shape[0], osr_hh))
                    #     output_img_merge[hh:, :] = output_img_resize

            cv2.imwrite(to_img_path, output_img_merge)
            # videoWriter.write(output_img_merge)
            # c = cv2.waitKey(1)%0x100
            # if c == 27:
            #     break
            print(1)
            count += 1


if __name__ == '__main__':
    main()


**老模型-mmcv [load_checkpoint] 加载模型 **

"""
Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""

import os
from pathlib import Path 
from collections import OrderedDict
import numpy as np
import torch
# from src.models_goe_1129_nornn_2d_2 import compile_model
from src.models_goe_1129_nornn_v8 import compile_model
from src.data_tfmap_newcxy_ori import compile_data
# from src.data_tfmap_newcxy_nextmask2 import compile_data
import cv2

import open3d as o3d
import json
from src.config.defaults import get_cfg_defaults
from src.options import get_opts
from src.utils.visualization import  extract_alpha
from src.rendering.neuconw_helper import NeuconWHelper

from mmcv.runner import load_checkpoint

"  推理关闭数据层train_sampler --  # train_sampler = val_sampler = None"


os.environ["CUDA_VISIBLE_DEVICES"] = "4"
os.environ['RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MASTER_ADDR'] = "localhost"
os.environ['MASTER_PORT'] = "12331"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

pi = 3.1415926

def convert_rollyawpitch_to_rot(roll, yaw, pitch):
    roll *= pi/180.
    yaw *= pi/180.
    pitch *= pi/180.
    Rr = np.array([[0.0, -1.0, 0.0],
                   [0.0, 0.0, -1.0],
                   [1.0, 0.0, 0.0]], dtype=np.float32)
    Rx = np.array([[1.0, 0.0, 0.0],
                   [0.0, np.cos(roll), np.sin(roll)],
                   [0.0, -np.sin(roll), np.cos(roll)]], dtype=np.float32)
    Ry = np.array([[np.cos(pitch), 0.0, -np.sin(pitch)],
                   [0.0, 1.0, 0.0],
                   [np.sin(pitch), 0.0, np.cos(pitch)]], dtype=np.float32)
    Rz = np.array([[np.cos(yaw), np.sin(yaw), 0.0],
                   [-np.sin(yaw), np.cos(yaw), 0.0],
                   [0.0, 0.0, 1.0]], dtype=np.float32)
    R = np.matrix(Rr) * np.matrix(Rx) * np.matrix(Ry) * np.matrix(Rz)
    return R

def get_view_control(vis, idx):
    view_control = vis.get_view_control()
    if idx == 0:
        ### cam view
        # view_control.set_front([-1, 0, 0])
        # view_control.set_lookat([8, 0, 2])
        # view_control.set_up([0, 0, 1])
        # view_control.set_zoom(0.025)
        # view_control.rotate(0, 2100 / 40)

        ### bev observe object depth
        view_control.set_front([-1, 0, 1])
        view_control.set_lookat([30, 0, 0])
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.3)
        view_control.rotate(0, 2100 / 20)

    elif idx == 1:
        view_control.set_front([-1, 0, 0])
        view_control.set_lookat([8, 0, 0])
        # view_control.set_lookat([8, 0, 2])  ### look down
        view_control.set_up([0, 0, 1])
        view_control.set_zoom(0.025)
        view_control.rotate(0, 2100 / 40)
    return view_control

def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--local_rank", default = 0, type=int)
    # args = parser.parse_args()

    args = get_opts()
    config = get_cfg_defaults()
    config.merge_from_file(args.cfg_path)

    args.local_rank = 1
    print("sssss",args.local_rank)
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device=torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method='env://')

    # model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_st_v0_1bag_bsz4_rays800_data_tfmap_newcxy_ori"
    model_path = "/mnt/sdb/xzq/occ_project/occ_nerf_st/checkpoints/models_20231128_nornn_2d_2_st_v0_10bag_bsz4_rays800"
    # model_path = "/home/algo/mnt/xzq/occ_project/occ_nerf_st/checkpoints/nerf_1204_nornn_v8_st_pretrain_data_tfmap_newcxy_nextmask2_1bag_adjustnearfar_newcondition"  # adjust_nearfar1
    
    model_name = "model_20000.pt"
    ckpt_path = model_path + "/checkpts/" + model_name

    to_result_path = "result/" + model_path.split('/')[-1] + '/' + model_name.split('.')[0] + '_p2'

    viz_train = False
    viz_gnd = False
    viz_osr = True


    bsz=1
    seq_len=5
    nworkers=6
    sample_num = 512
    datatype = "single"    #multi   single

    version = "0"
    # dataroot = "/home/algo/dataSpace/NeRF/bev_ground/data/aishare/share"
    #dataroot='/defaultShare/user-data'
    dataroot = "/data/zjj/data/aishare/share"

    xbound=[0.0, 96., 0.5]
    ybound=[-12.0, 12.0, 0.5]
    zbound=[-3.0, 5.0, 0.5]
    dbound=[3.0, 103.0, 2.]
    grid_conf = {
        'xbound': xbound,
        'ybound': ybound,
        'zbound': zbound,
        'dbound': dbound,
    }

    data_aug_conf = {
                'resize_lim': [(0.05, 0.4), (0.3, 0.90)],#(0.3-0.9)
                'final_dim': (128, 352),
                'rot_lim': (-5.4, 5.4),
                # 'H': H, 'W': W,
                'rand_flip': False,
                'bot_pct_lim': [(0.04, 0.35), (0.15, 0.4)],
                # 'bot_pct_lim': [(0.04, 0.35), (0.4, 0.4)],
                'cams': ['CAM_FRONT0', 'CAM_FRONT1'],
                'Ncams': 2,
            }


    train_sampler, val_sampler,trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
                      grid_conf=grid_conf, bsz=bsz, seq_len=seq_len, sample_num=sample_num, nworkers=nworkers,
                      parser_name='segmentation1data', datatype=datatype)
    loader = trainloader if viz_train else valloader

    model = compile_model(grid_conf, data_aug_conf, seq_len=seq_len, batchsize=int(bsz), config=config, args=args, phase='validation')
    checkpoint = load_checkpoint(model, ckpt_path, map_location='cpu')

# #------------------------------
#     checkpoint = torch.load(ckpt_path)
#     new_state_dict = OrderedDict()
#     for k, v in checkpoint.items():

#         if "neuconw_helper" in k:
#             # name = k[22:]  # remove "neuconw_helper.module."
#             name = k[15:]  # remove "neuconw_helper."
#             print(k, name)
#             continue
#         elif "module." in k:
#             name = k[7:]  # remove "module."
#             print(k)
#         else:
#             name = k
#         new_state_dict[name] = v

#     model.load_state_dict(new_state_dict, True)
# #------------------------------

    
    model.to(device)
    neuconw_helper = NeuconWHelper(args, config, model.neuconw, model.embedding_a, None)

    ww = 160
    hh = 480
    model.eval()
    fps = 30
    flourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    width = int(3715*300./1110)
    n_view = 2
    roi_num = 2
    osr_hh = int((width + ww * 6)/1853/2*1025)
    if viz_gnd:
        if viz_osr:
            out_shape = (width + ww * 6, hh + osr_hh)
        else:
            out_shape = (width + ww * 6, hh)
    else:
        if viz_osr:
            out_shape = (width + ww * 6, 1080)
        else:
            out_shape = (0, 0)

    colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]
    # vis = o3d.visualization.Visualizer()
    # vis.create_window(window_name='bev')
    cur_sce_name = None

    count = 0
    with torch.no_grad():

        for batchi, (imgs, rots, trans, intrins, dist_coeffss, post_rots, post_trans, cam_pos_embeddings, binimg, lf_label,   lf_norm,   fork_scale,    fork_offset, fork_ori, rays, pose_mats_2d, pose_mats_3d, img_paths, sce_name) in enumerate(valloader):

            if sce_name[0] != cur_sce_name:
                sname = '_'.join(sce_name[0].split('/')[-6:-3])
                # output_path = model_path + "/result/" + model_name.split('.')[0] + "/" + sname + '_roi3'
                output_path = to_result_path + "/" + sname
                os.makedirs(output_path, exist_ok=True)
                to_video_path = output_path + "/demo_" + sname + "_train.mp4"
                print(to_video_path)
                to_occ_gt_dir = output_path + '/occ_gts/'
                to_mesh_dir = output_path + '/meshes/'
                to_occ_pred_dir = output_path + '/occ_preds/'
                to_img_dir = output_path + '/img_result/'
                # if cur_sce_name is not None:
                #     videoWriter.release()
                # videoWriter = cv2.VideoWriter(to_video_path, flourcc, fps, out_shape)
                os.makedirs(to_occ_gt_dir, exist_ok=True)
                os.makedirs(to_occ_pred_dir, exist_ok=True)
                os.makedirs(to_mesh_dir, exist_ok=True)
                os.makedirs(to_img_dir, exist_ok=True)
                cur_sce_name = sce_name[0]

            voxel_map_data = model(imgs.to(device),
                                rots.to(device),
                                trans.to(device),
                                intrins.to(device),
                                dist_coeffss.to(device),
                                post_rots.to(device),
                                post_trans.to(device),
                                cam_pos_embeddings.to(device),
                                fork_scale.to(device),
                                fork_offset.to(device),
                                fork_ori.to(device),
                                rays,
                                pose_mats_2d.to(device),
                                0,
                                'validation'
                                )

            output_img_merge = np.zeros((out_shape[1], out_shape[0], 3), dtype=np.uint8)
            if viz_gnd:
                print('viz_gnd')
                # norm_mask = (lf_norm_gt > -500)
                binimgs = binimgs.cpu().numpy()
                lf_pred = lf_preds[:, :, :1].contiguous()
                lf_norm = lf_preds[:, :, 1:(1+4)].contiguous()

                seg_out = seg_preds.sigmoid()
                seg_out = seg_out.cpu().numpy()

                lf_out = lf_pred.sigmoid().cpu().numpy()
                lf_norm = lf_norm.cpu().numpy()

                H, W = 944, 1824
                fH, fW = data_aug_conf['final_dim']
                crop0 = []
                crop1 = []
                for cam_idx in range(2):
                    resize = np.mean(data_aug_conf['resize_lim'][cam_idx])
                    resize_dims = (int(fW / resize), int(fH / resize))
                    newfW, newfH = resize_dims
                    # print(newfW, newfH)
                    crop_h = int((1 - np.mean(data_aug_conf['bot_pct_lim'][cam_idx])) * H) - newfH
                    crop_w = int(max(0, W - newfW) / 2)
                    if cam_idx == 0:
                        crop0 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)
                    else:
                        crop1 = (crop_w, crop_h, crop_w + newfW, crop_h + newfH)

                si = seq_len - 1
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                print('imgname = ', img_paths[-si][0])
                img_org = cv2.imread(img_paths[si][0])

                imgpath = img_paths[si][0][: img_paths[si][0].rfind('org/')-1]
                param_path = imgpath + '/gen/param_infos.json'
                param_infos = {}
                with open(param_path, 'r') as ff :
                    param_infos = json.load(ff)
                yaw = param_infos['yaw']
                pitch = param_infos['pitch']
                if pitch == 0.789806:
                    pitch = -pitch
                roll = param_infos['roll']
                tran = np.array(param_infos['xyz'])

                H, W = param_infos['imgH_ori'], param_infos['imgW_ori']
                ori_K       = np.array(param_infos['ori_K'],dtype=np.float64).reshape(3,3)
                dist_coeffs = np.array(param_infos['dist_coeffs']).astype(np.float64)

                # cam2car_matrix
                rot = convert_rollyawpitch_to_rot(roll, yaw, pitch).I
                cam2car = np.eye(4, dtype= np.float64)
                cam2car[:3, :3] = rot
                cam2car[:3, 3] = tran.T

                norm = lf_norm[0, 4]
                fork = lf_out[0, 4]
                img_res = np.ones((480, 160, 3), dtype=np.uint8)
                colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(0, 255, 255)]
                for class_id in range(6):
                    result = seg_out[0][si][class_id]
                    if class_id == 5:
                        img_res[result> 0.4] = np.array(colors[class_id])
                    else:
                        img_res[result> 0.4] = np.array(colors[class_id])

                    ys, xs = np.where(result > 0.4)
                    pt = np.array([ys*0.2125, 0.125*xs-10, np.zeros(ys.shape), np.ones(ys.shape)])
                    if pt.shape[1] == 0:
                        continue
                    car2cam = np.matrix(cam2car).I.dot(pt)[:3, :]

                    rvec, tvec = np.array([0,0,0], dtype=np.float32), np.array([0,0,0], dtype=np.float32)
                    cam2img, _ = cv2.projectPoints(np.array(car2cam.T), rvec, tvec, ori_K, dist_coeffs)

                    for ii in range(cam2img.shape[0]):
                        ptx = round(cam2img[ii,0,0])
                        pty = round(cam2img[ii,0,1])
                        cv2.circle(img_org, (ptx, pty), 3, colors[class_id], -1)


                    # gt = binimgs[0][si][class_id]
                    # img_res[gt< -0.5] = np.array((128,128,128))
                img_res = cv2.flip(cv2.flip(img_res, 0), 1)

                img_gt = np.ones((480, 160, 3), dtype=np.uint8)
                for class_id in range(6):
                    result = binimgs[0][si][class_id]
                    img_gt[result> 0.5] = np.array(colors[class_id])
                    img_gt[result< -0.5] = np.array((128,128,128))


                img_gt = cv2.flip(cv2.flip(img_gt, 0), 1)

                cv2.rectangle(img_org, (int(crop0[0]), int(crop0[1])), (int(crop0[2]), int(crop0[3])), (0,255,255), 2)
                cv2.rectangle(img_org, (int(crop1[0]), int(crop1[1])), (int(crop1[2]), int(crop1[3])), (0,255,0), 2)
                img_org = cv2.resize(img_org, (width, hh))
                img_org_show = np.zeros((hh, width+ww*6, 3), dtype=np.uint8)*255
                img_org_show[:, ww*6:] = img_org

                outs = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                outs2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts1 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)
                gts2 = np.zeros((seq_len, hh, ww, 3), dtype=np.uint8)

                ys, xs = np.where(lf_label_gt[0, si, 0] > -0.5)
                ys1, xs1 = np.where(lf_label_gt[0, si, 0] > 0.5)
                ys2, xs2 = np.where(lf_out[0, si, 0] > 0.5)


                gts[si][binimgs[0, si, 0] > 0.5] = np.array(colors[0])
                outs[si][seg_out[0, si, 0] > 0.5] = np.array(colors[0])

                gts[si][binimgs[0, si, 4] > 0.6] = np.array(colors[4])
                outs[si][seg_out[0, si, 4] > 0.6] = np.array(colors[4])

                gts[si][binimgs[0, si, 5] > 0.6] = np.array(colors[5])
                outs[si][seg_out[0, si, 5] > 0.6] = np.array(colors[5])

                valid_mask = np.sum(gts[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                gts1[si][ys1, xs1, :] = 255

                mask = torch.squeeze(lf_norm_gt[:,si,0])
                # gts2[si][mask < -500] = (128, 128, 128)
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        # for mm in range(0, 800, 100):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm_gt[0, si, 0:2, y, x].numpy()
                        if norm[0] == -999.:
                            continue
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm_gt[0, si, 2:4, y, x].numpy()
                        cv2.line(gts2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)
                        # print (norm)
                        # cv2.circle(gts2[si], (x, y), 3, (0, 255, 255))


                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > 0.5, seg_out[0][si][5] > 0.5))
                # ys, xs = np.where(np.logical_or(seg_out[0][si][0] > -0.5, seg_out[0][si][5] > -0.5))
                valid_mask = np.sum(outs[si], axis=-1) > 0
                labels = np.where(valid_mask[ys, xs]> 0.5)
                ys = ys[labels]
                xs = xs[labels]
                outs1[si][ys2, xs2, :] = 255
                if xs.shape[0] > 0:
                    for mm in range(0, xs.shape[0], 2):
                        y = ys[mm]
                        x = xs[mm]
                        norm = lf_norm[0, si, 0:2, y, x] / 5.
                        # print (norm)
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (0, 255, 0),1)
                        norm = lf_norm[0, si, 2:4, y, x] / 5.
                        cv2.line(outs2[si], (x, y), (x+int(round((norm[1]+1)*100)), y+int(0.5*round(norm[0]*-100))), (255, 0, 0),1)

                # gts2[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)
                # gts1[si][lf_label_gt[0, si, 0] < -0.5] = (128,128,128)

                img_org_show[:, :ww] = img_res
                img_org_show[:, ww:ww*2] = img_gt
                img_org_show[:, ww*2:ww*3] = cv2.flip(cv2.flip(outs2[si], 0), 1)
                img_org_show[:, ww*3:ww*4] = cv2.flip(cv2.flip(gts2[si], 0), 1)
                img_org_show[:, ww*4:ww*5] = cv2.flip(cv2.flip(outs1[si], 0), 1)
                img_org_show[:, ww*5:ww*6] = cv2.flip(cv2.flip(gts1[si], 0), 1)

                cv2.putText(img_org_show, "NAME:" + imgname + 'seq_id: '+ str(si), (700+320, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                # print(idxs)

                output_img_merge[:img_org_show.shape[0], :] = img_org_show


            if viz_osr:
                # si = seq_len - 1
                si = 0
                imgname = img_paths[si][0][img_paths[si][0].rfind('/')+1 :]
                # print('imgname = ', img_paths[-si][0])
                output_img = np.zeros((1025, 1853*2, 3), dtype=np.uint8)
                to_occ_gt_path = to_occ_gt_dir + imgname.replace('.jpg', '.ply')
                to_occ_pred_path = to_occ_pred_dir + imgname.replace('.jpg', '.ply')
                to_mesh_path = to_mesh_dir + imgname.replace('.jpg', '.ply')
                to_img_path = to_img_dir + imgname
                to_bin_path = to_img_dir + imgname.replace('.jpg', '.bin')
                idx = rays[0, si, :, 15] < 1

                pts_gt = rays[0, si, idx, 0:3] + rays[0, si, idx, 3:6]*rays[0, si, idx, 9:10]  # gt_pts
                semantic_gt = rays[0, si, idx, 8].view(-1,1)

                # pts = rays_all[si][0, :, :3] + rays_all[si][0, :, 3:6] * rays_all[si][0, :, 9:10]
                # semantic_gt = rays_all[si][0, :, 9:10]
                # np.save(to_occ_gt_path, np.concatenate([pts, semantic_gt], axis=1))

                pcd_gt = o3d.geometry.PointCloud()
                pcd_gt.points = o3d.utility.Vector3dVector(pts_gt.numpy())
                pcd_gt.paint_uniform_color([0, 1, 0])  # 绿色
                o3d.io.write_point_cloud(to_occ_gt_path, pcd_gt)

                voxel_map = {
                    "origin": (model.bx - model.dx / 2).to(device),
                    "size": (model.dx * (model.nx - 1)).to(device),
                    "dx": model.dx.to(device),
                    # "origin": (model_bx - model_dx / 2).to(device),
                    # "size": (model_dx * (model_nx - 1)).to(device),
                    # "dx": model_dx.to(device),
                    "data": voxel_map_data[0][si:si + 1, ...],
                    "all_rays": rays[0, si:si + 1, :, :].view(-1, rays.shape[-1]).to(device),
                    "rots": rots[0, si * roi_num:si * roi_num + 1, ...],
                    "trans": trans[0, si * roi_num:si * roi_num + 1, ...],
                    "intrins": intrins[0, si * roi_num:si * roi_num + 1, ...],
                    "post_rots": post_rots[0, si * roi_num:si * roi_num + 1, ...],
                    "post_trans": post_trans[0, si * roi_num:si * roi_num + 1, ...],
                    # "valid_mask": valid_mask_coo[si:si + 1, ...]
                }
                all_rays = rays[0,si,idx,:].view(-1,rays.shape[-1]).to(device)                     # 确定渲染的是第几帧的rays
                sample = {
                    "rays": torch.cat(
                        (all_rays[:, :8], all_rays[:, 9:11],all_rays[:, 15:17]), dim=-1
                    ),
                    "ts": all_rays[:,17],       # delta_t
                    # "ts": torch.ones_like(all_rays[:, -1]).long()*0.,
                    "rgbs": all_rays[:, -3:],     # 索引错的,但是不影响--rgb loss没用上
                    "semantics": all_rays[:, 8],
                }
                # pts_generate, depth_loss = neuconw_helper.generate_depth(sample, voxel_map, 0, args.local_rank)  # 由渲染的depth得到预测点 
                # print(">>>>>>>>>>>>>>depth_loss:",depth_loss.mean())
                # if depth_loss.mean() > 0.2 : print('--imgname--', imgname)
                # # depth_loss_mean_list.append(depth_loss.mean().detach().cpu().numpy())
                # # count_list.append(count)

                # pts_pred = o3d.geometry.PointCloud()
                # pts_pred.points = o3d.utility.Vector3dVector(np.array(pts_generate.detach().cpu().numpy()))
                # pts_pred.paint_uniform_color([0, 0, 1]) 

                # idx_high_loss = np.where(depth_loss.cpu().numpy()>1.25)  #>0.5
                # idx_mid_loss = np.where((depth_loss.cpu().numpy()>0.2)*(depth_loss.cpu().numpy()<=1.25))  #0.2~0.5
                # idx_low_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2
                # # idx_lower_loss = np.where(depth_loss.cpu().numpy()<0.2)   #<0.2

                # np.asarray(pts_pred.colors)[idx_high_loss, :] = [1, 0, 0]
                # np.asarray(pts_pred.colors)[idx_mid_loss, :] = [1, 1, 0]
                # np.asarray(pts_pred.colors)[idx_low_loss, :] = [0, 1, 0]

                # # o3d.io.write_point_cloud(
                # #     f"/home/algo/1/1/debug_pts_gen_car_" + imgname.split('.jpg')[0] + ".ply", pts_pred)
                # o3d.io.write_point_cloud(os.path.join(to_occ_pred_dir + imgname.replace('.jpg', '_pred.ply')), pts_pred)

                if 1:
                    out_info = extract_alpha(
                        voxel_map, dim=512,  # np.int(np.round(self.scene_config["radius"]/(3**(1/3))/0.1))
                        chunk=16384,
                        with_color=False,
                        embedding_a=neuconw_helper.embedding_a((torch.ones(1).cuda() * 1).long()),
                        renderer=neuconw_helper.renderer,
                        # model=model
                    )

                    # mesh, out_info = extract_mesh2(voxel_map, renderer=neuconw_helper.renderer)
                    np.save(to_occ_pred_path, out_info)
                    occ_pred = out_info.numpy()
                    _, alpha_static, alpha_transient, valid_masks = occ_pred[:, :3], occ_pred[:, 3], occ_pred[:, 4], occ_pred[:,5]
                    # output_mask = valid_masks * np.logical_and((alpha_transient > 0.2), alpha_transient < 1)
                    output_mask = valid_masks * (alpha_transient > 0.2)
                    out_for_vis = occ_pred[output_mask > 0, :5]
                    np.savetxt(Path(to_occ_pred_path).with_suffix('.txt'), out_for_vis)

                    # mesh.export(to_mesh_path)
                    # mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(
                    # mesh.vertices.copy()),
                    # triangles=o3d.utility.Vector3iVector(
                    #     mesh.faces.copy()))
                    # mesh.compute_vertex_normals()

                    # for idx_v in range(n_view):
                    #     if idx_v == 0:
                    #         vis.add_geometry(mesh, True)
                    #         vis.add_geometry(pcd_gt, True)
                    #     else:
                    #         vis.add_geometry(mesh, True)

                    #     view_control = get_view_control(vis, idx_v)
                    #     vis.poll_events()
                    #     vis.update_renderer()
                    #     # vis.run()
                    #     mesh_capture_img = vis.capture_screen_float_buffer(True)
                    #     vis.clear_geometries()
                    #     mesh_capture_img = np.array(np.asarray(mesh_capture_img)[..., ::-1] * 255, dtype=np.uint8)
                    #     output_img[:, mesh_capture_img.shape[1] * idx_v:mesh_capture_img.shape[1] * (idx_v + 1),:] = mesh_capture_img
                    #     output_img_resize = cv2.resize(output_img, (out_shape[0], osr_hh))
                    #     output_img_merge[hh:, :] = output_img_resize

            cv2.imwrite(to_img_path, output_img_merge)
            # videoWriter.write(output_img_merge)
            # c = cv2.waitKey(1)%0x100
            # if c == 27:
            #     break
            # print(1)
            count += 1


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1631082.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Oracle中rman使用记录

最近在项目中&#xff0c;遇到使用RMAN的操作来恢复数据库中某个时间归档日志&#xff0c;RMAN的原理和理解&#xff0c;网友们百度了解一下。我重点将实操部分了。直接上实验环节&#xff0c;让网友更懂。&#xff08;特别提醒&#xff1a;我是1:1用VMware克隆数据库进行RMAN还…

构建高效智能的理赔业务系统:保险科技的未来

随着保险行业的发展和科技的不断进步&#xff0c;理赔业务作为保险服务的重要环节&#xff0c;也在不断演进和改进。传统的理赔流程可能存在效率低下、信息不透明等问题&#xff0c;而现代化的理赔业务系统则能够通过数字化、智能化等手段提升理赔服务的质量和效率&#xff0c;…

Java集成结巴中文分词器、Springboot项目整合jieba分词,实现语句最精确的切分、自定义拆词

文章目录 一、jieba介绍二、集成三、原理四、自定义拆词4.1、方式一&#xff1a;在源码的dict.txt中修改然后重新打包(推荐)4.2、新建文件自定义拆词 五、其他问题 一、jieba介绍 jieba是一个分词器&#xff0c;可以实现智能拆词&#xff0c;最早是提供了python包&#xff0c;…

Qt | 窗口的显示及可见性|标题、透明度、启用/禁用|窗口标志、设置其他属性|获取窗口部件、设置父部件|鼠标光标

​显示事件:QEvent::show,处理函数为 showEvent(QShowEvent*) 隐藏事件:QEvent::hide,处理函数为 hideEvent(QHideEvent* ) 01 QWidget 类中与可见性有关的属性 visible:bool 访问函数: bool isVisible() const; virtual void setVisible(bool visible); 02 QWid…

高频面试题:在浏览器搜索框中输入一个URL的完整请求过程?

相信很多小伙伴在校招或者社招面试中都遇到过这个问题 面试官&#xff1a;小伙子&#xff0c;了解 在浏览器搜索框中输入一个URL的完整请求过程吗&#xff1f;详细说说我&#xff1a;eeemm&#xff0c;不太清出具体的过程。整体过程应该是HTTP请求的过程。 如果在面试中不能很…

【C++】---STL容器适配器之底层deque浅析

【C】---STL容器适配器之底层deque浅析 一、deque的使用二、deque的原理1、deque的结构2、deque的底层结构&#xff08;1&#xff09;deque的底层空间&#xff08;2&#xff09;deque如何支持随机访问、deque迭代器 3、deque的优缺点&#xff08;1&#xff09;deque的优势&…

【golang学习之旅】报错:a declared but not used

目录 报错原因解决方法参考 报错 代码很简单&#xff0c;如下所示。可以发现a和b都飙红了&#xff1a; 运行后就会出现报错&#xff1a; 报错翻译过来就是a已经声明但未使用。当时我很疑惑&#xff0c;在其他语言中从来没有这种情况。况且这里的b不是赋值了吗&#xff0c;怎…

Sarcasm detection论文解析 | 通过阅读进行讽刺推理-Reasoning with sarcasm by reading in-between

论文地址 论文地址&#xff1a;[1805.02856] Reasoning with Sarcasm by Reading In-between (arxiv.org) 论文首页 笔记大纲 通过阅读进行讽刺推理论文笔记 &#x1f4c5;出版年份:2018&#x1f4d6;出版期刊:&#x1f4c8;影响因子:&#x1f9d1;文章作者:Tay Yi,Luu Anh…

制作一个RISC-V的操作系统十六-系统调用

文章目录 用户态和内核态mstatus设置模式切换核心流程封装代码背景解释代码示例解析解释目的 用户态和内核态 mstatus设置 此时UIE设置为1和MPIE为1&#xff0c;MPP设置为0 代表当前权限允许UIE中断发生&#xff0c;并且在第一个mret后将权限恢复为用户态&#xff0c;同时MIE也…

17 大数据定制篇-shell编程

第 17 章大数据定制篇-Shell 编程 17.1 为什么要学习 Shell 编程 Linux 运维工程师在进行服务器集群管理时&#xff0c;需要编写 Shell 程序来进行服务器管理。 对于 JavaEE 和 Python 程序员来说&#xff0c;工作的需要&#xff0c;你的老大会要求你编写一些 Shell 脚本进行…

ERP系统和SRM系统有什么关系?

一、什么是ERP系统和SRM系统&#xff1f; ERP系统是一种集成化的管理软件&#xff0c;能够帮助企业实现资源的优化配置&#xff0c;提高运营效率。ERP系统涵盖了企业的各个方面&#xff0c;包括财务、采购、库存、生产、销售、人力资源等&#xff0c;通过对这些方面的管理&…

MMSeg搭建自己的网络

配置结构 首先&#xff0c;我们知道MMSeg矿机的配置文件很多&#xff0c;主要结构如下图所示。 在configs/_base_下是模型配置、数据集配置、以及一些其他的常规配置和运行配置&#xff0c;四类。 configs/all_config目录下存放&#xff0c;即是将四种配置聚合在一起的一个总…

Android优化RecyclerView图片展示:Glide成堆加载批量Bitmap在RecyclerView成片绘制Canvas,Kotlin(b)

Android优化RecyclerView图片展示&#xff1a;Glide成堆加载批量Bitmap在RecyclerView成片绘制Canvas&#xff0c;Kotlin&#xff08;b&#xff09; 对 Android GridLayoutManager Glide批量加载Bitmap绘制Canvas画在RecyclerView&#xff0c;Kotlin&#xff08;a&#xff09;-…

【调研分析】目标在不同焦距和距离下与画面的比例(2.8-3.6-4.0)

之前在做项目中需要极度优化效果和代码运行速度 为此测试了同一个目标在不同焦距和距离下与画面的比例&#xff0c;从而可以方便在指定大小情况下搜索目标 NOTE: 这是早期滑窗检测做目标检测下的工作

分布式与一致性协议之Raft算法(一)

Raft算法 概述 Raft算法属于Multi-Paxos算法&#xff0c;它在兰伯特Multi-Paxos思想的基础上做了一些简化和限制&#xff0c;比如日志必须是连续的&#xff0c;只支持领导者(Leader)、跟随者(Follwer)和候选人(Candidate)3种状态。在理解和算法实现上&#xff0c;Raft算法相对…

【城市】2023浙江省/杭州市定居与生活相关政策(居住证、户籍、引进人才、高层次人才、车房)

【城市】2023浙江省/杭州市定居与生活相关政策1&#xff08;居住证、户籍、引进人才、高层次人才、车房&#xff09; 文章目录 一、户籍身份1、浙江省居住证&#xff08;杭州/地方&#xff09;2、户籍落户/身份证/户口本 二、人才引进1、应届生补贴2、引进人才居住证3、杭州市高…

Kubernetes学习-核心概念篇(三) 核心概念和专业术语

&#x1f3f7;️个人主页&#xff1a;牵着猫散步的鼠鼠 &#x1f3f7;️系列专栏&#xff1a;Kubernetes渐进式学习-专栏 &#x1f3f7;️个人学习笔记&#xff0c;若有缺误&#xff0c;欢迎评论区指正 1. 前言 在前面两篇文章我们简单介绍了什么是K8S&#xff0c;以及K8S的…

【介绍下分布式系统】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

C语言中的三大循环

C语言中为我们提供了三种循环语句&#xff0c;今天我就来与诸君细谈其中之奥妙。循环这一板块总结的内容较多&#xff0c;而且&#xff0c;很重要&#xff01;&#xff08;敲黑板&#xff01;&#xff01;&#xff01;)&#xff0c;所以诸君一定要对此上心&#xff0c;耐住性子…

算法训练营day25

零、回溯算法理论 参考链接13.1 回溯算法 - Hello 算法 (hello-algo.com) 1.尝试与回退 之所以称之为回溯算法&#xff0c;是因为该算法在搜索解空间时会采用“尝试”与“回退”的策略。当算法在搜索过程中遇到某个状态无法继续前进或无法得到满足条件的解时&#xff0c;它会…