100行Pytorch代码实现三维重建技术神经辐射场 (NeRF)

news2025/2/26 14:50:06

提起三维重建技术,NeRF是一个绝对绕不过去的名字。这项逆天的技术,一经提出就被众多研究者所重视,对该技术进行深入研究并提出改进已经成为一个热点。不到两年的时间,NeRF及其变种已经成为重建领域的主流。本文通过100行的Pytorch代码实现最初的 NeRF 论文。

NeRF全称为Neural Radiance Fields(神经辐射场),是一项利用多目图像重建三维场景的技术。该项目的作者来自于加州大学伯克利分校,Google研究院,以及加州大学圣地亚哥分校。NeRF使用一组多目图作为输入,通过优化一个潜在连续的体素场景方程来得到一个完整的三维场景。该方法使用一个全连接深度网络来表示场景,使用的输入是一个单连通的5D坐标(空间位置x,y,z以及观察视角θ,),输出为一个体素场景,可以以任意视角查看,并通过体素渲染技术,生成需要视角的照片。该方法同样支持视频合成。

该方法是一个基于体素重建的方法,通过在多幅图片中的五维坐标建立一个由粗到细的对应,进而恢复出原始的三维体素场景。

NeRF 和神经渲染的基本概念

Rendering

渲染是从 3D 模型创建图像的过程。该模型将包含纹理、阴影、阴影、照明和视点等特征,渲染引擎的作用是处理这些特征以创建逼真的图像。

三种常见的渲染算法类型是光栅化,它根据模型中的信息以几何方式投影对象,没有光学效果;光线投射,使用基本的光学反射定律从特定角度计算图像;和光线追踪,它使用蒙特卡罗技术在更短的时间内获得逼真的图像。光线追踪用于提高 NVIDIA GPU 中的渲染性能。

Volume Rendering

立体渲染使能够创建 3D 离散采样数据集的 2D 投影。

对于给定的相机位置,立体渲染算法为空间中的每个体素获取 RGBα(红色、绿色、蓝色和 Alpha 通道),相机光线通过这些体素投射。RGBα 颜色转换为 RGB 颜色并记录在 2D 图像的相应像素中。对每个像素重复该过程,直到呈现整个 2D 图像。

View Synthesis

视图合成与立体渲染相反——它涉从一系列 2D 图像创建 3D 视图。这可以使用一系列从多个角度显示对象的照片来完成,创建对象的半球平面图,并将每个图像放置在对象周围的适当位置。视图合成函数尝试在给定一系列描述对象不同视角的图像的情况下预测深度。

NeRF是如何工作的

NeRF使用一组稀疏的输入视图来优化连续的立体场景函数。这种优化的结果是能够生成复杂场景的新视图。

NeRF使用一组多目图作为输入:

输入为一个单连通的5D坐标(空间位置x,y,z以及观察视角(θ; Φ)

输出为一个体素场景 c = (r; g; b) 和体积密度 (α)。

下面是如何从一个特定的视点生成一个NeRF:

  • 通过移动摄像机光线穿过场景生成一组采样的3D点
  • 将采样点及其相应的2D观察方向输入神经网络,生成密度和颜色的输出集
  • 通过使用经典的立体渲染技术,将密度和颜色累积到2D图像中

上述过程深度的全连接、多层感知器(MLP)进行优化,并且不需要使用卷积层。它使用梯度下降来最小化每个观察到的图像和从表示中呈现的所有相应视图之间的误差。

Pytorch代码实现

渲染

神经辐射场的一个关键组件,是一个可微分渲染,它将由NeRF模型表示的3D表示映射到2D图像。该问题可以表述为一个简单的重构问题

这里的A是可微渲染,x是NeRF模型,b是目标2D图像。

代码如下:

 defrender_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
     device=ray_origins.device
     t=torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
     # Perturb sampling along each ray.
     mid= (t[:, :-1] +t[:, 1:]) /2.
     lower=torch.cat((t[:, :1], mid), -1)
     upper=torch.cat((mid, t[:, -1:]), -1)
     u=torch.rand(t.shape, device=device)
     t=lower+ (upper-lower) *u  # [batch_size, nb_bins]
     delta=torch.cat((t[:, 1:] -t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
 
     x=ray_origins.unsqueeze(1) +t.unsqueeze(2) *ray_directions.unsqueeze(1)   # [batch_size, nb_bins, 3]
     ray_directions=ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
 
     colors, sigma=nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
     colors=colors.reshape(x.shape)
     sigma=sigma.reshape(x.shape[:-1])
 
     alpha=1-torch.exp(-sigma*delta)  # [batch_size, nb_bins]
     weights=compute_accumulated_transmittance(1-alpha).unsqueeze(2) *alpha.unsqueeze(2)
     c= (weights*colors).sum(dim=1)  # Pixel values
     weight_sum=weights.sum(-1).sum(-1)  # Regularization for white background
     returnc+1-weight_sum.unsqueeze(-1)

渲染将NeRF模型和来自相机的一些光线作为输入,并使用立体渲染返回与每个光线相关的颜色。

代码的初始部分使用分层采样沿射线选择3D点。然后在这些点上查询神经辐射场模型(连同射线方向)以获得密度和颜色信息。模型的输出可以用蒙特卡罗积分计算每条射线的线积分。

累积透射率(论文中Ti)用下面的专用函数中单独计算。

 defcompute_accumulated_transmittance(alphas):
     accumulated_transmittance=torch.cumprod(alphas, 1)
     returntorch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                       accumulated_transmittance[:, :-1]), dim=-1)

NeRF

我们已经有了一个可以从3D模型生成2D图像的可微分模拟器,下面就是实现NeRF模型。

根据上面的介绍,NeRF非常的复杂,但实际上NeRF模型只是多层感知器(MLPs)。但是具有ReLU激活函数的mlp倾向于学习低频信号。当试图用高频特征建模物体和场景时,这就出现了一个问题。为了抵消这种偏差并允许模型学习高频信号,使用位置编码将神经网络的输入映射到高维空间。

 classNerfModel(nn.Module):
     def__init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
         super(NerfModel, self).__init__()
 
         self.block1=nn.Sequential(nn.Linear(embedding_dim_pos*6+3, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
 
         self.block2=nn.Sequential(nn.Linear(embedding_dim_pos*6+hidden_dim+3, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                     nn.Linear(hidden_dim, hidden_dim+1), )
 
         self.block3=nn.Sequential(nn.Linear(embedding_dim_direction*6+hidden_dim+3, hidden_dim//2), nn.ReLU(), )
         self.block4=nn.Sequential(nn.Linear(hidden_dim//2, 3), nn.Sigmoid(), )
 
         self.embedding_dim_pos=embedding_dim_pos
         self.embedding_dim_direction=embedding_dim_direction
         self.relu=nn.ReLU()
 
     @staticmethod
     defpositional_encoding(x, L):
         out= [x]
         forjinrange(L):
             out.append(torch.sin(2**j*x))
             out.append(torch.cos(2**j*x))
         returntorch.cat(out, dim=1)
 
     defforward(self, o, d):
         emb_x=self.positional_encoding(o, self.embedding_dim_pos)
         emb_d=self.positional_encoding(d, self.embedding_dim_direction)
         h=self.block1(emb_x)
         tmp=self.block2(torch.cat((h, emb_x), dim=1))
         h, sigma=tmp[:, :-1], self.relu(tmp[:, -1])
         h=self.block3(torch.cat((h, emb_d), dim=1))
         c=self.block4(h)
         returnc, sigma

训练

训练循环也很简单,因为它也是监督学习。我们可以直接最小化预测颜色和实际颜色之间的L2损失。

 deftrain(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
           nb_bins=192, H=400, W=400):
     training_loss= []
     for_intqdm(range(nb_epochs)):
         forbatchindata_loader:
             ray_origins=batch[:, :3].to(device)
             ray_directions=batch[:, 3:6].to(device)
             ground_truth_px_values=batch[:, 6:].to(device)
 
             regenerated_px_values=render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
             loss= ((ground_truth_px_values-regenerated_px_values) **2).sum()
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             training_loss.append(loss.item())
         scheduler.step()
 
         forimg_indexinrange(200):
             test(hn, hf, testing_dataset, img_index=img_index, nb_bins=nb_bins, H=H, W=W)
 
     returntraining_loss

测试

训练过程完成,NeRF模型就可以用于从任何角度生成图像。测试函数通过使用来自测试图像的射线数据集进行操作,然后使用渲染函数和优化的NeRF模型为这些射线生成图像。

 @torch.no_grad()
 deftest(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
     ray_origins=dataset[img_index*H*W: (img_index+1) *H*W, :3]
     ray_directions=dataset[img_index*H*W: (img_index+1) *H*W, 3:6]
 
     data= []
     foriinrange(int(np.ceil(H/chunk_size))):
         ray_origins_=ray_origins[i*W*chunk_size: (i+1) *W*chunk_size].to(device)
         ray_directions_=ray_directions[i*W*chunk_size: (i+1) *W*chunk_size].to(device)
 
         regenerated_px_values=render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
         data.append(regenerated_px_values)
     img=torch.cat(data).data.cpu().numpy().reshape(H, W, 3)
 
     plt.figure()
     plt.imshow(img)
     plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight')
     plt.close()

所有的部分都可以很容易地组合起来。

 if__name__=='main':
     device='cuda'
     training_dataset=torch.from_numpy(np.load('training_data.pkl', allow_pickle=True))
     testing_dataset=torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
     model=NerfModel(hidden_dim=256).to(device)
     model_optimizer=torch.optim.Adam(model.parameters(), lr=5e-4)
     scheduler=torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
 
     data_loader=DataLoader(training_dataset, batch_size=1024, shuffle=True)
     train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6, nb_bins=192, H=400,
           W=400)

这样一个简单的NeRF就完成了,看看效果:

希望本文对你有所帮助,如果你对NeRF感兴趣可以看看这个项目:

https://avoid.overfit.cn/post/3d89b7ed625b437993e3fde57f36c70a

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/347514.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

部门新来个00后卷王,太让人崩溃了,想离职了....

在职场上,什么样的人最让人反感? 是技术不好的人吗? 并不是。技术不好的同事,我们可以帮他。 是技术太强的人吗? 也不是。技术很强的同事,可遇不可求,向他学习还来不及呢。 真正让人反感的…

【uniapp】getOpenerEventChannel().once 接收参数无效的解决方案

uniapp项目开发跨平台应用常会遇到接收参数无效的问题,无法判断是哪里出错了,这里是讲替代的方案,现有三种方案可选。 原因 一般我们是这样处理向另一个页面传参,代码是这样写的 //... let { title, type, rank } args; uni.n…

STM32 HAL库-定时器中断

STM32 HAL库-定时器中断一、STM32F407定时器介绍定时器计算公式二、CubeMX配置定时器三、基本定时器中断配置流程1)开启定时器时钟2)初始化定时器参数,设置自动重装值,分频系数,计数方式等3)使能定时器更新中断&#x…

Ubuntu 系统 OpenCV 4 无法打开视频文件解决方案

目录 一、我的运行环境 二、问题描述 三、问题定位及分析 四、解决方案 一、我的运行环境 设备NVIDIA Jetson Nano处理器ARMv8 Processor rev 1 (v8l) 4 GPUNVIDIA Tegra X1 (nvgpu)/integrated操作系统ubuntu 18.04 LTSOpenCV版本4.6.0语言C 二、问题描述 之前一直用的O…

8 冒泡排序

文章目录1 基本介绍1 代码实现1.1 java1.1 scala1 基本介绍 冒泡排序(Bubble Sorting)的基本思想是:通过对待排序序列从前向后(从下标较小的元素开始),依次比较相邻元素的值,若发现逆序则交换,使…

存储管理(6)

存储管理 1 程序的装入与链接 编译:源代码——目标代码 链接:目标代码所需库函数装入模块 装入:将装入模块装入内存,该过程也叫做地址重定位,也称地址映射 地址空间: 源程序经编译后得到的目标程序&…

Leetcode 1223. 掷骰子模拟【动态规划】

有一个骰子模拟器会每次投掷的时候生成一个 1 到 6 的随机数。 不过我们在使用它时有个约束,就是使得投掷骰子时,连续 掷出数字 i 的次数不能超过 rollMax[i](i 从 1 开始编号)。 现在,给你一个整数数组 rollMax 和一…

WebDAV之葫芦儿·派盘+NMM

NMM 支持WebDAV方式连接葫芦儿派盘。 推荐一款文件管理器,可以对手机中的文件进行多方面的管理,支持语法高亮和ftp等远程的文件的管理。支持从WebDav服务器连接葫芦儿派盘服务下载文件和上传文件。 NMM文本编辑器是一款文件管理器,在功能上面更加的适合于一些编程人员进行使…

2023年应该了解的黑客知识

网络犯罪的艺术处于不断变化和演变的状态。与这些趋势保持同步是网络安全人员工作的重要组成部分。 今天的现代网络安全必须确保他们始终为下一个大趋势做好准备并保持领先于对手。 当我们开始迈向 2023 年时,安全格局与一年前相比已经发生了变化,更不…

Spark on hive Hive on spark

文章目录Spark on hive & Hive on sparkHive 架构与基本原理Spark on hiveHive on sparkSpark on hive & Hive on spark Hive 架构与基本原理 Hive 的核心部件主要是 User Interface(1)和 Driver(3)。而不论是元数据库&a…

webpack(高级)--性能优化-代码分离

webpack webpack性能优化 优化一:打包后的结果 上线时的性能优化 (比如分包处理 减少包体积 CDN服务器) 优化二:优化打包速度 开发或者构建优化打包速度 (比如exclude cache-loader等) 大多数情况下我们侧…

css 安全区域 safe-area-inset-

前言 安全区域与边界是iOS11 新增特性。 安全区域 安全区域的内容不受圆角(corners)、齐刘海(sensor housing)、小黑条(Home Indicator)影响。Webkit 为此增加了相应的CSS 函数,用于获取安全…

B树系列与MySQL数据库

前篇提到B树及其实现:一文看懂---B树及其简单实现_b树实现_且随疾风前行->的博客-CSDN博客 本篇继续谈B树系列的B树,B*树和它们与MySQL数据库的关系。 目录 B树系列 B树 B树的特性: B*树 B树系列总结 MySQL索引简介 MyISAM Inno…

Sphinx文档生成工具(二)

rst语法 官方的语法手册 行内的样式: #斜体 *message* #粗体 **message** #等宽 不能有换行 message标题 一级标题 ^^^^^^^^ 二级标题 --------- 三级标题 >>>>>>>>> 四级标题 ::::::::: 五级标题六级标题 """"…

Vue+node.js医院预约挂号信息管理系统vscode

网上预约挂号系统将会是今后医院发展的主要趋势。 前端技术:nodejsvueelementui,视图层其实质就是vue页面,通过编写vue页面从而展示在浏览器中,编写完成的vue页面要能够和控制器类进行交互,从而使得用户在点击网页进行操作时能够正…

关于 mysql数据库插入中文变空白 的解决方法

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/129048030 红胖子网络科技的博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软…

Session详解,学习 Session对象一篇文章就够了

目录 1 Session概述 2 Session原理 3 Session使用 3.1 获取Session 3.2 Session保存数据 3.3 Session获取数据 3.4 Session移除数据 4 Session与Request应用区别 4.1 Session和request存储数据 4.2 获取session和request中的值 4.3 session和request区别效果 5 Sess…

Transformer:开启CV研究新时代

来源:投稿 作者:魔峥 编辑:学姐 起源回顾 有关Attention的论文早在上世纪九十年代就提出了。 在2012年后的深度学习时代,Attention再次被翻了出来,被用在自然语言处理任务,提高RNN模型的训练速度。但是由…

数据库(4)--视图的定义和使用

一、学习目的 加深对视图的理解,熟练视图的定义、查看、修改等操作 二、实验环境 Windows 11 Sql server2019 三、实验内容 学生(学号,年龄,性别,系名) 课程(课号,课名,…

收藏|一文掌握数据分析在企业的实际流程

一、数据分析概念 1.1 数据分析 是指用适当的统计分析方法对收集来的大量数据进行分析,将他们加以汇总和理解并消化,以求最大化地开发数据的功能,发挥数据的作用。 1.2 数据分析包括 描述性数据分析(初级数据分析)…