NeRF学习——NeRF-Pytorch的源码解读

news2024/11/24 5:04:44

学习 github 上 NeRF 的 pytorch 实现项目(https://github.com/yenchenlin/nerf-pytorch)的一些笔记

1 参数

部分参数配置:

  1. 训练参数:

    这段代码是在设置一些命令行参数,这些参数用于控制NeRF(Neural Radiance Fields)的训练选项。具体来说:

    • netdepth:神经网络的层数。默认值为8

    • netwidth:每层的通道数。默认值为256

    • netdepth_fine:精细网络的层数。默认值为8

    • netwidth_fine:精细网络每层的通道数。默认值为256

    • N_rand:批量大小(每个梯度步骤的随机光线数)。默认值为 32 × 32 × 4 32 \times 32 \times 4 32×32×4

    • lrate:学习率。默认值为5e-4

    • lrate_decay:指数学习率衰减(在1000步中)。默认值为250

    • chunk:并行处理的光线数,如果内存不足,可以减少这个值。默认值为1024*32

    • netchunk:并行通过网络发送的点数,如果内存不足,可以减少这个值。默认值为1024*64

    • no_batching:是否只从一张图像中取随机光线

    • no_reload:是否不从保存的检查点重新加载权重

    • ft_path:用于重新加载粗网络的特定权重npy文件。默认值为None

    • precrop_iters:在中心裁剪上训练的步数。默认值为0。如果这个值大于0,那么在训练的开始阶段,模型将只在图像的中心部分进行训练,这可以帮助模型更快地收敛

    • precrop_frac:用于中心裁剪的图像的比例。默认值为0.5。这个值决定了在进行中心裁剪时,应该保留图像的多少部分。例如,如果这个值为0.5,那么将保留图像中心的50%

  2. 渲染参数:

    • N_samples:每条光线的粗采样数。默认64

    • N_importance:每条光线的额外精细采样数(分层采样)。默认0

    • perturb:设置为0表示没有抖动,设置为1表示有抖动。抖动可以增加采样点的随机性。默认1

    • use_viewdirs:是否使用完整的5D输入,而不是3D。5D输入包括3D位置和2D视角

    • i_embed:设置为0表示使用默认的位置编码,设置为-1表示不使用位置编码。默认0

    • multires:位置编码的最大频率的对数(用于3D位置)。默认10

    • multires_views:位置编码的最大频率的对数(用于2D方向)。默认4

      我们设置 d = 10 d=10 d=10 用于位置坐标 ϕ ( x ) ϕ(\bf x) ϕ(x) ,所以输入是60维的向量; d = 4 d=4 d=4 用于相机位姿 ϕ ( d ) ϕ(\bf d) ϕ(d) 对应的则是24维

    • raw_noise_std:添加到 sigma_a 输出的噪声的标准偏差,用于正则化 sigma_a 输出。默认0

    • render_only:如果设置,那么不进行优化,只加载权重并渲染出 render_poses 路径

    • render_test:如果设置,那么渲染测试集,而不是 render_poses 路径

    • render_factor:降采样因子,用于加速渲染。设置为4或8可以快速预览。默认0

  3. LLFF(Light Field Photography)数据集:

    • factor:LLFF图像的降采样因子。默认值为8。这个值决定了在处理LLFF图像时,应该降低多少分辨率

    • no_ndc:是否不使用归一化设备坐标(NDC)。如果在命令行中指定了这个参数,那么其值为True。这个选项应该在处理非前向场景时设置

    • lindisp:是否在视差中线性采样,而不是在深度中采样。如果在命令行中指定了这个参数,那么其值为True

    • spherify:是否处理球形360度场景。如果在命令行中指定了这个参数,那么其值为True

    • llffhold:每N张图像中取一张作为LLFF测试集。默认值为8。这个值决定了在处理LLFF数据集时,应该把多少图像作为测试集

      # 加载数据时,每隔args.llffhold个图像取一张图形
      i_test = np.arange(images.shape[0])[::args.llffhold]
      

2 大致过程

2.1 加载LLFF数据
  1. load_llff_data 函数返回五个值:images(图像),poses(姿态),bds(深度范围),render_poses(渲染姿态)和i_test(测试图像索引)

    • hwf是从poses中提取的图像的高度宽度焦距
    images, poses, bds, render_poses, i_test = load_llff_data(.....)
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    
  2. 将图像数据集划分为三个部分:训练集(i_train)、验证集(i_val)和测试集(i_test

    # 每隔args.llffhold个图像取一张做测试集
    i_test = np.arange(images.shape[0])[::args.llffhold]
    # 验证集 = 测试集
    i_val = i_test
    # 所有不在测试集和验证集中的图像
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                    (i not in i_test and i not in i_val)])
    
2.2 创建神经网络模型
  1. 将采样点坐标和观察坐标通过位置编码 get_embedder 成63维和27维
  2. 实例化NeRF模型和NeRF精细模型
  3. 创建网络查询函数 network_query_fn() ,用于运行网络
  4. 创建 Adam 优化器
  5. 加载检查点(如果有),即从检查点中重新加载模型和优化器状态
  6. 创建用于训练和测试的渲染参数 render_kwargs_trainrender_kwargs_test
  7. 根据数据集类型(只有LLFF才行)和参数确定是否使用NDC
2.3 准备光线

使用批处理:

  1. 对于每一个姿态,使用get_rays_np函数获取光线原点和方向( ro+rd ),然后将所有的光线堆叠起来,得到rays
  2. 将射线的原点和方向与图像的颜色通道连接起来( ro+rd+rgb
  3. 对张量进行重新排列和整形,只保留训练集中的图像
  4. 对训练数据进行随机重排
2.4 训练迭代
  1. 设置训练迭代次数 N_iters = 200000 + 1

  2. 开始进行训练迭代

    • 准备光线数据:在每次迭代中,从rays_rgb中取出一批(批处理)光线数据,数量为参数值N_rand,并准备好目标值 target_s

      如果完成一个了周期(i_batch >= rays_rgb.shape[0] ),则对数据进行打乱

    • 渲染:使用渲染函数 render()

    • 计算损失:计算渲染结果的损失。这里使用了均方误差损失函数 img2mse() 来计算图像损失
      L = ∑ r ∈ R ∥ C ^ c ( r ) − C ( r ) ∥ 2 2 + ∥ C ^ f ( r ) − C ( r ) ∥ 2 2 \mathcal{L} = \sum_{\mathbf{r} \in \mathcal{R}} \left\| \hat{C}^c(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 + \left\| \hat{C}^f(\mathbf{r}) - C(\mathbf{r}) \right\|_2^2 L=rR C^c(r)C(r) 22+ C^f(r)C(r) 22

      img2mse = lambda x, y : torch.mean((x - y) ** 2)
      
    • 反向传播:进行反向传播,并执行优化

    • 更新学习率:这里采用指数衰减的学习率调度策略,学习率在每个一定的步骤(decay_steps)内以一定的速率(decay_rate)衰减

  3. 根据参数设置的频率输出相关状态、视频和测试集

3 神经网络模型

模型结构如下:

image-20240316162459526

  • 应用 ReLU 激活函数

  • 采样点坐标和观察坐标通过位置编码成63维和27维

  • 中间有一个跳跃连接在第四次 256->256 的线性层

    跳跃连接可以将某一层的输入直接传递到后面的层,从而避免梯度消失和表示瓶颈,提高网络的性能

4 体积渲染

4.1 render()

渲染主函数是调用 render() 函数:

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):

其有两种用法:

  1. 测试用:

    rgb, disp, acc, _ = render(H, W, K, 
                               chunk=chunk, 
                               c2w=c2w[:3,:4], 
                               **render_kwargs)
    

    c2w=c2w[:3,:4] 意味着光线的起点和方向是由函数内部通过相机参数计算得出的

    这个只在 render_path() 函数中用到,其在给定相机路径下渲染图像

    • 不训练只渲染时直接渲染时
    • 定期输出结果时
  2. 训练用:

    rgb, disp, acc, extras = render(H, W, K, 
                                    chunk=args.chunk, 
                                    rays=batch_rays,
                                    verbose=i < 10, 
                                    retraw=True,
                                    **render_kwargs_train)
    

    rays=batch_rays 意味着光线的起点和方向是预先计算好的,而不是由函数内部通过相机参数计算得出

    这个只在训练迭代时用到:Core optimization loop 中,对从rays_rgb中取出一批(批处理)光线进行渲染,得到的 rgb 值与 target_s (也来自预先计算好的 rays_rgb )计算 loss,来进行神经网络的训练

4.2 batchify_rays()

在主函数 render() 中,渲染工作是调用的 batchify_rays()

主要目的是将大量的光线分批处理,以避免在渲染过程中出现内存溢出(OOM)的问题

4.3 render_rays()

分批处理函数 batchify_rays() 中的渲染操作是由 render_rays() 进行,其是真正的渲染操作的函数

def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):

其参数:光线批次(ray_batch)、网络函数(network_fn)、网络查询函数(network_query_fn)、样本数量(N_samples)等等

返回:一个字典 ,包含了 RGB 颜色映射、视差映射、累积不透明度等信息

其大致过程为:

  1. 从光线批次中提取出光线的起点、方向、视线方向以及近远边界

    • 根据是否进行线性分布采样,计算出每个光线上的采样点的深度值

    • 若设置扰动( perturb ),则在每个采样间隔内进行分层随机采样

  2. 函数计算出每个采样点在空间中的位置

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    
  3. 然后使用 network_query_fn() 对每个采样点进行预测,得到原始的预测结果 raw

  4. 使用 raw2outputs()(请看下一节4.4) 函数将原始预测结果转换为 RGB 颜色映射、视差映射、累积不透明度等输出

  5. 若分层采样 N_importance > 0,调用 sample_pdf() 分层采样,并将这些额外的采样点传递给精细网络 network_fine 进行预测

  6. 最后,函数返回一个字典,包含了所有的输出结果

4.4 raw2outputs()

其将模型的原始预测转换为语义上有意义的值,主要基于论文中离散形式的积分方程实现:

累积不透明度函数 C ^ ( r ) \hat{C}(r) C^(r) 的估计公式如下:

C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci

其中,

  • N N N 是样本点的数量,
  • T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj) 是权重系数
  • δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti 表示相邻样本之间的距离
  • c i c_i ci 是颜色值
  • σ i \sigma_i σi 是不透明度值(体积密度)

根据代码,我们可以得出以下关系:

  • c i c_i ci 对应着 rgb = torch.sigmoid(raw[...,:3]),表示颜色值
  • σ i \sigma_i σi 对应着 raw[...,3],表示不透明度值

然后,我们可以根据公式中的每个项逐一解释如何在代码中实现:

  1. δ i = t i + 1 − t i \delta_i = t_{i+1} - t_i δi=ti+1ti:计算相邻样本之间的距离。在代码中:

     dists = z_vals[...,1:] - z_vals[...,:-1]
    
  2. 1 − exp ⁡ ( − σ i δ i ) 1 - \exp(-\sigma_i \delta_i) 1exp(σiδi):计算每个样本的不透明度。在代码中:

    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
    
    alpha = raw2alpha(raw[...,3] + noise, dists)
    
  3. T i = exp ⁡ ( − ∑ j = 1 i − 1 σ j δ j ) T_i = \exp \left( - \sum_{j=1}^{i-1} \sigma_j \delta_j \right) Ti=exp(j=1i1σjδj)​:计算权重系数。在代码中:

    即对 1 − ( 1 − exp ⁡ ( − σ i δ i ) ) 1 - (1 - \exp(-\sigma_i \delta_i)) 1(1exp(σiδi)) 累乘

    torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    
  4. C ^ ( r ) = ∑ i = 1 N T i ( 1 − exp ⁡ ( − σ i δ i ) ) c i \hat{C}(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i C^(r)=i=1NTi(1exp(σiδi))ci​​:计算累积不透明度。在代码中:

    w i = T i ( 1 − exp ⁡ ( − σ i δ i ) ) w_i = T_i(1 - \exp(-\sigma_i\delta_i)) wi=Ti(1exp(σiδi))

    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    

最终,代码返回估计的 RGB 颜色、视差图、累积权重、权重以及估计的距离图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1522636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

sqllab第二十六关通关笔记

知识点&#xff1a; 空格替换 %09 %0a %0b %0c %0d %a0 (%2b)or替换&#xff1a;|| ||是不需要空格区分的and替换&#xff1a;&& &&同样不需要空格区分的双写绕过&#xff0c;但是绕过后需要和内容进行空格区分的&#xff0c;要不然不发挥作用&#xff1b;这关…

确保云原生部署中的网络安全

数字环境正在以惊人的速度发展&#xff0c;组织正在迅速采用云原生部署和现代化使用微服务和容器构建的应用程序&#xff08;通常运行在 Kubernetes 等平台上&#xff09;&#xff0c;以推动增长。 无论我们谈论可扩展性、效率还是灵活性&#xff0c;对于努力提供无与伦比的用…

【python开发】并发编程(上)

并发编程&#xff08;上&#xff09; 一、进程和线程&#xff08;一&#xff09;多线程&#xff08;二&#xff09;多进程&#xff08;三&#xff09;GIL锁 二、多线程开发&#xff08;一&#xff09;t.start()&#xff08;二&#xff09;t.join()&#xff08;三&#xff09;t.…

基于ESTAR指数平滑转换自回归模型的CPI数据统计分析matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 ESTAR模型概述 4.2 WNL值&#xff0c;P值&#xff0c; Q值&#xff0c;12阶ARCH值 4.3ADF检验 5.完整程序 1.程序功能描述 基于ESTAR指数平滑转换自回归模型的CPI数据统计分析matlab仿…

【Hadoop大数据技术】——MapReduce经典案例实战(倒排索引、数据去重、TopN)

&#x1f4d6; 前言&#xff1a;MapReduce是一种分布式并行编程模型&#xff0c;是Hadoop核心子项目之一。实验前需确保搭建好Hadoop 3.3.5环境、安装好Eclipse IDE &#x1f50e; 【Hadoop大数据技术】——Hadoop概述与搭建环境&#xff08;学习笔记&#xff09; 目录 &#…

基于springboot+mysql+Shiro实现的宠物医院管理系统

1.项目介绍 系统主要为用户提供了管理员权限的用户&#xff0c;实现了前台查看客户信息、在线添加预约等&#xff1b;后台管理医生坐诊信息、管理就诊信息、修改密码&#xff0c;管理公告、管理宠物分类、管理就诊、管理用户、修改密码等。在设计方面&#xff0c;本系统采用MV…

Echo框架:高性能的Golang Web框架

Echo框架&#xff1a;高性能的Golang Web框架 在Golang的Web开发领域&#xff0c;选择一个适合的框架是构建高性能和可扩展应用程序的关键。Echo是一个备受推崇的Golang Web框架&#xff0c;以其简洁高效和强大功能而广受欢迎。本文将介绍Echo框架的基本特点、使用方式及其优势…

计算机网络——物理层(数据交换方式)

计算机网络——数据交换方式 提高数据交换方式的必要性电路交换电路交换原理电路交换的阶段建立阶段通信阶段和连接拆除阶段 电路交换的优缺点报文交换什么是报文报文交换的阶段报文交换的优缺点 分组交换分组交换的阶段分组交换的优缺点 数据交换方式的选择数据报方式数据报方…

【二】【单片机】有关独立按键的实验

自定义延时函数Delay 分别用Delay.c文件存储Delay函数。用Delay.h声明Delay函数。每次将这两个文件复制到工程中&#xff0c;直接使用。 //Delay.c void Delay(unsigned int xms) //11.0592MHz {while(xms--){unsigned char i, j;i 2;j 199;do{while (--j);}…

[自研开源] MyData 数据集成之数据过滤 v0.7.2

开源地址&#xff1a;gitee | github 详细介绍&#xff1a;MyData 基于 Web API 的数据集成平台 部署文档&#xff1a;用 Docker 部署 MyData 使用手册&#xff1a;MyData 使用手册 试用体验&#xff1a;https://demo.mydata.work 交流Q群&#xff1a;430089673 概述 本篇基于…

Ubuntu 虚拟机安装

最小化安装后常用工具 sudo apt-get install vim# ifconfig apt install net-tools # nload apt install nload # 很多都要用到 apt install build-essential # 开发相关 apt install gcc gapt install iproute2 ntpdate tcpdump telnet traceroute \ nfs-kernel-server nfs…

Java项目:57 ssm011线上旅行信息管理系统ssm+vue

作者主页&#xff1a;源码空间codegym 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 本线上旅行信息管理系统&#xff0c;主要实现了用户功能模块和管理员功能模块两大部分 用户可查看旅行相关信息&#xff0c;注册登录后还可实…

第十二届蓝桥杯EDA省赛真题分析

前言&#xff1a; 第十二届蓝桥杯EDA比赛用的是AD软件&#xff0c;从第十四届起都是使用嘉立创EDA专业版&#xff0c;所以在这里我用嘉立创EDA专业版实现题目要求。 一、省赛第一套真题题目 主观题整套题目如下&#xff1a; 试题一&#xff1a;库文件设计&#xff08;5分&am…

Django 解决新建表删除后无法重新创建等问题

Django 解决新建表删除后无法重新创建等问题 问题发生描述处理办法首先删除了app对应目录migrations下除 __init__.py以外的所有文件:然后&#xff0c;删除migrations中关于你的app的同步数据数据库记录最后&#xff0c;重新执行迁移插入 问题发生描述 Django创建的表&#xf…

【计算机视觉】二、图像形成——实验:2D变换编辑(Pygame)

文章目录 一、向量和矩阵的基本运算二、几何基元和变换1、几何基元(Geometric Primitives)2、几何变换(Geometric Transformations)2D变换编辑器0. 程序简介环境说明程序流程 1. 各种变换平移变换旋转变换等比缩放变换缩放变换镜像变换剪切变换 2. 按钮按钮类创建按钮 3. Pygam…

前端vue-Taro框架中使用插件 ---pinyin 将城市树形分类

1.需求 当我做一个获取城市的功能的时候 我发向后端返回的数据 和我想i选要的相差太多 这样的在手机端可以滑动 并且 快捷选中的城市列表 目前的数据是这样的&#xff0c;就是一个城市数组 目前这样的数组 我要想显示我的页面实现功能是不行的 需要是树形结够 所以我前端…

CI/CD实战-git工具使用 1

版本控制系统 本地版本控制系统 集中化的版本控制系统 分布式版本控制系统 git官网文档&#xff1a;https://git-scm.com/book/zh/v2 Git 有三种状态&#xff1a;已提交&#xff08;committed&#xff09;、已修改&#xff08;modified&#xff09; 和 已暂存&#xff08;sta…

【CTF web1】

CTF web 一、CTF web -PHP弱类型1、是否相等&#xff1f;2、转换规则: 二、CTF web -md5绕过1、若类型比较绕过2、null绕过3、碰撞绕过 三、习题 一、CTF web -PHP弱类型 1、是否相等&#xff1f; &#xff1a;在进行比较的时候&#xff0c;会先判断两种字符串的类型是否相等&…

EVENG环境安装及测试 1

文章目录 下载eve镜像导入镜像访问测试导入自定义镜像 下载eve镜像 下载地址 链接&#xff1a;https://pan.baidu.com/s/1NqGE34oE5qZ6TCugMymPDg 提取码&#xff1a;f4m1 导入镜像 安装vmware 虚拟机&#xff0c;文件->打开 选中上述镜像 输入虚拟机的名称和保存 路径&a…

接口幂等性问题和常见解决方案

接口幂等性问题和常见解决方案 1.什么是接口幂等性问题1.1 会产生接口幂等性的问题1.2 解决思路 2.接口幂等性的解决方案2.1 唯一索引解决方案2.2 乐观锁解决方案2.3 分布式锁解决方案2.4 Token解决方案(最优方案) 1.什么是接口幂等性问题 幂等性: 用户同一操作发起的一次或多…