我自己的原文哦~ https://blog.51cto.com/whaosoft/13059544
一、PyTorch DDP
正在郁闷呢 jetson nx 的torchvision安装~~ 自带就剩5g 想弄到ssd 项目中的 venv中又 cuda.h没有... 明明已经装好什么都对
算了说今天主题 啊对 还是搬运啊 学习之工具人而已 勿怪
DistributedDataParallel(DDP)是一个支持多机多卡、分布式训练的深度学习工程方法。其能达到略低于卡数的加速比,是目前最流行的多机多卡训练方法。
本文DDP在实际生产中的应用,如在DDP中引入SyncBN,多机多卡环境下的inference加速等。
基本原理与入门:https://zhuanlan.zhihu.com/p/178402798
实现原理与源代码解析:https://zhuanlan.zhihu.com/p/187610959
在过去的两篇文章里,我们已经对DDP的理论、代码进行了充分、详细的介绍,相信大家都已经了然在胸。但是,实践也是很重要的。正所谓理论联系实践,如果只掌握理论而不进行实践,无疑是纸上谈兵。
在这篇文章里,我们通过几个实战例子,来给大家介绍一下DDP在实际生产中的应用。希望能对大家有所帮助!
- 在DDP中引入SyncBN
- DDP下的Gradient Accumulation的进一步加速
- 多机多卡环境下的inference加速
- 保证DDP性能:确保数据的一致性
- 和DDP有关的小技巧
- 控制不同进程的执行顺序
- 避免DDP带来的冗余输出
请欢快地开始阅读吧!
依赖:pytorch(gpu)>=1.5,python>=3.6
一. 在DDP中引入SyncBN
什么是Batch Normalization(BN)? 这里就不多加以介绍。附上BN文章(https://arxiv.org/abs/1502.03167)。接下来,让我们来深入了解下BN在多级多卡环境上的完整实现:SyncBN。
什么是SyncBN? SyncBN就是Batch Normalization(BN)。其跟一般所说的普通BN的不同在于工程实现方式:SyncBN能够完美支持多卡训练,而普通BN在多卡模式下实际上就是单卡模式。 我们知道,BN中有moving mean和moving variance这两个buffer,这两个buffer的更新依赖于当前训练轮次的batch数据的计算结果。但是在普通多卡DP模式下,各个模型只能拿到自己的那部分计算结果,所以在DP模式下的普通BN被设计为只利用主卡上的计算结果来计算moving mean和moving variance,之后再广播给其他卡。这样,实际上BN的batch size就只是主卡上的batch size那么大。当模型很大、batch size很小时,这样的BN无疑会限制模型的性能。 为了解决这个问题,PyTorch新引入了一个叫SyncBN的结构,利用DDP的分布式计算接口来实现真正的多卡BN。 SyncBN的原理 SyncBN的原理很简单:SyncBN利用分布式通讯接口在各卡间进行通讯,从而能利用所有数据进行BN计算。为了尽可能地减少跨卡传输量,SyncBN做了一个关键的优化,即只传输各自进程的各自的 小batch mean和 小batch variance,而不是所有数据。具体流程请见下面: 前向传播 在各进程上计算各自的 小batch mean和小batch variance 各自的进程对各自的 小batch mean和小batch variance进行all_gather操作,每个进程都得到s的全局量。 注释:只传递mean和variance,而不是整体数据,可以大大减少通讯量,提高速度。 每个进程分别计算总体mean和总体variance,得到一样的结果 注释:在数学上是可行的,有兴趣的同学可以自己推导一下。 接下来,延续正常的BN计算。 注释:因为从前向传播的计算数据中得到的batch mean和batch variance在各卡间保持一致,所以,running_mean和running_variance就能保持一致,不需要显式地同步了! 后向传播:和正常的一样 贴一下关键代码,有兴趣的同学可以研究下:pytorch源码(https://github.com/pytorch/pytorch/blob/release/1.5/torch/nn/modules/_functions.py#L5) SyncBN与DDP的关系 一句话总结,当前PyTorch SyncBN只在DDP单进程单卡模式中支持。SyncBN用到 all_gather这个分布式计算接口,而使用这个接口需要先初始化DDP环境。 复习一下DDP的伪代码中的准备阶段中的DDP初始化阶段
这里有三个点需要注意: 这里的为可能的SyncBN层做准备,实际上就是检测当前是否是DDP单进程单卡模式,如果不是,会直接停止。 这告诉我们,SyncBN需要在DDP环境初始化后初始化,但是要在DDP模型前就准备好。 为什么当前PyTorch SyncBN只支持DDP单进程单卡模式? 从SyncBN原理中我们可以看到,其强依赖了all_gather计算,而这个分布式接口当前是不支持单进程多卡或者DP模式的。当然,不排除未来也是有可能支持的。 怎么用SyncBN? 怎么样才能在我们的代码引入SyncBN呢?很简单:
# DDP init
dist.init_process_group(backend='nccl')
# 按照原来的方式定义模型,这里的BN都使用普通BN就行了。
model = MyModel()
# 引入SyncBN,这句代码,会将普通BN替换成SyncBN。
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
# 构造DDP模型
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
又是熟悉的模样,像DDP一样,一句代码就解决了问题。这是怎么做到的呢?
convert_sync_batchnorm
的原理:
torch.nn.SyncBatchNorm.convert_sync_batchnorm
会搜索model里面的每一个module,如果发现这个module是、或者继承了torch.nn.modules.batchnorm._BatchNorm
类,就把它替换成SyncBN。也就是说,如果你的Normalization层是自己定义的特殊类,没有继承过_BatchNorm
类,那么convert_sync_batchnorm
是不支持的,需要你自己实现一个新的SyncBN!
下面给一下convert_sync_batchnorm
的源码(https://github.com/pytorch/pytorch/blob/v1.5.0/torch/nn/modules/batchnorm.py#L474),可以看到convert的过程中,新的SyncBN复制了原来的BN层的所有参数:
@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
:class:`torch.nn.SyncBatchNorm` layers.
"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.SyncBatchNorm(module.num_features,
module.eps, module.momentum,
module.affine,
module.track_running_stats,
process_group)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output
二. DDP下的Gradient Accumulation的进一步加速什么是Gradient Accmulation?
Gradient Accumulation,即梯度累加,相信大家都有所了解,是一种增大训练时batch size的技术,造福了无数硬件条件窘迫的我等穷人。不了解的同学请看这个知乎链接(https://www.zhihu.com/question/303070254/answer/573037166)。
为什么还能进一步加速?
我们仔细思考一下DDP下的gradient accumulation。
# 单卡模式,即普通情况下的梯度累加
for 每次梯度累加循环
optimizer.zero_grad()
for 每个小step
prediction = model(data)
loss_fn(prediction, label).backward() # 积累梯度,不应用梯度改变
optimizer.step() # 应用梯度改变
我们知道,DDP的gradient all_reduce阶段发生在loss_fn(prediction, label).backward()
。这意味着,在梯度累加的情况下,假设一次梯度累加循环有K个step,每次梯度累加循环会进行K次 all_reduce!但事实上,每次梯度累加循环只会有一次 optimizer.step(),即只应用一次参数修改,这意味着在每一次梯度累加循环中,我们其实只要进行一次gradient all_reduce即可满足要求,有K-1次 all_reduce被浪费了!而每次 all_reduce的时间成本是很高的!
如何加速
解决问题的思路在于,对前K-1次step取消其梯度同步。幸运的是,DDP给我们提供了一个暂时取消梯度同步的context函数 no_sync()
(源代码:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L548)。在这个context下,DDP不会进行梯度同步。
所以,我们可以这样实现加速:
model = DDP(model)
for 每次梯度累加循环
optimizer.zero_grad()
# 前K-1个step,不进行梯度同步,累积梯度。
for K-1个小step:
with model.no_sync():
prediction = model(data)
loss_fn(prediction, label).backward()
# 第K个step,进行梯度同步
prediction = model(data)
loss_fn(prediction, label).backward()
optimizer.step()
给一个优雅写法(同时兼容单卡、DDP模式哦):
from contextlib import nullcontext
# 如果你的python版本小于3.7,请注释掉上面一行,使用下面这个:
# from contextlib import suppress as nullcontext
if local_rank != -1:
model = DDP(model)
optimizer.zero_grad()
for i, (data, label) in enumerate(dataloader):
# 只在DDP模式下,轮数不是K整数倍的时候使用no_sync
my_context = model.no_sync if local_rank != -1 and i % K != 0 else nullcontext
with my_context():
prediction = model(data)
loss_fn(prediction, label).backward()
if i % K == 0:
optimizer.step()
optimizer.zero_grad()
是不是很漂亮!
三. 多机多卡环境下的inference加速
问题
有一些非常现实的需求,相信大家肯定碰到过:
- 一般,训练中每几个epoch我们会跑一下inference、测试一下模型性能。在DDP多卡训练环境下,能不能利用多卡来加速inference速度呢?
- 我有一堆数据要跑一些网络推理,拿到inference结果。DP下多卡加速比太低,能不能利用DDP多卡来加速呢?
解法
这两个问题实际是同一个问题。答案肯定是可以的,但是,没有现成、省力的方法。
测试和训练的不同在于:
- 测试的时候不需要进行梯度反向传播,inference过程中各进程之间不需要通讯。
- 测试的时候,不同模型的inference结果、性能指标的类型多种多样,没有统一的形式。
- 我们很难定义一个统一的框架,像训练时
model=DDP(model)
那样方便地应用DDP多卡加速。
解决问题的思路很简单,就是各个进程中各自进行单卡的inference,然后把结果收集到一起。单卡inference很简单,我们甚至可以直接用DDP包装前的模型。问题其实只有两个:
- 我们要如何把数据split到各个进程中
- 我们要如何把结果合并到一起
如何把数据split到各个进程中:新的data sampler
大家肯定还记得,在训练的时候,我们用的 torch.utils.data.distributed.DistributedSampler
帮助我们把数据不重复地分到各个进程上去。但是,其分的方法是:每段连续的N个数据,拆成一个一个,分给N个进程,所以每个进程拿到的数据不是连续的。这样,不利于我们在inference结束的时候将结果合并到一起。
所以,这里我们需要实现一个新的data sampler。它的功能,是能够连续地划分数据块,不重复地分到各个进程上去。直接给代码:
# 来源:https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.batch_size = batch_size
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [indices[-1]] * (self.total_size - len(indices))
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
return iter(indices)
def __len__(self):
return self.num_samples
如何把结果合并到一起: all_gather
通过torch.distributed
提供的分布式接口all_gather
,我们可以把各个进程的prediction结果集中到一起。
难点就在这里。因为世界上存在着千奇百怪的神经网络模型,有着千奇百怪的输出,所以,把数据集中到一起不是一件容易的事情。但是,如果你的网络输出在不同的进程中有着一样的大小,那么这个问题就好解多了。下面给一个方法,其要求网络的prediction结果在各个进程中的大小是一模一样的:
# 合并结果的函数
# 1. all_gather,将各个进程中的同一份数据合并到一起。
# 和all_reduce不同的是,all_reduce是平均,而这里是合并。
# 2. 要注意的是,函数的最后会裁剪掉后面额外长度的部分,这是之前的SequentialDistributedSampler添加的。
# 3. 这个函数要求,输入tensor在各个进程中的大小是一模一样的。
def distributed_concat(tensor, num_total_examples):
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
return concat[:num_total_examples]
完整的流程
结合上面的介绍,我们可以得到下面这样一个完整的流程。
## 构造测试集
# 假定我们的数据集是这个
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
my_testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
# 使用我们的新sampler
test_sampler = SequentialDistributedSampler(my_testset, batch_size=16)
testloader = torch.utils.data.DataLoader(my_testset, batch_size=16, sampler=test_sampler)
# DDP和模型初始化,略。
# ......
# 正式训练和evaluation
for epoch in range(total_epoch_size):
# 训练代码,略
# .......
# 开始测试
with torch.no_grad():
# 1. 得到本进程的prediction
predictions = []
labels = []
for data, label in testloader:
data, label = data.to(local_rank), label.to(local_rank)
predictions.append(model(data))
labels.append(label)
# 进行gather
predictions = distributed_concat(torch.concat(predictions, dim=0),
len(test_sampler.dataset))
labels = distributed_concat(torch.concat(labels, dim=0),
len(test_sampler.dataset))
# 3. 现在我们已经拿到所有数据的predictioin结果,进行evaluate!
my_evaluate_func(predictions, labels)
更简化的解法
- 如果我们的目的只是得到性能数字,那么,我们甚至可以直接在各个进程中计算各自的性能数字,然后再合并到一起。上面给的解法,是为了更通用的情景。一切根据你的需要来定!
- 我们可以单向地把predictions、labels集中到 rank=0的进程,只在其进行evaluation并输出。PyTorch也提供了相应的接口(链接:https://pytorch.org/docs/stable/distributed.html,send和recv)。
四. 保证DDP性能:确保数据的一致性性能期望
从原理上讲,当没有开启SyncBN时,(或者更严格地讲,没有BN层;但一般有的话影响也不大),以下两种方法训练出来的模型应该是性能相似的:
- 进程数为N的DDP训练
- accumulation为N、其他配置完全相同的单卡训练
如果我们发现性能对不上,那么,往往是DDP中的某些设置出了问题。在DDP系列第二篇中,我们介绍过一个check list,可以根据它检查下自己的配置。其中,在造成性能对不齐的原因中,最有可能的是数据方面出现了问题。
DDP训练时,数据的一致性必须被保证:各个进程拿到的数据,要像是accumulation为N、其他配置完全相同的单卡训练中同个accumulation循环中不同iteration拿到的数据。想象一下,如果各个进程拿到的数据是一样的,或者分布上有任何相似的地方,那么,这就会造成训练数据质量的下降,最终导致模型性能下降。
容易错的点:随机数种子
为保证实验的可复现性,一般我们会在代码在开头声明一个固定的随机数种子,从而使得同一个配置下的实验,无论启动多少次,都会拿到同样的结果。
import random
import numpy as np
import torch
def init_seeds(seed=0, cuda_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
def main():
# 一般都直接用0作为固定的随机数种子。
init_seeds(0)
但是在DDP训练中,如果还是像以前一样,使用0作为随机数种子,不做修改,就会造成以下后果:
- DDP的N个进程都使用同一个随机数种子
- 在生成数据时,如果我们使用了一些随机过程的数据扩充方法,那么,各个进程生成的数据会带有一定的同态性。
- 比如说,YOLOv5会使用mosaic数据增强(从数据集中随机采样3张图像与当前的拼在一起,组成一张里面有4张小图的大图)。这样,因为各卡使用了相同的随机数种子,你会发现,各卡生成的图像中,除了原本的那张小图,其他三张小图都是一模一样的!
- 同态性的数据,降低了训练数据的质量,也就降低了训练效率!最终得到的模型性能,很有可能是比原来更低的。
所以,我们需要给不同的进程分配不同的、固定的随机数种子:
def main():
rank = torch.distributed.get_rank()
# 问题完美解决!
init_seeds(1 + rank)
五. 和DDP有关的小技巧控制不同进程的执行顺序
一般情况下,各个进程是各自执行的,速度有快有慢,只有在gradient all-reduce的时候,快的进程才会等一下慢的进程,也就是进行同步。那么,如果我们需要在其他地方进行同步呢?比如说,在加载数据前,如果数据集不存在,我们要下载数据集:
- 我们只需要在唯一一个进程中开启一次下载
- 我们需要让其他进程等待其下载完成,再去加载数据
怎么解决这个问题呢?torch.distributed
提供了一个barrier()
的接口,利用它我们可以同步各个DDP中的各个进程!当使用barrier函数时,DDP进程会在函数的位置进行等待,知道所有的进程都跑到了 barrier函数的位置,它们才会再次向下执行。
只在某进程执行,无须同步:
这是最简单的,只需要一个简单的判断,用不到barrier()
if rank == 0:
code_only_run_in_rank_0()
简单的同步:
没什么好讲的,只是一个示范
code_before()
# 在这一步同步
torch.distributed.barrier()
code_after()
在某个进程中执行A操作,其他进程等待其执行完成后再执行B操作:
也简单。
if rank == 0:
do_A()
torch.distributed.barrier()
else:
do_B()
torch.distributed.barrier()
在某个进程中优先执行A操作,其他进程等待其执行完成后再执行A操作:
这个值得深入讲一下,因为这个是非常普遍的需求。利用contextlib.contextmanager
,我们可以把这个逻辑给优雅地包装起来!
from contextlib import contextmanager
@contextmanager
def torch_distributed_zero_first(rank: int):
"""Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if rank not in [-1, 0]:
torch.distributed.barrier()
# 这里的用法其实就是协程的一种哦。
yield
if rank == 0:
torch.distributed.barrier()
然后我们就可以这样骚操作:
with torch_distributed_zero_first(rank):
if not check_if_dataset_exist():
download_dataset()
load_dataset()
优雅地解决了需求!
避免DDP带来的冗余输出
问题:
当我们在自己的模型中加入DDP模型时,第一的直观感受肯定是,终端里的输出变成了N倍了。这是因为我们现在有N个进程在同时跑整个程序。这不光是对有洁癖的同学造成困扰,其实对所有人都会造成困扰。因为各个进程的速度并不一样快,在茫茫的输出海洋中,我们难以debug、把控实验状态。
解法:
那么,有什么办法能避免这个现象呢?下面,笔者给一个可行的方法:logging模块+输出信息等级控制。即用logging输出代替所有print输出,并给不同进程设置不同的输出等级,只在0号进程保留低等级输出。举一个例子:
import logging
# 给主要进程(rank=0)设置低输出等级,给其他进程设置高输出等级。
logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN)
# 普通log,只会打印一次。
logging.info("This is an ordinary log.")
# 危险的warning、error,无论在哪个进程,都会被打印出来,从而方便debug。
logging.error("This is a fatal log!")
simple but powerful!
二、PyTorch~SyncBatchNorm
对于一些模型占用显存很大,导致可以上的 batch size 很小这类任务来说,分布式训练的时候就需要用 SyncBatchNorm 来使得统计量更加的准确。本文对SyncBatchNorm的前向以及反向实现细节进行阐述。
我们知道在分布式数据并行多卡训练的时候,BatchNorm 的计算过程(统计均值和方差)在进程之间是独立的,也就是每个进程只能看到本地 GlobalBatchSize / NumGpu 大小的数据。对于一般的视觉任务比如分类,分布式训练的时候,单卡的 batch size 也足够大了,所以不需要在计算过程中同步 batchnorm 的统计量,因为同步也会让训练效率下降。但是对于一些模型占用显存很大,导致可以上的 batch size 很小这类任务来说,分布式训练的时候就需要用 SyncBatchNorm 来使得统计量更加的准确。
SyncBatchNorm 前向实现
前向第一步,计算本地均值和方差
假设在4张GPU上做分布式数据并行训练,我们来看下各个进程上 SyncBN 的行为:
如上图所示,SyncBN前向实现的第一步是,每个GPU先单独计算各自本地数据 X_i
对应均值和方差(mean_i
和 var_i
) 。
而计算均值和方差的 CUDA kernel 具体实现是实现采用的 Welford
迭代计算算法
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
我们知道传统方法计算均值,是要先把所有数据加起来然后再除以个数,而方差是在平均值的基础上做进一步的计算。
但是这样的计算方式有个问题是,在数据量非常之大的情况下,把所有数相加的结果是一个非常大的值,容易导致精度溢出。
而Welford
迭代计算算法,则只需要对数据集进行单次遍历,然后根据迭代公式计算均值,可以避免传统算法可能导致的精度溢出的问题,且 Welford
算法可以并行化。
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
假设现在输入张量形状是 (B,C,H,W)
,下面解释输入张量是 NCHW 格式的时候, CUDA kernel 具体开启线程的方式和每个线程具体计算细节。
由于线程的配置是按照固定的公式计算出来的,这里为了解释方便就固定为其中一种情况:
如上图所示,总共起了 C
个 thread block,也就是 grid
大小等于通道数。每个 thread block 负责计算某一个通道的均值和方差。
每个 thread block 的形状是两维,x维度是 512, y 维度是 1,共处理 B * H * W
大小的数据,其中数据组织形式是 x 方向是 H * W
维度,y 方向是 B
维度。
每个thread block 负责处理的数据大小和其中每个线程负责处理的位置,如下图所示:
如上图所示紫色方块表示thread block中的一个thread,紫色箭头指向表示,在kernel执行过程中,该线程所要负责处理的数据。
每个线程在x方向上每处理完一个数据,移动的步长为 blockDim.x=512
,x方向遍历完之后,y方向移动步长为blockDim.y=1
,以此类推。
kernel 执行的第一步就是,所有线程处理完自己所负责的数据,然后同步一下,接着就是合并每个线程计算得到的局部均值和方差。
而我们知道一个 thread block 内的线程,是按全局 id 顺序从0开始每 32 个线程分为一组,也就是一个 warp,然后以warp为单位来执行。
kernel 执行的第二步就是,每个 warp 内的线程合并均值和方差,通过 warp 级的同步元语库函数 __shfl_xor_sync
来实现 warp 内线程结果的合并。
这一步做完之后,warp 内每个线程都包含了合并之后的均值和方差,下面解释如何通过 __shfl_xor_sync
来实现 warp 内线程结果合并的:
上图中的每一行的32个黑点表示一个 warp 内的32个线程,上方的id 表示每个线程在warp内的id。
然后我们看合并 mean 和 var 的循环,这里可视化了每个循环内线程之间的交互。
__shfl_xor_sync
简单来理解,只需要关注第 2 和 3 个参数,第二个参数是线程之间要交换的值,第三个参数传 i。
具体作用就是,当前线程的 id 和 这个 i 做异或 xor
位运算,计算得到的结果也是 id ,也就是当前线程要与哪个线程交换值。
当 i = 1
的时候,
对于线程 id 0 和 1, 0 xor 1 = 1
, 1 xor 1 = 0
,则就是线程 0 和 1 交换各自的均值和方差,然后就都持有了合并之后的均值和方差了。
再看线程 id 2 和 3, 2 xor 1 = 3
,3 oxr 1 = 2
,所以 2 和 3 交换。
同理可得第一轮循环,是线程按顺序2个为一组组内合并。
当 i = 2
的时候,
对于线程 id 0 和 2, 0 xor 2 = 2
, 2 xor 2 = 0
,
对于线程 id 1 和 3,1 xor 2 = 3
, 3 xor 2 = 1
所以交换完合并之后,thread 0 ~ 3 就都持有了这4个线程合并之后的均值和方差了。
同理可得,
i = 2
的时候线程按顺序4个为一组,组内根据异或运算计算交换线程对合并均值和方差。
i = 4
的时候,线程按顺序8个为一组,
i = 8
的时候,线程按顺序16个为一组,
当最后一轮 i = 16
循环完了之后,warp 内每个线程就都持有了该 warp 的所有线程合并的均值和方差了。
kernel 执行的最后一步是,上面每个 warp 内结果合并完,会做一次全局的线程同步。之后再将所有 warp 的结果合并就得到该 thread block 所负责计算的通道均值和方差了。
前向第二步,GPU之间同步均值和方差
通过集合通信操作 AllGather
让每个 GPU 上的进程都拿到所有 GPU 上的均值和方差,最后就是每个GPU内计算得到全局的均值和方差,同时更新 running_mean
和 running_var
前向第三步,计算 SyncBN 的输出
最后这一步就一个常规的batchnorm操作,对输入 x 做 normalize 操作得到输出,cuda kernel 就是一个 eltwise 的操作,因为不需要计算均值和方差了。这里就不展开了,有兴趣的读者可以看文末的参考链接,去阅读torch的源码,也可以学习一下对于 NHWC
格式的 cuda kernel 是如何实现的。
SyncBatchNorm 反向实现细节
BatchNorm 反向计算公式
首先复习一下 BatchNorm 反向,输入格式是 (B,C,H,W)
则某个通道(通道索引 c
)对应的 输入 x 、weight 和 bias 梯度计算公式,这里不做推导只列出公式:
前置公式:
输出梯度为 y_grad
weight 对应通道 c 的梯度:
bias 对应通道 c 的梯度:
输入 x 对应通道 c 上某个位置 b, h, w 的梯度:
反向计算流程
每个GPU都计算出本地对应的 weight_grad
,bias_grad
,sum_dy
和 sum_dy_xmu
,具体CUDA kernel 实现思路和前向第一步类似,这里就不展开了,有兴趣可以去阅读源码。
由于分布式数据并行下,权值的梯度会自动做全局同步,所以 SyncBN 就不需要管权值梯度的跨 GPU 的同步。
而对于sum_dy
和 sum_dy_xmu
,则通过集合通信操作 AllReduce
将所有GPU上的结果累加,使得每个GPU上持有全局累加的结果。
最后每个 GPU 根据上面的计算公式计算本地输入x对应的梯度,但是需要注意的是,由于 sum_dy
和 sum_dy_xmu
是跨 GPU 全局累加的结果,所以上面公式中的 rc=B*H*W
要改为 rc=B*H*W*num_gpu
。该 CUDA kernel 的实现,根据上述公式,也是一个 eltiwse 的操作,细节可以去阅读torch源码。
参考资料
- https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html
- https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html
- https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh
- https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cu
- https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/
- https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#simt-architecture
- https://people.maths.ox.ac.uk/gilesm/cuda/2019/lecture_04.pdf
- https://mpitutorial.com/tutorials/mpi-scatter-gather-and-allgather/
三、PyTorch~NeRF
笔者通过整理分析了NeRF论文和相关参考代码,将为读者朋友讲述利用PyTorch框架,从0到1简单复现一个NeRF(神经辐射场)的实现细节和过程。
在解释代码之前,首先对NeRF(神经辐射场)的原理与含义进行简单回顾。而NeRF论文中是这样解释NeRF算法流程的:
“我们提出了一个当前最优的方法,应用于复杂场景下合成新视图的任务,具体的实现原理是使用一个稀疏的输入视图集合,然后不断优化底层的连续体素场景函数。我们的算法,使用一个全连接(非卷积)的深度网络,表示一个场景,这个深度网络的输入是一个单独的5D坐标(空间位置(x,y,z)和视图方向(xita,sigma)),其对应的输出则是体素密度和视图关联的辐射向量。我们通过查询沿着相机射线的5D坐标合成新的场景视图,以及通过使用经典的体素渲染技术将输出颜色和密度投射到图像中。因为体素渲染具有天然的可变性,所以优化我们的表示方法所需的唯一输入就是一组已知相机位姿的图像。我们介绍如何高效优化神经辐射场照度,以渲染具有复杂几何形状和外观的逼真新颖视图,并展示了由于之前神经渲染和视图合成工作的结果。”
图1|NeRF实现流程
基于前文的原理,本节开始讲述具体的代码实现。首先,导入算法需要的Python库文件。
import os
from typing import Optional,Tuple,List,Union,Callable
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl\_toolkits.mplot3d import axes3d
from tqdm import trange
# 设置GPU还是CPU设备
device = torch.device\('cuda' if torch.cuda.is\_available\(\) else 'cpu'\)
1 输入
根据相关论文中的介绍可知,NeRF的输入是一个包含空间位置坐标与视图方向的5D坐标。然而,在PyTorch构建NeRF过程中使用的数据集只是一般的3D到2D图像数据集,包含拍摄相机的内参:位姿和焦距。因此在后面的操作中,我们会把输入数据集转为算法模型需要的输入形式。
在这一流程中使用乐高推土机图像作为简单NeRF算法的数据集,如图2所示:(具体的数据链接请在文末查看)
图2|乐高推土机数据集
这项工作中使用的小型乐高数据集由 106 幅乐高推土机的图像组成,并配有位姿数据和常用焦距数值。与其他数据集一样,这里保留前 100 张图像用于训练,并保留一张测试图像用于验证,具体的加载数据操作如下:
data = np.load\('tiny\_nerf\_data.npz'\) # 加载数据集
images = data\['images'\] # 图像数据
poses = data\['poses'\] # 位姿数据
focal = data\['focal'\] # 焦距数值
print\(f'Images shape: \{images.shape\}'\)
print\(f'Poses shape: \{poses.shape\}'\)
print\(f'Focal length: \{focal\}'\)
height, width = images.shape\[1:3\]
near, far = 2., 6.
n\_training = 100 # 训练数据数量
testimg\_idx = 101 # 测试数据下标
testimg, testpose = images\[testimg\_idx\], poses\[testimg\_idx\]
plt.imshow\(testimg\)
print\('Pose'\)
print\(testpose\)
2 数据处理
一般而言,为了收集这些特点输入数据,算法中需要对输入图像进行反渲染操作。具体来讲就是通过每个像素点在三维空间中绘制投影线,并从中提取样本。
要从图像以外的三维空间采样输入数据点,首先就得从乐高照片集中获取每台相机的初始位姿,然后通过一些矢量数学运算,将这些4x4姿态矩阵转换成「表示原点的三维坐标和表示方向的三维矢量」——这两类信息最终会结合起来描述一个矢量,该矢量用以表征拍摄照片时相机的指向。
下列代码则正是通过绘制箭头来描述这一操作,箭头表示每一帧图像的原点和方向:
# 方向数据
dirs = np.stack\(\[np.sum\(\[0, 0, -1\] \* pose\[:3, :3\], axis=-1\) for pose in poses\]\)
# 原点数据
origins = poses\[:, :3, -1\]
# 绘图的设置
ax = plt.figure\(figsize=\(12, 8\)\).add\_subplot\(projection='3d'\)
\_ = ax.quiver\(
origins\[..., 0\].flatten\(\),
origins\[..., 1\].flatten\(\),
origins\[..., 2\].flatten\(\),
dirs\[..., 0\].flatten\(\),
dirs\[..., 1\].flatten\(\),
dirs\[..., 2\].flatten\(\), length=0.5, normalize=True\)
ax.set\_xlabel\('X'\)
ax.set\_ylabel\('Y'\)
ax.set\_zlabel\('z'\)
plt.show\(\)
最终绘制出来的箭头结果如下图所示:
图3|采样点相机拍摄指向
当有了这些相机位姿数据之后,我们就可以沿着图像的每个像素找到投影线,而每条投影线都是由其原点(x,y,z)和方向联合定义。其中每个像素的原点可能相同,但方向一般是不同的。这些方向射线都略微偏离中心,因此不会存在两条平行方向线,如下图所示:
图4|相机内参示意图
根据图4所述的原理,我们就可以确定每条射线的方向和原点,相关代码如下:
def get\_rays\(
height: int, # 图像高度
width: int, # 图像宽带
focal\_length: float, # 焦距
c2w: torch.Tensor
\) -> Tuple\[torch.Tensor, torch.Tensor\]:
"""
通过每个像素和相机原点,找到射线的原点和方向。
"""
# 应用针孔相机模型收集每个像素的方向
i, j = torch.meshgrid\(
torch.arange\(width, dtype=torch.float32\).to\(c2w\),
torch.arange\(height, dtype=torch.float32\).to\(c2w\),
indexing='ij'\)
i, j = i.transpose\(-1, -2\), j.transpose\(-1, -2\)
# 方向数据
directions = torch.stack\(\[\(i - width \* .5\) / focal\_length,
-\(j - height \* .5\) / focal\_length,
-torch.ones\_like\(i\)
\], dim=-1\)
# 用相机位姿求出方向
rays\_d = torch.sum\(directions\[..., None, :\] \* c2w\[:3, :3\], dim=-1\)
# 默认所有射线原点相同
rays\_o = c2w\[:3, -1\].expand\(rays\_d.shape\)
return rays\_o, rays\_d
得到每个像素对应的射线的方向数据和原点数据之后,就能够获得了NeRF算法中需要的五维数据输入,下面将这些数据调整为算法输入的格式:
# 转为PyTorch的tensor
images = torch.from\_numpy\(data\['images'\]\[:n\_training\]\).to\(device\)
poses = torch.from\_numpy\(data\['poses'\]\).to\(device\)
focal = torch.from\_numpy\(data\['focal'\]\).to\(device\)
testimg = torch.from\_numpy\(data\['images'\]\[testimg\_idx\]\).to\(device\)
testpose = torch.from\_numpy\(data\['poses'\]\[testimg\_idx\]\).to\(device\)
# 针对每个图像获取射线
height, width = images.shape\[1:3\]
with torch.no\_grad\(\):
ray\_origin, ray\_direction = get\_rays\(height, width, focal, testpose\)
print\('Ray Origin'\)
print\(ray\_origin.shape\)
print\(ray\_origin\[height // 2, width // 2, :\]\)
print\(''\)
print\('Ray Direction'\)
print\(ray\_direction.shape\)
print\(ray\_direction\[height // 2, width // 2, :\]\)
print\(''\)
2.1 分层采样
当算法输入模块有了NeRF算法需要的输入数据,也就是包含原点和方向向量组合的线条时,就可以在线条上进行采样。这一过程是采用从粗到细的采样策略,即分层采样策略。
具体来说,分层采样就是将光线分成均匀分布的小块,接着在每个小块内随机抽样。其中扰动的设置决定了是均匀取样的,还是直接简单使用分区中心作为采样点。具体操作代码如下所示:
# 采样函数定义
def sample\_stratified\(
rays\_o: torch.Tensor, # 射线原点
rays\_d: torch.Tensor, # 射线方向
near: float,
far: float,
n\_samples: int, # 采样数量
perturb: Optional\[bool\] = True, # 扰动设置
inverse\_depth: bool = False # 反向深度
\) -> Tuple\[torch.Tensor, torch.Tensor\]:
"""
从规则的bin中沿着射线进行采样。
"""
# 沿着射线抓取采样点
t\_vals = torch.linspace\(0., 1., n\_samples, device=rays\_o.device\)
if not inverse\_depth:
# 由远到近线性采样
z\_vals = near \* \(1.-t\_vals\) + far \* \(t\_vals\)
else:
# 在反向深度中线性采样
z\_vals = 1./\(1./near \* \(1.-t\_vals\) + 1./far \* \(t\_vals\)\)
# 沿着射线从bins中统一采样
if perturb:
mids = .5 \* \(z\_vals\[1:\] + z\_vals\[:-1\]\)
upper = torch.concat\(\[mids, z\_vals\[-1:\]\], dim=-1\)
lower = torch.concat\(\[z\_vals\[:1\], mids\], dim=-1\)
t\_rand = torch.rand\(\[n\_samples\], device=z\_vals.device\)
z\_vals = lower + \(upper - lower\) \* t\_rand
z\_vals = z\_vals.expand\(list\(rays\_o.shape\[:-1\]\) + \[n\_samples\]\)
# 应用相应的缩放参数
pts = rays\_o\[..., None, :\] + rays\_d\[..., None, :\] \* z\_vals\[..., :, None\]
return pts, z\_vals
接着就到了对这些采样点做可视化分析的步骤。如图5中所述,未受扰动的蓝 色点是bin的“中心“,而红点对应扰动点的采样。请注意,红点与上方的蓝点略有偏移,但所有点都在远近采样设定值之间。具体代码如下:
y\_vals = torch.zeros\_like\(z\_vals\)
# 调用采样策略函数
\_, z\_vals\_unperturbed = sample\_stratified\(rays\_o, rays\_d, near, far, n\_samples,
perturb=False, inverse\_depth=inverse\_depth\)
# 绘图相关
plt.plot\(z\_vals\_unperturbed\[0\].cpu\(\).numpy\(\), 1 + y\_vals\[0\].cpu\(\).numpy\(\), 'b-o'\)
plt.plot\(z\_vals\[0\].cpu\(\).numpy\(\), y\_vals\[0\].cpu\(\).numpy\(\), 'r-o'\)
plt.ylim\(\[-1, 2\]\)
plt.title\('Stratified Sampling \(blue\) with Perturbation \(red\)'\)
ax = plt.gca\(\)
ax.axes.yaxis.set\_visible\(False\)
plt.grid\(True\)
图5|采样结果示意图
3 位置编码
与Transformer一样,NeRF也使用了位置编码器。因此NeRF就需要借助位置编码器将输入映射到更高的频率空间,以弥补神经网络在学习低频函数时的偏差。
这一环节将会为位置编码器建立一个简单的 torch.nn.Module 模块,相同的编码器可同时用于对输入样本和视图方向的编码操作。注意,这些输入被指定了不同的参数。代码如下所示:
# 位置编码类
class PositionalEncoder\(nn.Module\):
"""
对输入点,做sine或者consine位置编码。
"""
def \_\_init\_\_\(
self,
d\_input: int,
n\_freqs: int,
log\_space: bool = False
\):
super\(\).\_\_init\_\_\(\)
self.d\_input = d\_input
self.n\_freqs = n\_freqs
self.log\_space = log\_space
self.d\_output = d\_input \* \(1 + 2 \* self.n\_freqs\)
self.embed\_fns = \[lambda x: x\]
# 定义线性或者log尺度的频率
if self.log\_space:
freq\_bands = 2.\*\*torch.linspace\(0., self.n\_freqs - 1, self.n\_freqs\)
else:
freq\_bands = torch.linspace\(2.\*\*0., 2.\*\*\(self.n\_freqs - 1\), self.n\_freqs\)
# 替换sin和cos
for freq in freq\_bands:
self.embed\_fns.append\(lambda x, freq=freq: torch.sin\(x \* freq\)\)
self.embed\_fns.append\(lambda x, freq=freq: torch.cos\(x \* freq\)\)
def forward\(
self,
x
\) -> torch.Tensor:
"""
实际使用位置编码的函数。
"""
return torch.concat\(\[fn\(x\) for fn in self.embed\_fns\], dim=-1\)
4 NeRF模型
在此,定义一个NeRF 模型——主要由线性层模块列表构成,而列表中进一步包含非线性激活函数和残差连接。该模型有一个可选的视图方向输入,如果在实例化时提供具体的方向信息,那么会改变模型结构。
(本实现基于原始论文NeRF:Representing Scenes as Neural Radiance Fields for View Synthesis 的第3节,并使用相同的默认设置)
具体代码如下所示:
# 定义NeRF模型
class NeRF\(nn.Module\):
"""
神经辐射场模块。
"""
def \_\_init\_\_\(
self,
d\_input: int = 3,
n\_layers: int = 8,
d\_filter: int = 256,
skip: Tuple\[int\] = \(4,\),
d\_viewdirs: Optional\[int\] = None
\):
super\(\).\_\_init\_\_\(\)
self.d\_input = d\_input # 输入
self.skip = skip # 残差连接
self.act = nn.functional.relu # 激活函数
self.d\_viewdirs = d\_viewdirs # 视图方向
# 创建模型的层结构
self.layers = nn.ModuleList\(
\[nn.Linear\(self.d\_input, d\_filter\)\] +
\[nn.Linear\(d\_filter + self.d\_input, d\_filter\) if i in skip \\
else nn.Linear\(d\_filter, d\_filter\) for i in range\(n\_layers - 1\)\]
\)
# Bottleneck 层
if self.d\_viewdirs is not None:
# 如果使用视图方向,分离alpha和RGB
self.alpha\_out = nn.Linear\(d\_filter, 1\)
self.rgb\_filters = nn.Linear\(d\_filter, d\_filter\)
self.branch = nn.Linear\(d\_filter + self.d\_viewdirs, d\_filter // 2\)
self.output = nn.Linear\(d\_filter // 2, 3\)
else:
# 如果不使用试图方向,则简单输出
self.output = nn.Linear\(d\_filter, 4\)
def forward\(
self,
x: torch.Tensor,
viewdirs: Optional\[torch.Tensor\] = None
\) -> torch.Tensor:
r"""
带有视图方向的前向传播
"""
# 判断是否设置视图方向
if self.d\_viewdirs is None and viewdirs is not None:
raise ValueError\('Cannot input x\_direction if d\_viewdirs was not given.'\)
# 运行bottleneck层之前的网络层
x\_input = x
for i, layer in enumerate\(self.layers\):
x = self.act\(layer\(x\)\)
if i in self.skip:
x = torch.cat\(\[x, x\_input\], dim=-1\)
# 运行 bottleneck
if self.d\_viewdirs is not None:
# Split alpha from network output
alpha = self.alpha\_out\(x\)
# 结果传入到rgb过滤器
x = self.rgb\_filters\(x\)
x = torch.concat\(\[x, viewdirs\], dim=-1\)
x = self.act\(self.branch\(x\)\)
x = self.output\(x\)
# 拼接alpha一起作为输出
x = torch.concat\(\[x, alpha\], dim=-1\)
else:
# 不拼接,简单输出
x = self.output\(x\)
return x
5 体积渲染
上面得到NeRF模型的输出结果之后,仍需将NeRF的输出转换成图像。也就是通过渲染模块对每个像素沿光线方向的所有样本进行加权求和,从而得到该像素的估计颜色值,此外每个RGB样本都会根据其Alpha值进行加权。其中Alpha值越高,表明采样区域不透明的可能性越大,因此沿射线方向越远的点越有可能被遮挡,累加乘积可确保更远处的点受到抑制。具体代码如下:
# 体积渲染
def cumprod\_exclusive\(
tensor: torch.Tensor
\) -> torch.Tensor:
"""
\(Courtesy of https://github.com/krrish94/nerf-pytorch\)
和tf.math.cumprod\(..., exclusive=True\)功能类似
参数:
tensor \(torch.Tensor\): Tensor whose cumprod \(cumulative product, see \`torch.cumprod\`\) along dim=-1
is to be computed.
返回值:
cumprod \(torch.Tensor\): cumprod of Tensor along dim=-1, mimiciking the functionality of
tf.math.cumprod\(..., exclusive=True\) \(see \`tf.math.cumprod\` for details\).
"""
# 首先计算规则的cunprod
cumprod = torch.cumprod\(tensor, -1\)
cumprod = torch.roll\(cumprod, 1, -1\)
# 用1替换首个元素
cumprod\[..., 0\] = 1.
return cumprod
# 输出到图像的函数
def raw2outputs\(
raw: torch.Tensor,
z\_vals: torch.Tensor,
rays\_d: torch.Tensor,
raw\_noise\_std: float = 0.0,
white\_bkgd: bool = False
\) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor\]:
"""
将NeRF的输出转换为RGB输出。
"""
# 沿着\`z\_vals\`轴元素之间的差值.
dists = z\_vals\[..., 1:\] - z\_vals\[..., :-1\]
dists = torch.cat\(\[dists, 1e10 \* torch.ones\_like\(dists\[..., :1\]\)\], dim=-1\)
# 将每个距离乘以相应方向射线的法线,转换为现实世界中的距离(考虑非单位方向)。
dists = dists \* torch.norm\(rays\_d\[..., None, :\], dim=-1\)
# 为模型预测密度添加噪音。可用于在训练过程中对网络进行正则化(防止出现浮点伪影)。
noise = 0.
if raw\_noise\_std > 0.:
noise = torch.randn\(raw\[..., 3\].shape\) \* raw\_noise\_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point. \[n\_rays, n\_samples\]
alpha = 1.0 - torch.exp\(-nn.functional.relu\(raw\[..., 3\] + noise\) \* dists\)
# 预测每条射线上每个样本的密度。数值越大,表示该点被吸收的可能性越大。\[n\_ 射线,n\_样本]
weights = alpha \* cumprod\_exclusive\(1. - alpha + 1e-10\)
# 计算RGB图的权重。
rgb = torch.sigmoid\(raw\[..., :3\]\) # \[n\_rays, n\_samples, 3\]
rgb\_map = torch.sum\(weights\[..., None\] \* rgb, dim=-2\) # \[n\_rays, 3\]
# 估计预测距离的深度图。
depth\_map = torch.sum\(weights \* z\_vals, dim=-1\)
# 稀疏图
disp\_map = 1. / torch.max\(1e-10 \* torch.ones\_like\(depth\_map\),
depth\_map / torch.sum\(weights, -1\)\)
# 沿着每条射线加权。
acc\_map = torch.sum\(weights, dim=-1\)
# 要合成到白色背景上,请使用累积的 alpha 贴图。
if white\_bkgd:
rgb\_map = rgb\_map + \(1. - acc\_map\[..., None\]\)
return rgb\_map, depth\_map, acc\_map, weights
6 分层体积采样
事实上,三维空间中的遮挡物非常稀疏,因此大多数点对渲染图像的贡献不大。所以,对积分有贡献的区域进行超采样会有更好的效果。这里,笔者对第一组样本应用基于归一化的权重来创建整个光线的概率密度函数,然后对该密度函数应用反变换采样来收集第二组样本。具体代码如下:
# 采样概率密度函数
def sample\_pdf\(
bins: torch.Tensor,
weights: torch.Tensor,
n\_samples: int,
perturb: bool = False
\) -> torch.Tensor:
"""
应用反向转换采样到一组加权点。
"""
# 正则化权重得到概率密度函数。
pdf = \(weights + 1e-5\) / torch.sum\(weights + 1e-5, -1, keepdims=True\) # \[n\_rays, weights.shape\[-1\]\]
# 将概率密度函数转为累计分布函数。
cdf = torch.cumsum\(pdf, dim=-1\) # \[n\_rays, weights.shape\[-1\]\]
cdf = torch.concat\(\[torch.zeros\_like\(cdf\[..., :1\]\), cdf\], dim=-1\) # \[n\_rays, weights.shape\[-1\] + 1\]
# 从累计分布函数中提取样本位置。perturb == 0 时为线性。
if not perturb:
u = torch.linspace\(0., 1., n\_samples, device=cdf.device\)
u = u.expand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\]\) # \[n\_rays, n\_samples\]
else:
u = torch.rand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\], device=cdf.device\) # \[n\_rays, n\_samples\]
# 沿累计分布函数找出 u 值所在的索引。
u = u.contiguous\(\) # 返回具有相同值的连续张量。
inds = torch.searchsorted\(cdf, u, right=True\) # \[n\_rays, n\_samples\]
# 夹住超出范围的索引。
below = torch.clamp\(inds - 1, min=0\)
above = torch.clamp\(inds, max=cdf.shape\[-1\] - 1\)
inds\_g = torch.stack\(\[below, above\], dim=-1\) # \[n\_rays, n\_samples, 2\]
# 从累计分布函数和相应的 bin 中心取样。
matched\_shape = list\(inds\_g.shape\[:-1\]\) + \[cdf.shape\[-1\]\]
cdf\_g = torch.gather\(cdf.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1,
index=inds\_g\)
bins\_g = torch.gather\(bins.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1,
index=inds\_g\)
# 将样本转换为射线长度。
denom = \(cdf\_g\[..., 1\] - cdf\_g\[..., 0\]\)
denom = torch.where\(denom \< 1e-5, torch.ones\_like\(denom\), denom\)
t = \(u - cdf\_g\[..., 0\]\) / denom
samples = bins\_g\[..., 0\] + t \* \(bins\_g\[..., 1\] - bins\_g\[..., 0\]\)
return samples # \[n\_rays, n\_samples\]
7 整体的前向传播流程
此时应将上面所有内容整合在一起,通过模型计算一次前向传递。
由于潜在的内存问题,前向传递以“块“为单位进行计算,然后汇总到一个批次中。梯度传播是在整个批次处理完毕后进行的,因此有“块“和“批次“之分。对于内存紧张环境来说,分块处理尤为重要,因为该环境下提供的资源比原始论文中引用的资源更为有限。具体代码如下所示:
def get\_chunks\(
inputs: torch.Tensor,
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
"""
输入分块。
"""
return \[inputs\[i:i + chunksize\] for i in range\(0, inputs.shape\[0\], chunksize\)\]
def prepare\_chunks\(
points: torch.Tensor,
encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\],
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
"""
对点进行编码和分块,为 NeRF 模型做好准备。
"""
points = points.reshape\(\(-1, 3\)\)
points = encoding\_function\(points\)
points = get\_chunks\(points, chunksize=chunksize\)
return points
def prepare\_viewdirs\_chunks\(
points: torch.Tensor,
rays\_d: torch.Tensor,
encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\],
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
r"""
对视图方向进行编码和分块,为 NeRF 模型做好准备。
"""
viewdirs = rays\_d / torch.norm\(rays\_d, dim=-1, keepdim=True\)
viewdirs = viewdirs\[:, None, ...\].expand\(points.shape\).reshape\(\(-1, 3\)\)
viewdirs = encoding\_function\(viewdirs\)
viewdirs = get\_chunks\(viewdirs, chunksize=chunksize\)
return viewdirs
def nerf\_forward\(
rays\_o: torch.Tensor,
rays\_d: torch.Tensor,
near: float,
far: float,
encoding\_fn: Callable\[\[torch.Tensor\], torch.Tensor\],
coarse\_model: nn.Module,
kwargs\_sample\_stratified: dict = None,
n\_samples\_hierarchical: int = 0,
kwargs\_sample\_hierarchical: dict = None,
fine\_model = None,
viewdirs\_encoding\_fn: Optional\[Callable\[\[torch.Tensor\], torch.Tensor\]\] = None,
chunksize: int = 2\*\*15
\) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, dict\]:
"""
计算一次前向传播
"""
# 设置参数
if kwargs\_sample\_stratified is None:
kwargs\_sample\_stratified = \{\}
if kwargs\_sample\_hierarchical is None:
kwargs\_sample\_hierarchical = \{\}
# 沿着每条射线的样本查询点。
query\_points, z\_vals = sample\_stratified\(
rays\_o, rays\_d, near, far, \*\*kwargs\_sample\_stratified\)
# 准备批次。
batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\)
if viewdirs\_encoding\_fn is not None:
batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d,
viewdirs\_encoding\_fn,
chunksize=chunksize\)
else:
batches\_viewdirs = \[None\] \* len\(batches\)
# 稀疏模型流程。
predictions = \[\]
for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\):
predictions.append\(coarse\_model\(batch, viewdirs=batch\_viewdirs\)\)
raw = torch.cat\(predictions, dim=0\)
raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\)
# 执行可微分体积渲染,重新合成 RGB 图像。
rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals, rays\_d\)
outputs = \{
'z\_vals\_stratified': z\_vals
\}
if n\_samples\_hierarchical > 0:
# Save previous outputs to return.
rgb\_map\_0, depth\_map\_0, acc\_map\_0 = rgb\_map, depth\_map, acc\_map
# 对精细查询点进行分层抽样。
query\_points, z\_vals\_combined, z\_hierarch = sample\_hierarchical\(
rays\_o, rays\_d, z\_vals, weights, n\_samples\_hierarchical,
\*\*kwargs\_sample\_hierarchical\)
# 像以前一样准备输入。
batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\)
if viewdirs\_encoding\_fn is not None:
batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d,
viewdirs\_encoding\_fn,
chunksize=chunksize\)
else:
batches\_viewdirs = \[None\] \* len\(batches\)
# 通过精细模型向前传递新样本。
fine\_model = fine\_model if fine\_model is not None else coarse\_model
predictions = \[\]
for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\):
predictions.append\(fine\_model\(batch, viewdirs=batch\_viewdirs\)\)
raw = torch.cat\(predictions, dim=0\)
raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\)
# 执行可微分体积渲染,重新合成 RGB 图像。
rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals\_combined, rays\_d\)
# 存储输出
outputs\['z\_vals\_hierarchical'\] = z\_hierarch
outputs\['rgb\_map\_0'\] = rgb\_map\_0
outputs\['depth\_map\_0'\] = depth\_map\_0
outputs\['acc\_map\_0'\] = acc\_map\_0
# 存储输出
outputs\['rgb\_map'\] = rgb\_map
outputs\['depth\_map'\] = depth\_map
outputs\['acc\_map'\] = acc\_map
outputs\['weights'\] = weights
return outputs
到这一步骤,就几乎拥有了训练模型所需的一切模块。现在为一个简单的训练过程做一些设置,创建超参数和辅助函数,然后来训练模型。
7.1 超参数
所有用于训练的超参数都在此设置,默认值取自原始论文中数据,除非计算上有限制。在计算受限情况下,本次讨论采用的都是合理的默认值。
# 编码器
d\_input = 3 # 输入维度
n\_freqs = 10 # 输入到编码函数中的样本点数量
log\_space = True # 如果设置,频率按对数空间缩放
use\_viewdirs = True # 如果设置,则使用视图方向作为输入
n\_freqs\_views = 4 # 视图编码功能的数量
# 采样策略
n\_samples = 64 # 每条射线的空间样本数
perturb = True # 如果设置,则对采样位置应用噪声
inverse\_depth = False # 如果设置,则按反深度线性采样点
# 模型
d\_filter = 128 # 线性层滤波器的尺寸
n\_layers = 2 # bottleneck层数量
skip = \[\] # 应用输入残差的层级
use\_fine\_model = True # 如果设置,则创建一个精细模型
d\_filter\_fine = 128 # 精细网络线性层滤波器的尺寸
n\_layers\_fine = 6 # 精细网络瓶颈层数
# 分层采样
n\_samples\_hierarchical = 64 # 每条射线的样本数
perturb\_hierarchical = False # 如果设置,则对采样位置应用噪声
# 优化器
lr = 5e-4 # 学习率
# 训练
n\_iters = 10000
batch\_size = 2\*\*14 # 每个梯度步长的射线数量(2 的幂次)
one\_image\_per\_step = True # 每个梯度步骤一个图像(禁用批处理)
chunksize = 2\*\*14 # 根据需要进行修改,以适应 GPU 内存
center\_crop = True # 裁剪图像的中心部分(每幅图像裁剪一次)
center\_crop\_iters = 50 # 经过这么多epoch后,停止裁剪中心
display\_rate = 25 # 每 X 个epoch显示一次测试输出
# 早停
warmup\_iters = 100 # 热身阶段的迭代次数
warmup\_min\_fitness = 10.0 # 在热身\_iters 处继续训练的最小 PSNR 值
n\_restarts = 10 # 训练停滞时重新开始的次数
# 捆绑了各种函数的参数,以便一次性传递。
kwargs\_sample\_stratified = \{
'n\_samples': n\_samples,
'perturb': perturb,
'inverse\_depth': inverse\_depth
\}
kwargs\_sample\_hierarchical = \{
'perturb': perturb
\}
7.2 训练类和函数
这一环节会创建一些用于训练的辅助函数。NeRF很容易出现局部最小值,在这种情况下,训练很快就会停滞并产生空白输出。必要时,会利用EarlyStopping重新启动训练。
# 绘制采样函数
def plot\_samples\(
z\_vals: torch.Tensor,
z\_hierarch: Optional\[torch.Tensor\] = None,
ax: Optional\[np.ndarray\] = None\):
r"""
绘制分层样本和(可选)分级样本。
"""
y\_vals = 1 + np.zeros\_like\(z\_vals\)
if ax is None:
ax = plt.subplot\(\)
ax.plot\(z\_vals, y\_vals, 'b-o'\)
if z\_hierarch is not None:
y\_hierarch = np.zeros\_like\(z\_hierarch\)
ax.plot\(z\_hierarch, y\_hierarch, 'r-o'\)
ax.set\_ylim\(\[-1, 2\]\)
ax.set\_title\('Stratified Samples \(blue\) and Hierarchical Samples \(red\)'\)
ax.axes.yaxis.set\_visible\(False\)
ax.grid\(True\)
return ax
def crop\_center\(
img: torch.Tensor,
frac: float = 0.5
\) -> torch.Tensor:
r"""
从图像中裁剪中心方形。
"""
h\_offset = round\(img.shape\[0\] \* \(frac / 2\)\)
w\_offset = round\(img.shape\[1\] \* \(frac / 2\)\)
return img\[h\_offset:-h\_offset, w\_offset:-w\_offset\]
class EarlyStopping:
r"""
基于适配标准的早期停止辅助器
"""
def \_\_init\_\_\(
self,
patience: int = 30,
margin: float = 1e-4
\):
self.best\_fitness = 0.0
self.best\_iter = 0
self.margin = margin
self.patience = patience or float\('inf'\) # 在epoch停止提高后等待的停止时间
def \_\_call\_\_\(
self,
iter: int,
fitness: float
\):
r"""
检查是否符合停止标准。
"""
if \(fitness - self.best\_fitness\) > self.margin:
self.best\_iter = iter
self.best\_fitness = fitness
delta = iter - self.best\_iter
stop = delta >= self.patience # 超过耐性则停止训练
return stop
def init\_models\(\):
r"""
为 NeRF 训练初始化模型、编码器和优化器。
"""
# 编码器
encoder = PositionalEncoder\(d\_input, n\_freqs, log\_space=log\_space\)
encode = lambda x: encoder\(x\)
# 视图方向编码
if use\_viewdirs:
encoder\_viewdirs = PositionalEncoder\(d\_input, n\_freqs\_views,
log\_space=log\_space\)
encode\_viewdirs = lambda x: encoder\_viewdirs\(x\)
d\_viewdirs = encoder\_viewdirs.d\_output
else:
encode\_viewdirs = None
d\_viewdirs = None
# 模型
model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip,
d\_viewdirs=d\_viewdirs\)
model.to\(device\)
model\_params = list\(model.parameters\(\)\)
if use\_fine\_model:
fine\_model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip,
d\_viewdirs=d\_viewdirs\)
fine\_model.to\(device\)
model\_params = model\_params + list\(fine\_model.parameters\(\)\)
else:
fine\_model = None
# 优化器
optimizer = torch.optim.Adam\(model\_params, lr=lr\)
# 早停
warmup\_stopper = EarlyStopping\(patience=50\)
return model, fine\_model, encode, encode\_viewdirs, optimizer, warmup\_stopper
7.3 训练循环
下面就是具体的训练循环过程函数:
def train\(\):
r"""
启动 NeRF 训练。
"""
# 对所有图像进行射线洗牌。
if not one\_image\_per\_step:
height, width = images.shape\[1:3\]
all\_rays = torch.stack\(\[torch.stack\(get\_rays\(height, width, focal, p\), 0\)
for p in poses\[:n\_training\]\], 0\)
rays\_rgb = torch.cat\(\[all\_rays, images\[:, None\]\], 1\)
rays\_rgb = torch.permute\(rays\_rgb, \[0, 2, 3, 1, 4\]\)
rays\_rgb = rays\_rgb.reshape\(\[-1, 3, 3\]\)
rays\_rgb = rays\_rgb.type\(torch.float32\)
rays\_rgb = rays\_rgb\[torch.randperm\(rays\_rgb.shape\[0\]\)\]
i\_batch = 0
train\_psnrs = \[\]
val\_psnrs = \[\]
iternums = \[\]
for i in trange\(n\_iters\):
model.train\(\)
if one\_image\_per\_step:
# 随机选择一张图片作为目标。
target\_img\_idx = np.random.randint\(images.shape\[0\]\)
target\_img = images\[target\_img\_idx\].to\(device\)
if center\_crop and i \< center\_crop\_iters:
target\_img = crop\_center\(target\_img\)
height, width = target\_img.shape\[:2\]
target\_pose = poses\[target\_img\_idx\].to\(device\)
rays\_o, rays\_d = get\_rays\(height, width, focal, target\_pose\)
rays\_o = rays\_o.reshape\(\[-1, 3\]\)
rays\_d = rays\_d.reshape\(\[-1, 3\]\)
else:
# 在所有图像上随机显示。
batch = rays\_rgb\[i\_batch:i\_batch + batch\_size\]
batch = torch.transpose\(batch, 0, 1\)
rays\_o, rays\_d, target\_img = batch
height, width = target\_img.shape\[:2\]
i\_batch += batch\_size
# 一个epoch后洗牌
if i\_batch >= rays\_rgb.shape\[0\]:
rays\_rgb = rays\_rgb\[torch.randperm\(rays\_rgb.shape\[0\]\)\]
i\_batch = 0
target\_img = target\_img.reshape\(\[-1, 3\]\)
# 运行 TinyNeRF 的一次迭代,得到渲染后的 RGB 图像。
outputs = nerf\_forward\(rays\_o, rays\_d,
near, far, encode, model,
kwargs\_sample\_stratified=kwargs\_sample\_stratified,
n\_samples\_hierarchical=n\_samples\_hierarchical,
kwargs\_sample\_hierarchical=kwargs\_sample\_hierarchical,
fine\_model=fine\_model,
viewdirs\_encoding\_fn=encode\_viewdirs,
chunksize=chunksize\)
# 检查任何数字问题。
for k, v in outputs.items\(\):
if torch.isnan\(v\).any\(\):
print\(f"\! \[Numerical Alert\] \{k\} contains NaN."\)
if torch.isinf\(v\).any\(\):
print\(f"\! \[Numerical Alert\] \{k\} contains Inf."\)
# 反向传播
rgb\_predicted = outputs\['rgb\_map'\]
loss = torch.nn.functional.mse\_loss\(rgb\_predicted, target\_img\)
loss.backward\(\)
optimizer.step\(\)
optimizer.zero\_grad\(\)
psnr = -10. \* torch.log10\(loss\)
train\_psnrs.append\(psnr.item\(\)\)
# 以给定的显示速率评估测试值。
if i \% display\_rate == 0:
model.eval\(\)
height, width = testimg.shape\[:2\]
rays\_o, rays\_d = get\_rays\(height, width, focal, testpose\)
rays\_o = rays\_o.reshape\(\[-1, 3\]\)
rays\_d = rays\_d.reshape\(\[-1, 3\]\)
outputs = nerf\_forward\(rays\_o, rays\_d,
near, far, encode, model,
kwargs\_sample\_stratified=kwargs\_sample\_stratified,
n\_samples\_hierarchical=n\_samples\_hierarchical,
kwargs\_sample\_hierarchical=kwargs\_sample\_hierarchical,
fine\_model=fine\_model,
viewdirs\_encoding\_fn=encode\_viewdirs,
chunksize=chunksize\)
rgb\_predicted = outputs\['rgb\_map'\]
loss = torch.nn.functional.mse\_loss\(rgb\_predicted, testimg.reshape\(-1, 3\)\)
print\("Loss:", loss.item\(\)\)
val\_psnr = -10. \* torch.log10\(loss\)
val\_psnrs.append\(val\_psnr.item\(\)\)
iternums.append\(i\)
# 绘制输出示例
fig, ax = plt.subplots\(1, 4, figsize=\(24,4\), gridspec\_kw=\{'width\_ratios': \[1, 1, 1, 3\]\}\)
ax\[0\].imshow\(rgb\_predicted.reshape\(\[height, width, 3\]\).detach\(\).cpu\(\).numpy\(\)\)
ax\[0\].set\_title\(f'Iteration: \{i\}'\)
ax\[1\].imshow\(testimg.detach\(\).cpu\(\).numpy\(\)\)
ax\[1\].set\_title\(f'Target'\)
ax\[2\].plot\(range\(0, i + 1\), train\_psnrs, 'r'\)
ax\[2\].plot\(iternums, val\_psnrs, 'b'\)
ax\[2\].set\_title\('PSNR \(train=red, val=blue'\)
z\_vals\_strat = outputs\['z\_vals\_stratified'\].view\(\(-1, n\_samples\)\)
z\_sample\_strat = z\_vals\_strat\[z\_vals\_strat.shape\[0\] // 2\].detach\(\).cpu\(\).numpy\(\)
if 'z\_vals\_hierarchical' in outputs:
z\_vals\_hierarch = outputs\['z\_vals\_hierarchical'\].view\(\(-1, n\_samples\_hierarchical\)\)
z\_sample\_hierarch = z\_vals\_hierarch\[z\_vals\_hierarch.shape\[0\] // 2\].detach\(\).cpu\(\).numpy\(\)
else:
z\_sample\_hierarch = None
\_ = plot\_samples\(z\_sample\_strat, z\_sample\_hierarch, ax=ax\[3\]\)
ax\[3\].margins\(0\)
plt.show\(\)
# 检查 PSNR 是否存在问题,如果发现问题,则停止运行。
if i == warmup\_iters - 1:
if val\_psnr \< warmup\_min\_fitness:
print\(f'Val PSNR \{val\_psnr\} below warmup\_min\_fitness \{warmup\_min\_fitness\}. Stopping...'\)
return False, train\_psnrs, val\_psnrs
elif i \< warmup\_iters:
if warmup\_stopper is not None and warmup\_stopper\(i, psnr\):
print\(f'Train PSNR flatlined at \{psnr\} for \{warmup\_stopper.patience\} iters. Stopping...'\)
return False, train\_psnrs, val\_psnrs
return True, train\_psnrs, val\_psnrs
最终的结果如下图所示:
6|运行结果示意图
引用:
[1]https://www.matthewtancik.com/nerf
[2]http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
[3]https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666
[4]https://medium.com/@rparikshat1998/nerf-from-scratch-fe21c08b145d
四、Pytorch~SimCLR
使用Pytorch实现对比学习SimCLR 进行自监督预训练
这里将深入研究 SimCLR 框架并探索该算法的关键组件,包括数据增强、对比损失函数以及编码器和投影的head 架构。
SimCLR(Simple Framework for Contrastive Learning of Representations)是一种学习图像表示的自监督技术。与传统的监督学习方法不同,SimCLR 不依赖标记数据来学习有用的表示。它利用对比学习框架来学习一组有用的特征,这些特征可以从未标记的图像中捕获高级语义信息。
SimCLR 已被证明在各种图像分类基准上优于最先进的无监督学习方法。并且它学习到的表示可以很容易地转移到下游任务,例如对象检测、语义分割和小样本学习,只需在较小的标记数据集上进行最少的微调。
SimCLR 主要思想是通过增强模块 T 将图像与同一图像的其他增强版本进行对比,从而学习图像的良好表示。这是通过通过编码器网络 f(.) 映射图像,然后进行投影来完成的。head g(.) 将学习到的特征映射到低维空间。然后在同一图像的两个增强版本的表示之间计算对比损失,以鼓励对同一图像的相似表示和对不同图像的不同表示。
我们这里使用来自 Kaggle 的垃圾分类数据集来进行实验。
增强模块
SimCLR 中最重要的就是转换图像的增强模块。SimCLR 论文的作者建议,强大的数据增强对于无监督学习很有用。因此,我们将遵循论文中推荐的方法。
- 调整大小的随机裁剪
- 50% 概率的随机水平翻转
- 随机颜色失真(颜色抖动概率为 80%,颜色下降概率为 20%)
- 50% 概率为随机高斯模糊
def get_complete_transform(output_shape, kernel_size, s=1.0):
"""
Color distortion transform
Args:
s: Strength parameter
Returns:
A color distortion transform
"""
rnd_crop = RandomResizedCrop(output_shape)
rnd_flip = RandomHorizontalFlip(p=0.5)
color_jitter = ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
rnd_color_jitter = RandomApply([color_jitter], p=0.8)
rnd_gray = RandomGrayscale(p=0.2)
gaussian_blur = GaussianBlur(kernel_size=kernel_size)
rnd_gaussian_blur = RandomApply([gaussian_blur], p=0.5)
to_tensor = ToTensor()
image_transform = Compose([
to_tensor,
rnd_crop,
rnd_flip,
rnd_color_jitter,
rnd_gray,
rnd_gaussian_blur,
])
return image_transform
class ContrastiveLearningViewGenerator(object):
"""
Take 2 random crops of 1 image as the query and key.
"""
def __init__(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_views
def __call__(self, x):
views = [self.base_transform(x) for i in range(self.n_views)]
return views
下一步就是定义一个PyTorch 的 Dataset 。
class CustomDataset(Dataset):
def __init__(self, list_images, transform=None):
"""
Args:
list_images (list): List of all the images
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.list_images = list_images
self.transform = transform
def __len__(self):
return len(self.list_images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.list_images[idx]
image = io.imread(img_name)
if self.transform:
image = self.transform(image)
return image
作为样例,我们使用比较小的模型 ResNet18 作为主干,所以他的输入是 224x224 图像,我们按照要求设置一些参数并生成dataloader
out_shape = [224, 224]
kernel_size = [21, 21] # 10% of out_shape
# Custom transform
base_transforms = get_complete_transform(output_shape=out_shape, kernel_size=kernel_size, s=1.0)
custom_transform = ContrastiveLearningViewGenerator(base_transform=base_transforms)
garbage_ds = CustomDataset(
list_images=glob.glob("/kaggle/input/garbage-classification/garbage_classification/*/*.jpg"),
transform=custom_transform
)
BATCH_SZ = 128
# Build DataLoader
train_dl = torch.utils.data.DataLoader(
garbage_ds,
batch_size=BATCH_SZ,
shuffle=True,
drop_last=True,
pin_memory=True)
SimCLR
我们已经准备好了数据,开始对模型进行复现。上面的增强模块提供了图像的两个增强视图,它们通过编码器前向传递以获得相应的表示。SimCLR 的目标是通过鼓励模型从两个不同的增强视图中学习对象的一般表示来最大化这些不同学习表示之间的相似性。编码器网络的选择不受限制,可以是任何架构。上面已经说了,为了简单演示,我们使用 ResNet18。编码器模型学习到的表示决定了相似性系数,为了提高这些表示的质量,SimCLR 使用投影头将编码向量投影到更丰富的潜在空间中。这里我们将ResNet18的512维度的特征投影到256的空间中,看着很复杂,其实就是加了一个带relu的mlp。
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SimCLR(nn.Module):
def __init__(self, linear_eval=False):
super().__init__()
self.linear_eval = linear_eval
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = Identity()
self.encoder = resnet18
self.projection = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 256)
)
def forward(self, x):
if not self.linear_eval:
x = torch.cat(x, dim=0)
encoding = self.encoder(x)
projection = self.projection(encoding)
return projection
对比损失
对比损失函数,也称为归一化温度标度交叉熵损失 (NT-Xent),是 SimCLR 的一个关键组成部分,它鼓励模型学习相同图像的相似表示和不同图像的不同表示。
NT-Xent 损失是使用一对通过编码器网络传递的图像的增强视图来计算的,以获得它们相应的表示。对比损失的目标是鼓励同一图像的两个增强视图的表示相似,同时迫使不同图像的表示不相似。
NT-Xent 将 softmax 函数应用于增强视图表示的成对相似性。softmax 函数应用于小批量内的所有表示对,得到每个图像的相似性概率分布。温度参数temperature 用于在应用 softmax 函数之前缩放成对相似性,这有助于在优化过程中获得更好的梯度。
在获得相似性的概率分布后,通过最大化同一图像的匹配表示的对数似然和最小化不同图像的不匹配表示的对数似然来计算 NT-Xent 损失。
LABELS = torch.cat([torch.arange(BATCH_SZ) for i in range(2)], dim=0)
LABELS = (LABELS.unsqueeze(0) == LABELS.unsqueeze(1)).float() #one-hot representations
LABELS = LABELS.to(DEVICE)
def ntxent_loss(features, temp):
"""
NT-Xent Loss.
Args:
z1: The learned representations from first branch of projection head
z2: The learned representations from second branch of projection head
Returns:
Loss
"""
similarity_matrix = torch.matmul(features, features.T)
mask = torch.eye(LABELS.shape[0], dtype=torch.bool).to(DEVICE)
labels = LABELS[~mask].view(LABELS.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)
logits = logits / temp
return logits, labels
所有的准备都完成了,让我们训练 SimCLR 看看效果!
simclr_model = SimCLR().to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(simclr_model.parameters())
epochs = 10
with tqdm(total=epochs) as pbar:
for epoch in range(epochs):
t0 = time.time()
running_loss = 0.0
for i, views in enumerate(train_dl):
projections = simclr_model([view.to(DEVICE) for view in views])
logits, labels = ntxent_loss(projections, temp=2)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print stats
running_loss += loss.item()
if i%10 == 9: # print every 10 mini-batches
print(f"Epoch: {epoch+1} Batch: {i+1} Loss: {(running_loss/100):.4f}")
running_loss = 0.0
pbar.update(1)
print(f"Time taken: {((time.time()-t0)/60):.3f} mins")
上面代码训练了10轮,假设我们已经完成了预训练过程,可以将预训练的编码器用于我们想要的下游任务。这可以通过下面的代码来完成。
from torchvision.transforms import Resize, CenterCrop
resize = Resize(255)
ccrop = CenterCrop(224)
ttensor = ToTensor()
custom_transform = Compose([
resize,
ccrop,
ttensor,
])
garbage_ds = ImageFolder(
root="/kaggle/input/garbage-classification/garbage_classification/",
transform=custom_transform
)
classes = len(garbage_ds.classes)
BATCH_SZ = 128
train_dl = torch.utils.data.DataLoader(
garbage_ds,
batch_size=BATCH_SZ,
shuffle=True,
drop_last=True,
pin_memory=True,
)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class LinearEvaluation(nn.Module):
def __init__(self, model, classes):
super().__init__()
simclr = model
simclr.linear_eval=True
simclr.projection = Identity()
self.simclr = simclr
for param in self.simclr.parameters():
param.requires_grad = False
self.linear = nn.Linear(512, classes)
def forward(self, x):
encoding = self.simclr(x)
pred = self.linear(encoding)
return pred
eval_model = LinearEvaluation(simclr_model, classes).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(eval_model.parameters())
preds, labels = [], []
correct, total = 0, 0
with torch.no_grad():
t0 = time.time()
for img, gt in tqdm(train_dl):
image = img.to(DEVICE)
label = gt.to(DEVICE)
pred = eval_model(image)
_, pred = torch.max(pred.data, 1)
total += label.size(0)
correct += (pred == label).float().sum().item()
print(f"Time taken: {((time.time()-t0)/60):.3f} mins")
print(
"Accuracy of the network on the {} Train images: {} %".format(
total, 100 * correct / total
)
)
上面的代码最主要的部分就是读取刚刚训练的simclr模型,然后冻结所有的权重,然后再创建一个分类头self.linear ,进行下游的分类任务
总结
本文介绍了SimCLR框架,并使用它来预训练随机初始化权重的ResNet18。预训练是深度学习中使用的一种强大的技术,用于在大型数据集上训练模型,学习可以转移到其他任务中的有用特征。SimCLR论文认为,批量越大,性能越好。我们的实现只使用128个批大小,只训练10个epoch。所以这不是模型的最佳性能,如果需要性能对比还需要进一步的训练。
下图是论文作者给出的性能结论:
论文地址:https://arxiv.org/abs/2002.05709