扩散模型对抗蒸馏:ADD 和 Latent-ADD
ADD(Adversarial Diffusion Distillation)和 Latent-ADD 是 StabilityAI 公司提出的一系列针对 Stable Diffusion 的扩散模型对抗蒸馏方法,通过对抗训练和蒸馏训练来提高扩散模型的采样速度,并保持较高的生图质量。
Adversarial Diffusion Distillation
Adversarial Diffusion Distillation,扩散模型对抗蒸馏,顾名思义,是结合对抗训练和蒸馏训练两种方式来对扩散模型的生图进行加速的方法。
整个 ADD 方法在训练过程中有三个模型参与:参数为 ψ \psi ψ 的教师扩散模型,参数为 ϕ \phi ϕ 的判别器模型,以及参数为 θ \theta θ 的学生模型,其中学生模型和判别器模型的(部分)参数是可训练的,教师模型的参数 ψ \psi ψ 则是固定的。训练过程概览如下图所示。教师模型采样时间步 { 1 , … , 1000 } \{1,\dots,1000\} {1,…,1000},学生模型采样时间步 { τ 1 , … , τ N } \{\tau_1,\dots,\tau_N\} {τ1,…,τN}, N N N 为学生模型的总步数。本文取 N = 4 N=4 N=4 即学生模型通过四步采样即可完成生图,并取 τ N = 1000 \tau_N=1000 τN=1000 从而保证初始步是零信噪比的纯噪声,与预训练时分布保持一致。
ADD 由对抗损失和蒸馏损失两部分驱动训练。首先从数据集中随机采样一张干净图片
x
0
x_0
x0 后对其进行加噪得到
x
s
=
α
s
x
0
+
σ
s
ϵ
x_s=\alpha_s x_0+\sigma_s\epsilon
xs=αsx0+σsϵ,然后送入到学生模型进行去噪,得到
x
^
θ
(
x
s
,
s
)
\hat{x}_\theta(x_s,s)
x^θ(xs,s),将
x
0
x_0
x0 和
x
s
x_s
xs 送入判别器,计算对抗损失
L
adv
G
\mathcal{L}^G_\text{adv}
LadvG。在按照教师模型的 schedule 对学生模型的输出
x
^
θ
\hat{x}_\theta
x^θ 进行加噪,得到
x
^
θ
,
t
=
α
t
x
^
θ
+
σ
t
ϵ
′
\hat{x}_{\theta,t}=\alpha_t\hat{x}_\theta+\sigma_t\epsilon'
x^θ,t=αtx^θ+σtϵ′,送入到教师模型进行去噪得到
x
^
ψ
(
x
^
θ
,
t
,
t
)
\hat{x}_\psi(\hat{x}_{\theta,t},t)
x^ψ(x^θ,t,t),对
x
^
θ
\hat{x}_\theta
x^θ 和
x
^
ψ
\hat{x}_\psi
x^ψ 计算蒸馏损失
L
distill
\mathcal{L}_\text{distill}
Ldistill。最终的总损失为:
L
=
L
adv
G
(
x
^
θ
(
x
s
,
s
)
,
ϕ
)
+
λ
L
distill
(
x
^
θ
(
x
s
,
s
)
,
ψ
)
\mathcal{L}=\mathcal{L}^G_{\text{adv}}(\hat{x}_\theta(x_s,s),\phi)+\lambda\mathcal{L}_\text{distill}(\hat{x}_\theta(x_s,s),\psi)
L=LadvG(x^θ(xs,s),ϕ)+λLdistill(x^θ(xs,s),ψ)
以下分别详细介绍对抗损失和蒸馏损失。
对抗损失
在对抗损失的计算中,本文参考 StyleGAN-T 的做法。使用一个预训练的网络 F F F(DINO/CLIP 等 ViT 模型,最后发现 DINO v2 效果较好),冻结其参数,作为特征提取模型,在其每一层接一个参数可训练的判别器头 D ϕ , k \mathcal{D}_{\phi,k} Dϕ,k。为了进一步提升性能,还可以通过 projection 引入额外的条件信息,一般文生图模型就是引入文本条件 embedding c text c_\text{text} ctext。本文方法还支持了引入图像条件 c img c_\text{img} cimg。
本文采用 hinge loss 作为目标函数。学生网络(生成器)的优化目标为:
L
adv
G
(
x
^
θ
(
x
s
,
s
)
,
ϕ
)
=
−
E
s
,
ϵ
,
x
0
[
∑
k
D
ϕ
,
k
(
F
k
(
x
^
θ
(
x
s
,
s
)
)
)
]
\mathcal{L}^G_\text{adv}(\hat{x}_\theta(x_s,s),\phi)=-\mathbb{E}_{s,\epsilon,x_0}[\sum_k\mathcal{D}_{\phi,k}(F_k(\hat{x}_\theta(x_s,s)))]
LadvG(x^θ(xs,s),ϕ)=−Es,ϵ,x0[k∑Dϕ,k(Fk(x^θ(xs,s)))]
判别器的优化目标为:
L
adv
D
(
x
^
θ
(
x
s
,
s
)
,
ϕ
)
=
E
x
0
[
∑
k
max
(
0
,
1
−
D
ϕ
,
k
(
F
k
(
x
0
)
)
)
+
γ
R
1
(
ϕ
)
]
+
E
x
^
θ
[
∑
k
max
(
0
,
1
+
D
ϕ
,
k
(
F
k
(
x
^
θ
)
)
)
]
\mathcal{L}_\text{adv}^D(\hat{x}_\theta(x_s,s),\phi)=\mathbb{E}_{x_0}[\sum_k\max(0,1-\mathcal{D}_{\phi,k}(F_k(x_0)))+\gamma R1(\phi)] +\mathbb{E}_{\hat{x}_\theta}[\sum_k\max(0,1+\mathcal{D}_{\phi,k}(F_k(\hat{x}_\theta)))]
LadvD(x^θ(xs,s),ϕ)=Ex0[k∑max(0,1−Dϕ,k(Fk(x0)))+γR1(ϕ)]+Ex^θ[k∑max(0,1+Dϕ,k(Fk(x^θ)))]
蒸馏损失
ADD 的蒸馏损失为:
L
distill
(
x
^
θ
(
x
s
,
s
)
,
ψ
)
=
E
t
,
ϵ
′
[
c
(
t
)
d
(
x
^
θ
,
x
^
ψ
(
sg
(
x
^
θ
,
t
)
;
t
)
)
]
\mathcal{L}_\text{distill}(\hat{x}_\theta(x_s,s),\psi)=\mathbb{E}_{t,\epsilon'}[c(t)d(\hat{x}_\theta,\hat{x}_\psi(\text{sg}(\hat{x}_{\theta,t});t))]
Ldistill(x^θ(xs,s),ψ)=Et,ϵ′[c(t)d(x^θ,x^ψ(sg(x^θ,t);t))]
就是用某种距离
d
d
d 去计算学生模型的去噪结果
x
^
θ
\hat{x}_\theta
x^θ 和教师模型的去噪结果
x
^
ψ
\hat{x}_\psi
x^ψ 之间的差异。ADD 中采用的距离度量方式为欧式距离,
c
(
t
)
c(t)
c(t) 是一个关于时间
t
t
t 的加权函数。
Latent Adversarial Diffusion Distillation
ADD 通过对抗和蒸馏训练取得了不错的加速效果,但仍存在几个问题。首先,ADD 使用的是固定参数的预训练 DINOv2 模型作为判别器的特征提取器,其输入分辨率是固定的;第二,没办法直接控制判别器反馈的层级,比如是全局特征还是局部特征;第三,由于判别器没有在隐层空间上训练过,因此蒸馏 LDM 时,必须要先解码回像素空间,这使得在高分辨率图像上的蒸馏训练成本极高。另外,GAN 的训练很不稳定且未观察到 LLM/DM 中出现的 scaling law。
本文中提出的 LADD,可以在超高分辨率图像上对预训练的 DiT 模型进行稳定的、可扩展的对抗蒸馏训练。在对抗训练中,LADD 没有采用常用的判别式特征(比如自监督训练的 DINOv2 特征),而是采用了教师扩散模型本身的生成式特征,这使得模型能够进行多长宽比图像的训练,并可以通过采样对应的噪声层级来提取更全局(噪声层级更高)或者更局部(噪声层级更低)的判别器特征。此外,LADD 不需要解码回像素空间,可以在隐层空间进行蒸馏训练,从而可以在更高分辨率的图像上训练更大的学生模型。最终,LADD 的训练流程比 ADD 要简化得多,并且达到了更好的效果。
下图对比了 ADD 和 LADD 的训练流程。LADD 对 ADD 的核心改进就是使用教师扩散模型同时作为判别器。
在 ADD 中,采用 Projected GAN 的范式,使用一个预训练的判别式模型(DINOv2)提取图像的特征,然后在不同层各自接判别头来对输入图像的真假进行判别。LADD 则是采用了生成式特征,将教师扩散模型同时作为特征提取器,提取教师模型去噪过程中各层的 token sequence 作为图像特征,再接判别头,并且这些判别头需要同时接收噪声水平和文本提示词作为条件。
使用生成式特征相比于判别式特征有四点优势:一是简单高效,LADD 的训练过程完全在隐层空间上,不需要再解码回像素空间,更高效节省训练资源;二是可以提取不同噪声层级对应的特征,通过调整加噪的输入层级,可以选择更全局或更细节的特征;三是支持不同长宽比的图像训练,因为本来教师模型就是在不同长宽比的图像上进行训练的;四是更接近人类偏好,判别式特征会更聚焦于纹理细节而非全局特征,而生成式特征则与人类偏好更相似,
在判别头的网络结构上,LADD 还是基本延续了 ADD 中的选择。但是没有选择了 1D 的卷积,而是将 token sequence reshape 回原图的空间排布后,再使用 2D 卷积进行处理,更好地支持多长宽比输入。
总结
ADD 和 LADD 是 SD 系列同步推出的模型加速方法,结合对抗训练和蒸馏训练,提高模型的生图速度的同时兼顾生成质量,分别训练出了 SDXL-Turbo、SD3-Turbo 等模型,在开源加速模型中有比较大的影响力。