目录
背景:
(1)Ha_NeRF论文解读
(2)Ha_NeRF源码复现
(3)train_mask_grid_sample.py 运行
train_mask_grid_sample.py解读
1 NeRFSystem 模块
2 forward()详解
3 模型训练tranining_step()详解
4 模型验证validation_step()详解:
5 validation_epoch_end() 详解
6 main() 详解
背景:
(1)Ha_NeRF论文解读
NeRF系列(4):Ha-NeRF: Hallucinated Neural Radiance Fields in the Wild论文解读_LeapMay的博客-CSDN博客文章浏览阅读389次,点赞3次,收藏3次。提出了一个外观幻化模块,用于处理时间变化的外观并将其转移到新视角上。考虑到旅游图像中的复杂遮挡情况,我们引入了一个抗遮挡模块,用于准确地分解静态物体以获取清晰的可见性。https://blog.csdn.net/qq_35831906/article/details/131247784?spm=1001.2101.3001.6650.5&utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-5-131247784-blog-131334579.235%5Ev38%5Epc_relevant_yljh&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-5-131247784-blog-131334579.235%5Ev38%5Epc_relevant_yljh&utm_relevant_index=6
(2)Ha_NeRF源码复现
Ha-NeRF: Hallucinated Neural Radiance Fields in the Wild 代码复现与解读_LeapMay的博客-CSDN博客文章浏览阅读244次。code:本机环境: python 3.6.3,torch 1.8.1+cu102,pytorch-lightning 1.1.5。https://blog.csdn.net/qq_35831906/article/details/131334579
(3)train_mask_grid_sample.py 运行
python train_mask_grid_sample.py --root_dir ./data/IMC-PT/brandenburg_gate --dataset_name phototourism --save_dir save --img_downscale 2 --use_cache --N_importanc
e 64 --N_samples 64 --num_epochs 20 --batch_size 1024 --optimizer adam --lr 5e-4 --lr_scheduler cosine --exp_name exp_HaNeRF_Brandenburg_Gate --N_emb_xyz 15 --N_vocab 1500 --use_mask --maskrs_max 5e-2 --maskrs_min 6e-3 --maskrs_
k 1e-3 --maskrd 0 --encode_a --N_a 48 --weightKL 1e-5 --encode_random --weightRecA 1e-3 --weightMS 1e-6 --num_gpus 1
train_mask_grid_sample.py解读
1 NeRFSystem 模块
# 导入必要的库和模块
import torch
from models.nerf import NeRF # 假设 NeRF 在 models.nerf 模块中定义
from models.networks import E_attr, implicit_mask, PosEmbedding # 导入所需模块
# 定义一个名为 NeRFSystem 的 PyTorch Lightning 模块类
class NeRFSystem(LightningModule):
def __init__(self, hparams):
super().__init__() # 调用 LightningModule 构造函数
self.hparams = hparams # 存储模型的超参数
self.loss = loss_dict['hanerf'](hparams, coef=1) # 设置损失函数 'hanerf' 并指定系数为 1
self.models_to_train = [] # 初始化一个列表,用于存储需要训练的模型
self.embedding_xyz = PosEmbedding(hparams.N_emb_xyz - 1, hparams.N_emb_xyz) # 用于 XYZ 坐标的位置编码
self.embedding_dir = PosEmbedding(hparams.N_emb_dir - 1, hparams.N_emb_dir) # 用于方向的位置编码
self.embedding_uv = PosEmbedding(10 - 1, 10) # 用于 UV 坐标的位置编码
self.embeddings = {'xyz': self.embedding_xyz, 'dir': self.embedding_dir} # 将位置编码存储在字典中
# 如果需要属性编码
if hparams.encode_a:
self.enc_a = E_attr(3, hparams.N_a) # 使用指定维度创建属性编码器
self.models_to_train += [self.enc_a] # 将编码器添加到需要训练的模型列表中
self.embedding_a_list = [None] * hparams.N_vocab # 初始化属性编码列表
# 创建具有指定输入通道(XYZ 和方向)的粗糙 NeRF 模型
self.nerf_coarse = NeRF('coarse', in_channels_xyz=6 * hparams.N_emb_xyz + 3, in_channels_dir=6 * hparams.N_emb_dir + 3)
self.models = {'coarse': self.nerf_coarse} # 将粗糙 NeRF 模型存储在字典中
# 如果需要精细 NeRF 模型
if hparams.N_importance > 0:
# 创建具有指定输入通道(XYZ、方向和外观编码)的精细 NeRF 模型
self.nerf_fine = NeRF('fine', in_channels_xyz=6 * hparams.N_emb_xyz + 3, in_channels_dir=6 * hparams.N_emb_dir + 3,
encode_appearance=hparams.encode_a, in_channels_a=hparams.N_a,
encode_random=hparams.encode_random)
self.models['fine'] = self.nerf_fine # 将精细 NeRF 模型存储在模型字典中
self.models_to_train += [self.models] # 将模型添加到需要训练的列表中
# 如果需要使用遮罩
if hparams.use_mask:
self.implicit_mask = implicit_mask() # 初始化隐式遮罩模型
self.models_to_train += [self.implicit_mask] # 将隐式遮罩添加到需要训练的列表中
self.embedding_view = torch.nn.Embedding(hparams.N_vocab, 128) # 创建一个嵌入视图
self.models_to_train += [self.embedding_view] # 将嵌入视图添加到需要训练的列表中
NeRFSystem
类的初始化函数(__init__
):
- 首先调用
super().__init__()
来继承LightningModule
的初始化方法。self.hparams
用于存储模型的超参数。self.loss
是模型的损失函数,通过loss_dict
字典选择了名为 'hanerf' 的损失函数并初始化它。self.models_to_train
是一个模型列表,用于存储需要训练的模型组件。self.embedding_xyz
、self.embedding_dir
和self.embedding_uv
是位置嵌入(Positional Embedding)对象,用于编码不同类型的空间坐标。- 根据超参数的设定,如果
hparams.encode_a
为真,将创建属性编码器self.enc_a
,并将其加入self.models_to_train
列表中。- 通过
NeRF
类创建了粗糙(coarse)和精细(fine)NeRF模型,将这些模型添加到self.models
字典中,并将需要训练的模型也加入self.models_to_train
列表中。- 如果
hparams.use_mask
为真,将创建隐式遮罩(implicit_mask)模型和一个嵌入层(embedding layer),同样加入了self.models_to_train
列表中。
forward
方法和其他训练、验证相关的方法并未在这段代码中提供,这些方法一般用于执行前向传播,定义损失计算方式,指定优化器和学习率调度器,加载数据等等总体来说,这段代码创建了一个 PyTorch Lightning 模型,其中包含了多个 NeRF 相关的组件,根据指定的超参数和需求组织了不同的模型和模型组件,并将它们用于训练过程。这个类的功能主要是提供一个结构化的接口,以便构建和管理神经体积渲染模型,使得模型的训练和验证能够更加方便和易于管理。
2 forward()详解
def forward(self, rays, ts, whole_img, W, H, rgb_idx, uv_sample, test_blender):
results = defaultdict(list) # 使用 defaultdict 初始化结果存储
kwargs = {} # 初始化空字典 kwargs,用于存储关键字参数
# 如果需要对属性进行编码
if self.hparams.encode_a:
if test_blender:
# 如果是测试渲染器
kwargs['a_embedded_from_img'] = self.embedding_a_list[0] if self.embedding_a_list[0] is not None else self.enc_a(whole_img)
else:
# 否则,在图像数据上进行属性编码
kwargs['a_embedded_from_img'] = self.enc_a(whole_img)
# 如果需要编码随机属性(self.hparams.encode_random为True):
# - 获取非空属性编码列表的索引idexlist,其中k为索引,v为属性编码值
# - 若idexlist为空,意味着属性编码列表中没有非空值的属性编码
# - 将a_embedded_random设置为a_embedded_from_img,表示使用来自整个图像的属性编码
# - 否则,从属性编码列表中随机选择一个索引对应的属性编码,作为随机属性编码
if self.hparams.encode_random:
idexlist = [k for k, v in enumerate(self.embedding_a_list) if v is not None]
if len(idexlist) == 0:
kwargs['a_embedded_random'] = kwargs['a_embedded_from_img']
else:
# 随机选择一个非空属性编码,作为随机属性编码
random_index = random.choice(idexlist)
kwargs['a_embedded_random'] = self.embedding_a_list[random_index]
"""Do batched inference on rays using chunk."""
B = rays.shape[0] # 获取射线的批量大小
for i in range(0, B, self.hparams.chunk):
rendered_ray_chunks = render_rays(
self.models, # 使用预定义的模型
self.embeddings, # 使用预定义的嵌入
rays[i:i + self.hparams.chunk], # 批量处理的射线
ts[i:i + self.hparams.chunk], # 批量处理的时间点
self.hparams.N_samples, # 数值采样
self.hparams.use_disp, # 使用视差
self.hparams.perturb, # 扰动
self.hparams.noise_std, # 噪声标准差
self.hparams.N_importance, # 重要性数
self.hparams.chunk, # 有效的块大小
self.train_dataset.white_back, # 白色背景
**kwargs # 关键字参数
)
for k, v in rendered_ray_chunks.items():
results[k] += [v]
for k, v in results.items():
results[k] = torch.cat(v, 0) # 将结果连接起来
if self.hparams.use_mask:
if test_blender:
results['out_mask'] = torch.zeros(results['rgb_fine'].shape[0], 1).to(results['rgb_fine'])
else:
uv_embedded = self.embedding_uv(uv_sample)
results['out_mask'] = self.implicit_mask(torch.cat((self.embedding_view(ts), uv_embedded), dim=-1))
if self.hparams.encode_a:
results['a_embedded'] = kwargs['a_embedded_from_img'] # 存储属性编码结果
if self.hparams.encode_random:
results['a_embedded_random'] = kwargs['a_embedded_random'] # 存储随机属性编码结果
rec_img_random = results['rgb_fine_random'].view(1, H, W, 3).permute(0, 3, 1, 2) * 2 - 1
results['a_embedded_random_rec'] = self.enc_a(rec_img_random)
self.embedding_a_list[ts[0]] = kwargs['a_embedded_from_img'].clone().detach()
return results # 返回结果字典
在给定的
forward
方法中:
results = defaultdict(list)
- 创建一个defaultdict
,用于存储模型前向传播的结果,其结构是列表形式的字典。在后续的循环中,这将存储渲染的射线结果。通过
kwargs
存储关键字参数,这些参数将在render_rays
函数中使用。kwargs
会根据不同条件动态更改。对属性编码的处理:
if self.hparams.encode_a
:如果需要对属性进行编码。if test_blender
:根据条件是否是测试渲染器来确定是否使用整体图像(whole_img
)对属性进行编码。if self.hparams.encode_random
:如果需要对属性进行随机编码。
- 通过
idexlist
获取非空属性编码列表的索引,如果列表为空,则默认使用整体图像对属性进行编码。- 否则,随机选择一个非空属性编码,并存储为随机属性编码。
循环
for i in range(0, B, self.hparams.chunk)
:这里的代码进行了分块的射线渲染,根据射线的数量和块大小分块进行渲染。得到的结果存储在results
中。对结果进行整理:
for k, v in results.items()
循环结果字典,将分块的结果连接起来。- 如果需要使用遮罩(
use_mask
):
- 通过
if test_blender
和else
,确定是否使用输出遮罩。遮罩将根据不同的条件生成不同的值。如果需要对属性进行编码:
results['a_embedded']
和results['a_embedded_random']
存储属性编码的结果。results['a_embedded_random_rec']
存储随机属性编码的结果。self.embedding_a_list[ts[0]]
更新属性编码列表中的对应索引,将其设置为来自整体图像的属性编码的克隆值。返回结果
results
,这些结果包括射线渲染的数据和属性编码的结果。这段代码实现了射线渲染过程中对属性进行编码的功能,并存储了相关结果。
3 模型训练tranining_step()详解
def training_step(self, batch, batch_nb):
# 从批处理中提取数据
rays, ts = batch['rays'].squeeze(), batch['ts'].squeeze() # 提取射线和时间点
rgbs = batch['rgbs'].squeeze() # 提取 RGB 值
uv_sample = batch['uv_sample'].squeeze() # 提取 UV 样本
# 检查是否需要编码属性或使用掩膜
if self.hparams.encode_a or self.hparams.use_mask:
whole_img = batch['whole_img'] # 提取整个图像
rgb_idx = batch['rgb_idx'] # 提取 RGB 索引
else:
whole_img = None
rgb_idx = None
# 从 RGB 值的平方根计算高度和宽度
H = int(sqrt(rgbs.size(0)))
W = int(sqrt(rgbs.size(0)))
test_blender = False # 设置 test_blender 标志
# 执行前向传递以生成预测和损失
results = self(rays, ts, whole_img, W, H, rgb_idx, uv_sample, test_blender)
loss_d, AnnealingWeight = self.loss(results, rgbs, self.hparams, self.global_step)
loss = sum(l for l in loss_d.values()) # 计算总损失
# 记录与训练相关的指标
with torch.no_grad():
typ = 'fine' if 'rgb_fine' in results else 'coarse' # 确定结果类型
psnr_ = psnr(results[f'rgb_{typ}'], rgbs) # 计算 PSNR 指标
self.log('lr', get_learning_rate(self.optimizer)) # 记录学习率
self.log('train/loss', loss) # 记录总损失
self.log('train/AnnealingWeight', AnnealingWeight) # 记录 AnnealingWeight
self.log('train/min_scale_cur', batch['min_scale_cur']) # 记录最小规模
# 记录各个损失
for k, v in loss_d.items():
self.log(f'train/{k}', v)
self.log('train/psnr', psnr_) # 记录 PSNR 指标
# 特定步骤的可视化
if (self.global_step + 1) % 5000 == 0:
# 格式化图像、深度图和蒙版以进行可视化
img = results[f'rgb_{typ}'].detach().view(H, W, 3).permute(2, 0, 1).cpu()
img_gt = rgbs.detach().view(H, W, 3).permute(2, 0, 1).cpu()
depth = visualize_depth(results[f'depth_{typ}'].detach().view(H, W))
# 记录图像和可视化到实验日志器
if self.hparams.use_mask:
mask = results['out_mask'].detach().view(H, W, 1).permute(2, 0, 1).repeat(3, 1, 1).cpu()
if 'rgb_fine_random' in results:
img_random = results[f'rgb_fine_random'].detach().view(H, W, 3).permute(2, 0, 1).cpu()
stack = torch.stack([img_gt, img, depth, img_random, mask])
self.logger.experiment.add_images('train/GT_pred_depth_random_mask', stack, self.global_step)
else:
stack = torch.stack([img_gt, img, depth, mask])
self.logger.experiment.add_images('train/GT_pred_depth_mask', stack, self.global_step)
elif 'rgb_fine_random' in results:
img_random = results[f'rgb_fine_random'].detach().view(H, W, 3).permute(2, 0, 1).cpu()
stack = torch.stack([img_gt, img, depth, img_random])
self.logger.experiment.add_images('train/GT_pred_depth_random', stack, self.global_step)
else:
stack = torch.stack([img_gt, img, depth])
self.logger.experiment.add_images('train/GT_pred_depth', stack, self.global_step)
return loss # 返回计算的损失
以上代码是一个 PyTorch Lightning 中的
training_step
方法,用于执行一个训练步骤。它主要执行以下操作:
数据提取和预处理:
- 从传入的批量数据中提取射线、时间、RGB值和UV样本。
- 根据属性编码和掩膜的需求,提取整个图像和RGB索引。
预测和损失计算:
- 对提取的数据执行前向传递,得到预测结果和损失值。
- 根据损失结果计算总损失值,并在不需要梯度计算时计算 PSNR 指标。
记录指标和损失:
- 记录学习率、总损失、
AnnealingWeight
、最小规模以及各个损失值。- 记录 PSNR 指标作为训练指标。
特定步骤的可视化:
- 当全局步数是 5000 的倍数时,进行特定步骤的可视化。
- 将图像、深度图像和mask以图像格式准备好。
- 如果存在mask,将mask图、深度图、原始图像、预测图像、随机预测图像以图像堆叠的形式记录到实验日志器中。
- 如果没有mask,将深度图、原始图像和预测图像以图像堆叠的形式记录到实验日志器中。
返回损失:返回计算得到的损失值。
这个方法主要负责训练过程中的模型训练、指标记录和可视化。
4 模型验证validation_step()详解:
def validation_step(self, batch, batch_nb):
# 提取输入数据
rays, ts = batch['rays'].squeeze(), batch['ts'].squeeze()
rgbs = batch['rgbs'].squeeze()
# 根据数据集名称设置 uv_sample、W 和 H
if self.hparams.dataset_name == 'phototourism':
uv_sample = batch['uv_sample'].squeeze()
WH = batch['img_wh']
W, H = WH[0, 0].item(), WH[0, 1].item()
else:
W, H = self.hparams.img_wh
uv_sample = None
# 处理需要属性编码、mask 或去遮挡处理的情况
if self.hparams.encode_a or self.hparams.use_mask or self.hparams.deocclusion:
if self.hparams.dataset_name == 'phototourism':
whole_img = batch['whole_img']
else:
# 对于非 phototourism 数据集,构建张量表示原始图像
whole_img = rgbs.view(1, H, W, 3).permute(0, 3, 1, 2) * 2 - 1
rgb_idx = batch['rgb_idx']
else:
whole_img = None
rgb_idx = None
# 根据数据集设置测试渲染器
test_blender = (self.hparams.dataset_name == 'blender')
# 进行前向传播
results = self(rays, ts, whole_img, W, H, rgb_idx, uv_sample, test_blender)
# 计算损失和其他评估指标
loss_d, AnnealingWeight = self.loss(results, rgbs, self.hparams, self.global_step)
loss = sum(l for l in loss_d.values())
log = {'val_loss': loss}
for k, v in loss_d.items():
log[k] = v
# 计算 PSNR 和 SSIM
typ = 'fine' if 'rgb_fine' in results else 'coarse'
img = results[f'rgb_{typ}'].view(H, W, 3).permute(2, 0, 1).cpu()
img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu()
# 在第一个 batch 时计算并记录深度图像和 mask
if batch_nb == 0:
depth = visualize_depth(results[f'depth_{typ}'].view(H, W))
if self.hparams.use_mask:
mask = results['out_mask'].detach().view(H, W, 1).permute(2, 0, 1).repeat(3, 1, 1).cpu()
if 'rgb_fine_random' in results:
img_random = results[f'rgb_fine_random'].detach().view(H, W, 3).permute(2, 0, 1).cpu()
stack = torch.stack([img_gt, img, depth, img_random, mask])
self.logger.experiment.add_images('val/GT_pred_depth_random_mask', stack, self.global_step)
else:
stack = torch.stack([img_gt, img, depth, mask])
self.logger.experiment.add_images('val/GT_pred_depth_mask', stack, self.global_step)
elif 'rgb_fine_random' in results:
img_random = results[f'rgb_fine_random'].detach().view(H, W, 3).permute(2, 0, 1).cpu()
stack = torch.stack([img_gt, img, depth, img_random])
self.logger.experiment.add_images('val/GT_pred_depth_random', stack, self.global_step)
else:
stack = torch.stack([img_gt, img, depth])
self.logger.experiment.add_images('val/GT_pred_depth', stack, self.global_step)
# 计算 PSNR 和 SSIM 并记录到日志
psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
ssim_ = ssim(img[None, ...], img_gt[None, ...])
log['val_psnr'] = psnr_
log['val_ssim'] = ssim_
return log # 返回评估指标
这段代码是 PyTorch Lightning 中用于执行模型验证步骤的方法。
提取输入数据:
- 从输入批次中提取射线
rays
、时间ts
和颜色值rgbs
。对于特定数据集('phototourism'),还提取了uv_sample
和图像宽高信息WH
。- 根据数据集名称和条件,设置了
W
和H
。处理编码、遮罩和去遮挡:
- 根据模型是否需要属性编码、遮罩或者去遮挡,从输入数据中提取相应的参数。对于特定数据集,整个图像
whole_img
和颜色索引rgb_idx
也会被提取。设置测试渲染器:
- 如果数据集是 'blender',则设置
test_blender
为True
。执行前向传播:
- 利用模型执行前向传播,计算输出
results
。计算损失和评估指标:
- 利用计算得到的输出结果
results
计算损失loss
和其他评估指标。将损失值和其他指标记录在log
字典中。图像和深度可视化:
- 计算得到
results
中的图像img
和真实图像img_gt
,以及可能的深度图像depth
。- 在第一个 batch 时,如果使用了 mask,计算
mask
和可能的随机图像img_random
,并将它们与其他图像一起记录到实验日志中。计算 PSNR 和 SSIM:
- 利用计算得到的结果,计算 PSNR 和 SSIM,并将其记录在
log
字典中。返回结果:
- 返回包含评估指标的
log
字典。这个方法主要用于执行验证步骤,评估模型在给定数据集上的性能,并记录相应的指标。
5 validation_epoch_end() 详解
def validation_epoch_end(self, outputs):
# 检查 outputs 的长度以决定是否更新全局变量的当前 epoch
if len(outputs) == 1:
global_val.current_epoch = self.current_epoch # 当 outputs 的长度为 1 时,将 global_val.current_epoch 设置为当前 self.current_epoch
else:
global_val.current_epoch = self.current_epoch + 1 # 否则,将 global_val.current_epoch 设置为当前 self.current_epoch + 1
# 计算 outputs 中验证集上的损失、PSNR 和 SSIM 的平均值
mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() # 平均验证损失
mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean() # 平均 PSNR
mean_ssim = torch.stack([x['val_ssim'] for x in outputs]).mean() # 平均 SSIM
# 记录验证集指标到训练日志中
self.log('val/loss', mean_loss) # 记录平均验证损失
self.log('val/psnr', mean_psnr, prog_bar=True) # 记录平均 PSNR,并显示在进度条中
self.log('val/ssim', mean_ssim, prog_bar=True) # 记录平均 SSIM,并显示在进度条中
# 如果使用遮罩,记录其他相关指标
if self.hparams.use_mask:
self.log('val/c_l', torch.stack([x['c_l'] for x in outputs]).mean()) # 记录 c_l 指标的平均值
self.log('val/f_l', torch.stack([x['f_l'] for x in outputs]).mean()) # 记录 f_l 指标的平均值
self.log('val/r_ms', torch.stack([x['r_ms'] for x in outputs]).mean()) # 记录 r_ms 指标的平均值
self.log('val/r_md', torch.stack([x['r_md'] for x in outputs]).mean()) # 记录 r_md 指标的平均值
这个函数是 PyTorch Lightning 中用于在验证 epoch 结束时执行的方法。它的作用是整合并计算在整个验证集上的损失和指标,以便进行日志记录和报告。
让我们来解释一下这段代码的作用:
全局变量更新:
- 通过检查
outputs
的长度来决定是否在全局变量global_val
中更新当前 epoch。如果outputs
的长度为 1,则将global_val.current_epoch
设置为当前的self.current_epoch
;否则,将global_val.current_epoch
设置为当前的self.current_epoch + 1
。计算平均值:
- 从
outputs
中提取所有 epoch 的验证损失、PSNR 和 SSIM,并计算它们的平均值。- 将这些平均值记录到训练日志中。
记录其他指标:
- 如果使用了遮罩 (
use_mask
),还记录了其他相关指标,如c_l
、f_l
、r_ms
和r_md
的平均值。这个方法的主要作用是汇总整个验证集上的指标,并将这些指标记录在训练日志中,以便在训练过程中进行跟踪和分析。
6 main() 详解
def main(hparams):
# 创建 NeRFSystem 实例
system = NeRFSystem(hparams)
# 设置模型保存的检查点配置
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(hparams.save_dir, f'ckpts/{hparams.exp_name}', '{epoch:d}'),
monitor='val/psnr', # 监控 PSNR 指标
mode='max', # 以最大值作为监控模式
save_top_k=-1 # 保存所有检查点
)
# 设置日志记录器
logger = TestTubeLogger(
save_dir=os.path.join(hparams.save_dir, "logs"), # 日志保存路径
name=hparams.exp_name, # 实验名称
debug=False, # 调试模式
create_git_tag=False, # 是否创建 git tag
log_graph=False # 是否记录图表
)
# 设置训练器 Trainer
trainer = Trainer(
max_epochs=hparams.num_epochs, # 最大 epoch 数
checkpoint_callback=checkpoint_callback, # 检查点配置
resume_from_checkpoint=hparams.ckpt_path, # 从检查点路径恢复
logger=logger, # 日志记录器
weights_summary=None, # 不显示权重摘要
progress_bar_refresh_rate=hparams.refresh_every, # 进度条刷新频率
gpus=hparams.num_gpus, # GPU 数量
accelerator='ddp' if hparams.num_gpus > 1 else None, # 使用分布式数据并行(如果有多个 GPU)
num_sanity_val_steps=-1, # 验证步数
benchmark=True, # 启用性能基准
profiler="simple" if hparams.num_gpus == 1 else None # 使用简单的性能分析器(单 GPU)
)
# 开始模型训练
trainer.fit(system)
主函数主要是用于设置训练过程的配置,并调用
Trainer
来训练NeRFSystem
模型。配置包括模型保存的检查点、日志记录器、训练器的设置等。Trainer
类是 PyTorch Lightning 提供的用于管理训练循环的高级接口。
if __name__ == '__main__':
# 获取命令行参数作为超参数
hparams = get_opts()
# 调用主函数开始训练
main(hparams)