网格上低分辨率的分割结果到高分辨率的投影与可视化
- 引言
- 一、到高分辨率的投影
- 1.1 准确率
- 1.2 主要代码
- 1.3 投影核心代码
- 二、可视化代码
引言
三角网格的结构特性决定了其仅用少量三角形即可表示一个完整的3D模型。增加其分辨率可以展示更多模型的形状细节。对于网格分割
来说,并不需要很多模型细节,只需要知晓其数据元素所属部分(类别)即可。
- 上图分别为低分辨率分割结果、高分辨率投影结果以及Ground truth
在简化网格上进行预测,然后投影到高分辨率网格上一个可行的方案。例如:
MeshWalker1使用的的边界平滑
A Spectral Segmentation Method for Large Meshes2的feature-aware的网格简化
一、到高分辨率的投影
1.1 准确率
以面标签版本的COSEG外星人数据集为例,可参考三角网格(Triangular Mesh)分割数据集
简化网格上的准确率:96.94
到高分辨率网格投影:95.53
时间上也会快很多,毕竟计算高分辨率网格的输入特征较为费时
1.2 主要代码
部分代码来自3:MeshCNN
TriTransNet
是对简化三角网格进行分割的网络,可替换为其它神经网络
import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints
def is_mesh_file(filename):
return any(filename.endswith(extension) for extension in ['.obj', 'off'])
def fix_vertices(vs):
z = vs[:, 2].copy()
vs[:, 2] = vs[:, 1]
vs[:, 1] = z
max_range = 0
for i in range(3):
min_value = np.min(vs[:, i])
max_value = np.max(vs[:, i])
max_range = max(max_range, max_value - min_value)
vs[:, i] -= min_value
scale_by = max_range
vs /= scale_by
return vs
def get_seg_files(paths, seg_dir, seg_ext='.eseg'):
segs = []
for path in paths:
segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
assert (os.path.isfile(segfile))
segs.append(segfile)
return segs
def make_dataset(path):
meshes = []
assert os.path.isdir(path), '%s is not a valid directory' % path
for root, _, fnames in sorted(os.walk(path)):
for fname in fnames:
if is_mesh_file(fname):
path = os.path.join(root, fname)
meshes.append(path)
return meshes
if __name__ == '__main__':
# 简化网格
sim_root = '../../../datasets/face_label/coseg_aliens'
sim_paths = make_dataset(os.path.join(sim_root, 'test'))
# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))
# 原始网格
org_root = '../../../datasets/aliens' # '../../datasets/vases'
org_paths = make_dataset(os.path.join(org_root, 'test')) # shapes or seg
org_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')
# 网络读取
cfg = Config()
cfg.class_n = 4
cfg.mode = 'seg'
net = TriTransNet(cfg)
state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth') # latest_xyz_net 95.53432
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
net.load_state_dict(state_dict)
net.eval()
# 准确率统计
all_acc = 0
sim_acc = 0
are_acc = 0
for i in range(len(sim_paths)):
# 获取网格数据
sim_name = sim_paths[i]
filename, _ = os.path.splitext(sim_name)
prefix = os.path.basename(filename)
cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')
with open(cache, 'rb') as f: # 不再计算 读取缓存
meta = pickle.load(f)
# 获取网格数据
sim_mesh = meta['mesh']
sim_label = meta['label']
vs = fix_vertices(sim_mesh.vs)
# 获取预测标签
with torch.no_grad():
face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0) # sim_mesh.hks[0:3]
face_features = torch.from_numpy(face_features).float().unsqueeze(0)
out = net(face_features, [sim_mesh])
label = out.data.max(1)[1]
sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_num
sim_acc += sim_correct
# 面积
# idex = label.eq(torch.from_numpy(sim_label).long()).numpy().reshape(-1)
# face_area = sim_mesh.face_features[6, :]
# sum_area = face_area.sum()
# are_acc += face_area[idex].sum() / sum_area
# 时间
t = time.time()
# 投影准备
label = label.numpy().reshape(-1)
BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)
# 0.01 10 95.53432
# 0.5 1 退化成最简单的最近邻 94.02
kdt = cKDTree(BorderPoints_xyz)
# 读取高分辨率网格
org_vs, org_faces = pp3d.read_mesh(org_paths[i])
org_vs = fix_vertices(org_vs)
org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1
# 原始网格中心点
mean_vs = org_vs[org_faces]
mean_vs = mean_vs.sum(axis=1) / 3.0
dist, indices = kdt.query(mean_vs, workers=-1)
# 准确率计算
org_prolabels = BorderPoints_label[indices].reshape(-1)
pro_cnt = np.equal(org_prolabels, org_label).sum()
pro_acc = pro_cnt / len(org_label)
all_acc += pro_acc
print(filename, ':', pro_acc, ' time:', time.time()-t)
print(all_acc / len(sim_paths))
print(sim_acc / len(sim_paths))
# print(are_acc / len(sim_paths))
1.3 投影核心代码
def get_faces_BorderPoints(vs, faces, labels, border_k=0.1, border_num=1):
"""
border_k: 远离边的系数
border_num: 每条边的边缘点数
首先 默认简化是不会过分破坏分割边界 简化后的网格和原网格基本对齐
1.简化后的面更大 以一个面为例 均匀采样其边界部分形成边缘点 边缘点的标签赋值为面的标签
2.赋值原网格面标签为 距离其重心最近的简化网格边缘点标签
"""
BorderPoints_xyz = -np.ones((len(faces) * 3 * border_num, 3), np.float64)
BorderPoints_label = -np.ones((len(faces) * 3 * border_num, 1), np.int32)
cnt = 0
for face_id in range(len(faces)):
face = faces[face_id]
label = labels[face_id]
for i in range(3):
if border_num > 1:
p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]
for j in range(border_num):
center_p = p1 + (p2 - p1) / (border_num + 1) * (j + 1)
border_p = center_p + (p - center_p) * border_k
BorderPoints_xyz[cnt] = border_p
BorderPoints_label[cnt] = label
cnt = cnt + 1
else:
p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]
center_p = (p1 + p2) / 2
border_p = center_p + (p - center_p) * border_k
BorderPoints_xyz[cnt] = border_p
BorderPoints_label[cnt] = label
cnt = cnt + 1
return BorderPoints_xyz, BorderPoints_label
二、可视化代码
减小可视化网格边的边长,查看模型细节:
import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import pylab as pl
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints
import mpl_toolkits.mplot3d as a3
import matplotlib.colors as colors
from scipy import linalg
def rot_vs_axis_z(vs, radian, scale):
bias = np.mean(vs)
vs = vs - bias
vs *= scale
rot_matrix = linalg.expm(np.cross(np.eye(3), [0, 0, 1] / linalg.norm([0, 0, 1]) * radian))
vs = np.dot(rot_matrix, vs.T)
vs = vs.T + bias
return vs
def init_ax(ax):
# hide axis, thank to
# https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# Get rid of the spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
# Get rid of the ticks
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
return ax
def is_mesh_file(filename):
return any(filename.endswith(extension) for extension in ['.obj', 'off'])
def fix_vertices(vs):
z = vs[:, 2].copy()
vs[:, 2] = vs[:, 1]
vs[:, 1] = z
max_range = 0
for i in range(3):
min_value = np.min(vs[:, i])
max_value = np.max(vs[:, i])
max_range = max(max_range, max_value - min_value)
vs[:, i] -= min_value
scale_by = max_range
vs /= scale_by
return vs
def get_seg_files(paths, seg_dir, seg_ext='.eseg'):
segs = []
for path in paths:
segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
assert (os.path.isfile(segfile))
segs.append(segfile)
return segs
def make_dataset(path):
meshes = []
assert os.path.isdir(path), '%s is not a valid directory' % path
for root, _, fnames in sorted(os.walk(path)):
for fname in fnames:
if is_mesh_file(fname):
path = os.path.join(root, fname)
meshes.append(path)
return meshes
if __name__ == '__main__':
# 简化网格
sim_root = '../../../datasets/face_label/coseg_aliens'
sim_paths = make_dataset(os.path.join(sim_root, 'test'))
# sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))
# 原始网格
org_root = '../../../datasets/aliens' # '../../datasets/vases'
org_paths = make_dataset(os.path.join(org_root, 'test')) # shapes or seg
org_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')
# 网络读取
cfg = Config()
cfg.class_n = 4
cfg.mode = 'seg'
net = TriTransNet(cfg)
state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth') # latest_xyz_net 95.53432
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
net.load_state_dict(state_dict)
net.eval()
# 准确率统计
all_acc = 0
sim_acc = 0
are_acc = 0
for i in range(len(sim_paths)):
# 获取网格数据
sim_name = sim_paths[i]
filename, _ = os.path.splitext(sim_name)
prefix = os.path.basename(filename)
# 选择某一个网格可视化
#if prefix != '132':
# continue
if i != 3:
continue
cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')
with open(cache, 'rb') as f: # 不再计算 读取缓存
meta = pickle.load(f)
# 获取网格数据
sim_mesh = meta['mesh']
sim_label = meta['label']
vs = fix_vertices(sim_mesh.vs)
# 获取预测标签
with torch.no_grad():
face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0) # sim_mesh.hks[0:3]
face_features = torch.from_numpy(face_features).float().unsqueeze(0)
out = net(face_features, [sim_mesh])
label = out.data.max(1)[1]
sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_num
sim_acc += sim_correct
# 时间
t = time.time()
# 投影准备
label = label.numpy().reshape(-1)
BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)
# 0.01 10 95.53432
# 0.5 1 退化成最简单的最近邻 94.02
kdt = cKDTree(BorderPoints_xyz)
# 读取高分辨率网格
org_vs, org_faces = pp3d.read_mesh(org_paths[i])
org_vs = fix_vertices(org_vs)
org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1
# 原始网格中心点
mean_vs = org_vs[org_faces]
mean_vs = mean_vs.sum(axis=1) / 3.0
dist, indices = kdt.query(mean_vs, workers=-1)
# 准确率计算
org_prolabels = BorderPoints_label[indices].reshape(-1)
pro_cnt = np.equal(org_prolabels, org_label).sum()
pro_acc = pro_cnt / len(org_label)
all_acc += pro_acc
print(filename, ':', pro_acc, ' time:', time.time()-t)
# 可视化
f = pl.figure()
ax = f.add_subplot(1, 1, 1, projection='3d')
ax = init_ax(ax)
r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x)))
f_colors = [r2h((0, 0, 255)), r2h((0, 255, 255)), r2h((255, 0, 255)), r2h((0, 255, 0))]
vis_bias = 0.3 #
# 简化网格
faces_color = []
for l in label:
faces_color.append(f_colors[l - 1])
vs = rot_vs_axis_z(vs, 0.95, 1)
tri = a3.art3d.Poly3DCollection(vs[sim_mesh.faces],
facecolors=faces_color,
edgecolors=r2h((0, 0, 0)),
linewidths=0.1, # 0.1
# linestyles='dashdot',
alpha=1)
ax.add_collection3d(tri)
# 高分辨率网格
org_vs = rot_vs_axis_z(org_vs, 0.95, 1)
faces_color = []
for l in org_prolabels.astype(int):
faces_color.append(f_colors[l - 1])
org_vs[:, 0] += vs[:, 0].max()/2 + vis_bias
tri1 = a3.art3d.Poly3DCollection(org_vs[org_faces],
facecolors=faces_color,
edgecolors=r2h((0, 0, 0)),
linewidths=0.1,
# linestyles='dashdot',
alpha=1)
ax.add_collection3d(tri1)
max_x = org_vs[:, 0].max()
# 高分辨率网格Ground truth
faces_color = []
for l in org_label.astype(int):
faces_color.append(f_colors[l - 1])
org_vs[:, 0] += vs[:, 0].max() / 2 + vis_bias
tri2 = a3.art3d.Poly3DCollection(org_vs[org_faces],
facecolors=faces_color,
edgecolors=r2h((0, 0, 0)),
linewidths=0.1,
# linestyles='dashdot',
alpha=1)
ax.add_collection3d(tri2)
max_x = org_vs[:, 0].max()
ax.auto_scale_xyz([0, max_x], [0, max_x], [0, max_x])
ax.view_init(0, -90)
pl.tight_layout()
pl.savefig('corr.png', dpi=1000) # i
pl.show()
break
print(all_acc / len(sim_paths))
print(sim_acc / len(sim_paths))
MeshWalker: Deep Mesh Understanding by Random Walks ↩︎
A Spectral Segmentation Method for Large Meshes ↩︎
MeshCNN ↩︎