扩散模型之DiT:纯Transformer架构 - 知乎扩散模型大部分是采用 UNet架构来进行建模,UNet可以实现输出和输入一样维度,所以天然适合扩散模型。扩散模型使用的UNet除了包含基于残差的卷积模块,同时也往往采用self-attention。自从ViT之后,transformer架…https://zhuanlan.zhihu.com/p/6410131571.introduction
基于transformer架构在扩散模型上的scalability的能力,展示了通过在VAE的潜空间中训练扩散模型的LDM的框架下构建和评估DiT,利用transformer替换UNet,其中最大的模型DiT-XL/2在ImageNet 256x256的类别条件生成达到了2.27FID。
2.Diffusion transformers
2.1 Preliminaries
在介绍DiT模型架构之前,我们先来看一下DiT所采用的扩散模型。 首先,DiT并没有采用常规的pixel diffusion,而是采用了latent diffusion架构,这也是Stable Diffusion所采用的架构。latent diffusion采用一个autoencoder来将图像压缩为低维度的latent,扩散模型用来生成latent,然后再采用autoencoder来重建出图像。DiT采用的autoencoder是SD所使用的KL-f8,对于256x256x3的图像,其压缩得到的latent大小为32x32x4,这就降低了扩散模型的计算量(后面我们会看到这将减少transformer的token数量)。另外,这里扩散过程的nosie scheduler采用简单的linear scheduler(timesteps=1000,beta_start=0.0001,beta_end=0.02),这个和SD是不同的。 其次,DiT所使用的扩散模型沿用了OpenAI的Improved DDPM,相比原始DDPM一个重要的变化是不再采用固定的方差,而是采用网络来预测方差。在DDPM中,生成过程的分布采用一个参数化的高斯分布来建模:
2.2 Diffusion Transformer design space
DiTs是新的扩散模型架构,重点是对图像的DDPM进行训练(图像的空间表示),DiT基于ViT架构。首先是一个patch embedding来将输入进行patch化,得到一系列的tokens,其中patch size属于一个超参数,直接决定了tokens的数量,影响模型的计算量。DiT的patch size有3种设置:2,4,8。token之后还要加上positional embeddings,采用非学习的sin-cos位置编码。将输入token化之后,可以像ViT一样接入transformer blocks,但是对于扩散模型,还需要在网络嵌入额外的条件信息,无论是timesteps还是类别标签,都可以采用一个embedding来进行编码,这和sd有所不同,DiT设计了四种方案来实现额外的另个embedding的嵌入:
1.In-context conditioning:将两个embeddings看成两个tokens合并在输入的tokens中,这种处理方式有点类似ViT中的cls token,实现起来比较简单,也不基本上不额外引入计算量。
2.Cross-attention block:将两个embeddings拼接成一个数量为2的序列,然后在transformer block中插入一个cross attention,条件embeddings作为cross attention的key和value;这种方式也是目前文生图模型所采用的方式,它需要额外引入15%的Gflops。
3.Adaptive layer norm(adaLN) block:采用adaLN,这里是将time embedding和class embedding相加,然后来回归scale和shift两个参数,这种方式也基本不增加计算量。
4.adaLN-Zero block:采用zero初始化的adaLN,这里是将adaLN的linear层参数初始化为zero,这样网络初始化时transformer block的残差模块就是一个identity函数;另外一点是,这里除了在LN之后回归scale和shift,还在每个残差模块结束之前回归一个scale。
上面四种嵌入,adaLN-Zero最好,DiT默认这种方式来嵌入条件embedding。DiT发现adaLN-Zero最好,但是这种方式只适合这种只有类别信息的简单条件嵌入,只需要引入一个class embedding,但对于文生图来说,条件往往是序列化的text embeddings,因此采用cross-attention通常是更合适的方式。
由于对输入进行了token化,所以在网络的最后还需要一个decoder来恢复输入的原始维度,DiT采用一个简单的linear层来实现,直接将每个token映射为pxpx2C的tensor,然后再进行reshape来得到和原始输入空间维度一样的输出,但是特征维度大小是原来的2倍,分别用来预测噪音和方差。
注意这里先进行LayerNorm,同时也引入了zero adaLN,并且decoder的linear层也采用zero初始化。 仿照ViT,DiT也设计了4种不同规模的模型,分别是DiT-S、DiT-B、DiT-L和DiT-XL,其中最大的模型DiT-XL参数量为675M,计算量Gflops为29.1(256x256图像,patch size=4时)。四个模型的具体配置如下所示:
论文重点探究了不同规模的DiT的性能,即模型的scalability能力,不同模型的性能对比如下所示:
在具体性能上,最大的模型DiT-XL/2采用classifier free guidance可以在class-conditional image generation on ImageNet 256×256任务上实现当时的sota。