参考的依然是Pytorch的实现,PointNet++里面的主要实现部分都在utils.py里,里面从微小模块逐渐的,搭建出网络中的几个主要模块结构,包括sampling&group等,所以我们主要分析的就是这个utils.py里面的内容
这份Pytorch实现的PointNet++的代码写得非常清楚,非常值得一行行品读,而且作者写代码的结构非常清楚,何不将作者的想法和知识提炼化为自己的东西呢?嘻嘻
(一)pointnet2_utils.py文件:
1.参考内容:PointNet++论文解读和代码解析-CSDN博客
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
#打印时间
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
#对点云数据进行归一化处理,以centor为中心,球半径为1
def pc_normalize(pc):
#pc维度[n,3]
l = pc.shape[0]
#求中心,对pc数组的每一列求平均值,得到[x_mean,y_mean,z_mean]
centroid = np.mean(pc, axis=0)
#求这个点集里面的点到中心点的相对坐标
pc = pc - centroid
#将同一行的元素求平方再相加,再开方求最大。x^2+y^2+z^2,得到最大标准差
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
#进行归一化,这里使用的是Z-score标准化方法
pc = pc / m
return pc
#主要用来在ball query过程中确定每一个点距离采样点的距离,返回的是两组点之间的欧氏距离,N*M矩阵
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
#torch.matmul也是一种矩阵相乘操作,但是它具有广播机制,可以进行维度不同的张量相乘
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) #[B,N,M]
dist += torch.sum(src ** 2, -1).view(B, N, 1) #[B,N,M]+[B,N,1]dist每一列都加上后面的列值
dist += torch.sum(dst ** 2, -1).view(B, 1, M) #[B,N,M]+[B,1,N]dist每一行都加上后面的行值
return dist
#按照输入的点云数据和索引返回索引的点云数据
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape) #view_shape=[B,S]
view_shape[1:] = [1] * (len(view_shape) - 1) #去掉第零个数,其余变为1,[B,1]
repeat_shape = list(idx.shape)
repeat_shape[0] = 1 #[1,S]
#arrange生成[0,...,B-1],view后变为列向量[B,1],repeat后[B,S]
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
#下面这个感觉理解不了,后面自己敲一下验证一波
new_points = points[batch_indices, idx, :]#从points中取出每个batch_indices对应索引的数据点
return new_points
#最远点采样算法,返回的是npoint个采样点在原始点云中的索引
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
#初始化一个中心点矩阵,用于存储采样点的索引位置
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
#distance矩阵用于记录某个batch中所有点到某个采样点的距离,初始值很大,后面会迭代
distance = torch.ones(B, N).to(device) * 1e10
#farthest表示当前最远的点,也是随机初始化,范围0-N,初始化B个
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
#初始化0-B-1的数组
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest#先把第一个随机采样点下标放入
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)#取出初始化的B个点的坐标
dist = torch.sum((xyz - centroid) ** 2, -1) #求每个batch里面每个点到中心点的距离 [B,N]
#建立一个mask,如果dist中记录的距离小于distance里的,则更新distance的值,这样distance里保留的就是每个点距离所有已采样的点的最小距离
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1] #得到最大距离的下标作为下一次的选择点
return centroids
#用于寻找球形领域中的点,S为FPS得到的中心点个数
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz) #计算中心点坐标与全部点坐标的距离 [B,S,N]
group_idx[sqrdists > radius ** 2] = N #找到所有大于半径的,其group_idx直接置N,其余不变
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]#将所有点到中心点的距离从小到大排序,取前nsample个
#有可能前nsample里有距离大于半径的,我们要去除掉,当半径内的点不够nsample时,我们对距离最小的点进行重复采样
#group_idx[:, :, 0]获得距离最小的点,他的shape是[B,S],所以view一下,再repeat
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
#看哪些点是球体外的,得到一个mask,用mask进行赋值,把最近的点赋值给刚采样在球体外的点
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
#采样与分组,xyz与points的区别,一个特征只有xyz,一个是其他特征
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
#S个中心点
S = npoint
#从原点云通过FPS采样得到采样点的索引,
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
new_xyz = index_points(xyz, fps_idx) #[B,npoint,C]
idx = query_ball_point(radius, nsample, xyz, new_xyz) #每个中心点采样nsample个点的下标[B,npoint,nsample]
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
#每个点减去质心的坐标
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
#直接将所有点作为一个group
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device) #原点为采样点
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
#该类实现普通的SetAbstraction,然后通过sample_and_group的操作形成局部的group,然后对局部group的每一个点进行MLP操作,最后进行最大池化,得到局部的全局特征
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
#nn.ModuleList是一个存储器,自动将每个module的参数添加到网络之中,可以把任意nn.module的子类(nn.Conv2d,nn.Linear)加到里面
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
#下面是pointnet操作,对局部进行MLP操作,利用1*12d卷积相当于把C+D当作特征通道
#对[nsample,npoint]的维度上进行逐像素卷积
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
#对每一个group做maxpooling得到局部的全局特征,[B,3+D,npoint]
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
#MSG方法的set abstraction,radius_list是一个列表
class PointNetSetAbstractionMsg(nn.Module):
#例如128,[0.2,0.4,0.8],[32,64,128],320,[[64,64,128],[128,128,256],[128,128,256]]
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
#找到S个中心点
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
#对不同的半径做ball query,将不同半径下的点云特征保存在new_points_list中,最后再拼接到一起
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
#按照球形分组
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
#进行归一化处理
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
#进行维度交换,准备卷积,D维特征,每组K个点
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
#卷积完在组内的点进行最大池化
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)#在特征维度进行合并
return new_xyz, new_points_concat
#特征上采样模块,当点的个数只有一个时,采用repeat直接复制成N个点,当点数大于1个时,采用线性插值的方法进行上采样,拼接上下采样对应点的SA的特征,再对拼接后的每个点做一次MLP
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1) #[B,N,C]
xyz2 = xyz2.permute(0, 2, 1) #[B,S,C]
points2 = points2.permute(0, 2, 1) #[B,S,D]
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
#如果该层只有一个点,那么上采样直接复制成N个点即可
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2) #计算上一层与该层点之间的距离[B,N,S]
dists, idx = dists.sort(dim=-1)#默认升序排列,取距离N个点最小的三个S里面的点
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)#求距离的倒数,距离越远,权重越小
norm = torch.sum(dist_recip, dim=2, keepdim=True) #对离的最近的三个点权重相加
weight = dist_recip / norm #weight是指计算权重,他们三个权重和为1
#index_points之后维度是[B,N,3,C],在第二维度求和,等于三个点特征加权之后的和。[B,N,C]
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
这个PointNet++论文中的主要模块实现都在这里啦,这个博主的中文注释做得非常好只有很少很少的解释不对的地方,不怎么影响。
2.需要借用的接口:
(1)从points集中选取对应index位置的points的数据:
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
(2) 计算任意两组points集合之间任意两点之间的距离的NxM矩阵
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
(3)FPS最远点采样算法的实现:
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
(4)在每个选择的centroid中心点的周围利用球状查询的方式实现group分组:
在S个sample点的周围每个点的周围都分成nsample个一组,利用的分组方式是球状查询
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
(5)通过调用上面实现的FPS、index_points、query_ball_poins实现sample&&group
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
总结:
其实,每个utils中的模块的实现中就是对应的构建一个模型需要用到的算法结构。
对我来说,首先,第一遍的收获,一定是认真的尽可能的利用各种方式来理解每一行代码的含义,最好是按照费曼学习法,利用自己的理解将这些东西内化为自己的东西。
然后,我也不是非要把每一个实现都牢牢记住。在第二遍中,我会归纳各个算法的实现结构,关键是能够自己在后面的实现中能够用上,所以重要的是记住各个实现的输入和输出,做到可以在自己的代码中数量调用各个接口,并且理解其中经过的处理。
(二)ModelNetDataLoader.py文件
注意,学会使用其中的main里面的测试方法,并且能够熟练使用args.parse的调用方法
另外,这里构建 ModelNet40的DataLoader的有机会也要自己尝试构建一下,这次就过一次
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
# 引入必要的库
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
# 将所有点的坐标归一化
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
# FPS算法的实现
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, args, split='train', process_data=False):
self.root = root
self.npoints = args.num_point
self.process_data = process_data
self.uniform = args.use_uniform_sample
self.use_normals = args.use_normals
self.num_category = args.num_category
if self.num_category == 10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
shape_ids = {}
if self.num_category == 10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
if self.uniform:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
else:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
if self.process_data:
if not os.path.exists(self.save_path):
print('Processing data %s (only running in the first time)...' % self.save_path)
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)
for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
self.list_of_points[index] = point_set
self.list_of_labels[index] = cls
with open(self.save_path, 'wb') as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print('Load processed data from %s...' % self.save_path)
with open(self.save_path, 'rb') as f:
self.list_of_points, self.list_of_labels = pickle.load(f)
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.use_normals:
point_set = point_set[:, 0:3]
return point_set, label[0]
def __getitem__(self, index):
return self._get_item(index)
import argparse
def parse_args():
parser = argparse.ArgumentParser('myparser')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--use_normals',action = 'store_true',default= False)
parser.add_argument('--num_category',type = int, default = 40)
return parser.parse_args()
if __name__ == '__main__':
import torch
args = parse_args()
data = ModelNetDataLoader('../data/modelnet40_normal_resampled/', args = args, split='train')
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
for point, label in DataLoader:
print(point.shape) # 12x1024x3 这里的12是batch_size,这里的1024是一团点云有1024个点,3是每个点的坐标(x,y,z)
print(label.shape) # 12 x 1 或者 12 是batch_size,做classification分类任务 每团点云只要一个类别label即可
(三)train_classification.py文件
1.关于 inplace relu的含义和实现
将所有包含ReLU字符串的模块中ReLU的inplace属性全部打开
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
2.我自己对整个training的代码进行了详细的注释如下:
"""
Author: Benny
Date: Nov 2019
"""
# 引入系统库
import os
import sys
import torch
import numpy as np
import datetime
import logging
import provider
import importlib
import shutil
import argparse
from pathlib import Path
from tqdm import tqdm
from data_utils.ModelNetDataLoader import ModelNetDataLoader
# 解决path的问题,包括定义全局的DIR变量os.path.abspath(_file_)应该是获取当前文件的当前绝对路径,然后利用sys.path.append保证查询自己定义的库是可行的
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
# 熟练掌握parse_args函数的使用
import argparse
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('training')
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
return parser.parse_args()
# 打开ReLU的inplace属性
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
# 其实这里应该就是相当于是validation部分的内容了
# 算了,为了把读代码的能力补回来,这里把下面的代码逻辑一行行分析清楚 finished!
def test(model, loader, num_class=40):
mean_correct = []
class_acc = np.zeros((num_class, 3)) # 为什么这里的class_acc需要设置num_class x 3的形状?答:采用top 3分类
# 而且这个class_acc是每个类别的所以instance的分类准确率哦!
classifier = model.eval()
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
if not args.use_cpu:
points, target = points.cuda(), target.cuda()
points = points.transpose(2, 1)
pred, _ = classifier(points) # 预测结果
pred_choice = pred.data.max(1)[1] # 取最大的那个的下标即可
for cat in np.unique(target.cpu()):
#下面这个代码:拆成两部分-"pred_choice[target == cat]" 和 “target[target == cat].long().data”
#然后,就是利用eq就算两者相等的个数,也就是属于这个category且预测正确的个数
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc
# main里面是整个training的代码:
def main(args):
# 目的是将传入的字符串 str 记录到日志中,并且打印到控制台
def log_string(str):
logger.info(str)
print(str)
# args.gpu这个bool值传递给字典
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# 创建dir地址变量并且设置好对应的全局参数,这里作者处理得非常清楚
'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) #获取具体的日期和时间str类型变量
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('classification')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
exp_dir = exp_dir.joinpath(timestr)
else:
exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
# 设置好对应的log日志记录的处理,现在我也是越来越喜欢日志这东西啦!--这里作者的写法非常规范
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
# 创建好对应的DataLoader用于后续的处理
'''DATA LOADING'''
log_string('Load dataset ...')
data_path = 'data/modelnet40_normal_resampled/'
train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)
# 创建好模型实例以及对应的优化器等
'''MODEL LOADING'''
num_class = args.num_category
model = importlib.import_module(args.model) #这里的model其实是pointnet2_cls_ssg.py这个文件对象
shutil.copy('./models/%s.py' % args.model, str(exp_dir)) #这三行就是正常的将.py文件复制到log目录下的操作
shutil.copy('models/pointnet2_utils.py', str(exp_dir))
shutil.copy('./train_classification.py', str(exp_dir))
# 从model文件对象中导入对应的model和loss_func实例
classifier = model.get_model(num_class, normal_channel=args.use_normals)
criterion = model.get_loss()
classifier.apply(inplace_relu)
if not args.use_cpu:
classifier = classifier.cuda()
criterion = criterion.cuda()
try:
checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
start_epoch = checkpoint['epoch']
classifier.load_state_dict(checkpoint['model_state_dict'])
log_string('Use pretrain model')
except:
log_string('No existing model, starting training from scratch...')
start_epoch = 0
if args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
classifier.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
else:
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
# 全局training需要更新的参数
global_epoch = 0
global_step = 0
best_instance_acc = 0.0
best_class_acc = 0.0
# 上面把所有准备工作做完,下面开始按照每个epoch来training...
'''TRAINING'''
logger.info('Start training...')
for epoch in range(start_epoch, args.epoch):
# 每个epoch中的操作:
log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
mean_correct = [] # 用于存储本个epoch的平均classification的正确率
classifier = classifier.train() # 开始model的train训练模式
scheduler.step() # 清空整个scheduler
# 按照batch分批处理
for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
optimizer.zero_grad() # 每个step训练batch的开头都要利用zero_grad清空之前在optim中的提取累积
# 这部分的代码主要用来对points进行预处理的,得到数据增强后的points集
points = points.data.numpy()
points = provider.random_point_dropout(points)
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points = points.transpose(2, 1)
if not args.use_cpu:
points, target = points.cuda(), target.cuda()
# 输入points数据,调用model输出当前模型的输出的预测结果,根据结果调用优化器进行优化
pred, trans_feat = classifier(points)
loss = criterion(pred, target.long(), trans_feat) # 这个trans_feat没用
pred_choice = pred.data.max(1)[1]
# 这里相比于前面的validataion部分的代码来说,就是多了计算Loss传播的部分
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0])) # 所以mean_correct[]list中的每一项都是一个batch计算出来的平均分类正确率
loss.backward()
optimizer.step()
global_step += 1
# 每个epoch:直接把所有的instance的正确率取平均值就是最终的所有instance的acc了
train_instance_acc = np.mean(mean_correct)
log_string('Train Instance Accuracy: %f' % train_instance_acc)
# 下面也是处理每个epoch后的结果:
with torch.no_grad():
# 调用alidation进行val得到instance_acc和class_acc
instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
# 记录best
if (instance_acc >= best_instance_acc):
best_instance_acc = instance_acc
best_epoch = epoch + 1
if (class_acc >= best_class_acc):
best_class_acc = class_acc
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
if (instance_acc >= best_instance_acc):
logger.info('Save model...')
savepath = str(checkpoints_dir) + '/best_model.pth'
log_string('Saving at %s' % savepath)
state = {
'epoch': best_epoch,
'instance_acc': instance_acc,
'class_acc': class_acc,
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
global_epoch += 1
logger.info('End of training...')
if __name__ == '__main__':
args = parse_args()
main(args)
(四)test_classification.py文件
(在看之前,我的预期是作者应该是把training.py中的test()函数搬过来了)
确实每区别,特别是投票机制不生效的时候
"""
Author: Benny
Date: Nov 2019
"""
# 引入必要的库:包括系统库和自己的DataLoader库
from data_utils.ModelNetDataLoader import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import logging
from tqdm import tqdm
import sys
import importlib
# 设置DIR变量,用于设置环境中的models所在的文件位置
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
# 一样的解析hyperparameter
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('Testing')
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--log_dir', type=str, required=True, help='Experiment root')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
return parser.parse_args()
def test(model, loader, num_class=40, vote_num=1):
mean_correct = []
classifier = model.eval()
class_acc = np.zeros((num_class, 3))
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
if not args.use_cpu:
points, target = points.cuda(), target.cuda()
points = points.transpose(2, 1)
# ------除了这个vote部分外,其他都和train里面的val的代码一样:
# 不过这里的vote投票机制如果设置为1就没什么区别了
vote_pool = torch.zeros(target.size()[0], num_class).cuda()
# 会对同一个points循环重复vote_num次,然后取平均
for _ in range(vote_num):
pred, _ = classifier(points)
vote_pool += pred
pred = vote_pool / vote_num
pred_choice = pred.data.max(1)[1]
for cat in np.unique(target.cpu()):
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
'''CREATE DIR'''
experiment_dir = 'log/classification/' + args.log_dir
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
'''DATA LOADING'''
log_string('Load dataset ...')
data_path = 'data/modelnet40_normal_resampled/'
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)
'''MODEL LOADING'''
num_class = args.num_category
model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
model = importlib.import_module(model_name)
classifier = model.get_model(num_class, normal_channel=args.use_normals)
if not args.use_cpu:
classifier = classifier.cuda()
checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
classifier.load_state_dict(checkpoint['model_state_dict'])
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes, num_class=num_class)
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
if __name__ == '__main__':
args = parse_args()
main(args)
(五)provider.py文件
这个文件仅仅使用了numpy库,其中构建的函数有用到points的数据增强的地方。
import numpy as np
# 将整个batch的数据进行正则化 -- 里面的坐标的中心在origin
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
# shuffle的是batch
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
# shuffle的是一个batch中的点云
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
# 类似于二维平面中rotato其实就是一个矩阵乘,所以三维中也是
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
# 下面的shift和下面的random_scale 还有random_dropout也是用到了training中的数据增强中的
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
# 每个Batch都会抽取一个scale将点云中所有的点坐标都乘以这个
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
# 随即删除一定数量的点——虽然下面作者的是实现其实不是很直观——其实有其他的实现,不过作者为了保持点的数量为1024,
# 就把所有需要删除的点的坐标都设置为第1个点的坐标重合了
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc