2023-ICCV-Scalable Diffusion Models with Transformers
- 使用 Transformer 的可扩展扩散模型
- 摘要
- 1. 引言
- 2. 相关工作
- 3. 扩散 Transformer
- 3.1 准备工作
- 3.2 扩散 Transformer 设计空间
- 4. 实验设置
- 5. 实验
- 5.1 最先进的扩散模型
- 5.2 缩放模型与采样计算
- 6. 结论
- 参考文献
使用 Transformer 的可扩展扩散模型
作者:William Peebles, Saining Xie
单位:UC Berkeley, New York University
论文地址:2023-ICCV-Scalable Diffusion Models with Transformers
摘要
我们探索了一类基于 Transformer 架构的新型扩散模型。我们训练图像的潜在扩散模型,用在潜在 patchs 上运行的 Transformer 替换常用的 U-Net 主干。我们通过以 Gflops 为衡量标准的前向传递复杂度来分析我们的扩散 Transformer(DiT)
的可扩展性。我们发现,通过增加 Transformer 深度/宽度或增加输入 token 数量,具有更高 Gflops 的 DiT 始终具有较低的 FID。除了具有良好的可扩展性之外,我们最大的 DiT-XL/2 模型在类条件 ImageNet 512×512 和 256×256 基准上的表现优于所有先前的扩散模型,在后者上实现了最先进的 2.27 FID。
1. 引言
机器学习正在经历由 Transformer 推动的复兴。在过去五年中,自然语言处理 [42, 8]、视觉 [10] 和其他几个领域的神经架构已被 Transformer [60] 所取代。然而,许多类别的图像级生成模型仍然没有赶上这一趋势——虽然 Transformer 在自回归模型中得到了广泛的应用 [43, 3, 6, 47],但在其他生成建模框架中却很少采用。例如,扩散模型一直处于图像生成领域最新进展的前沿 [9, 46];然而,它们都采用卷积 U-Net 架构作为事实上的主干选择。
Ho 等人 [19] 的开创性工作首次为扩散模型引入了 U-Net 主干。U-Net 最初在像素级自回归模型和条件 GAN [23] 中取得成功,之后从 PixelCNN++ [52, 58] 继承而来,并进行了一些更改。该模型是卷积的,主要由 ResNet [15] 块组成。与标准 U-Net [49] 相比,Transformer 中必不可少的组件——额外的空间自注意力块以较低的分辨率散布。Dhariwal 和 Nichol [9] 放弃了 UNet 的几种架构选择,例如使用自适应归一化层 [40] 为卷积层注入条件信息和通道计数。然而,Ho 等人的 UNet 的高级设计基本保持不变。
通过这项工作,我们旨在揭开扩散模型中架构选择的重要性,并为未来的生成建模研究提供经验基线。我们表明,U-Net 归纳偏差对扩散模型的性能并不重要,它们可以很容易地用 Transformer 等标准设计取代。因此,扩散模型很可能从最近的架构统一趋势中受益——例如,通过继承其他领域的最佳实践和训练配方,以及保留可扩展性、鲁棒性和效率等有利特性。标准化架构还将为跨领域研究开辟新的可能性。在本文中,我们重点关注一类基于 Transformer 的新型扩散模型。我们称它们为扩散 Transformer,简称 DiT。DiT 遵循视觉 Transformer(ViT)[10] 的最佳实践,事实证明,与传统卷积网络(例如 ResNet [15])相比,ViT 在视觉识别方面的扩展效果更佳。
更具体地说,我们研究了 Transformer 相对于网络复杂度与样本质量的扩展行为。我们表明,通过在潜在扩散模型(LDM)[48] 框架下构建和基准测试 DiT 设计空间(其中扩散模型在 VAE 的潜在空间内训练),我们可以成功地用 Transformer 替换 U-Net 主干。我们进一步表明,DiT 是扩散模型的可扩展架构:网络复杂度(以 Gflops 衡量)与样本质量(以 FID 衡量)之间存在很强的相关性。通过简单地扩大 DiT 的规模并使用高容量主干(118.6 Gflops)训练 LDM,我们就能够在类条件 256×256 ImageNet 生成基准上实现 2.27 FID 的最佳结果。
代码和项目页面在这里提供 https://www.wpeebles.com/DiT.html。
2. 相关工作
Transformer。Transformer [60] 已取代语言、视觉 [10]、强化学习 [5, 25] 和元学习 [39] 等领域的特定领域架构。它们在增加模型大小、训练计算和语言领域数据 [26] 的情况下表现出了出色的扩展特性,可用作通用自回归模型 [17] 和 ViT [63]。除了语言之外,Transformer 还被训练用于自回归预测像素 [38, 7, 6]。它们还在离散码本 [59] 上被训练为自回归模型 [11, 47] 和掩码生成模型 [4, 14];前者在多达 20B 个参数的情况下表现出出色的扩展行为 [62]。最后,人们在 DDPM 中探索了使用 Transformer 来合成非空间数据;例如,在 DALL·E2 [46, 41] 中生成 CLIP 图像嵌入。在本文中,我们研究了 Transformer 作为图像扩散模型的骨干时的缩放特性。
去噪扩散概率模型(DDPM)。扩散 [54, 19] 和基于分数的生成模型 [22, 56] 作为图像生成模型特别成功 [35, 46, 50, 48],在许多情况下优于之前最先进的生成对抗网络(GAN)[12]。过去两年对 DDPM 的改进很大程度上得益于改进的采样技术 [19, 55, 27],最显著的是无分类器指导 [21]、重新制定扩散模型以预测噪声而不是像素 [19] 以及使用级联管道,其中低分辨率基础扩散模型与上采样器并行训练 [20, 9]。对于上面列出的所有扩散模型,卷积 U-Nets [49] 是事实上的骨干架构选择。并行工作 [24] 为 DDPM 引入了一种基于注意力机制的新型高效架构;我们探索纯 Transformer。
架构复杂性。在评估图像生成文献中的架构复杂性时,通常使用参数计数。通常,参数计数不能很好地代表图像模型的复杂性,因为它们没有考虑图像分辨率等因素,而图像分辨率会显著影响性能 [44, 45]。相反,本文的大部分分析都是通过计算的视角进行的。这使我们与架构设计文献保持一致,其中广泛使用触发器来衡量复杂性。在实践中,黄金指标将取决于特定的应用场景。Nichol 和 Dhariwal 改进扩散模型的开创性工作 [36, 9] 与我们最相关——在那里,他们分析了 U-Net 架构类的可扩展性。在本文中,我们重点关注 Transformer 类。
3. 扩散 Transformer
3.1 准备工作
扩散公式。在介绍我们的架构之前,我们先简要回顾一下理解扩散模型(DDPM)所需的一些基本概念 [54, 19]。高斯扩散模型假设一个前向噪声过程,该过程逐渐将噪声应用于真实数据 x 0 : q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) x_0:\ q\left(x_t|x_0\right)=\mathcal{N}\left(x_t;\ \sqrt{{\bar{\alpha}}_t}x_0,\ \left(1-{\bar{\alpha}}_t\right)\mathbf{I}\right) x0: q(xt∣x0)=N(xt; αˉtx0, (1−αˉt)I),其中常数 α ˉ t {\bar{\alpha}}_t αˉt 是超参数。通过应用重新参数化技巧,我们可以采样 x t = α ˉ t x 0 + 1 − α ˉ t ϵ t x_t=\sqrt{{\bar{\alpha}}_t}x_0+\sqrt{1-{\bar{\alpha}}_t}\epsilon_t xt=αˉtx0+1−αˉtϵt,其中 ϵ t ∼ N ( 0 , I ) \epsilon_t\sim\mathcal{N}\left(0,\ \mathbf{I}\right) ϵt∼N(0, I)。
扩散模型经过训练,可以学习逆转正向过程腐败的逆过程: p θ ( x t − 1 ∣ x t ) = N ( μ θ ( x t ) , ∑ θ ( x t ) ) p_\theta\left(x_{t-1}|x_t\right)=\ \mathcal{N}\left(\mu_\theta\left(x_t\right),\ \sum_{\theta}\left(x_t\right)\right) pθ(xt−1∣xt)= N(μθ(xt), ∑θ(xt)),其中神经网络用于预测 p θ p_\theta pθ 的统计数据。逆过程模型使用 x 0 x_0 x0 对数似然的变分下界 [30] 进行训练,简化为 L ( θ ) = − p ( x 0 ∣ x 1 ) + ∑ t D K L ( q ∗ ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) \mathcal{L}\left(\theta\right)=-p\left(x_0|x_1\right)+\sum_{t}\mathcal{D}_{KL}\left(q^\ast\left(x_{t-1}|x_t,\ x_0\right)||p_\theta\left(x_{t-1}|x_t\right)\right) L(θ)=−p(x0∣x1)+∑tDKL(q∗(xt−1∣xt, x0)∣∣pθ(xt−1∣xt)),不包括与训练无关的附加项。通过将 μ θ \mu_\theta μθ 重新参数化为噪声预测网络 ϵ θ \epsilon_\theta ϵθ,可以使用预测噪声 ϵ θ ( x t ) \epsilon_\theta\left(x_t\right) ϵθ(xt) 和真实采样高斯噪声 ϵ t \epsilon_t ϵt 之间的简单均方误差来训练模型: L s i m p l e ( θ ) = ∣ ∣ ϵ θ ( x t ) − ϵ t ∣ ∣ 2 2 \mathcal{L}_{simple}\left(\theta\right)={||\epsilon_\theta\left(x_t\right)-\epsilon_t||}_2^2 Lsimple(θ)=∣∣ϵθ(xt)−ϵt∣∣22。但是,为了训练具有学习到的逆过程协方差 Σ θ \Sigma_\theta Σθ 的扩散模型,需要优化完整的 D K L \mathcal{D}_{KL} DKL 项。我们遵循 Nichol 和 Dhariwal 的方法 [36]:用 L s i m p l e \mathcal{L}_{simple} Lsimple 训练 ϵ θ \epsilon_\theta ϵθ,用完整的 L \mathcal{L} L 训练 Σ θ \Sigma_\theta Σθ。一旦 p θ p_\theta pθ 训练完成,就可以通过初始化 x t m a x ∼ N ( 0 , I ) x_{t_{max}}\sim\mathcal{N}\left(0,\ \mathbf{I}\right) xtmax∼N(0, I) 并通过重新参数化技巧采样 x t − 1 p θ ( x t − 1 ∣ x t ) x_{t-1}~p_\theta\left(x_{t-1}|x_t\right) xt−1 pθ(xt−1∣xt) 来采样新图像。
无分类器指导。条件扩散模型将额外信息作为输入,例如类标签 c c c。在这种情况下,逆过程变为 p θ ( x t − 1 ∣ x t , c ) p_\theta\left(x_{t-1}|x_t,\ c\right) pθ(xt−1∣xt, c),其中 ϵ θ \epsilon_\theta ϵθ 和 Σ θ \Sigma_\theta Σθ 以 c c c 为条件。在这种情况下,无分类器指导可用于鼓励采样过程找到 x x x,使得 log p ( c ∣ x ) \log\ p\left(c|x\right) log p(c∣x) 较高 [21]。根据贝叶斯规则, log p ( c ∣ x ) ∝ log p ( x ∣ c ) − log p ( x ) \log_\ p\left(c|x\right)\propto\log\ p\left(x|c\right)-\log\ p\left(x\right) log p(c∣x)∝log p(x∣c)−log p(x),因此 ∇ x log p ( c ∣ x ) ∝ ∇ x log p ( x ∣ c ) − ∇ x log p ( x ) \nabla_x\log\ p\left(c|x\right)\propto\nabla_x\log\ p\left(x|c\right)-\nabla_x\log\ p\left(x\right) ∇xlog p(c∣x)∝∇xlog p(x∣c)−∇xlog p(x)。通过将扩散模型的输出解释为得分函数,可以引导 DDPM 采样程序通过以下方式对具有高 p ( x ∣ c ) p\left(x|c\right) p(x∣c) 的 x x x 进行采样: ϵ ^ θ ( x t , c ) = ϵ θ ( x t , ∅ ) + s ⋅ ∇ x log p ( x ∣ c ) ∝ ϵ θ ( x t , ∅ ) + s ⋅ ( ϵ θ ( x t , c ) − ϵ θ ( x t , ∅ ) ) {\hat{\epsilon}}_\theta\left(x_t,\ c\right)=\epsilon_\theta\left(x_t,\ \emptyset\right)+s\cdot\nabla_x\log\ p\left(x|c\right)\propto\epsilon_\theta\left(x_t,\ \emptyset\right)+s\cdot\left(\epsilon_\theta\left(x_t,\ c\right)-\epsilon_\theta\left(x_t,\ \emptyset\right)\right) ϵ^θ(xt, c)=ϵθ(xt, ∅)+s⋅∇xlog p(x∣c)∝ϵθ(xt, ∅)+s⋅(ϵθ(xt, c)−ϵθ(xt, ∅)),其中 s > 1 s>1 s>1 表示指导的规模(请注意, s = 1 s=1 s=1 恢复标准采样)。评估 c = ∅ c=\emptyset c=∅ 的扩散模型是通过在训练期间随机删除 c c c 并将其替换为学习到的 “null” 嵌入 ∅ \emptyset ∅ 来完成的。众所周知,无分类器指导可以比通用采样技术产生显着改进的样本 [21, 35, 46],并且这种趋势也适用于我们的 DiT 模型。
潜在扩散模型。直接在高分辨率像素空间中训练扩散模型在计算上是无法承受的。潜在扩散模型(LDM)[48] 采用两阶段方法解决这个问题:(1) 学习一个自动编码器,使用学习到的编码器 E E E 将图像压缩为更小的空间表示;(2) 训练表示 z = E ( x ) z=E\left(x\right) z=E(x) 的扩散模型,而不是图像 x x x 的扩散模型( E E E 是冻结的)。然后可以通过从扩散模型中采样表示 z z z 来生成新图像,然后使用学习到的解码器 x = D ( z ) x=D\left(z\right) x=D(z) 将其解码为图像
如图 2 所示,LDM 仅使用 ADM 等像素空间扩散模型的一小部分 Gflops 即可实现良好的性能。由于我们关心计算效率,因此这使它们成为架构探索的一个有吸引力的起点。在本文中,我们将 DiT 应用于潜在空间,尽管它们也可以不加修改地应用于像素空间。这使得我们的图像生成管道成为一种基于混合的方法;我们使用现成的卷积 VAE 和基于 Transformer 的 DDPM。
3.2 扩散 Transformer 设计空间
我们引入了扩散 Transformer(DiT),这是一种用于扩散模型的新架构。我们的目标是尽可能忠实于标准 Transformer 架构,以保留其缩放属性。由于我们的重点是训练图像的 DDPM(具体来说,是图像的空间表示),因此 DiT 基于对 patch 序列进行操作的视觉 Transformer(ViT)架构 [10]。DiT 保留了 ViT 的许多最佳实践。图 3 显示了完整 DiT 架构的概览。在本节中,我们将描述 DiT 的前向传递,以及 DiT 类设计空间的组成部分。
Patchify。DiT 的输入是空间表示 z z z(对于 256×256×3 图像, z z z 的形状为 32×32×4)。DiT 的第一层是 “patchify”,它通过将每个 patch 线性嵌入到输入中,将空间输入转换为 T T T 个标记序列,每个标记的维度为 d d d。在 patch 化之后,我们将标准 ViT 基于频率的位置嵌入(正弦-余弦版本)应用于所有输入标记。Patchify 创建的标记 T T T 的数量由 patch 大小超参数 p p p 决定。如图 4 所示,将 p p p 减半将使 T T T 增加四倍,从而使总 Transformer Gflops 至少增加四倍。虽然它对 Gflops 有显著影响,但请注意,更改 p p p 对下游参数数量没有任何重大影响。
我们将 p = 2 , 4 , 8 p=2,\ 4,\ 8 p=2, 4, 8 添加到 DiT 设计空间。
DiT 块设计。在 patchify 之后,输入标记由一系列 Transformer 块处理。除了噪声图像输入外,扩散模型有时还会处理其他条件信息,例如噪声时间步长 t t t、类标签 c c c、自然语言等。我们探索了四种以不同方式处理条件输入的 Transformer 块变体。这些设计对标准 ViT 块设计进行了细微但重要的修改。所有块的设计如图 3 所示。
- 上下文条件。我们只需将 t t t 和 c c c 的向量嵌入作为两个附加标记附加到输入序列中,将它们与图像标记处理在一起。这类似于 ViT 中的 cls 标记,它允许我们使用标准 ViT 块而无需修改。在最后一个块之后,我们从序列中删除条件标记。这种方法为模型引入了可忽略不计的新 Gflops。
- 交叉注意块。我们将 t t t 和 c c c 的嵌入连接成一个长度为 2 的序列,与图像标记序列分开。Transformer 块经过修改,在多头自注意块之后包含一个额外的多头交叉注意层,类似于 Vaswani 等人 [60] 的原始设计,也类似于 LDM 用于条件化类标签的设计。交叉注意为模型增加了最多的 Gflops,大约 15% 的开销。
- 自适应层范数(adaLN)块。随着自适应归一化层 [40] 在 GAN [2, 28] 和具有 UNet 主干 [9] 的扩散模型中的广泛使用,我们探索用自适应层范数(adaLN)替换 Transformer 块中的标准层范数层。我们不是直接学习维度尺度和移位参数 γ \gamma γ 和 β \beta β,而是从 t t t 和 c c c 的嵌入向量之和中回归它们。在我们探索的三种块设计中,adaLN 增加的 Gflops 最少,因此计算效率最高。它也是唯一被限制对所有代币应用相同功能的调节机制。
- adaLN-Zero 块。ResNets 的先前研究发现,将每个残差块初始化为恒等函数是有益的。例如,Goyal 等人发现,将每个块中的最终批量标准比例因子 γ \gamma γ 初始化为零可以加速监督学习设置中的大规模训练 [13]。扩散 U-Net 模型使用类似的初始化策略,在任何残差连接之前将每个块中的最终卷积层初始化为零。我们探索了 adaLN DiT 块的修改,它执行相同的操作。除了回归 γ \gamma γ 和 β \beta β 之外,我们还回归在 DiT 块内的任何残差连接之前立即应用的维度缩放参数 α \alpha α。我们初始化 MLP 以输出所有 α \alpha α 的零向量;这会将完整的 DiT 块初始化为恒等函数。与 vanilla adaLN 块一样,adaLNZero 为模型添加了可忽略不计的 Gflops。
我们在 DiT 设计空间中包含了上下文、交叉注意、自适应层规范和 adaLN-Zero 块。
模型大小。我们应用一系列 N N N DiT 块,每个块在隐藏维度大小 d d d 上运行。在 ViT 之后,我们使用标准 Transformer 配置来联合扩展 N N N、 d d d 和注意头 [10, 63]。具体来说,我们使用四种配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。它们涵盖了广泛的模型大小和浮点数分配,从 0.3 到 118.6 Gflops,使我们能够衡量扩展性能。表 1 给出了配置的详细信息。
我们将 B、S、L 和 XL 配置添加到 DiT 设计空间。
Transformer 解码器。在最终的 DiT 块之后,我们需要将图像标记序列解码为输出噪声预测和输出对角协方差预测。这两个输出的形状都等于原始空间输入。我们使用标准线性解码器来执行此操作;我们应用最后一层范数(如果使用 adaLN,则为自适应的)并将每个标记线性解码为 p × p × 2 C p\times p\times2C p×p×2C 张量,其中 C C C 是 DiT 空间输入中的通道数。最后,我们将解码后的标记重新排列为其原始空间布局,以获得预测的噪声和协方差。
我们探索的完整 DiT 设计空间是 patch 大小、Transformer 块架构和模型大小。
4. 实验设置
我们探索 DiT 设计空间并研究模型类的缩放属性。我们的模型根据其配置和潜在 patch 大小 p p p 命名;例如,DiT-XL/2 指的是 XLarge 配置和 p = 2 p=2 p=2。
训练。我们在 ImageNet 数据集 [31] 上以 256×256 和 512×512 图像分辨率训练类条件潜在 DiT 模型,这是一个竞争激烈的生成建模基准。我们用零初始化最后的线性层,否则使用 ViT 的标准权重初始化技术。我们使用 AdamW [33, 29] 训练所有模型。我们使用 1 × 10 − 4 1\times{10}^{-4} 1×10−4 的恒定学习率,没有权重衰减和 256 的批量大小。我们使用的唯一数据增强是水平翻转。与之前对 ViT 的许多工作 [57, 61] 不同,我们发现学习率预热和正则化对于将 DiT 训练到高性能来说不是必要的。即使没有这些技术,训练在所有模型配置中也非常稳定,我们没有观察到训练 Transformer 时常见的任何损失峰值。按照生成建模文献中的常见做法,我们在训练过程中保持 DiT 权重的指数移动平均值(EMA),衰减率为 0.9999。报告的所有结果均使用 EMA 模型。我们在所有 DiT 模型大小和 patch 大小中使用相同的训练超参数。我们的训练超参数几乎完全保留自 ADM。我们没有调整学习率、decay/warm-up schedules、Adam β 1 / β 2 \beta_1/\beta_2 β1/β2 或权重衰减。
扩散。我们使用来自稳定扩散 [48] 的现成的预训练变分自动编码器(VAE)模型 [30]。VAE 编码器的下采样因子为 8 — 给定一个形状为 256×256×3 的 RGB 图像 x x x, z = E ( x ) z=E\left(x\right) z=E(x) 的形状为 32×32×4。在本节的所有实验中,我们的扩散模型都在此 Z \mathcal{Z} Z- 空间中运行。从我们的扩散模型中采样新的潜在值后,我们使用 VAE 解码器 x = D ( z ) x=D\left(z\right) x=D(z) 将其解码为像素。我们保留了来自 ADM [9] 的扩散超参数;具体来说,我们使用 t m a x = 1000 t_{max}=1000 tmax=1000 线性方差计划,范围从 1 × 10 − 4 1\times{10}^{-4} 1×10−4 到 2 × 10 − 2 2\times{10}^{-2} 2×10−2,ADM 的协方差参数化 Σ θ \Sigma_\theta Σθ 及其嵌入输入时间步长和标签的方法。
评估指标。我们使用 Frechet Inception Distance(FID)[18] 来测量缩放性能,这是评估图像生成模型的标准指标。
在与先前的研究进行比较时,我们遵循惯例,并使用 250 DDPM 采样步骤报告 FID-50K。众所周知,FID 对小的实施细节很敏感 [37];为了确保准确的比较,本文中报告的所有值都是通过导出样本并使用 ADM 的 TensorFlow 评估套件 [9] 获得的。除非另有说明,否则本节中报告的 FID 数字不使用无分类器指导。我们还报告了 Inception Score [51]、sFID [34] 和 Precision/Recall [32] 作为次要指标。
计算。我们在 JAX [1] 中实现所有模型,并使用 TPU-v3 pod 进行训练。DiT-XL/2 是我们计算最密集的模型,在全局批处理大小为 256 的 TPU v3-256 pod 上以大约 5.7 次迭代/秒的速度进行训练。
5. 实验
DiT 块设计。我们训练了四个最高 Gflop DiT-XL/2 模型,每个模型都使用不同的块设计 - 上下文(119.4 Gflops)、交叉注意(137.6 Gflops)、自适应层规范(adaLN,118.6 Gflops)或 adaLN-zero(118.6 Gflops)。我们在训练过程中测量 FID。图 5 显示了结果。adaLN-Zero 块产生的 FID 低于交叉注意和上下文条件,同时计算效率最高。在 400K 次训练迭代中,使用 adaLN-Zero 模型实现的 FID 几乎是上下文模型的一半,这表明条件机制对模型质量有重大影响。初始化也很重要 - adaLNZero 将每个 DiT 块初始化为身份函数,其性能明显优于 vanilla adaLN。对于本文的其余部分,所有模型都将使用 adaLN-Zero DiT 块。
缩放模型大小和 patch 大小。我们训练了 12 个 DiT 模型,涵盖了模型配置(S、B、L、XL)和 patch 大小(8、4、2)。请注意,DiT-L 和 DiT-XL 在相对 Gflops 方面比其他配置更接近。图 2(左)概述了每个模型的 Gflops 及其在 400K 次训练迭代中的 FID。在所有情况下,我们发现增加模型大小和减小 patch 大小可以大大改善扩散模型。
图 6(顶部)展示了 FID 如何随着模型大小的增加和 patch 大小的保持不变而变化。在所有四种配置中,通过使 Transformer 更深更宽,可以在训练的所有阶段获得 FID 的显着改进。同样,图 6(底部)显示了随着 patch 大小的减小和模型大小的保持不变,FID 的显着改进。我们再次观察到,通过简单地缩放 DiT 处理的标记数量,在整个训练过程中,FID 得到了显着的改进,参数保持大致固定。
DiT Gflops 对于提高性能至关重要。图 6 的结果表明,参数数量并不能唯一地决定 DiT 模型的质量。当模型大小保持不变并且 patch 大小减小时,Transformer 的总参数实际上保持不变(实际上,总参数略有减少),只有 Gflops 增加了。这些结果表明,扩展模型 Gflops 实际上是提高性能的关键。为了进一步研究这一点,我们在图 8 中绘制了 400K 训练步骤的 FID-50K 与模型 Gflops 的关系。结果表明,当不同的 DiT 配置的总 Gflops 相似时(例如,DiT-S/2 和 DiT-B/4),它们会获得相似的 FID 值。我们发现模型 Gflops 和 FID-50K 之间存在很强的负相关性,这表明额外的模型计算是改进 DiT 模型的关键因素。在图 12(附录)中,我们发现这种趋势也适用于其他指标,例如 Inception Score。
较大的 DiT 模型具有更高的计算效率。在图 9 中,我们将 FID 绘制为所有 DiT 模型的总训练计算的函数。我们将训练计算估计为模型 G f l o p s ⋅ b a t c h s i z e ⋅ t r a i n i n g s t e p ⋅ 3 {\rm Gflops}\cdot{\rm batch\ size}\cdot{\rm training\ step}\cdot{3} Gflops⋅batch size⋅training step⋅3,其中 3 的因子大致近似于向后传递的计算量是前向传递的两倍。我们发现,即使训练时间更长,小型 DiT 模型最终也会变得计算效率低下,而大型 DiT 模型的训练步骤更少。同样,我们发现,除了 patch 大小之外完全相同的模型即使在控制训练 Gflops 时也具有不同的性能配置文件。例如,在大约 1010 Gflops 之后,XL/4 的表现优于 XL/2。
可视化缩放。我们在图 7 中可视化了缩放对样本质量的影响。在 400K 训练步骤中,我们使用相同的起始噪声 x t m a x x_{t_{max}} xtmax、采样噪声和类标签从我们的 12 个 DiT 模型中的每一个中采样一个图像。这让我们可以直观地解释缩放如何影响 DiT 样本质量。事实上,扩大模型大小和标记数量可以显著提高视觉质量。
5.1 最先进的扩散模型
256×256 ImageNet。在进行缩放分析之后,我们继续训练最高 Gflop 模型 DiT-XL/2,进行 7M 步训练。我们在图 1 中展示了模型样本,并与最先进的类条件生成模型进行了比较。我们在表 2 中报告了结果。当使用无分类器指导时,DiT-XL/2 的表现优于所有先前的扩散模型,将 LDM 之前的最佳 FID-50K 从 3.60 降低到 2.27。图 2(右)显示,DiT-XL/2(118.6 Gflops)相对于潜在空间 U-Net 模型(如 LDM-4(103.6 Gflops))具有计算效率,并且比像素空间 U-Net 模型(如 ADM(1120 Gflops)或 ADM-U(742 Gflops))效率高得多。我们的方法实现了所有先前生成模型中最低的 FID,包括之前最先进的 StyleGANXL [53]。最后,我们还观察到,与 LDM-4 和 LDM-8 相比,DiT-XL/2 在所有测试的无分类器指导量表上实现了更高的召回率。当仅训练 2.35M 步(与 ADM 类似)时,XL/2 仍然优于所有先前的扩散模型,FID 为 2.55。
512×512 ImageNet。我们在 ImageNet 上以 512×512 分辨率训练新的 DiT-XL/2 模型,进行 3M 次迭代,超参数与 256×256 模型相同。在 patch 大小为 2 的情况下,此 XL/2 模型在对 64×64×4 输入潜在 (524.6 Gflops) 进行 patch 处理后,总共处理 1024 个标记。表 3 显示了与最先进方法的比较。XL/2 再次以这种分辨率超越所有先前的扩散模型,将 ADM 之前实现的最佳 FID 从 3.85 提高到 3.04。即使标记数量增加,XL/2 仍然具有计算效率。例如,ADM 使用 1983 Gflops,ADM-U 使用 2813 Gflops;XL/2 使用 524.6 Gflops。我们在图 1 和附录中展示了高分辨率 XL/2 模型的样本。
5.2 缩放模型与采样计算
通过在生成图像时增加采样步骤数,扩散模型可以在训练后使用额外的计算。考虑到模型 Gflops 对样本质量的影响,在本节中,我们将研究较小的 DiT 模型是否可以通过使用更多的采样计算胜过较大的模型。我们在 400K 次训练迭代中为所有 12 个 DiT 模型计算 FID,每张图像使用 [16, 32, 64, 128, 256, 1000] 采样步骤。结果如图 10 所示。考虑使用 1000 个采样步骤的 DiT-L/2 与使用 128 个步骤的 DiT-XL/2。在这种情况下,L/2 使用 80.7 Tflops 对每张图像进行采样;XL/2 使用少 5× 的计算(15.2 Tflops)对每张图像进行采样。尽管如此,XL/2 的 FID-10K 更好(23.7 vs 25.9)。一般来说,扩大采样计算无法弥补模型计算的不足。
6. 结论
我们引入了 Diffusion Transformers(DiTs),这是一种基于 Transformer 的简单扩散模型主干,其性能优于之前的 U-Net 模型,并继承了 Transformer 模型类的出色扩展属性。鉴于本文中令人鼓舞的扩展结果,未来的工作应继续将 DiT 扩展到更大的模型和 token 计数。DiT 还可以作为 DALL·E 2 和 Stable Diffusion 等文本到图像模型的嵌入式主干进行探索。
致谢。我们感谢何凯明、胡荣航、Alexander Berg、Shoubhik Debnath、Tim Brooks、Ilija Radosavovic 和 Tete Xiao 的有益讨论。William Peebles 是由 NSF GRFP 资助的。谢赛宁得到了谷歌 TRC 项目的支持和 Cirrascale 的贡献。
参考文献
[1] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. 6
[2] Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale GAN training for high fidelity natural image synthesis. In ICLR, 2019. 5, 9
[3] Tom B Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al Language models are few-shot learners. In NeurIPS, 2020. 1
[4] Huiwen Chang, Han Zhang, Lu Jiang, Ce Liu, and William T Freeman. Maskgit: Masked generative image transformer. In CVPR, pages 11315–11325, 2022. 2
[5] Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Misha Laskin, Pieter Abbeel, Aravind Srinivas, and Igor Mordatch. Decision transformer: Reinforcement learning via sequence modeling. In NeurIPS, 2021. 2
[6] Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, and Ilya Sutskever. Generative pretraining from pixels. In ICML, 2020. 1, 2
[7] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019. 2
[8] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In NAACL-HCT, 2019. 1
[9] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. In NeurIPS, 2021. 1, 2, 3, 5, 6, 9, 12
[10] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2020. 1, 2, 4, 5
[11] Patrick Esser, Robin Rombach, and Bjorn Ommer. Taming ¨ transformers for high-resolution image synthesis, 2020. 2
[12] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, 2014. 3
[13] Priya Goyal, Piotr Dollar, Ross Girshick, Pieter Noord- ´ huis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv:1706.02677, 2017. 5
[14] Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, and Baining Guo. Vector quantized diffusion model for text-to-image synthesis. In CVPR, pages 10696–10706, 2022. 2
[15] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, 2016. 2
[16] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv preprint arXiv:1606.08415, 2016. 12
[17] Tom Henighan, Jared Kaplan, Mor Katz, Mark Chen, Christopher Hesse, Jacob Jackson, Heewoo Jun, Tom B Brown, Prafulla Dhariwal, Scott Gray, et al Scaling laws for autoregressive generative modeling. arXiv preprint arXiv:2010.14701, 2020. 2
[18] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. 2017. 6
[19] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. In NeurIPS, 2020. 2, 3
[20] Jonathan Ho, Chitwan Saharia, William Chan, David J Fleet, Mohammad Norouzi, and Tim Salimans. Cascaded diffusion models for high fidelity image generation. arXiv:2106.15282, 2021. 3, 9
[21] Jonathan Ho and Tim Salimans. Classifier-free diffusion guidance. In NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021. 3, 4
[22] Aapo Hyvarinen and Peter Dayan. Estimation of non- ¨ normalized statistical models by score matching. Journal of Machine Learning Research, 6(4), 2005. 3
[23] Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, and Alexei A Efros. Image-to-image translation with conditional adversarial networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1125–1134, 2017. 2
[24] Allan Jabri, David Fleet, and Ting Chen. Scalable adaptive computation for iterative generation. arXiv preprint arXiv:2212.11972, 2022. 3
[25] Michael Janner, Qiyang Li, and Sergey Levine. Offline reinforcement learning as one big sequence modeling problem. In NeurIPS, 2021. 2
[26] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv:2001.08361, 2020. 2, 13
[27] Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. In Proc. NeurIPS, 2022. 3
[28] Tero Karras, Samuli Laine, and Timo Aila. A style-based generator architecture for generative adversarial networks. In CVPR, 2019. 5
[29] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015. 5
[30] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013. 3, 6 [31] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. In NeurIPS, 2012. 5
[32] Tuomas Kynka¨anniemi, Tero Karras, Samuli Laine, Jaakko ¨ Lehtinen, and Timo Aila. Improved precision and recall metric for assessing generative models. In NeurIPS, 2019. 6
[33] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv:1711.05101, 2017. 5
[34] Charlie Nash, Jacob Menick, Sander Dieleman, and Peter W Battaglia. Generating images with sparse representations. arXiv preprint arXiv:2103.03841, 2021. 6
[35] Alex Nichol, Prafulla Dhariwal, Aditya Ramesh, Pranav Shyam, Pamela Mishkin, Bob McGrew, Ilya Sutskever, and Mark Chen. Glide: Towards photorealistic image generation and editing with text-guided diffusion models. arXiv:2112.10741, 2021. 3, 4
[36] Alexander Quinn Nichol and Prafulla Dhariwal. Improved denoising diffusion probabilistic models. In ICML, 2021. 3
[37] Gaurav Parmar, Richard Zhang, and Jun-Yan Zhu. On aliased resizing and surprising subtleties in gan evaluation. In CVPR, 2022. 6
[38] Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Lukasz Kaiser, Noam Shazeer, Alexander Ku, and Dustin Tran. Image transformer. In International conference on machine learning, pages 4055–4064. PMLR, 2018. 2
[39] William Peebles, Ilija Radosavovic, Tim Brooks, Alexei Efros, and Jitendra Malik. Learning to learn with generative models of neural network checkpoints. arXiv preprint arXiv:2209.12892, 2022. 2
[40] Ethan Perez, Florian Strub, Harm De Vries, Vincent Dumoulin, and Aaron Courville. Film: Visual reasoning with a general conditioning layer. In AAAI, 2018. 2, 5
[41] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al Learning transferable visual models from natural language supervision. In ICML, 2021. 2
[42] Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training. 2018. 1
[43] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al Language models are unsupervised multitask learners. 2019. 1
[44] Ilija Radosavovic, Justin Johnson, Saining Xie, Wan-Yen Lo, and Piotr Dollar. On network design spaces for visual recog- ´ nition. In ICCV, 2019. 3
[45] Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, and Piotr Dollar. Designing network design ´ spaces. In CVPR, 2020. 3
[46] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv:2204.06125, 2022. 1, 2, 3, 4
[47] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. Zero-shot text-to-image generation. In ICML, 2021. 1, 2
[48] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Bjorn Ommer. High-resolution image syn- ¨ thesis with latent diffusion models. In CVPR, 2022. 2, 3, 4, 6, 9
[49] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. Unet: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pages 234–241. Springer, 2015. 2, 3
[50] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J Fleet, and Mohammad Norouzi. Photorealistic text-toimage diffusion models with deep language understanding. arXiv:2205.11487, 2022. 3
[51] Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen, and Xi Chen. Improved techniques for training GANs. In NeurIPS, 2016. 6
[52] Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P Kingma. PixelCNN++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications. arXiv preprint arXiv:1701.05517, 2017. 2
[53] Axel Sauer, Katja Schwarz, and Andreas Geiger. Styleganxl: Scaling stylegan to large diverse datasets. In SIGGRAPH, 2022. 9
[54] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In ICML, 2015. 3
[55] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv:2010.02502, 2020. 3
[56] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. In NeurIPS, 2019. 3
[57] Andreas Steiner, Alexander Kolesnikov, Xiaohua Zhai, Ross Wightman, Jakob Uszkoreit, and Lucas Beyer. How to train your ViT? data, augmentation, and regularization in vision transformers. TMLR, 2022. 6
[58] Aaron Van den Oord, Nal Kalchbrenner, Lasse Espeholt, Oriol Vinyals, Alex Graves, et al Conditional image generation with pixelcnn decoders. Advances in neural information processing systems, 29, 2016. 2
[59] Aaron Van Den Oord, Oriol Vinyals, et al Neural discrete representation learning. Advances in neural information processing systems, 30, 2017. 2
[60] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NeurIPS, 2017. 1, 2, 5
[61] Tete Xiao, Piotr Dollar, Mannat Singh, Eric Mintun, Trevor Darrell, and Ross Girshick. Early convolutions help transformers see better. In NeurIPS, 2021. 6
[62] Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong, Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku, Yinfei Yang, Burcu Karagol Ayan, et al Scaling autoregressive models for content-rich text-to-image generation. arXiv:2206.10789, 2022. 2
[63] Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lucas Beyer. Scaling vision transformers. In CVPR, 2022. 2, 5