基于可微分渲染器的相机位置优化【PyTorch3D】

news2024/11/23 19:03:10

在这个教程中,我们将使用可微渲染学习给定参考图像的相机的 [x, y, z] 位置。

我们将首先使用相机的起始位置初始化渲染器。 然后,我们将使用它来生成图像,使用参考图像计算损失,最后通过整个管道进行反向传播以更新相机的位置。

NSDT在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器

本教程展示如何:

  • 从 .obj 文件加载网格
  • 初始化相机、着色器和渲染器,
  • 渲染网格
  • 使用损失函数和优化器设置优化循环

首先确保已安装torch和torchvision,并使用以下代码安装pytorch3d:

import os
import sys
import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("2.1.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

导入模块:

import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte

# io utils
from pytorch3d.io import load_obj

# datastructures
from pytorch3d.structures import Meshes

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)

1、加载obj模型

我们将加载一个 obj 文件并创建一个 Meshes 对象。 网格是 PyTorch3D 中提供的独特数据结构,用于处理批量不同大小的网格。 它有几个在渲染管道中使用的有用的类方法:

# Set the cuda device 
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
teapot_mesh = Meshes(
    verts=[verts.to(device)],   
    faces=[faces.to(device)], 
    textures=textures
)

如果在克隆 PyTorch3D 存储库后在本地运行此笔记本,则网格将已经可用。 如果使用 Google Colab,请获取网格并将其保存在路径 data/ 中:

!mkdir -p data
!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj

2、优化设置

PyTorch3D 中的渲染器由光栅器和着色器组成,每个组件都有许多子组件,例如相机(正交/透视)。 在这里,我们初始化其中一些组件,并对其余组件使用默认值。

2.1 创建渲染器

为了优化相机位置,我们将使用渲染器,它仅生成对象的轮廓,而不应用任何照明或阴影。 我们还将初始化另一个应用完整 Phong 着色的渲染器,并使用它来可视化输出。

# Initialize a perspective camera.
cameras = FoVPerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)

2.2 创建参考图像

我们将首先定位茶壶并生成图像。 我们使用辅助函数将茶壶旋转到所需的视角。 然后我们可以使用渲染器来生成图像。 在这里,我们将使用两个渲染器并可视化轮廓和全着色图像。

世界坐标系定义为+Y向上、+X向左和+Z向内。世界坐标中的茶壶的壶嘴指向左侧。

我们定义了一个位于 z 轴正方向上的相机,因此可以看到右侧的喷口。

# Select the viewpoint using spherical angles  
distance = 3   # distance from camera to the object
elevation = 50.0   # angle of elevation in degrees
azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. 

# Get the position of the camera based on the spherical angles
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)

# Render the teapot providing the values of R and T. 
silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)

silhouette = silhouette.cpu().numpy()
image_ref = image_ref.cpu().numpy()

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(silhouette.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image
plt.grid(False)
plt.subplot(1, 2, 2)
plt.imshow(image_ref.squeeze())
plt.grid(False)

输出如下:

2.3 创建基础模型

这里我们创建一个简单的模型类并初始化相机位置的参数。

class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer
        
        # Get the silhouette of the reference RGB image by finding all non-white pixel values. 
        image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))
        self.register_buffer('image_ref', image_ref)
        
        # Create an optimizable parameter for the x, y, z position of the camera. 
        self.camera_position = nn.Parameter(
            torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))

    def forward(self):
        
        # Render the image using the updated camera position. Based on the new position of the 
        # camera we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        
        # Calculate the silhouette loss
        loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        return loss, image

3、初始化模型和优化器

现在我们可以创建上述模型的实例并为相机位置参数设置优化器。

# We will save images periodically and compose them into a GIF.
filename_output = "./teapot_optimization_demo.gif"
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)

# Initialize a model using the renderer, mesh and reference image
model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

可视化起始位置和参考位置:

plt.figure(figsize=(10, 10))

_, image_init = model()
plt.subplot(1, 2, 1)
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Starting position")

plt.subplot(1, 2, 2)
plt.imshow(model.image_ref.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette");

输出如下:

4、运行优化

我们运行前向和后向传递的多次迭代,并每 10 次迭代保存输出。

loop = tqdm(range(200))
for i in loop:
    optimizer.zero_grad()
    loss, _ = model()
    loss.backward()
    optimizer.step()
    
    loop.set_description('Optimizing (loss %.4f)' % loss.data)
    
    if loss.item() < 200:
        break
    
    # Save outputs to create a GIF. 
    if i % 10 == 0:
        R = look_at_rotation(model.camera_position[None, :], device=model.device)
        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = img_as_ubyte(image)
        writer.append_data(image)
        
        plt.figure()
        plt.imshow(image[..., :3])
        plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
        plt.axis("off")
    
writer.close()

迭代期间输出如下:

完成后可以查看 ./teapot_optimization_demo.gif,优化过程的炫酷 gif:

5、结束语

在本教程中,我们学习了如何从 obj 文件加载网格,初始化名为 Meshes 的 PyTorch3D 数据结构,设置由光栅化器和着色器组成的渲染器,设置包括模型和损失函数的优化循环,并运行优化。


原文链接:用可微渲染优化相机位置 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1261042.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Web3.0时代:区块链DAPP将如何颠覆传统模式

小编介绍&#xff1a;10年专注商业模式设计及软件开发&#xff0c;擅长企业生态商业模式&#xff0c;商业零售会员增长裂变模式策划、商业闭环模式设计及方案落地&#xff1b;扶持10余个电商平台做到营收过千万&#xff0c;数百个平台达到百万会员&#xff0c;欢迎咨询。 随着…

6.Spring源码解析-loadBeanDefinitions(String location)

这里resourceLoader其实就是ClassPathXmlApplicationContext 1.ClassPathXmlApplicationContext 在上文中图例就能看出来 获取资源组可能存在多个bean.xml 循环单独加载资源组 创建一个编码资源并解析 获取当前正在加载的资源发现是空 创建了一个字节输入流&#xff0c…

竞赛选题 题目:基于大数据的用户画像分析系统 数据分析 开题

文章目录 1 前言2 用户画像分析概述2.1 用户画像构建的相关技术2.2 标签体系2.3 标签优先级 3 实站 - 百货商场用户画像描述与价值分析3.1 数据格式3.2 数据预处理3.3 会员年龄构成3.4 订单占比 消费画像3.5 季度偏好画像3.6 会员用户画像与特征3.6.1 构建会员用户业务特征标签…

cpu飙高问题,案例分析(二)——批处理数据过大引起的应用服务CPU飙高

上接cpu飙高问题&#xff0c;案例分析&#xff08;一&#xff09; 一、批处理数据过大引起的应用服务CPU飙高 1.1 问题场景 某定时任务job 收到cpu连续&#xff08;配置的时间是180s&#xff09;使用超过90%的报警; 1.2 问题定位 观察报警中的jvm监控&#xff0c;发现周期…

LeetCode(33)最小覆盖子串【滑动窗口】【困难】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 76. 最小覆盖子串 1.题目 给你一个字符串 s 、一个字符串 t 。返回 s 中涵盖 t 所有字符的最小子串。如果 s 中不存在涵盖 t 所有字符的子串&#xff0c;则返回空字符串 "" 。 注意&#xff1a; 对于 t 中重复字…

python爬虫实习找工作练习测试(以下内容仅供参考学习)

要求&#xff1a;获取下图指定网站的指定数据 空气质量状况报告-中国环境监测总站 输入&#xff1a;用户输入下载时间范围&#xff0c;格式为2022-10 输出&#xff1a;将更新时间在2022年10月1日到31日之间的文件下载到本地目录&#xff08;可配置&#xff09;&#xff0c;并…

在Rust中编写自动化测试

1.摘要 Rust中的测试函数是用来验证非测试代码是否是按照期望的方式运行的, 测试函数体通常需要执行三种操作:1.设置任何所需的数据或状态;2.运行需要测试的代码;3.断言其结果是我们所期望的。本篇文章主要探讨了Rust自动化测试的几种常见场景。 2.测试函数详解 在Rust项目工…

图像去噪——k-Sigma变换,模拟增益,噪声方差

目录 一、k-Sigma变化k-Sigma变换定义式定义式参数解析 二、模拟增益三、噪声方差 一、k-Sigma变化 k-Sigma变换是一种用于图像去噪的方法&#xff0c;它的主要思想是通过一个特定的线性转换&#xff0c;将训练数据从ISO-dependent的域名转换到ISO-independent的域上。这个转换…

QtCreator创建的文件复制到VS下报错

报错1&#xff1a; 错误 C2447 “{”: 缺少函数标题(是否是老式的形式表?) (编译源文件 myselectpoint.cpp) DataTypeLib e:\qtnewproject\linuxversion\videointerpretationanddataprocesssystem_vs\jwycfsoftware\datatypelib\allstructdefine.h 117 解…

快速开发出一个公司网站

问题描述&#xff1a;参加一个创业活动&#xff0c;小组要求做一个公司网站&#xff0c;简单介绍一下自己公司的业务。需要快速完成。 问题解决&#xff1a;从网上找一个网站模板&#xff0c;类似于做PPT&#xff0c;搭建一个网站即可。 这里推荐的是京美建站、wordpress、he…

车规激光雷达再商用车前装市场的应用

1、商用车需要什么样的激光雷达 2、如何实现车规级&#xff08;商用车&#xff09;的激光雷达 3、激光雷达安装部署方案

PHP TCP服务端监听端口接收客户端RFID网络读卡器上传的读卡数据

本示例使用设备&#xff1a;WIFI/TCP/UDP/HTTP协议RFID液显网络读卡器可二次开发语音播报POE-淘宝网 (taobao.com) <?php header("content-type:text/html;charsetGBK");set_time_limit(0); $port39169; //监听端口if(($socket socket_create(AF_INET, SOCK…

pycharm 怎么切换Anaconda简单粗暴

&#xff08;1&#xff09;创建一个环境 &#xff08;2&#xff09;选择一下自己conda的安装路径中conba.exe (3)选择存在的环境&#xff0c;一般会自动检测到conda创建有哪些环境&#xff0c;导入就行

C++二分查找视频教程:两数之和

作者推荐 利用广度优先或模拟解决米诺骨牌 本文涉及的基础知识点 二分查找算法合集 题目 给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 n…

单片机学习11——矩阵键盘

矩阵键盘&#xff1a; 这个矩阵键盘可以接到P0、P1、P2、P3都是可以的。 使用矩阵键盘是能节省单片机的IO口。 P3.0 P3.1 P3.2 P3.3 称之为行号。 P3.4 P3.5 P3.6 P3.7 称之为列号。 矩阵键盘检测原理&#xff1a; 1、检查是否有键按下&#xff1b; 2、键的抖动处理&#xf…

XIAO ESP32S3之套件简绍

很高兴收到柴火创客空间寄来的XIAO ESP32S3开发套件。 一、套件介绍 1、电路板部分 一块XIAO ESP32S3主板、一块摄像头接口板&#xff08;可接SD卡&#xff09;&#xff0c;一根2.4G天线。 2、配件部分 一根USB-A转TypeC数据线、一个USB3.0转TypeC转接头、一个SD卡读卡器&am…

集动作捕捉与表情捕捉的系统,怎么用于动画制作?

对于传统动画制作来说&#xff0c;将要处理数字人的动作与表情&#xff0c;最原始的方式是打关键帧&#xff0c;通过关键帧的形式来展现数字人的弹跳、行走、奔跑等动作&#xff0c;但这种制作方式往往时间长&#xff0c;成本高&#xff0c;效率低。而一个集动作捕捉与表情捕捉…

spring boot的redis连接数过多导致redis服务器压力过大的一次问题排查

一、背景 在今天上午的时候&#xff0c;突然收到大量的sentry报错&#xff0c;都是关于redis连接超时的警告。 首先想到的是去查看redis的监控&#xff0c;发现那个时间段&#xff0c;redis的请求数剧增&#xff0c;cpu使用率和带宽都陡增双倍。 下面的是redis监控的cpu情况 …

【MySql】14- 实践篇(十二)-grant权限/分区表/自增Id用完怎么办

文章目录 1.grant之后要跟着flush privileges吗&#xff1f;1.1 全局权限1.2 db 权限1.3 表权限和列权限1.4 flush privileges 使用场景 2. 要不要使用分区表?2.1 分区表是什么?2.2 分区表的引擎层行为2.3 分区策略2.4 分区表的 server 层行为2.5 分区表的应用场景 3. 自增Id…

高效视频剪辑:按指定时长批量分割视频,释放无尽创意

随着数字媒体技术的不断发展&#xff0c;视频剪辑已经成为日常生活中不可或缺的一部分。无论是制作电影、电视剧&#xff0c;还是创意生活短视频&#xff0c;视频剪辑都扮演着重要的角色。然而&#xff0c;对于许多非专业人士来说&#xff0c;视频剪辑可能是一项复杂而耗时的任…