【三维重建】【深度学习】NeuS代码Pytorch实现--测试阶段代码解析(上)

news2024/12/25 9:04:34

【三维重建】【深度学习】NeuS代码Pytorch实现–测试阶段代码解析(上)

论文提出了一种新颖的神经表面重建方法,称为NeuS,用于从2D图像输入以高保真度重建对象和场景。在NeuS中建议将曲面表示为有符号距离函数(SDF)的零级集,并开发一种新的体绘制方法来训练神经SDF表示,因此即使没有掩模监督,也可以实现更准确的表面重建。NeuS在高质量的表面重建方面的性能优于现有技术,特别是对于具有复杂结构和自遮挡的对象和场景。本篇博文将根据代码执行流程解析测试阶段具体的功能模块代码。

文章目录

  • 【三维重建】【深度学习】NeuS代码Pytorch实现--测试阶段代码解析(上)
  • 前言
  • save_checkpoint
  • validate_image
  • gen_rays_at
  • validate_mesh
  • extract_geometry
  • extract_fields
  • 总结


前言

在详细解析NeuS网络之前,首要任务是搭建NeuS【win10下参考教程】所需的运行环境,并完成模型的训练和测试,展开后续工作才有意义。
本博文将对NeuS测试阶段涉及的功能代码模块进行解析。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。

这里的代码段是exp_runner.py文件的train函数部分,它是在属于广义上的训练阶段的一部分,但是由于不参与NeuS网络的更新,只是对NeuS网络进行阶段性验证,因此博主放到该博文中进行详细讲解。

if self.iter_step % self.save_freq == 0:
    self.save_checkpoint()

if self.iter_step % self.val_freq == 0:
    self.validate_image()

if self.iter_step % self.val_mesh_freq == 0:
    self.validate_mesh()

self.update_learning_rate()

if self.iter_step % len(image_perm) == 0:
    image_perm = self.get_image_perm()

save_checkpoint

属于exp_runner.py文件的Runner类中的成员方法,目的是保存完成阶段训练的NeuS权重。

def save_checkpoint(self):
    checkpoint = {
        'nerf': self.nerf_outside.state_dict(),     # 各深度学习网络参数权重
        'sdf_network_fine': self.sdf_network.state_dict(),
        'variance_network_fine': self.deviation_network.state_dict(),
        'color_network_fine': self.color_network.state_dict(),
        'optimizer': self.optimizer.state_dict(),   # 优化器
        'iter_step': self.iter_step,                # 训练的次数
    }
    # 创建放置权重模型的文件夹
    os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
    # 保存
    torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))

validate_image

阶段性的完成NeuS模型训练后,需要渲染图片并与真实的训练图片进行比较从而验证模型训练的效果。
首先需要gen_rays_at函数生成整张图片(下采样后)的光线rays,然后获取rays光线上采样点(前景)的最远点和最近点,最后通过renderer函数获取所需的结果。

def validate_image(self, idx=-1, resolution_level=-1):
    # 假设验证图像的序号小于0,随机获取一个图片序号
    if idx < 0:
        idx = np.random.randint(self.dataset.n_images)

    print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))

    if resolution_level < 0:
        # 下采样倍数
        resolution_level = self.validate_resolution_level

    # [W, H, 3]
    rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
    H, W, _ = rays_o.shape

    # 按照batch_size切分,[W*H,3]=>tuple形式:W*H/batch_size个[batch_size, 3]
    rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
    rays_d = rays_d.reshape(-1, 3).split(self.batch_size)

    out_rgb_fine = []
    out_normal_fine = []

    for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
        # 最近点和最远点
        near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
        # 背景颜色
        background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
        render_out = self.renderer.render(rays_o_batch,
                                          rays_d_batch,
                                          near,
                                          far,
                                          cos_anneal_ratio=self.get_cos_anneal_ratio(),
                                          background_rgb=background_rgb)

        def feasible(key): return (key in render_out) and (render_out[key] is not None)

        # 前景颜色
        if feasible('color_fine'):
            out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
        # 梯度信息和采样点权重
        if feasible('gradients') and feasible('weights'):
            n_samples = self.renderer.n_samples + self.renderer.n_importance
            # 梯度信息权重加成
            normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]  # [batch_size,n_samples,3]
            # 采样点是否在球体内
            if feasible('inside_sphere'):
                # 只保留采样点在球体内的部分
                normals = normals * render_out['inside_sphere'][..., None]  # [batch_size,n_samples,3]
            # normals是带有权重的有效梯度信息
            normals = normals.sum(dim=1).detach().cpu().numpy()     # [batch_size,3]
            out_normal_fine.append(normals)
        del render_out

gen_rays_at

Dataset数据管理器的定义的函数,在models/dataset.py文件下。博主【NeuS总览】的博文中,已经简单介绍过这个过程。

def gen_rays_at(self, img_idx, resolution_level=1):
    """
    Generate rays at world space from one camera.
    一个摄影机在世界空间中生成光线
    """
    # 下采样倍数
    l = resolution_level
    # 获取2D图像上所有的像素点(下采样后的)
    tx = torch.linspace(0, self.W - 1, self.W // l)
    ty = torch.linspace(0, self.H - 1, self.H // l)

    # 生成网格用于生成坐标
    pixels_x, pixels_y = torch.meshgrid(tx, ty)     # [W, H]

    # 相机坐标系下的方向向量:内参(逆)×像素坐标系
    p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1)    # [W, H, 3]
    p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze()  # [W, H, 3]

    # 单位方向向量:对方向向量做归一化处理
    rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True)  # [W, H, 3]

    # 世界坐标系下的方向向量:外参(逆)×相机坐标系
    rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze()  # [W, H, 3]
    # 世界坐标系下的光心位置(外参的逆对应的平移矩阵t)
    rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape)  # [W, H, 3]
    return rays_o.transpose(0, 1), rays_v.transpose(0, 1)       # [H, W, 3]

代码的执行示意图如下图所示,函数返回了rays_o(光心)和rays_v(单位方向向量)。

注意区分训练过程和验证过程生成光线rays的不同,训练过程中是随机选取batch_size个像素点从而生成穿过这些像素点的光线rays,而验证过程是需要选取整个图片的所有像素点从而生成穿过整个图片像素点的光线rays。


validate_mesh

阶段性的完成NeuS模型训练后,同样需要三维重建出实物模型从而验证模型训练的效果。
首先需要划定重建的空间范围,然后通过绘制算法获取顶点坐标和面索引,最后输出实际的三维模型文件。

def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
    # 获取提取域(方体)的对角线顶点
    bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
    bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)

    # 面绘制算法获取vertices顶点坐标和triangles面索引
    vertices, triangles =\
        self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
    os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)

    if world_space:
        # 再次缩放位移
        vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]

    # 表示和操作三角网格模型
    mesh = trimesh.Trimesh(vertices, triangles)
    # 保存mesh模型
    mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))
    logging.info('End')

下图展示的是bound_min 和bound_max划定了三维重建范围。

这里提醒一下,三维重建的范围和渲染成二维图片的范围是不一样的,都是各自有各自的设定,别搞混了。


extract_geometry

都在models/renderer.py文件下,这里源码作者做了个套娃,前一个extract_geometry是属于NeuSRenderer类的类成员方法,后一个是独立的函数。

def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
    return extract_geometry(bound_min,
                            bound_max,
                            resolution=resolution,
                            threshold=threshold,
                            query_func=lambda pts: -self.sdf_network.sdf(pts))

marching_cubes面绘制算法参考,extract_fields是为了获得三维重建范围每个点的sdf值。

def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
    print('threshold: {}'.format(threshold))
    # 获取提取域多的sdf
    u = extract_fields(bound_min, bound_max, resolution, query_func)
    # 面绘制算法
    # vertices 顶点坐标[N,3] N是根据具有情况而通过算法得出,与其他无关
    # triangles 面索引[M,3] 索引指向顶点坐标数组中的对应顶点,3个顶点一个面
    vertices, triangles = mcubes.marching_cubes(u, threshold)

    # 提取域的对角顶点
    b_max_np = bound_max.detach().cpu().numpy()     # [3]
    b_min_np = bound_min.detach().cpu().numpy()     # [3]

    # 缩小位移
    vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
    return vertices, triangles

extract_fields

该函数的作用是在三维重建范围内获取到合适的提取点(体素),并为每个提取点(体素)的计算出对应的sdf值。

def extract_fields(bound_min, bound_max, resolution, query_func):
    N = 64
    # 根据提取域(方体)的对角顶点,获取提取域在各xyz轴的范围(max-min)和单位刻度((max-min)/resolution)
    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)

    # 初始化对应方体的sdf值
    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
    with torch.no_grad():
        for xi, xs in enumerate(X):
            for yi, ys in enumerate(Y):
                for zi, zs in enumerate(Z):
                    # 网格化
                    xx, yy, zz = torch.meshgrid(xs, ys, zs)     # [N,N,N]
                    # [N^3,3]
                    pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
                    # 找到对应点的sdf
                    val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
                    # 为方体正确的赋sdf值
                    u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
    return u

代码的执行示意图如下图所示,橙色方块就是提取点(体素),可以根据划分要求更细致的划分出更小的提取点(体素)。


总结

尽可能简单、详细的介绍NeuS测试阶段部分代码:validate_image渲染图片和validate_mesh重建模型的过程。后续会讲解测试阶段的剩余代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/911498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前馈神经网络解密:深入理解人工智能的基石

目录 一、前馈神经网络概述什么是前馈神经网络前馈神经网络的工作原理应用场景及优缺点 二、前馈神经网络的基本结构输入层、隐藏层和输出层激活函数的选择与作用网络权重和偏置 三、前馈神经网络的训练方法损失函数与优化算法反向传播算法详解避免过拟合的策略 四、使用Python…

SVM详解

公式太多了&#xff0c;就用图片用笔记呈现&#xff0c;SVM虽然算法本质一目了然&#xff0c;但其中用到的数学推导还是挺多的&#xff0c;其中拉格朗日约束关于α>0这块证明我看了很长时间&#xff0c;到底是因为悟性不够。对偶问题也是&#xff0c;用了一个简单的例子才明…

超级计算机

超级计算机是一种高性能计算机&#xff0c;它能够以极高的速度执行大规模的计算任务。超级计算机通常由数千个甚至数百万个处理器组成&#xff0c;这些处理器能够同时处理大量的数据&#xff0c;从而实现高效的计算。超级计算机广泛应用于科学、工程、金融、天气预报等领域&…

shell的变量

一、什么是变量 二、变量的命名 三、查看变量的值 env显示全局变量&#xff0c;刚刚定义的root_mess是局部变量 四、变量的定义 旧版本&#xff08;7、8四个文件都加载&#xff09;和新版本&#xff08;9只加载两个etc&#xff09;不一样&#xff0c;所以su - 现在要永久生效在…

IIS缺少MIME类型配置导致404错误

最近程序准备发布到测试服务器供用户端测试&#xff0c;公司提供的测试环境竟是Windows Server 2008 R2 Enterprise IIS 7.5 由于程序引用了fontawesome字体库&#xff0c;发布后浏览发现库中引用的woff、woff2均返回404错误&#xff1a; 前期本地测试以及本地IIS 10并无出现该…

Java“牵手”拼多多产品详情接口API-产品SKU,价格,优惠券,图文介绍,拼多多API接口实现海量商品采集

拼多多是中国的一家电子商务平台&#xff0c;以团购模式为主&#xff0c;成立于2015年。拼多多的宝贝详情页是指在商品页面上展示商品信息和图片的区域&#xff0c;是用户了解和购买商品的重要窗口。下面就让我们来全面解析拼多多宝贝详情页&#xff0c;帮助你更好地了解商品信…

2023-08-21 LeetCode每日一题(移动片段得到字符串)

2023-08-21每日一题 一、题目编号 2337. 移动片段得到字符串二、题目链接 点击跳转到题目位置 三、题目描述 给你两个字符串 start 和 target &#xff0c;长度均为 n 。每个字符串 仅 由字符 ‘L’、‘R’ 和 ‘_’ 组成&#xff0c;其中&#xff1a; 字符 ‘L’ 和 ‘R…

电脑运行缓慢?4个方法,加速电脑运行!

“我电脑才用了没多久哎&#xff01;怎么突然就变得运行很缓慢了呢&#xff1f;有什么方法可以加速电脑运行速度吗&#xff1f;真的很需要&#xff0c;看看我吧&#xff01;” 电脑的运行速度快会让用户在使用电脑时感觉愉悦&#xff0c;而电脑运行缓慢可能会影响我们的工作效率…

ASP.NET实验室信息管理系统源码 LIMS成品源码

实验室信息管理系统&#xff08;Laboratory Information Management System&#xff09;简称LIMS系统&#xff0c;是指通过计算机对实验室的各种信息进行管理的计算机软、硬件系统&#xff0c;并将实验室的设备各种信息通过计算机网络连接起来&#xff0c;采用科学的管理思想和…

Python爬取斗罗大陆全集

打开网址http://www.luoxu.cc/dmplay/C888H-1-265.html F12打开Fetch/XHR&#xff0c;看到m3u8&#xff0c;ts&#xff0c;一眼顶真&#xff0c;打开index.m3u8 由第一个包含第二个index.m3u8的地址&#xff0c;ctrlf在源代码中一查index&#xff0c;果然有&#xff0c;不过/…

Hadoop集群搭建(hadoop-3.3.5)

一、修改服务器配置文件 1、配置环境变量 vim /etc/profile #java环境变量 export JAVA_HOME/usr/local/jdk/jdk8 export JRE_HOME$JAVA_HOME/jre export CLASSPATH$JAVA_HOME/lib:$JRE_HOME/lib:$CLASSPATH export PATH$JAVA_HOME/bin:$JRE_HOME/bin:$PATH #hadoop环境变量 …

ssm实验中心管理系统的设计与实现

ssm实验中心管理系统的设计与实现040 开发工具&#xff1a;idea 数据库mysql5.7 数据库链接工具&#xff1a;navcat,小海豚等 技术&#xff1a;ssm 研究目的与意义&#xff1a; 随着高校硬件水平的提高和教学改革的深入&#xff0c;实验教学所占的地位越来越重要&#x…

【PHP】PHP常见语法

文章目录 PHP简介前置知识了解静态网站的特点动态网站特点 PHP基础语法代码标记注释语句分隔(结束)符变量变量的基本概念变量的使用变量命名规则预定义变量可变变量变量传值内存分区 常量基本概念常量定义形式命名规则使用形式系统常量魔术常量 数据类型简单&#xff08;基本&a…

Qt应用开发(基础篇)——富文本浏览器 QTextBrowser

一、前言 QTextBrowser类继承于QTextEdit&#xff0c;是一个具有超文本导航的富文本浏览器。 框架类 QFramehttps://blog.csdn.net/u014491932/article/details/132188655 滚屏区域基类 QAbstractScrollAreahttps://blog.csdn.net/u014491932/article/details/132245486 文…

智慧化工地SaaS平台源码,PC端+APP端+智慧数据可视化大屏端,源码完全开源不封装,自主研发,支持二开,项目使用,微服务+Java++vue+mysql

智慧工地管理平台充分运用数字化技术&#xff0c;聚焦施工现场岗位一线&#xff0c;依托物联网、互联网、AI等技术&#xff0c;围绕施工现场管理的人、机、料、法、环五大维度&#xff0c;以及施工过程管理的进度、质量、安全三大体系为基础应用&#xff0c;实现全面高效的工程…

js使用for of遍历map

//使用for of遍历map console.log("---") console.log(odata.studentDetails) let obj odata.studentDetails[0].answerSituation for(let [key,value] of Object.entries(obj)){console.log(value) }

vscode远程调试

安装ssh 在vscode扩展插件搜索remote-ssh安装 如果连接失败&#xff0c;出现 Resolver error: Error: XHR failedscode 报错&#xff0c;可以看这篇帖子vscode ssh: Resolver error: Error: XHR failedscode错误_阿伟跑呀的博客-CSDN博客 添加好后点击左上角的加号&#xff0…

【HCIP】12.BGP基础

AS之间传递路由&#xff08;不产生路由&#xff0c;只传递路由&#xff09;BGP属于应用层&#xff0c;采用TLV价格。AS号&#xff0c;16bit与32bit。运行BGP的路由器成为BGP发言者&#xff0c;或者BGP路由器 概述 采用目的端口179&#xff0c;触发式更新能承载大量路由信息13…

美创科技荣获“2023年网络安全优秀创新成果大赛—杭州分站赛”两项优胜奖

近日&#xff0c;由浙江省互联网信息办公室指导、中国网络安全产业联盟&#xff08;CCIA&#xff09;主办&#xff0c;浙江省网络空间安全协会承办的“2023年网络安全优秀创新成果大赛-杭州分站赛”正式公布评选结果。 经专家评审&#xff0c;美创科技报名参赛的解决方案—“医…

万界星空科技/免费MES系统/免费质量检测系统

质量管理也是万界星空科技免费MES中的一个重要组成部分&#xff0c;旨在帮助制造企业实现全面的质量管理。该系统涵盖了供应商来料、生产过程、质量检验、数据分析等各个环节&#xff0c;为企业提供了一站式的质量管理解决方案。 1. 实时质量监控 质量管理能够实时监控生产过程…