潜在一致性模型:通过少步推理合成高分辨率图像
Paper Title: Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference
Paper是清华发表在Arxiv 2023的工作
Paper地址
Code地址
ABSTRACT
潜在扩散模型 (LDM) 在合成高分辨率图像方面取得了显著成果。然而,迭代采样过程计算量大,导致生成速度慢。受一致性模型 (Song et al, 2023) 的启发,我们提出了潜在一致性模型 (LCM),能够在任何预训练的 LDM(包括稳定扩散 (Rombach et al, 2022))上以最少的步骤进行快速推理。将引导的反向扩散过程视为求解增强概率流 ODE (PF-ODE),LCM 旨在直接预测此类 ODE 在潜在空间中的解,从而减少大量迭代的需要并实现快速、高保真采样。从预训练的无分类器引导扩散模型中高效提炼而来,高质量的 768×768 2∼4 步 LCM 仅需 32 个 A100 GPU 小时的训练时间。此外,我们引入了潜在一致性微调 (LCF),这是一种新方法,专门用于在定制图像数据集上微调 LCM。在 LAION-5B-Aesthetics 数据集上的评估表明,LCM 仅需几步推理即可实现最先进的文本到图像生成性能。项目页面:地址
1 INTRODUCTION
扩散模型已成为强大的生成模型,在各个领域引起了广泛关注并取得了显着成果(Ho et al, 2020; Song et al, 2020a; Nichol & Dhariwal, 2021; Ramesh et al, 2022; Song & Ermon, 2019; Song et al, 2021)。特别是,潜在扩散模型(LDM)(例如稳定扩散(Rombach et al, 2022))表现出色,尤其是在高分辨率文本到图像合成任务中。LDM 可以通过利用对样本进行逐步去噪的迭代反向采样过程来生成以文本描述为条件的高质量图像。然而,扩散模型有一个明显的缺点:迭代反向采样过程导致生成速度慢,限制了它们的实时适用性。为了克服这个缺点,研究人员提出了几种提高采样速度的方法,包括通过增强 ODE 求解器来加速去噪过程 (Ho et al, 2020; Lu et al, 2022a;b),可以在 10∼20 个采样步骤内生成图像。另一种方法是将预先训练的扩散模型蒸馏成能够进行几步推理的模型 Salimans & Ho (2022); Meng et al (2023)。特别是,Meng et al (2023) 提出了一种两阶段蒸馏方法来提高无分类器引导模型的采样效率。最近,Song et al (2023) 提出了一致性模型作为一种有前途的替代方案,旨在加快生成过程。通过学习在 ODE 轨迹上保持点一致性的一致性映射,这些模型允许单步生成,从而无需计算密集型迭代。然而,Song 等人 (2023) 受限于像素空间图像生成任务,因此不适合合成高分辨率图像。此外,条件扩散模型的应用和无分类器引导的结合尚未得到探索,因此他们的方法不适合文本到图像的生成合成。
图 1:CFG 尺度为 ω = 8.0 的潜在一致性模型 (LCM) 生成的图像。只需 4,000 个训练步骤(约 32 个 A100 GPU 小时)即可从任何预训练的稳定扩散 (SD) 中提取 LCM,以在 2∼4 步甚至一步内生成高质量的 768×768 分辨率图像,从而显著加快文本到图像的生成。我们使用 LCM 仅用 4,000 次训练迭代就提取了 Dreamer-V7 版本的 SD。
在本文中,我们介绍了用于快速、高分辨率图像生成的潜在一致性模型 (LCM)。类似LDM,我们在稳定扩散 (Rombach 等人,2022) 的预训练自编码器的图像潜在空间中采用一致性模型。我们提出了一种单阶段引导蒸馏方法,通过求解增强 PF-ODE 将预训练的引导扩散模型有效地转换为潜在一致性模型。此外,我们提出了潜在一致性微调,它允许对预训练的 LCM 进行微调,以支持对定制图像数据集进行几步推理。我们的主要贡献总结如下:
-
我们提出了潜在一致性模型 (LCM),用于快速、高分辨率图像生成。LCM 在图像潜在空间中采用一致性模型,能够在预先训练的潜在扩散模型(例如稳定扩散 (SD))上实现快速的几步甚至一步高保真采样。
-
我们提供了一种简单高效的单阶段引导一致性蒸馏方法,用于蒸馏几步(2∼4)甚至 1 步采样的 SD。我们提出了 SKIPPING-STEP 技术来进一步加速收敛。对于 2 步和 4 步推理,我们的方法仅花费 32 个 A100 GPU 小时进行训练,并在 LAION-5B-Aesthetics 数据集上实现了最佳性能。
-
我们为 LCM 引入了一种新的微调方法,称为潜在一致性微调,可以使预训练的 LCM 有效地适应定制数据集,同时保留快速推理的能力。
2 RELATED WORK
扩散模型在图像生成方面取得了巨大成功(Ho et al, 2020; Song et al, 2020a; Nichol & Dhariwal, 2021; Ramesh et al, 2022; Rombach et al, 2022; Song & Ermon, 2019)。
它们经过训练,可以对受噪声破坏的数据进行去噪,以估计数据分布的分数。 在推理过程中,通过运行反向扩散过程来抽取样本,以逐渐对数据点进行去噪。 与 VAE(Kingma & Welling, 2013; Sohn et al, 2015)和 GAN(Goodfellow et al, 2020)相比,扩散模型具有训练稳定性和更好的似然估计的优势。
加速 DM。 然而,扩散模型的瓶颈在于其缓慢的生成速度。
已经提出了各种方法,包括无需训练的方法,例如 ODE 求解器(Song 等人,2020a;Lu 等人,2022a;b)、自适应步长求解器(Jolicoeur-Martineau 等人,2021)、预测校正方法(Song 等人,2020b)。基于训练的方法包括优化离散化(Watson 等人,2021)、截断扩散(Lyu 等人,2022;Zheng 等人,2022)、神经算子(Zheng 等人,2023)和蒸馏(Salimans & Ho,2022;Meng 等人,2023)。最近,还提出了用于更快采样的新型生成模型(Liu 等人,2022;2023)。
潜在扩散模型 (LDM) (Rombach 等人,2022) 在合成高分辨率文本到图像方面表现出色。例如,稳定扩散 (SD) 在数据潜在空间中执行正向和反向扩散过程,从而提高计算效率。
一致性模型 (CM) (Song 等人,2023) 作为一种新型生成模型,在保持生成质量的同时实现更快的采样,显示出巨大的潜力。CM 采用一致性映射将 ODE 轨迹中的任何点直接映射到其原点,从而实现快速的一步生成。CM 可以通过蒸馏预训练的扩散模型或作为独立的生成模型进行训练。CM 的细节将在下一节中详细说明。
3 PRELIMINARIES
在本节中,我们简要回顾了扩散模型和一致性模型,并定义了相关的符号。
扩散模型:扩散模型,或基于评分的生成模型 (Ho et al., 2020; Song et al., 2020a),是一类生成模型,它通过逐步向数据中注入高斯噪声,然后通过反向去噪过程从噪声中生成样本。具体来说,扩散模型定义了一个将原始数据分布 p data ( x ) p_{\text{data}}(x) pdata(x) 转换为边缘分布 q t ( x t ) q_t\left(\boldsymbol{x}_t\right) qt(xt) 的前向过程,通过转移核: q 0 t ( x t ∣ x 0 ) = N ( x t ∣ α ( t ) x 0 , σ 2 ( t ) I ) q_{0 t}\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)=\mathcal{N}\left(\boldsymbol{x}_t \mid \alpha(t) \boldsymbol{x}_0, \sigma^2(t) \boldsymbol{I}\right) q0t(xt∣x0)=N(xt∣α(t)x0,σ2(t)I),其中 α ( t ) , σ ( t ) \alpha(t), \sigma(t) α(t),σ(t) 指定了噪声时间表。从连续时间的角度来看,前向过程可以通过随机微分方程 (SDE) 来描述 (Song et al., 2020b; Lu et al., 2022a; Karras et al., 2022) 对于 t ∈ [ 0 , T ] t \in [0, T] t∈[0,T]:
d x t = f ( t ) x t d t + g ( t ) d w t , x 0 ∼ p data ( x 0 ) \mathrm{d} \boldsymbol{x}_t=f(t) \boldsymbol{x}_t \mathrm{~d} t+g(t) \mathrm{d} \boldsymbol{w}_t, \boldsymbol{x}_0 \sim p_{\text {data }}\left(\boldsymbol{x}_0\right) dxt=f(t)xt dt+g(t)dwt,x0∼pdata (x0)
其中 w t \boldsymbol{w}_t wt 是标准布朗运动,并且
f ( t ) = d log α ( t ) d t , g 2 ( t ) = d σ 2 ( t ) d t − 2 d log α ( t ) d t σ 2 ( t ) ( 1 ) f(t)=\frac{\mathrm{d} \log \alpha(t)}{\mathrm{d} t}, \quad g^2(t)=\frac{\mathrm{d} \sigma^2(t)}{\mathrm{d} t}-2 \frac{\mathrm{d} \log \alpha(t)}{\mathrm{d} t} \sigma^2(t)\quad(1) f(t)=dtdlogα(t),g2(t)=dtdσ2(t)−2dtdlogα(t)σ2(t)(1)
通过考虑反向时间SDE(更多详细信息请参见附录A),可以证明边缘分布 q t ( x ) q_t(\boldsymbol{x}) qt(x) 满足以下常微分方程,称为概率流ODE (PF-ODE) (Song et al., 2020b; Lu et al., 2022a):
d x t d t = f ( t ) x t − 1 2 g 2 ( t ) ∇ x log q t ( x t ) , x T ∼ q T ( x T ) ( 2 ) \frac{\mathrm{d} \boldsymbol{x}_t}{\mathrm{~d} t}=f(t) \boldsymbol{x}_t-\frac{1}{2} g^2(t) \nabla_{\boldsymbol{x}} \log q_t\left(\boldsymbol{x}_t\right), \boldsymbol{x}_T \sim q_T\left(\boldsymbol{x}_T\right)\quad(2) dtdxt=f(t)xt−21g2(t)∇xlogqt(xt),xT∼qT(xT)(2)
在扩散模型中,我们训练噪声预测模型 ϵ θ ( x t , t ) \boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right) ϵθ(xt,t) 以拟合 − ∇ log q t ( x t ) -\nabla \log q_t\left(\boldsymbol{x}_t\right) −∇logqt(xt)(称为得分函数)。通过使用噪声预测模型来近似得分函数,可以得到以下用于采样的经验 P F − O D E PF-ODE PF−ODE:
d x t d t = f ( t ) x t + g 2 ( t ) 2 σ t ϵ θ ( x t , t ) , x T ∼ N ( 0 , σ ~ 2 I ) ( 3 ) \frac{\mathrm{d} \boldsymbol{x}_t}{\mathrm{~d} t}=f(t) \boldsymbol{x}_t+\frac{g^2(t)}{2 \sigma_t} \boldsymbol{\epsilon}_\theta\left(\boldsymbol{x}_t, t\right), \quad \boldsymbol{x}_T \sim \mathcal{N}\left(\mathbf{0}, \tilde{\sigma}^2 \boldsymbol{I}\right)\quad(3) dtdxt=f(t)xt+2σtg2(t)ϵθ(xt,t),xT∼N(0,σ~2I)(3)
对于类条件扩散模型,Classifier-Free Guidance (CFG) (Ho & Salimans, 2022) 是一种有效的技术,可以显著提高生成样本的质量,并已广泛应用于多个大规模扩散模型中,包括GLIDE (Nichol et al., 2021)、Stable Diffusion (Rombach et al., 2022)、DALL·E 2 (Ramesh et al., 2022) 和 Imagen (Saharia et al., 2022)。给定CFG尺度 ω \omega ω,原始噪声预测被替换为条件和无条件噪声预测的线性组合,即 ϵ ~ θ ( z t , ω , c , t ) = ( 1 + ω ) ϵ θ ( z t , c , t ) − ω ϵ θ ( z , ∅ , t ) \tilde{\boldsymbol{\epsilon}}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_t, \omega, \boldsymbol{c}, t\right)=(1+\omega) \boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right)-\omega \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{z}, \varnothing, t) ϵ~θ(zt,ω,c,t)=(1+ω)ϵθ(zt,c,t)−ωϵθ(z,∅,t)。
一致性模型:一致性模型 (CM) (Song et al., 2023) 是一类新的生成模型,可以实现一步或少步生成。CM的核心思想是学习将PF-ODE轨迹上的任意点映射到该轨迹的起点(即PF-ODE的解)的函数。更正式地,一致性函数定义为 f : ( x t , t ) ⟼ x ϵ f:\left(x_t, t\right) \longmapsto x_\epsilon f:(xt,t)⟼xϵ,其中 ϵ \epsilon ϵ 是一个固定的小正数。一个重要的观察是,一致性函数应满足自一致性属性:
f ( x t , t ) = f ( x t ′ , t ′ ) , ∀ t , t ′ ∈ [ ϵ , T ] ( 4 ) \boldsymbol{f}\left(\boldsymbol{x}_t, t\right)=\boldsymbol{f}\left(\boldsymbol{x}_{t^{\prime}}, t^{\prime}\right), \forall t, t^{\prime} \in[\epsilon, T]\quad(4) f(xt,t)=f(xt′,t′),∀t,t′∈[ϵ,T](4)
(Song et al., 2023) 中学习一致性模型 f θ f_\theta fθ 的关键思想是通过有效地在公式4中强制执行自一致性属性来从数据中学习一致性函数。为了确保 f θ ( x , ϵ ) = x f_\theta(x, \epsilon)=x fθ(x,ϵ)=x,一致性模型 f θ f_\theta fθ 被参数化为:
f θ ( x , t ) = c skip ( t ) x + c out ( t ) F θ ( x , t ) ( 5 ) \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}, t)=c_{\text {skip }}(t) \boldsymbol{x}+c_{\text {out }}(t) \boldsymbol{F}_{\boldsymbol{\theta}}(\boldsymbol{x}, t)\quad(5) fθ(x,t)=cskip (t)x+cout (t)Fθ(x,t)(5)
其中 c skip ( t ) c_{\text {skip }}(t) cskip (t) 和 c out ( t ) c_{\text {out }}(t) cout (t) 是可微函数,并且 c skip ( ϵ ) = 1 c_{\text {skip }}(\epsilon)=1 cskip (ϵ)=1 和 c out ( ϵ ) = 0 c_{\text {out }}(\epsilon)=0 cout (ϵ)=0,而 F θ ( x , t ) \boldsymbol{F}_{\boldsymbol{\theta}}(\boldsymbol{x}, t) Fθ(x,t) 是一个深度神经网络。CM可以通过从预训练的扩散模型中提炼(称为一致性提炼)或从头开始训练来获得。为了强制执行自一致性属性,我们维护一个目标模型 θ − \boldsymbol{\theta}^{-} θ−,通过参数 θ \boldsymbol{\theta} θ 的指数移动平均 (EMA) 更新,即 θ − ← μ θ − + ( 1 − μ ) θ \boldsymbol{\theta}^{-} \leftarrow \mu \boldsymbol{\theta}^{-}+(1-\mu) \boldsymbol{\theta} θ−←μθ−+(1−μ)θ,并定义一致性损失为:
L ( θ , θ − ; Φ ) = E x , t [ d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x ^ i n ϕ , t n ) ) ] ( 6 ) \mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{-} ; \Phi\right)=\mathbb{E}_{\boldsymbol{x}, t}\left[d\left(\boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t_{n+1}}, t_{n+1}\right), \boldsymbol{f}_{\boldsymbol{\theta}^{-}}\left(\hat{\boldsymbol{x}}_{i_n}^\phi, t_n\right)\right)\right]\quad(6) L(θ,θ−;Φ)=Ex,t[d(fθ(xtn+1,tn+1),fθ−(x^inϕ,tn))](6)
其中 d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(⋅,⋅) 是用于测量两个样本之间距离的度量函数,例如平方 ℓ 2 \ell_2 ℓ2 距离 d ( x , y ) = ∥ x − y ∥ 2 2 d(\boldsymbol{x}, \boldsymbol{y})=\|\boldsymbol{x}-\boldsymbol{y}\|_2^2 d(x,y)=∥x−y∥22。 x ^ t n ϕ \hat{\boldsymbol{x}}_{t_n}^\phi x^tnϕ 是 x t n \boldsymbol{x}_{t_n} xtn 的一步估计:
x ^ t n ϕ ← x t n + 1 + ( t n − t n + 1 ) Φ ( x t n + 1 , t n + 1 ; ϕ ) ( 7 ) \hat{\boldsymbol{x}}_{t_n}^\phi \leftarrow \boldsymbol{x}_{t_{n+1}}+\left(t_n-t_{n+1}\right) \Phi\left(\boldsymbol{x}_{t_{n+1}}, t_{n+1} ; \phi\right)\quad(7) x^tnϕ←xtn+1+(tn−tn+1)Φ(xtn+1,tn+1;ϕ)(7)
其中 Φ \Phi Φ 表示应用于公式24中的PF-ODE的一步ODE求解器。 (Song et al., 2023) 使用Euler (Song et al., 2020b) 或 Heun求解器 (Karras et al., 2022) 作为数值ODE求解器。更多详细信息和一致性蒸馏的伪代码(算法2)请参见附录A。
4 LATENT CONSISTENCY MODELS
一致性模型 (CMs) (Song et al., 2023) 仅专注于 ImageNet 64 × 64 64 \times 64 64×64 (Deng et al., 2009) 和 LSUN 256 × 256 256 \times 256 256×256 (Yu et al., 2015) 上的图像生成任务。CMs 在生成更高分辨率的文本到图像任务上的潜力仍未被探索。在本文中,我们在第4.1节中引入了潜在一致性模型 (LCMs) 来应对这些更具挑战性的任务,释放CMs的潜力。与LDMs类似,我们的LCMs在图像潜在空间中采用一致性模型。我们选择强大的Stable Diffusion (SD) 作为基础扩散模型进行蒸馏。我们的目标是在不损失图像质量的前提下,实现SD的少步 ( 2 ∼ 4 ) (2 \sim 4) (2∼4) 甚至一步推理。分类器自由引导 (CFG) (Ho & Salimans, 2022) 是一种有效的技术,可以进一步提高样本质量,并广泛应用于SD中。然而,其在CMs中的应用尚未被探索。我们在第4.2节中提出了一种简单的单阶段引导蒸馏方法,通过求解增强的PF-ODE,有效地将CFG整合到LCM中。在第4.3节中,我们提出了跳步技术 (SKIPPING-STEP) 来加速LCMs的收敛。最后,我们在第4.4节中提出潜在一致性微调 (Latent Consistency Fine-tuning) 以在自定义数据集上微调预训练的LCM,实现少步推理。
4.1 CONSISTENCY DISTILLATION IN THE LATENT SPACE
在大规模扩散模型如Stable Diffusion (SD) (Rombach et al., 2022)中利用图像潜在空间有效提升了图像生成质量并减少了计算负载。在SD中,首先训练一个自编码器 ( E , D ) (\mathcal{E}, \mathcal{D}) (E,D),将高维图像数据压缩为低维潜在向量 z = E ( x ) z=\mathcal{E}(x) z=E(x),然后解码以重建图像为 x ^ = D ( z ) \hat{x}=\mathcal{D}(z) x^=D(z)。在潜在空间中训练扩散模型相比于基于像素的模型大大降低了计算成本并加快了推理过程;LDMs使得在笔记本GPU上生成高分辨率图像成为可能。对于LCMs,我们利用潜在空间的一致性蒸馏优势,与CMs (Song et al., 2023)中使用的像素空间形成对比。这种方法称为潜在一致性蒸馏 (LCD),应用于预训练的SD,使得在 1 ∼ 4 1 \sim 4 1∼4步内合成高分辨率 (例如 768 × 768 768 \times 768 768×768) 图像。我们专注于条件生成。回顾逆扩散过程的PF-ODE (Song et al., 2020b; Lu et al., 2022a) 为
d z t d t = f ( t ) z t + g 2 ( t ) 2 σ t ϵ θ ( z t , c , t ) , z T ∼ N ( 0 , σ ~ 2 I ) ( 8 ) \frac{\mathrm{d} \boldsymbol{z}_t}{\mathrm{~d} t}=f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right), \quad \boldsymbol{z}_T \sim \mathcal{N}\left(\mathbf{0}, \tilde{\sigma}^2 \boldsymbol{I}\right) \quad(8) dtdzt=f(t)zt+2σtg2(t)ϵθ(zt,c,t),zT∼N(0,σ~2I)(8)
其中 z t \boldsymbol{z}_t zt为图像潜变量, ϵ θ ( z t , c , t ) \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right) ϵθ(zt,c,t)是噪声预测模型, c \boldsymbol{c} c是给定的条件 (例如文本)。可以通过求解从 T T T到0的PF-ODE来绘制样本。为了执行 L C D \mathbf{LCD} LCD,我们引入了一致性函数 f θ : ( z t , c , t ) ↦ z 0 f_\theta:\left(z_t, c, t\right) \mapsto z_0 fθ:(zt,c,t)↦z0以直接预测 t = 0 t=0 t=0的PF-ODE (Eq. 8) 的解。我们通过噪声预测模型 ϵ ^ θ \hat{\boldsymbol{\epsilon}}_\theta ϵ^θ对 f θ f_\theta fθ进行参数化,具体如下:
f θ ( z , c , t ) = c skip ( t ) z + c out ( t ) ( z − σ t ϵ ^ θ ( z , c , t ) α t ) , ( ϵ -Prediction ) ( 9 ) \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{z}, \boldsymbol{c}, t)=c_{\text {skip }}(t) \boldsymbol{z}+c_{\text {out }}(t)\left(\frac{\boldsymbol{z}-\sigma_t \hat{\boldsymbol{\epsilon}}_\theta(\boldsymbol{z}, \boldsymbol{c}, t)}{\alpha_t}\right), \quad(\boldsymbol{\epsilon} \text {-Prediction })\quad(9) fθ(z,c,t)=cskip (t)z+cout (t)(αtz−σtϵ^θ(z,c,t)),(ϵ-Prediction )(9)
其中 c skip ( 0 ) = 1 , c out ( 0 ) = 0 c_{\text {skip }}(0)=1, c_{\text {out }}(0)=0 cskip (0)=1,cout (0)=0, ϵ ^ θ ( z , c , t ) \hat{\boldsymbol{\epsilon}}_\theta(\boldsymbol{z}, \boldsymbol{c}, t) ϵ^θ(z,c,t)是一个噪声预测模型,初始化时与教师扩散模型具有相同的参数。值得注意的是, f θ \boldsymbol{f}_{\boldsymbol{\theta}} fθ可以通过多种方式进行参数化,具体取决于教师扩散模型对预测的参数化 (例如, x , ϵ \boldsymbol{x}, \boldsymbol{\epsilon} x,ϵ (Ho et al., 2020), v \boldsymbol{v} v (Salimans & Ho, 2022))。我们在附录D中讨论了其他可能的参数化方法。
我们假设有一个有效的ODE求解器 Ψ ( z t , t , s , c ) \Psi\left(\boldsymbol{z}_t, t, s, c\right) Ψ(zt,t,s,c)可用于从时间 t t t到 s s s的近似Eq 8右侧的积分。在实践中,我们可以使用DDIM (Song et al., 2020a)、DPM-Solver (Lu et al., 2022a)或DPM-Solver++ (Lu et al., 2022b)作为 Ψ ( ⋅ , ⋅ , ⋅ , ⋅ ) \Psi(\cdot, \cdot, \cdot, \cdot) Ψ(⋅,⋅,⋅,⋅)。请注意,我们仅在训练/蒸馏时使用这些求解器,而不是在推理时。当我们在第4.3节中介绍SKIPPING-STEP技术时,我们将进一步讨论这些求解器。LCM旨在通过最小化一致性蒸馏损失 (Song et al., 2023) 来预测PF-ODE的解:
L C D ( θ , θ − ; Ψ ) = E z , c , n [ d ( f θ ( z t n + 1 , c , t n + 1 ) , f θ − ( z ^ t n Ψ , c , t n ) ) ] ( 10 ) \mathcal{L}_{\mathcal{C D}}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{-} ; \Psi\right)=\mathbb{E}_{\boldsymbol{z}, c, n}\left[d\left(\boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_{n+1}}, \boldsymbol{c}, t_{n+1}\right), \boldsymbol{f}_{\boldsymbol{\theta}^{-}}\left(\hat{\boldsymbol{z}}_{t_n}^{\Psi}, \boldsymbol{c}, t_n\right)\right)\right]\quad(10) LCD(θ,θ−;Ψ)=Ez,c,n[d(fθ(ztn+1,c,tn+1),fθ−(z^tnΨ,c,tn))](10)
这里, z ^ t n Ψ \hat{z}_{t_n}^{\Psi} z^tnΨ是从 t n + 1 → t n t_{n+1} \rightarrow t_n tn+1→tn 使用ODE求解器 Ψ \Psi Ψ对 P F − O D E P F-O D E PF−ODE的演化的估计:
z ^ t n Ψ − z t n + 1 = ∫ t n + 1 t n ( f ( t ) z t + g 2 ( t ) 2 σ t ϵ θ ( z t , c , t ) ) d t ≈ Ψ ( z t n + 1 , t n + 1 , t n , c ) ( 11 ) \hat{\boldsymbol{z}}_{t_n}^{\Psi}-\boldsymbol{z}_{t_{n+1}}=\int_{t_{n+1}}^{t_n}\left(f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right)\right) \mathrm{d} t \approx \Psi\left(\boldsymbol{z}_{t_{n+1}}, t_{n+1}, t_n, \boldsymbol{c}\right)\quad(11) z^tnΨ−ztn+1=∫tn+1tn(f(t)zt+2σtg2(t)ϵθ(zt,c,t))dt≈Ψ(ztn+1,tn+1,tn,c)(11)
其中,求解器 Ψ ( ⋅ , ⋅ , ⋅ , ⋅ ) \Psi(\cdot, \cdot, \cdot, \cdot) Ψ(⋅,⋅,⋅,⋅)用于近似从 t n + 1 → t n t_{n+1} \rightarrow t_n tn+1→tn的积分。
4.2 ONE-STAGE GUIDED DISTILLATION BY SOLVING AUGMENTED PF-ODE
Classifier-free guidance (CFG) (Ho & Salimans, 2022) 在SD中对于生成高质量文本对齐图像至关重要,通常需要一个CFG缩放 ω \omega ω超过6。因此,将CFG整合到蒸馏方法中变得不可或缺。之前的方法Guided-Distill (Meng et al., 2023)引入了一个两阶段的蒸馏方法,以支持从一个引导的扩散模型进行少步采样。然而,它计算量很大(例如,据估计在(Liu et al., 2023)中,用于2步推理的计算量至少需要45个A100 GPU日)。相较之下,一个LCM仅需32个A100 GPU小时的训练即可实现2步推理,如图1所示。此外,两阶段的引导蒸馏可能导致累积误差,从而导致次优性能。相比之下,LCMs采用高效的一阶段引导蒸馏,通过求解增强的PF-ODE来实现。回顾在逆扩散过程中使用的CFG:
ϵ ~ θ ( z t , ω , c , t ) : = ( 1 + ω ) ϵ θ ( z t , c , t ) − ω ϵ θ ( z t , ∅ , t ) ( 12 ) \tilde{\boldsymbol{\epsilon}}_\theta\left(\boldsymbol{z}_t, \omega, \boldsymbol{c}, t\right):=(1+\omega) \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right)-\omega \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \varnothing, t\right)\quad(12) ϵ~θ(zt,ω,c,t):=(1+ω)ϵθ(zt,c,t)−ωϵθ(zt,∅,t)(12)
其中,原始的噪声预测被条件和无条件噪声的线性组合所取代,而 ω \omega ω被称为引导比例。要从引导的逆过程进行采样,我们需要求解以下增强的PF-ODE:(即,增加了与 ω \omega ω相关的项)
d z t d t = f ( t ) z t + g 2 ( t ) 2 σ t ϵ ~ θ ( z t , ω , c , t ) , z T ∼ N ( 0 , σ ~ 2 I ) ( 13 ) \frac{\mathrm{d} \boldsymbol{z}_t}{\mathrm{~d} t}=f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \tilde{\boldsymbol{\epsilon}}_\theta\left(\boldsymbol{z}_t, \omega, \boldsymbol{c}, t\right), \quad \boldsymbol{z}_T \sim \mathcal{N}\left(\mathbf{0}, \tilde{\sigma}^2 \boldsymbol{I}\right)\quad(13) dtdzt=f(t)zt+2σtg2(t)ϵ~θ(zt,ω,c,t),zT∼N(0,σ~2I)(13)
为了高效地执行一阶段引导蒸馏,我们引入了一个增强的一致性函数 f θ : ( z t , ω , c , t ) ↦ z 0 \boldsymbol{f}_\theta:\left(z_t, \omega, \boldsymbol{c}, t\right) \mapsto \boldsymbol{z}_0 fθ:(zt,ω,c,t)↦z0,直接预测增强的PF-ODE (Eq. 13)在 t = 0 t=0 t=0时的解。我们对 f θ f_{\boldsymbol{\theta}} fθ的参数化方式与Eq. 9相同,不同的是 ϵ ^ θ ( z , c , t ) \hat{\boldsymbol{\epsilon}}_\theta(\boldsymbol{z}, \boldsymbol{c}, t) ϵ^θ(z,c,t)被 ϵ ^ θ ( z , ω , c , t ) \hat{\boldsymbol{\epsilon}}_\theta(\boldsymbol{z}, \omega, \boldsymbol{c}, t) ϵ^θ(z,ω,c,t)替代,这是一个噪声预测模型,初始化时与教师扩散模型的参数相同,但也包含用于 ω \omega ω调节的额外可训练参数。一致性损失与Eq. 10相同,只不过我们使用增强的一致性函数 f θ ( z t , ω , c , t ) \boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{\boldsymbol{t}}, \omega, \boldsymbol{c}, t\right) fθ(zt,ω,c,t)。
L C D ( θ , θ − ; Ψ ) = E z , c , ω , n [ d ( f θ ( z t n + 1 , ω , c , t n + 1 ) , f θ − ( z ^ t n Ψ , ω , ω , c , t n ) ) ] ( 14 ) \mathcal{L}_{\mathcal{C D}}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{-} ; \Psi\right)=\mathbb{E}_{\boldsymbol{z}, c, \omega, n}\left[d\left(\boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_{n+1}}, \omega, \boldsymbol{c}, t_{n+1}\right), \boldsymbol{f}_{\boldsymbol{\theta}^{-}}\left(\hat{\boldsymbol{z}}_{t_n}^{\Psi, \omega}, \omega, \boldsymbol{c}, t_n\right)\right)\right]\quad(14) LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+1,ω,c,tn+1),fθ−(z^tnΨ,ω,ω,c,tn))](14)
在Eq 14中, ω \omega ω和 n n n分别从区间 [ ω min , ω max ] \left[\omega_{\min }, \omega_{\max }\right] [ωmin,ωmax]和 { 1 , … , N − 1 } \{1, \ldots, N-1\} {1,…,N−1}中均匀采样。 z ^ t n Ψ , ω \hat{\boldsymbol{z}}_{t_n}^{\Psi, \omega} z^tnΨ,ω使用新的噪声模型 ϵ ~ θ ( z t , ω , c , t ) \tilde{\boldsymbol{\epsilon}}_\theta\left(\boldsymbol{z}_t, \omega, \boldsymbol{c}, t\right) ϵ~θ(zt,ω,c,t)进行估计,如下:
z ^ t n D , ω − z t n + 1 = ∫ t n + 1 t n ( f ( t ) z t + g 2 ( t ) 2 σ t ϵ ~ θ ( z t , ω , c , t ) ) d t = ( 1 + ω ) ∫ t n + 1 t n ( f ( t ) z t + g 2 ( t ) 2 σ t ϵ θ ( z t , c , t ) ) d t − ω ∫ t n + 1 t n ( f ( t ) z t + g 2 ( t ) 2 σ t ϵ θ ( z t , ∅ , t ) ) d t ≈ ( 1 + ω ) Ψ ( z t n + 1 , t n + 1 , t n , c ) − ω Ψ ( z t n + 1 , t n + 1 , t n , ∅ ) ( 15 ) \begin{aligned} \hat{\boldsymbol{z}}_{t_n}^{\mathbb{D}, \omega}-\boldsymbol{z}_{t_{n+1}} & =\int_{t_{n+1}}^{t_n}\left(f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \tilde{\boldsymbol{\epsilon}}_\theta\left(\boldsymbol{z}_t, \omega, \boldsymbol{c}, t\right)\right) \mathrm{d} t \\ & =(1+\omega) \int_{t_{n+1}}^{t_n}\left(f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \boldsymbol{c}, t\right)\right) \mathrm{d} t-\omega \int_{t_{n+1}}^{t_n}\left(f(t) \boldsymbol{z}_t+\frac{g^2(t)}{2 \sigma_t} \boldsymbol{\epsilon}_\theta\left(\boldsymbol{z}_t, \varnothing, t\right)\right) \mathrm{d} t \\ & \approx(1+\omega) \Psi\left(\boldsymbol{z}_{t_{n+1}}, t_{n+1}, t_n, \boldsymbol{c}\right)-\omega \Psi\left(\boldsymbol{z}_{t_{n+1}}, t_{n+1}, t_n, \varnothing\right) \end{aligned}\quad(15) z^tnD,ω−ztn+1=∫tn+1tn(f(t)zt+2σtg2(t)ϵ~θ(zt,ω,c,t))dt=(1+ω)∫tn+1tn(f(t)zt+2σtg2(t)ϵθ(zt,c,t))dt−ω∫tn+1tn(f(t)zt+2σtg2(t)ϵθ(zt,∅,t))dt≈(1+ω)Ψ(ztn+1,tn+1,tn,c)−ωΨ(ztn+1,tn+1,tn,∅)(15)
同样,我们可以使用DDIM (Song et al., 2020a)、DPM-Solver (Lu et al., 2022a)或DPM-Solver++ (Lu et al., 2022b)作为PF-ODE求解器 Ψ ( ⋅ , ⋅ , ⋅ , ⋅ ) \Psi(\cdot, \cdot, \cdot, \cdot) Ψ(⋅,⋅,⋅,⋅)。
4.3 ACCELERATING DISTILLATION WITH SKIPPING TIME STEPS
离散扩散模型 (Ho et al., 2020; Song & Ermon, 2019) 通常使用一个长时间步长计划 { t i } i \left\{t_i\right\}_i {ti}i(也称为离散计划或时间计划)来训练噪声预测模型,以实现高质量的生成结果。例如,Stable Diffusion (SD) 的时间计划长度为1000。然而,直接将潜在一致性蒸馏(Latent Consistency Distillation, LCD)应用于SD具有这样一个扩展的计划可能会有问题。模型需要在所有1000个时间步长之间进行采样,并且一致性损失试图将LCM模型 f θ ( z t n + 1 , c , t n + 1 ) \boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_{n+1}}, \boldsymbol{c}, t_{n+1}\right) fθ(ztn+1,c,tn+1)的预测与沿着相同轨迹的下一个时间步长的预测 f θ ( z t n , c , t n ) \boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_n}, \boldsymbol{c}, t_n\right) fθ(ztn,c,tn)对齐。由于 t n − t n + 1 t_n-t_{n+1} tn−tn+1非常小, z t n \boldsymbol{z}_{t_n} ztn和 z t n + 1 \boldsymbol{z}_{t_{n+1}} ztn+1(以及 f θ ( z t n + 1 , c , t n + 1 ) \boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_{n+1}}, \boldsymbol{c}, t_{n+1}\right) fθ(ztn+1,c,tn+1)和 f θ ( z t n , c , t n ) \boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_n}, \boldsymbol{c}, t_n\right) fθ(ztn,c,tn))已经非常接近,导致一致性损失较小,因此导致收敛缓慢。为了解决这些问题,我们引入了SKIPPING-STEP方法,大大缩短了时间计划的长度(从数千缩短到数十),以实现快速收敛,同时保持生成质量。
一致性模型 (CMs) (Song et al., 2023) 使用EDM (Karras et al., 2022) 连续时间计划,并使用Euler或Heun Solver作为数值连续PF-ODE求解器。对于LCMs,为了适应Stable Diffusion中的离散时间计划,我们使用DDIM (Song et al., 2020a)、DPM-Solver (Lu et al., 2022a) 或 DPM-Solver++ (Lu et al., 2022b) 作为ODE求解器。 (Lu et al., 2022a) 表明这些高级求解器可以高效地求解Eq. 8中的PF-ODE。现在,我们在潜在一致性蒸馏 (LCD) 中介绍SKIPPING-STEP方法。LCMs旨在确保当前时间步长与 k k k步远的时间步长之间的一致性, t n + k → t n t_{n+k} \rightarrow t_n tn+k→tn,而不是确保相邻时间步长 t n + 1 → t n t_{n+1} \rightarrow t_n tn+1→tn之间的一致性。注意,设置 k = 1 k=1 k=1会还原为(Song et al., 2023)中的原始计划,导致收敛缓慢,而非常大的 k k k可能会导致ODE求解器的大近似误差。在我们的主要实验中,我们将 k k k设置为20,大大将时间计划长度从数千减少到数十。Sec. 5.2中的结果展示了各种 k k k值的效果,并揭示了SKIPPING-STEP方法在加速LCD过程中的重要性。具体来说,Eq. 14中的一致性提炼损失被修改为确保从 t n + k t_{n+k} tn+k到 t n t_n tn的一致性:
L C D ( θ , θ − ; Ψ ) = E z , c , ω , n [ d ( f θ ( z t n + k , ω , c , t n + k ) , f θ − ( z ^ t n Ψ , ω , ω , c , t n ) ) ] ( 16 ) \mathcal{L}_{\mathcal{C D}}\left(\boldsymbol{\theta}, \boldsymbol{\theta}^{-} ; \Psi\right)=\mathbb{E}_{\boldsymbol{z}, \boldsymbol{c}, \omega, n}\left[d\left(\boldsymbol{f}_{\boldsymbol{\theta}}\left(\boldsymbol{z}_{t_{n+k}}, \omega, \boldsymbol{c}, t_{n+k}\right), \boldsymbol{f}_{\boldsymbol{\theta}-}\left(\hat{\boldsymbol{z}}_{t_n}^{\Psi, \omega}, \omega, \boldsymbol{c}, t_n\right)\right)\right]\quad(16) LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+k,ω,c,tn+k),fθ−(z^tnΨ,ω,ω,c,tn))](16)
其中 z ^ t n Ψ , ω \hat{\boldsymbol{z}}_{t_n}^{\Psi, \omega} z^tnΨ,ω是使用数值增强PF-ODE求解器 Ψ \Psi Ψ估计的 z t n \boldsymbol{z}_{t_n} ztn:
z ^ t n Ψ , ω ⟵ z t n + k + ( 1 + ω ) Ψ ( z t n + k , t n + k , t n , c ) − ω Ψ ( z t n + k , t n + k , t n , ∅ ) ( 17 ) \hat{\boldsymbol{z}}_{t_n}^{\Psi, \omega} \longleftarrow \boldsymbol{z}_{t_{n+k}}+(1+\omega) \Psi\left(\boldsymbol{z}_{t_{n+k}}, t_{n+k}, t_n, \boldsymbol{c}\right)-\omega \Psi\left(z_{t_{n+k}}, t_{n+k}, t_n, \varnothing\right)\quad(17) z^tnΨ,ω⟵ztn+k+(1+ω)Ψ(ztn+k,tn+k,tn,c)−ωΨ(ztn+k,tn+k,tn,∅)(17)
上述推导类似于Eq. 15。对于LCM,我们在此使用三种可能的ODE求解器:DDIM (Song et al., 2020a)、DPM-Solver (Lu et al., 2022a)、DPM-Solver++ (Lu et al., 2022b),我们在Sec 5.2中比较了它们的性能。实际上,DDIM (Song et al., 2020a)是DPM-Solver的一级离散化近似(已在(Lu et al., 2022a)中证明)。在此我们提供从 t n + k t_{n+k} tn+k到 t n t_n tn的DDIM PF-ODE求解器 Ψ DIIM \Psi_{\text {DIIM }} ΨDIIM 的详细公式。其他两个求解器 Ψ DPM-Solver , Ψ DPM-Solver++ \Psi_{\text {DPM-Solver }}, \Psi_{\text {DPM-Solver++ }} ΨDPM-Solver ,ΨDPM-Solver++ 的公式在附录E中提供。
Ψ D D I M ( z t n + k , t n + k , t n , c ) = α t n α t n + k z t n + k − σ t n ( σ t n + k ⋅ α t n α t n + k ⋅ σ t n − 1 ) ϵ ^ θ ( z t n + k , c , t n + k ) ⏟ DDIM Estimated z t n − z t n + k ( 18 ) \Psi_{\mathrm{DDIM}}\left(\boldsymbol{z}_{t_{n+k}}, t_{n+k}, t_n, \boldsymbol{c}\right)=\underbrace{\frac{\alpha_{t_n}}{\alpha_{t_{n+k}}} \boldsymbol{z}_{t_{n+k}}-\sigma_{t_n}\left(\frac{\sigma_{t_{n+k}} \cdot \alpha_{t_n}}{\alpha_{t_{n+k}} \cdot \sigma_{t_n}}-1\right) \hat{\boldsymbol{\epsilon}}_\theta\left(\boldsymbol{z}_{t_{n+k}}, \boldsymbol{c}, t_{n+k}\right)}_{\text {DDIM Estimated } \boldsymbol{z}_{t_n}}-\boldsymbol{z}_{t_{n+k}}\quad(18) ΨDDIM(ztn+k,tn+k,tn,c)=DDIM Estimated ztn αtn+kαtnztn+k−σtn(αtn+k⋅σtnσtn+k⋅αtn−1)ϵ^θ(ztn+k,c,tn+k)−ztn+k(18)
我们在算法1中提供了具有CFG和SKIPPING-STEP技术的LCD伪代码。对(Song et al., 2023)中原始一致性蒸馏(CD)算法的修改部分以蓝色标出。此外,LCM采样算法3在附录B中提供。
4.4 LATENT CONSISTENCY FINE-TUNING FOR CUSTOMIZED DATASET
稳定扩散等基础生成模型在各种文本到图像生成任务中表现出色,但通常需要在定制数据集上进行微调以满足下游任务的要求。
我们提出了潜在一致性微调 (LCF),这是一种针对预训练 LCM 的微调方法。受一致性训练 (CT) (Song et al, 2023) 的启发,LCF 可以在不依赖于此类数据上训练的教师扩散模型的情况下对定制数据集进行高效的几步推理。这种方法为扩散模型的传统微调方法提供了一种可行的替代方案。LCF 的伪代码在算法 4 中提供,附录 C 中有更详细的说明。