centerpoint论文和代码解读

news2024/11/15 17:33:57

 

目录

一、序论

二、论文结构

三、代码


论文地址: https://arxiv.org/pdf/2006.11275.pdf

 代码地址:tianweiy/CenterPoint (github.com)

一、序论

centorpoint是一种anchor-free的方法,直接预测物体的中心点,然后直接回归其whl,省去了anchor与GT匹配过程(传统的anchor-base方法需要计算GT和anchor的iou进行分配),同时基于点的预测方便下游跟踪等任务的进行。论文最后的实验表明,该方法对于物体的旋转角度的学习更强一点。因为初始化只有一个点,强迫模型去学习更多的旋转角度信息。反之,anchor-base的方法因为有anchor的先验,所以模型更容易收敛。

二、论文结构

 

整体的网络架构和pointpillar很像,主要的改动地方在于head部分是anchor-free的。所以我们主要分析的也就是head部分。 

前面的部分,点云经过VFE处理,scatter投影到BEV上,使用FPN的neck对其进行处理得到[B,C,H,W],然后通过一个conv对通道数进行调整,分别经过五个头(其实就是一堆卷积+一个卷积把channel降到需要的维度),得到reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]。预测的reg是在一个像素内的偏移,主要是为了

推理时:将dim求指数,根据rot的正余弦值得到角度,将reg与meshgrid生成的坐标相加得到特征图上的绝对坐标。将他们拼接成[B,H*W,7]的box形式,同时对hm求sigmoid,送入后处理,首先对heatmap在channel维度求max,得到其分数和label,根据类别阈值对hm求mask,看哪些能够保留,然后进行NMS过滤掉多余的框,这里我们就说一阶段的,论文里用的两阶段,还有一个box修正阶段。注意:centorpoint使用了NMS

训练时:首先要得到GT的hm和box,所以先0初始化hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500] 因为每个样本的GT数量不可能一样,所以有的多有的少,统一为500最多,用mask来表示是不是GT,遍历GT个数,根据类别生成相应的hm,高斯半径是根据wh的框的最小iou重叠度确定的,具体见说点Cornernet/Centernet代码里面GT heatmap里面如何应用高斯散射核 - 知乎 (zhihu.com)(分三种,内切,外切,交叉),这里作者限定了高斯半径的最小值。然后看中心点落在哪个pillar里,求个整型做差得到偏移量。对whl求log,对角度求sincos组成anno_box,ind表示该物体中心点在H*W中的下标,cat表示该物体的类别。这样就得到了example。如何画高斯就是用指数的负dist次表示权重,这样离中心点越近,越接近1.

这时有了GT的hm [B,8,h,w]  anno_box [B,500,8] ind [B,500] msk [B,500] cat [B,500]

模型预测的reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W]

对模型预测的hm进行sigmoid,并组成pred_box[B,8,H*W]这时要把pred_box根据ind用gather转换为[B,8,500],用L1loss计算。而hm则直接用Fastfocalloss计算。

三、代码

import logging
from collections import defaultdict
from torch import double, nn
import copy 


import torch
import numpy as np
import torch.nn.functional as F

from ...ops.iou3d_nms import iou3d_nms_cuda
from ..model_utils import model_nms_utils


class Sequential(torch.nn.Module):
    r"""A sequential container.
    Modules will be added to it in the order they are passed in the constructor.
    Alternatively, an ordered dict of modules can also be passed in.

    To make it easier to understand, given is a small example::

        # Example of using Sequential
        model = Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )

        # Example of using Sequential with OrderedDict
        model = Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))

        # Example of using Sequential with kwargs(python 3.6+)
        model = Sequential(
                  conv1=nn.Conv2d(1,20,5),
                  relu1=nn.ReLU(),
                  conv2=nn.Conv2d(20,64,5),
                  relu2=nn.ReLU()
                )
    """

    def __init__(self, *args, **kwargs):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
        for name, module in kwargs.items():
            if sys.version_info < (3, 6):
                raise ValueError("kwargs only supported in py36+")
            if name in self._modules:
                raise ValueError("name exists.")
            self.add_module(name, module)

    def __getitem__(self, idx):
        if not (-len(self) <= idx < len(self)):
            raise IndexError("index {} is out of range".format(idx))
        if idx < 0:
            idx += len(self)
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __len__(self):
        return len(self._modules)

    def add(self, module, name=None):
        if name is None:
            name = str(len(self._modules))
            if name in self._modules:
                raise KeyError("name exists")
        self.add_module(name, module)

    def forward(self, input):
        # i = 0
        for module in self._modules.values():
            # print(i)
            input = module(input)
            # i += 1
        return input




def rotate_nms_pcdet(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
    """
    :param boxes: (N, 7) [x, y, z, l, w, h, theta]
    :param scores: (N)
    :param thresh:
    :return:
    """
    # transform back to pcdet's coordinate
    #将角度转换为openpcdet的坐标
    boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
    boxes[:, -1] = -boxes[:, -1] - np.pi /2

    order = scores.sort(0, descending=True)[1] #将这n个box根据分数从大到小排
    if pre_maxsize is not None:  #如果盒子大于阈值,取前max个
        order = order[:pre_maxsize]

    boxes = boxes[order].contiguous()

    keep = torch.LongTensor(boxes.size(0))

    if len(boxes) == 0:
        num_out =0
    else:
        num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)

    selected = order[keep[:num_out].cuda()].contiguous()

    if post_max_size is not None:
        selected = selected[:post_max_size]

    return selected 


def kaiming_init(
    module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
):
    assert distribution in ["uniform", "normal"]
    if distribution == "uniform":
        nn.init.kaiming_uniform_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    else:
        nn.init.kaiming_normal_(
            module.weight, a=a, mode=mode, nonlinearity=nonlinearity
        )
    if hasattr(module, "bias") and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def gaussian_radius(det_size, min_overlap=0.5):
    """
    compute gaussian radius by min_overlap, you can get principle in <<CenterNet :Objects as Points>> paper
    """
    height, width = det_size  #得到高宽

    a1  = 1
    b1  = (height + width)
    c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1  = (b1 + sq1) / 2

    a2  = 4
    b2  = 2 * (height + width)
    c2  = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2  = (b2 + sq2) / 2

    a3  = 4 * min_overlap
    b3  = -2 * min_overlap * (height + width)
    c3  = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3  = (b3 + sq3) / 2
    return min(r1, r2, r3)

def gaussian2D(shape, sigma=1):
    """
    compute gaussian
    """
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m+1,-n:n+1]  #y[7,1]  x [1,7]

    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) # [7,7],离原点越近越大
    h[h < np.finfo(h.dtype).eps * h.max()] = 0  #np.finfo(h.dtype).eps是指非负的最小值
    return h


def draw_umich_gaussian(heatmap, center, radius, k=1):
    """
    draw gaussian in heatmap
    """
    diameter = 2 * radius + 1 #radius
    # compute gaussian value
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) #是一个7*7的矩阵

    x, y = int(center[0]), int(center[1]) #获得整形的中点坐标

    height, width = heatmap.shape[0:2]

    # get gaussian map pos
    left, right = min(x, radius), min(width - x, radius + 1)  #如果xy落在heatmap的边上,离边的距离小于r,就要限制一下防止越界
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # get masked heatmap pos 
    masked_heatmap  = heatmap[y - top:y + bottom, x - left:x + right] # 得到我们要替换heatmap的位置
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] #得到可用高斯的范围

    # this is used for debug, actuly no use
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) #取两者中较大的部分
    return heatmap

def _gather_feat(feat, ind, mask=None):
    dim  = feat.size(2) # 8
    ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #ind[B,500]--[B,500,1]--[B,500,8] 其表示物体在特征图上的索引
    feat = feat.gather(1, ind)  #根据ind在第一维度H*W找索引ind
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()  # [B,200,380,8]
    feat = feat.view(feat.size(0), -1, feat.size(3)) # [B,H*W,8]
    feat = _gather_feat(feat, ind)
    return feat

def _circle_nms(boxes, min_radius, post_max_size=83):
    """
    NMS according to center distance, no use now
    """
    keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]

    keep = torch.from_numpy(keep).long().to(boxes.device)

    return keep 


class RegLoss(nn.Module):
  '''Regression loss for an output tensor
    Arguments:
      output (batch x dim x h x w)
      mask (batch x max_objects)
      ind (batch x max_objects)
      target (batch x max_objects x dim)
  '''
  def __init__(self):
    super(RegLoss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    # output[B,8,200,380]  pred[B,500,8]
    # compute mask by ind as not all box number is same and not all grid in use
    pred = _transpose_and_gather_feat(output, ind)
    mask = mask.float().unsqueeze(2) 

    # use L1 loss 两者都是[B,500,8]乘上mask计算loss,然后在B和500维度求和,出来八维的loss
    loss = F.l1_loss(pred*mask, target*mask, reduction='none')
    loss = loss / (mask.sum() + 1e-4)
    loss = loss.transpose(2 ,0).sum(dim=2).sum(dim=1)
    return loss

class FastFocalLoss(nn.Module):
  '''
  Reimplemented focal loss, exactly the same as the CornerNet version.
  Faster and costs much less memory.
  '''
  def __init__(self):
    super(FastFocalLoss, self).__init__()

  def forward(self, out, target, ind, mask, cat):
    '''
    Arguments:
      out, target: B x C x H x W
      ind, mask: B x M
      cat (category id for peaks): B x M
    '''
    mask = mask.float()
    gt = torch.pow(1 - target, 4)
    # compute negtive loss in heatmap
    neg_loss = torch.log(1 - out) * torch.pow(out, 2) * gt
    neg_loss = neg_loss.sum()

    pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C
    pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M
    num_pos = mask.sum()

    # compute positive loss in heatmap
    pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * \
               mask.unsqueeze(2)
    pos_loss = pos_loss.sum()
    if num_pos == 0:
      return - neg_loss
    return - (pos_loss + neg_loss) / num_pos



def neg_loss_cornernet(pred, gt, mask=None):
    """
    Refer to https://github.com/tianweiy/CenterPoint.
    Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
    Args:
        pred: (B x 8 x h x w)
        gt: (B x 8 x h x w)
        mask: (B x h x w)
    Returns:
    """
    pos_inds = gt.eq(1).float() #有物体中心点的地方才为1
    neg_inds = gt.lt(1).float() #不是物体中心的为1

    neg_weights = torch.pow(1 - gt, 4) #[B,8,H,W]  #把负样本的权重设置的很小

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds #这样负样本loss会很低

    if mask is not None:
        mask = mask[:, None, :, :].float()
        pos_loss = pos_loss * mask
        neg_loss = neg_loss * mask
        num_pos = (pos_inds.float() * mask).sum()
    else:
        num_pos = pos_inds.float().sum()

    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos  #求完的loss之和除以正样本的个数
    return loss


class FocalLossCenterNet(nn.Module):
    """
    Refer to https://github.com/tianweiy/CenterPoint
    """
    def __init__(self):
        super(FocalLossCenterNet, self).__init__()
        self.neg_loss = neg_loss_cornernet

    def forward(self, out, target, mask=None):
        return self.neg_loss(out, target, mask=mask)



class AssignLabel(object):
    def __init__(self, **kwargs):
        """Return CenterNet training labels like heatmap, height, offset"""

        self.tasks = kwargs["tasks"] #assigner_cfg.target_assigner.tasks

        assigner_cfg = kwargs["cfg"]

        self.out_size_factor = assigner_cfg.out_size_factor # 2
        self.gaussian_overlap = assigner_cfg.gaussian_overlap # 0.1
        self._max_objs = assigner_cfg.max_objs  # 500
        self._min_radius = assigner_cfg.min_radius # 2
        # tasks
        self.class_names = self.tasks["class_names"] # 列表里是八个名字
        self.num_classes = self.tasks["num_class"]  # 8

    def __call__(self, res,  grid_size , voxel_size , pc_range):
        max_objs = self._max_objs   # 500

        feature_map_size = grid_size[:2] // self.out_size_factor  # 得到特征图的长宽
        
        draw_gaussian = draw_umich_gaussian
        # 分别是xyzhwl,yaw,类别
        gt_boxes = res['gt_boxes'].cpu().numpy() # 得到data_dict里的GT  [B,N,8]
        batch_size = res['batch_size']

        # hm is heatmap
        hms, anno_boxs, inds, masks, cats = [], [], [], [], []

        #jinmu: batch one by one compute now
        for batch_idx in range(batch_size):
            batch_box = gt_boxes[batch_idx,...]  #[n,8]
            batch_box_mask = batch_box[...,-1] != 0 # 因为n表示batch里一个样本最多的物体数,有些没有这么多
            #上面这句是指遍历n个物体,最后一维不为0表示有物体
            if np.all(batch_box_mask == False):
                batch_box_valid_num = 0
            else:  # batch_box_mask=[1,1,1,1,0,0,0,0,0]一维的话,np.where只返回列数
                batch_box_valid_num = np.where(batch_box_mask)[0].squeeze().max() + 1 #得到有几个物体

            # c, h, w  [8, 200,380]
            hm = np.zeros((len(self.class_names), feature_map_size[1], feature_map_size[0]),
                            dtype=np.float32)
            # [500, 8]
            anno_box = np.zeros((max_objs, 8), dtype=np.float32)
            # [500]
            ind = np.zeros((max_objs), dtype=np.int64)
            mask = np.zeros((max_objs), dtype=np.uint8) # [500]
            cat = np.zeros((max_objs), dtype=np.int64)  # [500]

            # should keep box number same in different frame to
            # compute in one time, but actualy different frame not 
            # has same box number, so should keep mask
            num_objs = min(batch_box_valid_num, max_objs)  #得到当前帧的物体个数

            for k in range(num_objs):
                cls_id = batch_box[k][-1] - 1  #cls的id
                l, w, h = batch_box[k][3], batch_box[k][4], batch_box[k][5]
                # 得到在特征图上的wl
                w, l = w / voxel_size[1] / self.out_size_factor, l / voxel_size[0] / self.out_size_factor
                if w > 0 and l > 0:  #根据长宽得到高斯半径,根据两个框的最小重叠区,建立r的方程求根,内切外切,一个内一个外
                    radius = gaussian_radius((l, w), min_overlap=self.gaussian_overlap) #wl是浮点数,超参为0.1,得到高斯半径
                    radius = max(self._min_radius, int(radius)) #确保最小的高斯半径为2

                    # 得到中心点在特征图上的坐标
                    x, y, z = batch_box[k][0], batch_box[k][1], batch_box[k][2]
                    coor_x, coor_y = (x - pc_range[0]) / voxel_size[0] / self.out_size_factor, \
                                        (y - pc_range[1]) / voxel_size[1] / self.out_size_factor
                    
                    ct = np.array([coor_x, coor_y], dtype=np.float32)  
                    ct_int = ct.astype(np.int32)  #变为整型

                    # throw out not in range objects to avoid out of array area when creating the heatmap
                    # if beyond range, then continue
                    if not (0 <= ct_int[0] < feature_map_size[0] and 0 <= ct_int[1] < feature_map_size[1]):
                        continue 

                    # draw gaussian in heatmap gt
                    draw_gaussian(hm[int(cls_id)], ct, radius) #画到相应类的heatmap上

                    new_idx = k #表示第k个物体
                    x, y = ct_int[0], ct_int[1]

                    cat[new_idx] = cls_id # 得到相应物体的类别
                    ind[new_idx] = y * feature_map_size[0] + x  # 得到该物体在特征图上的索引
                    mask[new_idx] = 1  #把相应位置的mask赋值为1
                    rot = batch_box[k][6]
                    # fill regression target, ct - (x,y) is x_offset and y_offset
                    # rot is yaw angle
                    anno_box[new_idx] = np.concatenate(
                        (ct - (x, y), z, np.log(batch_box[k][3:6]),
                        np.sin(rot), np.cos(rot)), axis=None)  #得到当前heatmap的xy偏移,whl,sincos,

            hms.append(hm)
            anno_boxs.append(anno_box)
            masks.append(mask)
            inds.append(ind)
            cats.append(cat)

        hms = torch.from_numpy(np.stack(hms)).cuda() #将数组沿着第0维堆叠
        anno_boxs = torch.from_numpy(np.stack(anno_boxs)).cuda()
        inds = torch.from_numpy(np.stack(inds)).cuda()
        cats = torch.from_numpy(np.stack(cats)).cuda()
        masks = torch.from_numpy(np.stack(masks)).cuda()
        # [B,8,h,w]   [B,500,8]  [B,500,1] [B,500,1] [B,500,1]
        example = {'hm': hms, 'anno_box': anno_boxs, 'ind': inds, 'mask': masks, 'cat': cats}

        return example


class SepHead(nn.Module):
    """
    this is seqhead that contains actual head like (heatmap) (lxoffset yoffset) (z) (dim) (cos(theta) sin(theta))
    """
    def __init__(
        self,
        in_channels,
        heads,
        head_conv=64,
        final_kernel=1,
        bn=False,
        init_bias=-2.19,
        **kwargs,
    ):
        super(SepHead, self).__init__(**kwargs)

        self.heads = heads # {'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
        for head in self.heads:  #遍历的是键
            classes, num_conv = self.heads[head] #根据键得到值,第一个最终的channel数,用来回归的,第二个是几个conv

            fc = Sequential()
            # layers number decided by config
            for i in range(num_conv-1):
                fc.add(nn.Conv2d(in_channels, head_conv,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))  #
                if bn:
                    fc.add(nn.BatchNorm2d(head_conv))
                fc.add(nn.ReLU())

            # output conv
            fc.add(nn.Conv2d(head_conv, classes,
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))    
            # hm的偏置是固定的,其余的开明初始化
            if 'hm' in head:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, nn.Conv2d):
                        kaiming_init(m)
            # 每个头都有两个卷积,再接一个卷积用来得到预测结果channel维度
            # python method, 设置完可以用getattr通过head调用fc
            self.__setattr__(head, fc)
        

    def forward(self, x):
        ret_dict = dict()        
        for head in self.heads:
            ret_dict[head] = self.__getattr__(head)(x)
        #ret_dict是一个字典 reg:[B,2,200,380] height [B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        return ret_dict


class CenterHead(nn.Module):
    def __init__(
        self,
        model_cfg,
        input_channels=[128,],
        num_class=1,
        class_names=None,
        grid_size=[0.32,0.32,0.16],
        point_cloud_range=None,
        predict_boxes_when_training=False,
        logger=None,
        init_bias=-2.19,
        num_hm_conv=2,
    ):
        super(CenterHead, self).__init__()
        assert(len(class_names) == num_class)
        
        tasks = dict(num_class=num_class, class_names=class_names)
        self.label_assigner = AssignLabel(cfg=model_cfg.TARGET_ASSIGNER_CONFIG, tasks=tasks)
        
        self.out_size_factor = model_cfg.TARGET_ASSIGNER_CONFIG.out_size_factor # 2
        self.model_cfg = model_cfg

        self.class_names = [class_names] #class_name本来是一个列表现在[[a,b,c,,,,]]
        self.num_classes = [num_class]  # [8]

        self.code_weights = model_cfg.code_weights #[5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0]
        self.weight = model_cfg.weight # 0.25 
        
        self.in_channels = input_channels # 384

        #self.crit = FastFocalLoss()
        self.crit = FocalLossCenterNet()
        self.crit_reg = RegLoss()

        

        common_heads = model_cfg.common_heads #{'reg': [ 2, 2 ],'height': [ 1, 2 ],'dim': [ 3, 2 ],'rot': [ 2, 2 ]}

        self.box_n_dim = 9 if 'vel' in common_heads else 7  # 7
        self.use_direction_classifier = False 

        if not logger:
            logger = logging.getLogger("CenterHead")
        self.logger = logger

        logger.info(
            f"num_classes: {self.num_classes}"
        )

        # a shared convolution 
        share_conv_channel = 64 if "share_conv_channel" not in model_cfg else model_cfg.share_conv_channel # 64
        self.shared_conv = nn.Sequential(
            nn.Conv2d(self.in_channels, share_conv_channel,
            kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(share_conv_channel),
            nn.ReLU(inplace=True)
        )

        self.tasks = nn.ModuleList()
        print("Use HM Bias: ", init_bias)

        for num_cls in self.num_classes:  #[8]相当于就遍历一个8
            heads = copy.deepcopy(common_heads) 
            heads.update(dict(hm=(num_cls, num_hm_conv))) #{'reg': [2, 2], 'height': [1, 2], 'dim': [3, 2], 'rot': [2, 2], 'hm': [8, 2]}
            self.tasks.append(
                SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
            )

        self.frozen_param = model_cfg.FROZON_PARAM
        self.frozen_parameters()

        logger.info("Finish CenterHead Initialization")

    def forward(self, data_dict, *kwargs):

        x = data_dict['spatial_features_2d'] # [B, 384, 200, 380]
        x = self.shared_conv(x)  #先将channel变为64
        ret_dicts = []

        for task in self.tasks:
            ret_dicts.append(task(x))
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        data_dict['centerhead_preds'] = ret_dicts

        return data_dict

    def _sigmoid(self, x):
        y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
        return y

    def loss(self, data_dict, **kwargs):
        #是一个字典根据GT生成的 hm[B,8,H,W],anno_box [B,n,8] ind[B,n] mask[B,n] cat[B,n]
        example = self.label_assigner(data_dict, kwargs["grid_size"], kwargs["voxel_size"], kwargs["pc_range"])

        # get centerhead output reg[B,2,200,380] heigh[B,1,200,380] dim [B,3,200,380] rot [B,2,200,380] hm [B,8,200,380]
        preds_dicts = data_dict['centerhead_preds']

        assert(len(preds_dicts) == 1)
        # TODO refactor this
        preds_dict = preds_dicts[0] #本来是一个数组,得到字典
        
        # apply sigmoid for heatmap output
        preds_dict['hm'] = self._sigmoid(preds_dict['hm']) #对heatmap预测加上sigmoid,自定义的sigmoid,防止梯度消失
        # hm_loss = self.crit(
        #     preds_dict['hm'], 
        #     example['hm'], 
        #     example['ind'], 
        #     example['mask'], 
        #     example['cat']
        #     )
        
        hm_loss = self.crit(preds_dict['hm'], example['hm']) #使用focallosscenternet

        target_box = example['anno_box']
        # not care about vel as not vel now
        if 'vel' in preds_dict:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['vel'], preds_dict['rot']), dim=1)  
        else:
            preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
                                                preds_dict['rot']), dim=1)   

        # Regression loss for dimension, offset, height, rotation  得到长度为8的loss张量          
        box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'], example['ind'], target_box)
        box_loss = box_loss * box_loss.new_tensor(self.code_weights) #这样可以使后面的张量拥有和前面一样的属性
        
        reg_loss = box_loss[:2]
        height_loss = box_loss[2]
        dim_loss = box_loss[2:5]
        rot_loss = box_loss[5:]
        
        loc_loss = box_loss.sum()
        loc_loss *= self.weight

        # total loss
        loss = hm_loss + loc_loss
        #ret = {'loss': loss, 'hm_loss': hm_loss, 'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][0].float().sum()}
        # ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss, 
        #         'reg_loss': reg_loss, 'height_loss': height_loss, 
        #         'dim_loss': dim_loss, 'rot_loss': rot_loss}

        ret = {'hm_loss': hm_loss, 'loc_loss': loc_loss}
        
        return ret
    
    def frozen_parameters(self):
        if self.frozen_param:
            for parameter in self.parameters():
                parameter.requires_grad = False

    @torch.no_grad()
    def predict(self, preds_dicts, test_cfg, **kwargs):
        """decode, nms, then return the detection result.
        """

        voxel_size = kwargs["voxel_size"]
        pc_range = kwargs["pc_range"]

        post_center_range = pc_range
        # reg [B,2,W,H] heigh[B,1,w,h] dim [B,3,W,H] rot [B,2,H,W] hm [B,8,H,W] 是一个字典
        preds_dicts = preds_dicts['centerhead_preds']

        if len(post_center_range) > 0:
            post_center_range = torch.tensor(
                post_center_range,
                dtype=preds_dicts[0]['hm'].dtype,
                device=preds_dicts[0]['hm'].device,
            )

        rets = []
        #jinmu now only support one task
        for task_id, preds_dict in enumerate(preds_dicts):
            # convert B C H W to B H W C 
            for key, val in preds_dict.items():
                preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()

            batch_size = preds_dict['hm'].shape[0]
            batch_hm = torch.sigmoid(preds_dict['hm'])

            # exp for dim output to keep dim > 0
            batch_dim = torch.exp(preds_dict['dim']) #dim is h, w, d

            # cos(theta) and sin(theta)
            batch_rots = preds_dict['rot'][..., 0:1]
            batch_rotc = preds_dict['rot'][..., 1:2]

            # x offset and y offset output
            batch_reg = preds_dict['reg']
            # z output
            batch_hei = preds_dict['height']

            # atan to recover true theta
            batch_rot = torch.atan2(batch_rots, batch_rotc) #根据正余弦得到角度

            batch, H, W, num_cls = batch_hm.size()

            # reshape for compute convenient
            batch_reg = batch_reg.reshape(batch, H*W, 2)
            batch_hei = batch_hei.reshape(batch, H*W, 1)

            batch_rot = batch_rot.reshape(batch, H*W, 1)
            batch_dim = batch_dim.reshape(batch, H*W, 3)
            batch_hm = batch_hm.reshape(batch, H*W, num_cls) #把hw放一块方便计算

            #compute x and y axies for each grid for later to recover lidar axies x y with 
            # x_offset and y_offset
            ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
            ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()
            xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm.device).float()

            # x y  + x_offset y_offset to recover continuous x y value
            xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
            ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]

            xs = xs * self.out_size_factor * voxel_size[0] + pc_range[0]
            ys = ys * self.out_size_factor * voxel_size[1] + pc_range[1]

            # jinmu: not care aboud this as we has not vel output now
            if 'vel' in preds_dict:
                batch_vel = preds_dict['vel']
                batch_vel = batch_vel.reshape(batch, H*W, 2)
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
            else: 
                batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)

            if test_cfg.get('per_class_nms', False):
                pass 
            else:
                rets.append(self.post_processing(batch_box_preds, batch_hm, test_cfg, post_center_range)) 

        assert(len(rets) == 1) # only one task

        return rets[0]

    @torch.no_grad()
    def post_processing(self, batch_box_preds, batch_hm, test_cfg, post_center_range):
        batch_size = len(batch_hm)
        # batch_box_preds [B,H*W,7] batch_hm [B,H*W,8]
        prediction_dicts = []
        for i in range(batch_size):  #一个一个batch处理
            box_preds = batch_box_preds[i]
            hm_preds = batch_hm[i]

            # score and label is get as max operation in heatmap #在八个维度里取个max
            scores, labels = torch.max(hm_preds, dim=-1) #得到最大分数和最大分数的下标(也就是类别)形状都为[H*W]

            # score mask is get as > score_thresh
            #score_mask = scores > test_cfg.score_threshold 
            score_threshold = torch.tensor(test_cfg.score_threshold)[labels] #得到H*W对应类别的thresh
            score_mask = scores > score_threshold.cuda() #如果这个分数大于阈值,就判定为正样本

            # distance_mask means that noly keep 3d box center in some range
            # not use this in perception postprocess code
            distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
                & (box_preds[..., :3] <= post_center_range[3:]).all(1)

            # mask is intersection of two mask
            mask = distance_mask & score_mask 

            # get masked data
            box_preds = box_preds[mask] #得到H*W个box里符合要求的
            scores = scores[mask]
            labels = labels[mask]

            # get box for nms, each box in [x y z dx dy dz theta] format
            boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]

            # bev rotated box nms
            selected = rotate_nms_pcdet(boxes_for_nms, scores, 
                                thresh=test_cfg.nms.nms_iou_threshold,
                                pre_maxsize=test_cfg.nms.nms_pre_max_size,
                                post_max_size=test_cfg.nms.nms_post_max_size)

            # selected is box mask after nms
            selected_boxes = box_preds[selected]
            selected_scores = scores[selected]
            selected_labels = labels[selected]

            # fill result, selected_boxes: n * 7, selected_scores: n * 1,
            # selected_labels: n * 1
            record_dict = {
                'pred_boxes': selected_boxes,
                'pred_scores': selected_scores,
                'pred_labels': selected_labels + 1
            }

            prediction_dicts.append(record_dict)

        return prediction_dicts 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/510117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《微服务实战》 第四章 Spring Cloud Netflix 之 Eureka

前言 Eureka 是 Netflix 公司开发的一款开源的服务注册与发现组件。 Spring Cloud 使用 Spring Boot 思想为 Eureka 增加了自动化配置&#xff0c;开发人员只需要引入相关依赖和注解&#xff0c;就能将 Spring Boot 构建的微服务轻松地与 Eureka 进行整合。 1、Eureka 两大组…

三维后处理与重建PACS源码,大容量图像存储 报告单多种模式及自定义样式

医学影像系统源码 三维后处理与重建PACS源码 医学影像系统由PACS系统、RIS系统组成&#xff0c;同时提供与HIS的接口&#xff08;HL7或其他类型&#xff09;。 主要功能介绍 信息预约登记 支持对患者、检查项目、申请医生、申请单据、设备等信息进行管理。且支持检查病人排…

SwiftUI中DatePicker学习

在界面开发中&#xff0c;经常要使用到DatePicker控件&#xff0c;如年月日&#xff0c;时分选择。 但我们还要修改它的显示方式&#xff0c;文字语言&#xff0c;及其他功能 先看下最简单的调用方法就是&#xff1a; State private var date Date()var body: some View {D…

STM32F10X--中断

中断 中断优先级数值越小&#xff0c;中断优先级越高 中断编程的顺序 1、使能的是外设&#xff0c;配置外设寄存器 2、配置中断优先级分组&#xff0c;在msic.h里面有中断优先级组分组函数 这个函数里面配置了SCB->AIRCR寄存器 3、配置NVIC寄存器&#xff0c;初始化NVIC…

可视化和回归分析星巴克咖啡在中国的定价建议

可视化和回归分析星巴克咖啡在中国的定价建议。星巴克的拿铁大杯Tall 在各国的价格。 Claude AI | 代码自动生成的数据可视化代码 选择Claude AI 而非 ChatGPT的理由是前者更懂中文​&#xff01;具体可以参见我前面的两篇文章对比两者的中英文翻译的表现及使用安装等难易程度​…

2.SpringBoot运维实用篇

SpringBoot运维实用篇 ​ 基础篇发布以后&#xff0c;看到了很多小伙伴在网上的留言&#xff0c;也帮助超过100位小伙伴解决了一些遇到的问题&#xff0c;并且已经发现了部分问题具有典型性&#xff0c;预计将有些问题在后面篇章的合适位置添加到本套课程中&#xff0c;作为解…

做公有云服务,为什么对象存储不可或缺?

试问&#xff1a;公有云的竞争&#xff0c; 你觉得从什么时候开始白热化了&#xff1f; 【全球云观察 &#xff5c; 热点关注】对于这个问题&#xff0c;可谓仁者见仁智者见智。 在我看来&#xff0c;火山引擎还未推出全面的云服务之前&#xff0c;在国内的公有云领域&#x…

KingbaseES V8R6运维案例之---MySQL和KingbaseES字符串排序规则对比

案例说明&#xff1a; 相同数据排序后查询&#xff0c;在MySQL和KingbaseES下得到的排序顺序不一致&#xff0c;本案例从MySQL和KingbaseES的排序规则分析&#xff0c;两种数据库排序的异同点。适用版本&#xff1a; KingbaseES V8R6、MySQL 8.0 一、MySQL的排序规则1、排序规则…

各种预训练模型的理论和调用方式大全

诸神缄默不语-个人CSDN博文目录 本文主要以模型被提出的时间为顺序&#xff0c;系统性介绍各种预训练模型的理论&#xff08;尤其是相比之前工作的创新点&#xff09;、调用方法和表现效果。 最近更新时间&#xff1a;2023.5.10 最早更新时间&#xff1a;2023.5.10 BertRobe…

上海亚商投顾:沪指缩量调整跌超1% 新能源车产业链掀涨停潮

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 市场情绪 大小指数今日再度分化&#xff0c;沪指低开低走&#xff0c;午后一度跌超1.5%&#xff0c;创业板指则拉升涨超1%&a…

nodejs安装和环境配置-Windows

0.安装过程中遇到的常见问题 访问&#xff1a;https://blog.csdn.net/weixin_52799373/article/details/125718587?spm1001.2014.3001.5502 1.下载node.js 下载地址: https://nodejs.org/en/ 2.安装 2.1 安装 其实就是无脑下一步&#xff0c;第三步的时候可以选择自定义目…

springboot配置文件加载顺序, java启动参数优先级

搜索: "spring boot 外化配置" Spring Boot Reference Guide Spring Boot 中文文档 参考手册 中文版 SpringBoot中配置文件加载位置与优先级_apllication 配置文件项目启动时加载参数_流烟默的博客-CSDN博客 SpringBoot的外部化配置最全解析!_广州建站小戴BOTAO博…

k8s之HPA(Pod水平自动伸缩)

1.hpa介绍 HPA是根据指标来进行自动伸缩的&#xff0c;目前HPA有两个版本–v1和v2beta HPA的API有三个版本&#xff0c;通过kubectl api-versions | grep autoscal可看到 kubectl api-versions | grep autosca autoscaling/v1 autoscaling/v2beta1 autoscaling/v2beta2 查看使…

Uboot源码目录分析

在分析uboot源码之前一定要在Ubuntu中编译一下uboot源码&#xff0c;因为编译过程会生成一些文件&#xff0c;而生成的这些恰恰是分析uboot源码不可或缺的文件。 arch文件夹 存放和架构有关的文件&#xff0c;我们现在用的是ARM芯片&#xff0c;所以只需要关系arm文件夹即可 …

2.是人就能学会的Spring源码教学-Spring的简单使用

是人就能学会的Spring源码教学-Spring的简单使用 Spring的最简单入门使用第一步 创建项目第二步 配置项目第三步 启动项目 Spring的最简单入门使用 各位道友且跟我一道来学习Spring的最简单的入门使用&#xff0c;为了方便和简单&#xff0c;我使用了Spring Boot项目&#xff…

linux CentOs 安装 mysql8.0.30

心酸历程。。。 网上的各种教程都有各种bug&#xff0c;安了三个小时终于安好。现在奉上我的宝典秘籍。 第一步&#xff0c;去mysql官网下载&#xff0c;然后将下载的tar包放到linux里面&#xff0c;最好专门创建一个目录来存放&#xff0c;我放到了/usr/local/src的mysql目录下…

基于51单片机的简易电子秤

首先看看题目要求&#xff1a; 1.方案论证 &#xff08;1&#xff09;压力传感器的论证与选择 方案一&#xff1a;采用惠更斯电桥&#xff0c;当电阻应变片承受载荷产生变形时&#xff0c;其阻值将发生变化。从而使电桥失去平衡&#xff0c;产生相应的差动信号&#xff0c;但…

Jenkins入门系列之Gitlab账号登录

目录 背景步骤1. 安装插件2. Gitlab 配置3. Jenkins 配置4. 验证 背景 版本 Jenkins Version&#xff1a;Jenkins 2.403Gitlab Version: Gitlab 15.6部署环境&#xff1a;群晖NAS Docker 部署JenkinsGitlab Jenkins 默认使用的是自带的数据库&#xff0c;支持LDAP&#xff0…

【C++】-类和对象之初始化列表(explicit的简单介绍)(下)

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树 ❤️‍&#x1fa79;作者宣言&#xff1a;认真写好每一篇博客 &#x1f4a8;作者gitee:gitee &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 文章目录 前言 前言 经过前面的好几篇博客&#xff0c;大家应…

SpringSecurity自定义实现手机短信登录

SpringSecurity自定义登录验证-手机验证码登录 其实实现原理上跟账号密码登录一样的 1、自定义短信验证Token 定义一个仅使用手机号验证权限的鉴权Token&#xff0c;SpringSecurity原生的UsernamePasswordAuthenticationToken是使用username和password&#xff0c;如下图 pr…