NeRF项目代码详解

news2024/12/25 15:34:03

1 项目结构

开源代码:https://github.com/yenchenlin/nerf-pytorch

在上述框架图中,首先重config_parse 中读取文件参数,

然后通过load_blender加载数据,加载的数据包括训练集、验证集和测试集以及摄像机的内外参数;

在creat_nerf中通过get_embeder 获取 视线方向和三维点的位置编码,并初始化NeRF模型的MLP层 ,

通过get_rays_np 获取视线起点rays_o 和方向rays_d.

在渲染时,若使用LLFF数据集需要调用ndc_rays, 将空间变换到NDC空间中;

在batchify_rays 中,可以通过chunk的大小来控制加载的数据量大小,

在run_network 函数中,将通过get_rays_np  获取视线起点rays_o 和方向rays_d的位置编码输入到MLP网络中,以预测颜色和体素。

在raw2outputs中,通过计算得到的体素和颜色,利用体渲染得到图像,然后计算损失函数。

2 视线方向代码解读

def get_rays_np(H, W, K, c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    dirs = np.stack([ (i-K[0][2]) / K[0][0],
                     -(j-K[1][2]) / K[1][1],
                     -np.ones_like(i)       ], -1)
    # rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
    return rays_o, rays_d

 

 

3  基本渲染流程

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
           near=0., far=1.,
           use_viewdirs=False, c2w_staticcam=None,
           **kwargs):
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays
    # provide ray directions as input
    if use_viewdirs:
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # shape: … × 3
    # for forward facing scenes
    if ndc:
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
    # create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()
    near, far = near * torch.ones_like(rays_d[...,:1]), \
                far  * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)
    # render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret
                                   if k not in k_extract}
    return ret_list + [ret_dict]

3.1 bathcify_rays 代码解读

def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """ render rays in smaller minibatches to avoid OOM
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])
    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}

    return all_ret

3.2 render_rays 代码解读

def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0., # 1.0, overridden by input
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], \
                     ray_batch[:,3:6] # (ray #, 3)
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 \
                                else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], \
                bounds[...,1] # (ray #, 1)
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1. - t_vals) + far * t_vals
    else:
        z_vals = 1. / (1./near * (1. - t_vals) +
                       1./far  * (     t_vals) )
    # copy sample distances of 1 ray to the others
    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        # pytest: overwrite U with fixed NumPy random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[..., None, :] + \
          rays_d[..., None, :] * z_vals[..., :, None] # (ray #, sample #, 3)
    #raw = run_network(pts)
    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
    # hierarchical sampling
    if N_importance > 0:
        # log outputs of coarse network
        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

        z_vals_mid = .5 * (z_vals[..., 1: ] + z_vals[..., :-1])
        z_samples = sample_pdf(z_vals_mid,
                               weights[..., 1:-1],
                               N_importance,
                               det=(perturb==0.), # FALSE by default
                               pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[..., None, :] + \
              rays_d[..., None, :] * z_vals[..., :, None] # (ray #, coarse & fine sample #, 3)

        run_fn = network_fn if   network_fine is None \
                            else network_fine
        #raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, viewdirs, run_fn)

        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    ret = {'rgb_map' :  rgb_map,
           'disp_map': disp_map,
           'acc_map' :  acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0' ] =  rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0' ] =  acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # (ray #)
    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")
    return ret

 

 

3.3 raw2outputs 代码详解 

def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    raw2alpha = lambda raw, dists, act_fn=F.relu : \
                       1. - torch.exp(-act_fn(raw) * dists) # σ column of `raw`

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, # (ray #, sample #)
                       torch.Tensor([1e10]).expand(dists[..., :1].shape)],
                       -1)
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    rgb = torch.sigmoid(raw[..., :3]) # (ray #, sample #, 3)

    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[..., 3].shape) * raw_noise_std
        # overwrite randomly sampled data
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise)

    alpha = raw2alpha(raw[..., 3] + noise, dists) # (ray #, sample #)
    #weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)),
                                               1. - alpha + 1e-10], -1), 
                                    -1)[:, :-1]

    rgb_map   = torch.sum(weights[..., None] * rgb, -2)  # (ray #, 3)
    depth_map = torch.sum(weights * z_vals, -1)
    disp_map  = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
                               depth_map / torch.sum(weights, -1))
    acc_map   = torch.sum(weights, -1)
    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map[..., None])

    return rgb_map, disp_map, acc_map, weights, depth_map

 3.4 sample_rays 代码详解

def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # get PDF
    weights = weights + 1e-5 # prevent NaN
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (ray #, bin #)
    # Here, `N_samples` refers to `N_importance`.
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[ :-1]) + [N_samples])
    # if pytest, overwrite u with NumPy fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)
    # invert CDF
    u = u.contiguous()
    inds   = torch.searchsorted(cdf, u, right=True)
    below  = torch.max(torch.zeros_like(inds-1), inds-1)
    above  = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (ray #, sample #, 2)

    #cdf_g  = tf.gather(cdf , inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    #bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g  = torch.gather( cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom<1e-5, torch.ones_like(denom),
                                    denom)
    t = (u - cdf_g[..., 0]) / denom
    samples =  bins_g[..., 0] + \
              (bins_g[..., 1] - bins_g[..., 0]) * t
    return samples # (ray #, sample #), unsorted along each ray

 

3.5 optimization 代码详解

…
    for i in trange(start, N_iters):
        …
        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

       
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        #############
        …

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1630190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

淘宝、京东、拼多多纷争:“造节”过气,“制剧”当红

经过多年发展,消费者对国内电商三巨头形成了固有印象:拼多多价格低、京东物流快、淘宝生态完善。 消费者的固有印象是淘宝、京东、拼多多在市场上建立的“安全区”,安全区之内已没有挑战,安全区之外才是它们想要征服的新领地。而…

计算机视觉——使用OpenCV GrabCut算法从图像中移除背景

GrabCut算法 GrabCut算法是一种用于图像前景提取的技术,由Carsten Rother、Vladimir Kolmogorov和Andrew Blake三位来自英国剑桥微软研究院的研究人员共同开发。该技术的核心目标是在用户进行最少交互操作的情况下,自动从图像中分割出前景对象。 在Gra…

每日一题:视频拼接

你将会获得一系列视频片段,这些片段来自于一项持续时长为 time 秒的体育赛事。这些片段可能有所重叠,也可能长度不一。 使用数组 clips 描述所有的视频片段,其中 clips[i] [starti, endi] 表示:某个视频片段开始于 starti 并于 …

LeetCode39题: 组合总和(原创)

【题目描述】 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 ,并以列表形式返回。你可以按 任意顺序 返回这些组合。candidates 中的 同一个 数字可以 无限制重复…

技术速递|利用 Redis 使 AI 驱动的 .NET 应用程序更加一致和智能

作者:Catherine Wang 排版:Alan Wang Redis 是一种流行的内存数据存储,可用于解决构建和扩展智能应用程序的关键挑战。在本文中,你将了解如何使用 Redis 的 Azure 缓存来提高使用 Azure OpenAI 的应用程序的效率。 Redis 的 Azur…

聚观早报 | 生数科技推出Vidu;2024款欧拉好猫正式上市

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 4月28日消息 生数科技推出Vidu 2024款欧拉好猫正式上市 雷诺与小米汽车洽谈技术合作 微软张祺谈未来AI如何发展 …

机器学习/算法工程师面试题目与答案-数学基础部分

机器学习/算法工程师面试题目--数学基础部分 一、数学基础1、微积分SGD,Momentum,Adagard,Adam原理L1不可导的时候该怎么办sigmoid函数特性 2、统计学,概率论求 Max(a, b) 期望拿更长的玫瑰花的最好策略最大化工作天数的员工数切比雪夫不等式随机截成三段组成三角形…

基于MSP430F249的电子钟仿真(源码+仿真)

目录 1、前言 2、仿真 3、程序 资料下载地址&#xff1a;基于MSP430F249的电子钟仿真(源码仿真&#xff09; 1、前言 基于MSP430F249的电子钟仿真&#xff0c;数码管显示时分秒&#xff0c;并可以通过按键调节时间。 2、仿真 3、程序 #include <MSP430x24x.h> #def…

Jenkins集成Terraform实现阿里云CDN自动刷新

在互联网业务中&#xff0c;CDN的应用已经成了普遍&#xff0c;SRE的日常需求中&#xff0c;CDN的刷新在前端需求逐渐中占了很大比例&#xff0c;并且比较琐碎。做为合格的SRE&#xff0c;把一切自动化是终极使命&#xff0c;而今天就分享通过JenkinsTerraform实现阿里云的CDN自…

java-动态代理

为什么需要代理&#xff1f; 如何创建代理 注意&#xff1a;实现类和代理需要实现同一个接口 接口 public interface Star {String sing(String song);void dance(); }实现类 public class BigStar implements Star {private String name;public BigStar(String name) {this.…

2024Mac系统热门游戏排行榜 Mac支持的网络游戏有哪些?mac能玩哪些大型网游 苹果电脑Mac游戏资源推荐 Mac玩Windows游戏

“游戏是这个世界上唯一能和女性争夺男朋友的东西&#xff08;/滑稽&#xff0c;有不少女生也喜欢玩游戏&#xff09;。” 虽然只是一句玩笑话&#xff0c;不过也可以看出游戏对大多数男生来说是必不可少的一项娱乐活动了。而网络游戏是游戏中的一大分支&#xff0c;能让玩家们…

uniapp问题归类

最近使用uniapp中&#xff0c;遇到了一些问题&#xff0c;这边mark下。 1. 启动页变形 设置启动页的时候发现在部分android手机上启动页被拉伸了&#xff0c;最后看了下官方建议使用9.png图 生成9.png地址&#xff0c;推荐图片大小为1080x2340 uniapp推荐官方地址传送门 我…

Thread类的基本用法

1.线程创建 这里介绍线程创建常用的五种方法 1.继承Thread&#xff0c;重写run class MyThread extends Thread{public void run(){//这里写的代码就是线程要完成的任务while (true){System.out.println("hello thread");try {Thread.sleep(1000);//线程会休眠一秒…

Springboot+Vue项目-基于Java+MySQL的家政服务平台系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

FreeRTOS:3.信号量

FreeRTOS信号量 参考链接&#xff1a;FreeRTOS-信号量详解_freertos信号量-CSDN博客 目录 FreeRTOS信号量一、信号量是什么二、 FreeRTOS信号量1、二值信号量1、获取信号量2、释放信号量 2、计数信号量3、互斥信号量1、优先级反转2、优先级继承3、源码解析1、互斥量创建2、获取…

[蓝桥杯2024]-PWN:fd解析(命令符转义,标准输出重定向)

查看保护 查看ida 这里有一次栈溢出&#xff0c;并且题目给了我们system函数。 这里的知识点没有那么复杂 完整exp&#xff1a; from pwn import* pprocess(./pwn) pop_rdi0x400933 info0x601090 system0x400778payloadb"ca\\t flag 1>&2" print(len(paylo…

2024.04.28 Typecho管理视频文件,出现预览功能

需求原因原版的Typecho不支持在线视频预览,只有一个图片预览功能, 所以为了实现可以在线预览视频功能, 修改 typecho/admin/media.php 在大概19行的时候,追加如下内容 <?php if ($attachment->attachment->isImage): ?><p><img src"<?php $att…

装饰器模式【结构型模式C++】

1.概述 装饰器模式是一种结构型设计模式&#xff0c; 允许你通过将对象放入包含行为的特殊封装对象中来为原对象绑定新的行为。 2.结构 抽象构件&#xff08;Component&#xff09;角色&#xff1a;定义一个抽象接口以规范准备接收附加责任的对象。具体构件&#xff08;Concre…

关于文档中心的英文快捷替换方案

背景&#xff1a;文档中心需要接入国际化&#xff0c;想节省时间做统一英文方案处理&#xff1b; 文档中心是基于vuepress框架编写的&#xff1b; 1、利用百度翻译 API 的接口去做底层翻译处理&#xff0c;https://api.fanyi.baidu.com/需要在该平台上注册账号&#xff0c;个人…

决策树学习笔记

一、衡量标准——熵 随机变量不确定性的度量 信息增益&#xff1a;表示特征X使得类Y的不确定性减少的程度。 二、数据集 14天的打球情况 特征&#xff1a;4种环境变化&#xff08;天气、温度等等&#xff09; 在上述数据种&#xff0c;14天中打球的天数为9天&#xff1b;不…