3D Gaussian Splatting代码中的train和render两个文件代码解读

news2025/1/11 2:50:26

现在来聊一聊训练和渲染是如何进行的

training

train.py
line 31
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
    # 初始化第一次迭代的索引为0
    first_iter = 0
    
    # 准备输出和日志记录器
    tb_writer = prepare_output_and_logger(dataset)
    
    # 初始化高斯模型,参数为数据集的球谐函数(SH)级别
    gaussians = GaussianModel(dataset.sh_degree)
    
    # 创建场景对象,包含数据集和高斯模型
    scene = Scene(dataset, gaussians)
    
    # 设置高斯模型的训练配置
    gaussians.training_setup(opt)
    
    # 加载检查点(如果有),恢复模型参数和设置起始迭代次数
    if checkpoint:
        (model_params, first_iter) = torch.load(checkpoint)
        gaussians.restore(model_params, opt)

    # 设置背景颜色,如果数据集背景为白色,则设置为白色([1, 1, 1]),否则为黑色([0, 0, 0])
    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    
    # 将背景颜色转换为CUDA张量,以便在GPU上使用
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    # 创建两个CUDA事件,用于记录迭代开始和结束的时间
    iter_start = torch.cuda.Event(enable_timing=True)
    iter_end = torch.cuda.Event(enable_timing=True)

    # 初始化视点堆栈为空
    viewpoint_stack = None
    
    # 用于记录指数移动平均损失的变量,初始值为0.0
    ema_loss_for_log = 0.0
    
    # 创建进度条,用于显示训练进度,从起始迭代数到总迭代数
    progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
    
    # 增加起始迭代数,以便从下一次迭代开始
    first_iter += 1

    for iteration in range(first_iter, opt.iterations + 1):
        # 尝试连接网络GUI,如果当前没有连接
        if network_gui.conn == None:
            network_gui.try_connect()

        # 如果已经连接网络GUI,处理接收和发送数据
        while network_gui.conn != None:
            try:
                # 初始化网络图像字节为None
                net_image_bytes = None

                # 从网络GUI接收数据
                custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()

                # 如果接收到自定义相机数据,则进行渲染
                if custom_cam != None:
                    # 使用自定义相机数据、当前的高斯模型、管道和背景颜色进行渲染
                    net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]

                    # 将渲染结果转为字节格式,并转换为内存视图
                    net_image_bytes = memoryview(
                        (torch.clamp(net_image, min=0, max=1.0) * 255).byte()
                        .permute(1, 2, 0).contiguous().cpu().numpy()
                    )

                # 发送渲染结果到网络GUI,并附带数据集的源路径
                network_gui.send(net_image_bytes, dataset.source_path)

                # 如果需要进行训练,并且当前迭代次数小于总迭代次数,或不需要保持连接,则退出循环
                if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
                    break

            except Exception as e:
                # 如果出现异常,断开网络连接
                network_gui.conn = None

        # 记录当前迭代的开始时间,用于计算每次迭代的持续时间
        iter_start.record()


        # 更新学习率
        gaussians.update_learning_rate(iteration)

        # 每1000次迭代增加一次SH级别,直到达到最大度
        if iteration % 1000 == 0:
            gaussians.oneupSHdegree()

        # 随机选择一个相机视角
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        # 从相机视角堆栈中随机弹出一个相机视角
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))

        # 渲染
        if (iteration - 1) == debug_from:
            pipe.debug = True

        # 如果设置了随机背景颜色,则生成一个随机背景颜色,否则使用预定义的背景颜色
        bg = torch.rand((3), device="cuda") if opt.random_background else background

        # 使用选定的相机视角、高斯模型、渲染管道和背景颜色进行渲染
        render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
        # 提取渲染结果、视点空间点张量、可见性过滤器和半径
        image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

        # 计算损失
        gt_image = viewpoint_cam.original_image.cuda()  # 获取地面真实图像
        Ll1 = l1_loss(image, gt_image)  # 计算L1损失
        # 计算总损失,结合L1损失和结构相似性损失(SSIM)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
        loss.backward()  # 反向传播计算梯度

        # 记录当前迭代的结束时间,用于计算每次迭代的持续时间
        iter_end.record()

        # 在不需要计算梯度的上下文中进行操作
        with torch.no_grad():
            # 更新进度条和日志
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log  # 更新指数移动平均损失
            if iteration % 10 == 0:
                # 每10次迭代更新一次进度条
                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()

            # 记录训练报告并保存
            training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
            if iteration in saving_iterations:
                # 在指定的迭代次数保存高斯模型
                print("\n[ITER {}] Saving Gaussians".format(iteration))
                scene.save(iteration)

            # 密集化操作
            if iteration < opt.densify_until_iter:
                # 跟踪图像空间中的最大半径,用于修剪
                gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

                # 在指定的迭代范围和间隔内进行密集化和修剪
                if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
                    size_threshold = 20 if iteration > opt.opacity_reset_interval else None
                    gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

                # 在指定的间隔内或满足特定条件时重置不透明度
                if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
                    gaussians.reset_opacity()

            # 优化器步骤
            if iteration < opt.iterations:
                gaussians.optimizer.step()  # 更新模型参数
                gaussians.optimizer.zero_grad(set_to_none=True)  # 清空梯度

            # 保存检查点
            if iteration in checkpoint_iterations:
                print("\n[ITER {}] Saving Checkpoint".format(iteration))
                torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")

render

现在是渲染的这个文件进行方式,首先是主文件里单张图片的渲染和整个数据集的渲染方法:

render.py
line 24
# 渲染一组视角并保存渲染结果和对应的真实图像
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
    # 定义渲染结果和真实图像的保存路径
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")

    # 创建保存路径,如果路径不存在
    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)

    # 遍历每个视角进行渲染
    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
        # 渲染图像
        rendering = render(view, gaussians, pipeline, background)["render"]
        # 获取对应的真实图像
        gt = view.original_image[0:3, :, :]
        # 保存渲染结果和真实图像到指定路径
        torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
        torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))

# 渲染训练集和测试集的图像,并保存结果
def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool):
    with torch.no_grad():
        # 初始化高斯模型和场景
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        # 设置背景颜色
        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        # 如果不跳过训练集渲染,则渲染训练集的图像
        if not skip_train:
            render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)

        # 如果不跳过测试集渲染,则渲染测试集的图像
        if not skip_test:
            render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)

但是这两个方法都是外层函数,并没有展示渲染如何进行参数传递和具体操作,在以下代码中才是最关键的内容:

gaussian_renderer\__init__.py
line 18
def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None):
    """
    渲染场景。
    
    参数:
    viewpoint_camera - 摄像机视角
    pc - 高斯模型
    pipe - 管道参数
    bg_color - 背景颜色张量,必须在GPU上
    scaling_modifier - 缩放修饰符,默认为1.0
    override_color - 覆盖颜色,默认为None
    """
 
    # 创建一个全零张量,用于使PyTorch返回2D(屏幕空间)均值的梯度
    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()  # 保留梯度信息
    except:
        pass

    # 设置光栅化配置
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)  # 计算视角的X轴正切
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)  # 计算视角的Y轴正切

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),  # 图像高度
        image_width=int(viewpoint_camera.image_width),  # 图像宽度
        tanfovx=tanfovx,  # 视角X轴正切
        tanfovy=tanfovy,  # 视角Y轴正切
        bg=bg_color,  # 背景颜色
        scale_modifier=scaling_modifier,  # 缩放修饰符
        viewmatrix=viewpoint_camera.world_view_transform,  # 世界视图变换矩阵
        projmatrix=viewpoint_camera.full_proj_transform,  # 投影变换矩阵
        sh_degree=pc.active_sh_degree,  # 球谐函数度数
        campos=viewpoint_camera.camera_center,  # 摄像机中心
        prefiltered=False,  # 预过滤
        debug=pipe.debug  # 调试模式
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)  # 初始化光栅化器

    means3D = pc.get_xyz  # 获取3D均值
    means2D = screenspace_points  # 获取2D均值
    opacity = pc.get_opacity  # 获取不透明度

    # 如果提供了预计算的3D协方差,则使用它。如果没有,则从光栅化器的缩放/旋转中计算。
    scales = None
    rotations = None
    cov3D_precomp = None
    if pipe.compute_cov3D_python:
        cov3D_precomp = pc.get_covariance(scaling_modifier)  # 计算3D协方差
    else:
        scales = pc.get_scaling  # 获取缩放
        rotations = pc.get_rotation  # 获取旋转

    # 如果提供了预计算的颜色,则使用它们。否则,如果需要在Python中预计算SH到颜色的转换,则进行转换。
    shs = None
    colors_precomp = None
    if override_color is None:
        if pipe.convert_SHs_python:
            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2)
            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)  # 计算颜色
        else:
            shs = pc.get_features  # 获取球谐函数特征
    else:
        colors_precomp = override_color  # 覆盖颜色

    # 将可见的高斯体光栅化为图像,并获取它们在屏幕上的半径。
    rendered_image, radii = rasterizer(
        means3D=means3D,
        means2D=means2D,
        shs=shs,
        colors_precomp=colors_precomp,
        opacities=opacity,
        scales=scales,
        rotations=rotations,
        cov3D_precomp=cov3D_precomp)

    # 那些被视锥剔除或半径为0的高斯体是不可见的。
    # 它们将被排除在用于分裂标准的值更新之外。
    return {
        "render": rendered_image,  # 渲染图像
        "viewspace_points": screenspace_points,  # 视图空间点
        "visibility_filter": radii > 0,  # 可见性过滤器
        "radii": radii  # 半径
    }

最值得关注的光栅化器,如果转到定义去查看,其实会发现它就是第二期里讲forward的代码,只是这里面用python写了变量的调用,实际的操作方式还是在cu文件里面。所以在此就不多做赘述,可以看上一期博客里面对forward的解读。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1890503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0703_ARM7

练习&#xff1a; 封装exti&#xff0c;cic初始化函数 //EXTI初始化 void hal_key_exti_init(int id,int exticr,int mode){//获取偏移地址int address_offset (id%4)*8;//获取寄存器编号int re_ser (id/4)1;//printf("address_offset%d,re_ser%d\n",address_o…

苹果手机怎么刷机?适合小白的刷机办法!

自己的苹果手机用时间长了&#xff0c;有些人想要为自己的手机重新刷新一下&#xff0c;但又不知道怎么刷机。不要慌现在就来给大家详细介绍一下苹果手机怎么刷机&#xff0c;希望可以帮助到大家。 iPhone常见的刷机方式&#xff0c;分为iTunes官方和第三方软件两种刷机方式。 …

基于Web技术的教育辅助系统设计与实现(SpringBoot MySQL)+文档

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

强行仅用time.localtime制作“日历牌”——全程记录“顶牛”“调戏”我的AI学习搭子

强行只用time.localtime制作“日历牌”&#xff0c;码好代码试炼通过&#xff0c;想榨取ai智能优化算法&#xff0c;结果失败。本文详细记录“顶牛”全过程。 (笔记模板由python脚本于2024年07月01日 19:16:26创建&#xff0c;本篇笔记适合喜欢python&#xff0c;喜欢搞“事儿”…

p2p、分布式,区块链笔记: 通过libp2p的Kademlia网络协议实现kv-store

Kademlia 网络协议 Kademlia 是一种分布式哈希表协议和算法&#xff0c;用于构建去中心化的对等网络&#xff0c;核心思想是通过分布式的网络结构来实现高效的数据查找和存储。在这个学习项目里&#xff0c;Kademlia 作为 libp2p 中的 NetworkBehaviour的组成。 以下这些函数或…

controller不同的后端路径对应vue前端传递数据发送请求的方式,vue请求参数 param 与data 如何对应后端参数

目录 案例一&#xff1a; 为什么使用post发送请求&#xff0c;参数依旧会被拼接带url上呢&#xff1f;这应该就是param 与data传参的区别。即param传参数参数会被拼接到url后&#xff0c;data会以请求体传递 补充&#xff1a;后端controller 参数上如果没写任何注解&#xff0c…

Redis中hash类型的操作命令(命令的语法、返回值、时间复杂度、注意事项、操作演示)

文章目录 字符串和哈希类型相比hset 命令hget 命令hexistshdelhkeyshvalshgetallhmgethlenhsetnxhincrbyhincrbyfloat 字符串和哈希类型相比 假设有以下一种场景&#xff1a;现在要在 Redis 中存储一个用户的基本信息(id1、namezhangsan、age17)&#xff0c;下图表示使用字符串…

Vue3轻松创建交互式仪表盘

本文由ScriptEcho平台提供技术支持 项目地址&#xff1a;传送门 基于 Plotly.js 的 Vue 仪表盘组件 应用场景介绍 仪表盘是一种交互式可视化工具&#xff0c;用于监控和分析关键指标。它广泛应用于各种行业&#xff0c;例如金融、医疗保健和制造业。 代码基本功能介绍 本…

Linux源码阅读笔记12-RCU案例分析

在之前的文章中我们已经了解了RCU机制的原理和Linux的内核源码&#xff0c;这里我们要根据RCU机制写一个demo来展示他应该如何使用。 RCU机制的原理 RCU&#xff08;全称为Read-Copy-Update&#xff09;,它记录所有指向共享数据的指针的使用者&#xff0c;当要修改构想数据时&…

搭建论坛和mysql数据库安装和php安装

目录 概念 步骤 安装mysql8.0.30 安装php 安装Discuz 概念 搭建论坛的架构&#xff1a; lnmpDISCUZ l 表示linux操作系统 n 表示nginx前端页面的web服务 m 表示 mysql 数据库 用来保存用户和密码以及论坛的相关内容 p 表示php 动态请求转发的中间件 步骤 &#xff…

基于Cardinal的AWD攻防平台搭建与使用以及基于docker的题目环境部署

关于 CTF 靶场的搭建与完善勇师傅前面已经总结过了&#xff0c;参考&#xff1a; CTF靶场搭建及Web赛题制作与终端docker环境部署_ctfoj搭建-CSDN博客 基于H1ve一分钟搭好CTF靶场-CSDN博客 Nginx首页修改及使用Nginx实现端口转发_nginx 修改欢迎首页-CSDN博客 关于H1ve导…

《IT 领域准新生暑期预习指南:开启未来科技之旅》

IT专业入门&#xff0c;高考假期预习指南 高考的落幕&#xff0c;只是人生长途中的一个逗号&#xff0c;对于心怀 IT 梦想的少年们&#xff0c;新的征程已然在脚下铺展。这个七月&#xff0c;当分数尘埃落定&#xff0c;你们即将迈向新的知识殿堂&#xff0c;而这个假期&#…

235、二叉搜索树的最近公共祖先

给定一个二叉搜索树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个结点 p、q&#xff0c;最近公共祖先表示为一个结点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以是它自…

代码随想录第42天|动态规划

198.打家劫舍 参考 dp[j] 表示偷盗的总金额, j 表示前 j 间房(包括j)的总偷盗金额初始化: dp[0] 一定要偷, dp[1] 则取房间0,1的最大值遍历顺序: 从小到大 class Solution { public:int rob(vector<int>& nums) {if (nums.size() < 2) {return nums[0];}vector&…

Docker安装PostgreSQL详细教程

本章教程,使用Docker安装PostgreSQL具体步骤。 一、拉取镜像 docker pull postgres二、启动容器 docker run -it --name postgres --restart always -e POSTGRES_PASSWORD=123456 -e

VideoPrism——探索视频分析领域模型的算法与应用

概述 论文地址:https://arxiv.org/pdf/2402.13217.pdf 视频是我们观察世界的生动窗口&#xff0c;记录了从日常瞬间到科学探索的各种体验。在这个数字时代&#xff0c;视频基础模型&#xff08;ViFM&#xff09;有可能分析如此海量的信息并提取新的见解。迄今为止&#xff0c;…

全国数学建模大赛(一)

全国数学建模大赛 &#x1f388;1.数学模型是什么&#xff1f;&#x1f52d;1.1原型与模型&#x1f52d;1.2模型的分类&#x1f52d;1.3数学模型的分类&#x1f52d;1.4数学模型的全过程&#x1f52d;1.5论文写作基本流程&#x1f52d;1.6数学建模的六个步骤&#x1f52d;1.7小…

【SpringBoot配置文件读取】无法读取yaml文件中文字符

1. yaml配置文件 注意要将该文件编码格式改为UTF-8 spring:application:name: 好好学习admin:name: 李斯age: 24books:- name: 数据结构desc: 数据书- name: 编译原理desc: 编译书2.配置实体类 Data设置get&#xff0c;set方法Component注册为BeanConfigurationProperties(p…

第6章:结构化开发方法

第6章&#xff1a;结构化开发方法 系统设计基本原理 1、抽象 抽象是一种设计技术&#xff0c;重点说明一个实体的本质方面&#xff0c;而忽略或者掩盖不是很重要或非本质的方面。 模块化 模块化是指将一个待开发的软件分解成若干个小的、简单的部分一模块&#xff0c;每个模…