1.似然函数最大化
扩散模型的训练目标是负的对数似然的一个变分下界(VLB)。在本节中,我们总结并调查最近关于扩散模型的似然最大化的工作。首先我们介绍似然函数最大化的意义,然后重点讨论3种类型的方法:噪声调度优化、逆向方差学习和精确的对数似然估计。需要注意的是,目前扩散模型的似然提高方法是通过改善负对数似然的VLB 实现的,不能像归一化流(NormalizingFlow)那样直接改善似然函数值。
在生成模型中,我们认为真实世界的一个个数据是某个随机变量一个一个实现的。为了生成趋于真实的数据,我们希望能够学习到真实数据的分布q,然后通过模拟这个分布来生成新样本。所以我们会建立深度学习模型来对分布q进行参数化和学习。似然函数指的是,数据点在模型中的概率密度函数值即p(x,)所组成的函数,其中x是数据点,是参数,p(·,)是模型在参数下的生成样本的分布。似然函数是一个关于模型参数0的函数,当选择不同的参数0时,似然函数的值是不同的,它描述了在当前参数下,使用模型分布p(x,)产生数据集中所有样本的概率。一个朴素的想法是,在最好的模型参数下,产生数据集中的所有样本的概率是最大的。但在计算机中,多个概率的乘积结果并不方便计算和储存,例如,在计算过程中可能发生数值下溢的问题,即对比较小的、接近于0的数进行四舍五入后成为0。我们可以对似然函数取对数来缓解该问题,即L=log[L],并且仍然求解最好的模型参数ml,使对数似然函数最大:ml argmaxl。可以证明这两者是等价的。在统计学中,参数往往有明确的含义,所以,人们希望知道参数的取值及其置信区间。通过数学推导可以证明,假设数据真实分布是p(x,*),那么在一定的正则条件下ml是*的相合估计,即√n(ml-*)有渐进正态性,并且是渐进最优(渐进有效)的。
但是对于深度学习来说,参数并不一定是可识别的,并且因为深度学习中参数往往没有具体含义,所以我们常常不关心具体的取值。但我们仍然希望能够让似然函数以某种形式最大化,这是因为似然函数的最大化可以视作对模型的分布p和真实的数据分布q做匹配。但是直接计算是非常复杂的,经过一系列等价做近似计算。有的人可能会注意到q的真实分布是我们不知道的,所以没办法显式地计算这个KL散度,但是在数据量较大的情况下可以通过蒙特卡罗方法来模拟。这也是扩散模型最常用的损失函数,不仅如此,基于能量的模型、VAE、归一化流的训练方式都采用的最大化似然方式。GAN的训练方式也是在匹配模型分布和数据分布,但不是通过最大化似然的方式,而是使用GAN的判别器(test function)来评判两个分布的区别。这就导致 GAN 会出现模式崩溃的情况,即产生的样本单一。而最大化似然的方式就不会出现这个问题,因为它强制模型考虑到所有数据点。下面我们介绍如何提高扩散模型的似然值从而获得高质量、多样性的样本。
4.2 加噪策略优化
在扩散模型中,我们希望优化生成样本分布的对数似然,也就是Eq0logp0,其中q0是真实样本的分布,p0是生成的样本的分布。这等价于最小化q0与p0之间的 KL散度Dkl(p0||q0)。但直接计算KL散度是很难处理的,因为在扩散模型中样本是迭代生成的,一般一个样本就需要几百甚至上千次计算。所以为了提高计算效率,我们转而优化Dkl(p||q),这里p是整个前向加噪过程的分布,q是整个逆向去噪过程的分布。根据 KL 散度的性质,可以证明Dk(pllq)是Dkl(p0||q0)的一个上界,即可以通过减小Dkl(pllq)近似优化生成样本的似然。在经典的扩散模型(如DDPM)中,前向过程中的噪声进程是手工调试的,没有可训练的参数。也就是说,唯一能做的事就是学习p的分布使其与qπ匹配。如果q选择得不好,比如加噪的进度过快导致信息丢失过多,那么会导致p难以通过学习的方式匹配q。从最优传输的角度来看,q和p是匹配数据分布q0和先验分布的一座桥梁,而事实上能够匹配数据分布q0和先验分布的随机过程有无限多个。所以我们会期望能够优化或者学习前向过程q,从而使学习p更简单,二者的KL散度更小。通过优化前向噪声的进程和扩散模型的其他参数,人们可以进一步最大化 VLB,以获得更高的对数似然值。
iDDPM的工作表明,经典 DDPM 中的线性噪声在加噪的后期加噪程度过快,导致信息快速丢失,逆向去噪过程就会难以复原丢失的信息。而某种余弦加噪策略可以让信息丢失的速率更平缓,容易复原,从而改善模型的对数似然值。
在变分扩散模型(Variational Diffusion Model,VDM)中,Kingma等人提出通过联合训练加噪策略和其他扩散模型参数来最大化 VLB,从而提高连续时间扩散模型的似然函数值。VDM 使用单调神经网络(t)对加噪策略进行参数化,其中表示单调神经网络中可学习的参数。此外,Kingma等人还证明了在连续时间的情形下(T趋于正无穷),数据点x的VLB可以简化为只取决于信噪比。VDM对前向过程的学习也可以表示为对信噪比的学习。