摘要
分类器指导是最近引入的一种方法,在训练后在条件扩散模型中权衡模式覆盖率和样本保真度,在精神上与其他类型的生成模型中的低温采样或截断相同。分类器引导将扩散模型的分数估计与图像分类器的梯度相结合,因此需要训练与扩散模型分离的图像分类器。它还提出了一个问题,即是否可以在没有分类器的情况下执行指导。我们表明,如果没有这样的分类器,引导确实可以由纯生成模型执行:在我们称之为无分类器指导的情况下,我们联合训练一个条件扩散模型和一个无条件扩散模型,我们结合得到的条件和无条件分数估计来获得样本质量和多样性之间的权衡,类似于使用分类器指导获得的。
引言
分类器引导将扩散模型的分数估计与a的对数概率的输入梯度混合在一起的分类器。通过改变分类器梯度的强度,Dhariwal和Nichol可以权衡Inception分数(Salimans等人,2016)和FID分数(Heusel等人,2017)(或精度和召回率),其方式类似于改变BigGAN的截断参数。
针对64x64 ImageNet扩散模型的malamute类的无分类器指导。从左到右:增加无分类器引导的数量,从左边的非引导样本开始。
引导器对三个高斯混合的影响,每个混合分量代表一个类条件下的数据。最左边的图是非引导的边际密度。从左到右是随引导强度增加的归一化引导条件的混合密度。
我们感兴趣的是是否可以在没有分类器的情况下进行分类器引导。分类器引导使扩散模型训练管道变得复杂,因为它需要训练一个额外的分类器,而且这个分类器必须在有噪声的数据上训练,所以通常不可能插入一个预训练的分类器。此外,由于分类器引导在采样期间混合了分数估计和分类器梯度,分类器引导的扩散采样可以被解释为试图将图像分类器与基于梯度的对抗性攻击混淆。这就提出了一个问题:分类器指导是否能够成功地提高基于分类器的度量,比如FID和Inception分数(is),仅仅是因为它与这些分类器是对立的。在分类器梯度方向上的步进也与GAN训练有一些相似之处,特别是使用非参数生成器;这也提出了一个问题,即分类器引导的扩散模型是否在基于分类器的指标上表现良好,因为它们开始类似于gan,而gan已经在这些指标上表现良好。
为了解决这些问题,我们提出了无分类器的引导方法,即完全避免使用任何分类器的引导方法。与在图像分类器的梯度方向上采样不同,无分类器引导混合了条件扩散模型和联合训练的无条件扩散模型的分数估计。通过扫过混合权重,我们获得了类似于分类器引导所获得的FID/IS权衡。我们的无分类器引导结果表明,纯生成扩散模型能够与其他类型的生成模型合成极高保真度的样本。
算法
虽然分类器指导成功地权衡了截断或低温采样预期的IS和FID,但它仍然依赖于图像分类器的梯度,我们试图消除分类器,原因见第1节。在这里,我们描述了无分类器指导,在没有这种梯度的情况下实现了相同的效果。无分类器指导是修改与分类器引导具有相同的效果的替代方法,但没有分类器。算法 1 和 2 详细描述了使用无分类器指导进行训练和采样。
算法1
训练时
算法2
推理时
讨论
我们的无分类器指导方法最实际的优点是它极其简单:在训练期间(随机放弃条件)和在采样期间(混合条件和无条件分数估计),只需要对代码进行一行更改。相比之下,分类器引导使训练管道变得复杂,因为它需要训练额外的分类器。这个分类器必须在有噪声的zλ上进行训练,因此不可能插入一个标准的预训练分类器。
代码片段
if args.conditioning_dropout_prob is not None:
random_p = torch.rand(bsz, device=latents.device, generator=generator)
# Sample masks for the edit prompts.
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
# Final text conditioning.
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
# Sample masks for the original images.
image_mask_dtype = original_image_embeds.dtype
image_mask = 1 - ((random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
* (random_p < * args.conditioning_dropout_prob).to(image_mask_dtype)
)
image_mask = image_mask.reshape(bsz, 1, 1, 1)
# Final image conditioning.
original_image_embeds = image_mask * original_image_embeds
# Concatenate the `original_image_embeds` with the `noisy_latents`.
concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
参考
diffusers/src/diffusers/training_utils.py at main · huggingface/diffusers · GitHub