1 项目结构
开源代码:https://github.com/yenchenlin/nerf-pytorch
在上述框架图中,首先重config_parse 中读取文件参数,
然后通过load_blender加载数据,加载的数据包括训练集、验证集和测试集以及摄像机的内外参数;
在creat_nerf中通过get_embeder 获取 视线方向和三维点的位置编码,并初始化NeRF模型的MLP层 ,
通过get_rays_np 获取视线起点rays_o 和方向rays_d.
在渲染时,若使用LLFF数据集需要调用ndc_rays, 将空间变换到NDC空间中;
在batchify_rays 中,可以通过chunk的大小来控制加载的数据量大小,
在run_network 函数中,将通过get_rays_np 获取视线起点rays_o 和方向rays_d的位置编码输入到MLP网络中,以预测颜色和体素。
在raw2outputs中,通过计算得到的体素和颜色,利用体渲染得到图像,然后计算损失函数。
2 视线方向代码解读
def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([ (i-K[0][2]) / K[0][0],
-(j-K[1][2]) / K[1][1],
-np.ones_like(i) ], -1)
# rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
return rays_o, rays_d
3 基本渲染流程
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
# provide ray directions as input
if use_viewdirs:
viewdirs = rays_d
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # shape: … × 3
# for forward facing scenes
if ndc:
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), \
far * torch.ones_like(rays_d[...,:1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# render and reshape
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret
if k not in k_extract}
return ret_list + [ret_dict]
3.1 bathcify_rays 代码解读
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
""" render rays in smaller minibatches to avoid OOM
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
3.2 render_rays 代码解读
def render_rays(ray_batch,
network_fn,
network_query_fn,
N_samples,
retraw=False,
lindisp=False,
perturb=0., # 1.0, overridden by input
N_importance=0,
network_fine=None,
white_bkgd=False,
raw_noise_std=0.,
verbose=False,
pytest=False):
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], \
ray_batch[:,3:6] # (ray #, 3)
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 \
else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], \
bounds[...,1] # (ray #, 1)
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1. - t_vals) + far * t_vals
else:
z_vals = 1. / (1./near * (1. - t_vals) +
1./far * ( t_vals) )
# copy sample distances of 1 ray to the others
z_vals = z_vals.expand([N_rays, N_samples])
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
# pytest: overwrite U with fixed NumPy random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + \
rays_d[..., None, :] * z_vals[..., :, None] # (ray #, sample #, 3)
#raw = run_network(pts)
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
# hierarchical sampling
if N_importance > 0:
# log outputs of coarse network
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[..., 1: ] + z_vals[..., :-1])
z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1],
N_importance,
det=(perturb==0.), # FALSE by default
pytest=pytest)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + \
rays_d[..., None, :] * z_vals[..., :, None] # (ray #, coarse & fine sample #, 3)
run_fn = network_fn if network_fine is None \
else network_fine
#raw = run_network(pts, fn=run_fn)
raw = network_query_fn(pts, viewdirs, run_fn)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
ret = {'rgb_map' : rgb_map,
'disp_map': disp_map,
'acc_map' : acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0' ] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0' ] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # (ray #)
for k in ret:
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
3.3 raw2outputs 代码详解
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
raw2alpha = lambda raw, dists, act_fn=F.relu : \
1. - torch.exp(-act_fn(raw) * dists) # σ column of `raw`
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, # (ray #, sample #)
torch.Tensor([1e10]).expand(dists[..., :1].shape)],
-1)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
rgb = torch.sigmoid(raw[..., :3]) # (ray #, sample #, 3)
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[..., 3].shape) * raw_noise_std
# overwrite randomly sampled data
if pytest:
np.random.seed(0)
noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
noise = torch.Tensor(noise)
alpha = raw2alpha(raw[..., 3] + noise, dists) # (ray #, sample #)
#weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)),
1. - alpha + 1e-10], -1),
-1)[:, :-1]
rgb_map = torch.sum(weights[..., None] * rgb, -2) # (ray #, 3)
depth_map = torch.sum(weights * z_vals, -1)
disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
depth_map / torch.sum(weights, -1))
acc_map = torch.sum(weights, -1)
if white_bkgd:
rgb_map = rgb_map + (1. - acc_map[..., None])
return rgb_map, disp_map, acc_map, weights, depth_map
3.4 sample_rays 代码详解
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# get PDF
weights = weights + 1e-5 # prevent NaN
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (ray #, bin #)
# Here, `N_samples` refers to `N_importance`.
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[ :-1]) + [N_samples])
# if pytest, overwrite u with NumPy fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (ray #, sample #, 2)
#cdf_g = tf.gather(cdf , inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
#bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather( cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom<1e-5, torch.ones_like(denom),
denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + \
(bins_g[..., 1] - bins_g[..., 0]) * t
return samples # (ray #, sample #), unsorted along each ray
3.5 optimization 代码详解
…
for i in trange(start, N_iters):
…
optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
loss.backward()
optimizer.step()
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
#############
…