系列文章目录
- 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
- 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
- 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
- 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
- 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
- 【扩散模型(六)】Stable Diffusion 3 diffusers 源码详解1-推理代码-文本处理部分
- 本系列将对比介绍 DiT 和 MMDiT 的区别和具体的代码实现,本文先介绍 DiT 的核心代码。
文章目录
- 系列文章目录
- 一、DiT
- DiT 整体代码
- DiT Block
- 这六个参数是否是相同的值?
- 代替 Cross-attention 的 adaLN-Zero Block
一、DiT
- DiT 1 是 SD3 中 MMDiT 的核心基础,而
- 通过将 Diffusion 中的 Unet 换成了 DiT Block,来实现基于条件的图像生成。
- 原文中的条件是类别标签,而非文本提示词。
- 原文测试了多种设置,最终采用了 adaLN-Zero 作为 Cross-Attention 的替代。
DiT 整体代码
官方代码仓库为 https://github.com/facebookresearch/DiT,下面代码的具体位置在 /path/to/DiT/models.py
- 下方代码为上图的左边部分,输入 x 是 Noised Latent,t 是 Timestep,Label 为 y
- 其中 block 则是上图中的 DiT Block,将 x 和 c 共同作为输入,以 c 为条件来生成 x (对 Noised Latent 进行去噪)。
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
DiT Block
与下面代码中 forward 函数内对应的变量在 DiT Block 中的位置。
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
在这个 DiTBlock
类中,shift_msa
、scale_msa
、gate_msa
、shift_mlp
、scale_mlp
和 gate_mlp
是从 adaLN_modulation(c)
这一步中得到的,它们在具体功能上是有所区别的,虽然它们是通过同一个输入 c
生成的。
-
shift_msa
和scale_msa
这两个变量与Multi-Head Self-Attention (MSA)
模块的自适应层归一化(adaptive LayerNorm, adaLN)有关:shift_msa
: 这个变量用于平移LayerNorm
的输出,也就是在归一化的基础上加上一个偏置。它在调节 MSA 模块的激活输出时用作偏移量。scale_msa
: 这个变量用于缩放LayerNorm
的输出,即对归一化的结果乘以一个比例因子。它控制了 MSA 模块中激活的放大或缩小程度。
-
gate_msa
: 这个变量是作为一个门控(gate)信号,作用于 MSA 模块的输出上。它决定了 MSA 模块输出在累加到x
之前的权重。如果gate_msa
很小,那么这个输出会被抑制;如果gate_msa
接近1,则输出会如常累加。 -
shift_mlp
和scale_mlp
这两个变量与Pointwise Feedforward (MLP)
模块的自适应层归一化(adaLN)有关,类似于shift_msa
和scale_msa
,但它们作用在 MLP 模块上:shift_mlp
: 用于平移LayerNorm
的输出,在 MLP 模块中作为偏移量。scale_mlp
: 用于缩放LayerNorm
的输出,在 MLP 模块中控制激活的放大或缩小。
-
gate_mlp
: 类似于gate_msa
,但它控制的是 MLP 模块的输出。它决定了 MLP 模块输出在累加到x
之前的权重。
这六个参数是否是相同的值?
在 adaLN_modulation(c)
中,c 经过一个 nn.Linear 层(即 nn.Linear(hidden_size, 6 * hidden_size, bias=True)),然后被 chunk(6, dim=1) 分成六个部分,分别得到 shift_msa、scale_msa、gate_msa、shift_mlp、scale_mlp 和 gate_mlp。
虽然这些变量来自于同一个线性层的输出,但由于 nn.Linear 层的权重在训练过程中是可学习的,并且是随机初始化的,因此这些权重会在训练过程中被更新为不同的值。
代替 Cross-attention 的 adaLN-Zero Block
那么为什用 adaLN-Zero 来代替 Cross-Attention 呢?主要是因为计算资源。(DiT 原文提到 Cross-attention adds the most Gflops to the model, roughly a 15% overhead.)
-
什么是adaLN-Zero Block?
adaLN-Zero Block是一种改进版的adaLN(Adaptive Layer Normalization)模块,主要用于扩散模型(Diffusion Model)中。它的核心思想是通过初始化技巧和引入额外的缩放参数,来加速模型训练并提高生成样本的质量。 -
为什么引入adaLN-Zero Block?
- 加速训练: 通过将残差块初始化为恒等映射,模型在训练初期更容易收敛,从而加快训练速度。
- 提升性能: 引入维度缩放参数,使得模型能够学习到更具表达能力的特征表示,从而生成质量更高的样本。
- 增强稳定性: 恒等初始化有助于稳定模型的训练过程,尤其对于深层模型。
-
adaLN-Zero Block的工作原理
- 恒等初始化: 对于每个残差块的最后一个adaLN层,将缩放参数γ初始化为0。这使得该层在初始阶段相当于一个恒等映射,不会对输入数据进行缩放。
- 维度缩放参数α: 在残差连接之前,引入一个维度缩放参数α,用于对特征进行缩放。这个参数是可学习的,能够自适应地调整特征的尺度。
-
与传统adaLN的区别
- 初始化方式不同: adaLN-Zero对缩放参数γ进行了特殊的初始化,而传统的adaLN通常使用随机初始化。
- 参数数量增加: adaLN-Zero引入了额外的维度缩放参数α,增加了模型的参数数量。
-
为什么有效?
- 恒等初始化使得模型在训练初期能够快速学习到残差部分,从而加速训练过程。
- 维度缩放参数提供了更大的灵活性,使得模型能够更好地适应不同尺度的特征。
最后也附上原文,便于对照理解。
Peebles, William, and Saining Xie. “Scalable diffusion models with transformers.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023. ↩︎