Diffusion Models专栏文章汇总:入门与实战
前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。
目录
DiT相比于Unet的关键改进点
Token化方法
因果3D卷积
Adaptive Layer Norm (adaLN) block
完整DiT Block 设计
DiT相比于Unet的关键改进点
虽然Transformer架构已经在诸多自然语言处理和计算机视觉任务中展现出卓越的scalable能力,但目前主导扩散模型架构的仍是UNet。
采用DiT
架构替换UNet
主要需要探索以下几个关键问题:
- Token化处理。Transformer的输入为一维序列,形式为𝑅𝑇×𝑑RT×d(忽略batch维度),而
LDM
的latent表征𝑧∈𝑅𝐻𝑓×𝑊𝑓×𝐶z∈RfH×fW×C为spatial张量。因此,需要设计合适的Token化方法将二维latent映射为一维序列。 - 条件信息嵌入。sable diffusion火出圈的一个关键在于它能够根据用户的文本指令生成高质量的图像。这里面的核心在于需要将文本特征嵌入到扩散模型中协同生成。并且扩散模型的每一个生成还需要融入time-embedding来引入时间步的信息。因此,若要用Transformer架构取代
Unet
需要系统研究Transformer架构的条件嵌入
Token化方法
假定原始图片𝑥∈𝑅256×256×3,经过auto-encoder
后得到latent表征𝑧∈𝑅32×32×4。首先DiT
用ViT中patch化的方式将隐表征𝑧转化为token序列,随后给序列添加位置编码。图中展示了patch化的过程。patch_size p
是一个超参数。
刚才是DiT原始论文的描述,在视频里用了一个PatchEmbed3D 执行Token化:
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size=(2, 4, 4),
in_chans=3,
embed_dim=96,
norm_layer=None,
flatten=True,
):
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # (B C T H W)
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
return x
先把视频的长宽和时间长都填充成偶数,然后用一个3D卷积,把时间、空间都进一步压缩,Channel从4膨胀到96,然后把时空都压缩到一起,即:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
因果3D卷积
刚才Token化用的是普通的3D卷积,其他有些代码里用了因果3D卷积,因果3D卷积在视频任务里非常常用:
因果3D卷积(Causal 3D Convolution)是一种特殊的3D卷积,它在处理具有时间维度的数据(如视频)时保持因果性。这意味着在生成当前时间点的输出时,它只依赖于当前和之前的时间点,而不依赖于未来的时间点。卷积核在时间维度上滑动,它也只会接触到当前和过去的帧。这在序列建模和时间序列预测等任务中非常重要,因为它们需要保证模型输出的因果关系。
与传统的3D卷积相比,因果3D卷积在时间维度上增加了填充(padding),以确保输出的时间长度与输入相同。这种填充通常是在时间维度的开始处添加,而不是在两端添加,这样可以保证在预测当前帧时不会使用到后续帧的信息。通过在时间轴的正方向上(即未来的方向)添加适当的零填充来实现这一点。
下面是EasyAnimate的实现代码:
class CausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3, # : int | tuple[int, int, int],
stride=1, # : int | tuple[int, int, int] = 1,
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
dilation=1, # : int | tuple[int, int, int] = 1,
**kwargs,
):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
t_ks, h_ks, w_ks = kernel_size
_, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
t_pad = (t_ks - 1) * t_dilation
# TODO: align with SD
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
self.temporal_padding = t_pad
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
self.padding_flag = 0
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
**kwargs,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
if self.padding_flag == 0:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
else:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
)
return super().forward(x)
def set_padding_one_frame(self):
def _set_padding_one_frame(name, module):
if hasattr(module, 'padding_flag'):
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 1
for sub_name, sub_mod in module.named_children():
_set_padding_one_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_one_frame(name, module)
def set_padding_more_frame(self):
def _set_padding_more_frame(name, module):
if hasattr(module, 'padding_flag'):
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 2
for sub_name, sub_mod in module.named_children():
_set_padding_more_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_more_frame(name, module)
Adaptive Layer Norm (adaLN) block
这是DiT里面最核心的设计之一,adaptive normalization layer(adaLN
),将transformer block的layer norm替换为adaLN
。简单来说就是,原本的将原本layer norm用于仿射变换的scale parameter 𝛾和shift parameter 𝛽 用condition embedding来替代。
原始的Layer Norm设计:
class LayerNorm:
def __init__(self, feature_dim, epsilon=1e-6):
self.epsilon = epsilon
self.gamma = np.random.rand(feature_dim) # scale parameters
self.beta = np.random.rand(feature_dim) # shift parametrs
def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Args:
x (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
return:
x_layer_norm (np.ndarray): shape: (batch_size, sequence_length, feature_dim)
"""
_mean = np.mean(x, axis=-1, keepdims=True)
_std = np.var(x, axis=-1, keepdims=True)
x_layer_norm = self.gamma * (x - _mean / (_std + self.epsilon)) + self.beta
return x_layer_norm
DiT中的adaLN设计:
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
enable_flash_attn=False,
enable_layernorm_kernel=False,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.enable_flash_attn = enable_flash_attn
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flash_attn=enable_flash_attn,
)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp))
return x
完整DiT Block 设计
好了,到这里已经是把主要的DiT构建出来了,接下来把DiT结构堆积28层,构成了现在的DiT结构:
@MODELS.register_module()
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=(16, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
learn_sigma=True,
condition="text",
no_temporal_pos_emb=False,
caption_channels=512,
model_max_length=77,
dtype=torch.float32,
enable_flash_attn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.use_text_encoder = not condition.startswith("label")
if enable_flash_attn:
assert dtype in [
torch.float16,
torch.bfloat16,
], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}"
self.no_temporal_pos_emb = no_temporal_pos_emb
self.mlp_ratio = mlp_ratio
self.depth = depth
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"
self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size)
if not self.use_text_encoder:
num_classes = int(condition.split("_")[-1])
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
else:
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=1, # pooled token
)
self.t_embedder = TimestepEmbedder(hidden_size)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
enable_flash_attn=enable_flash_attn,
enable_layernorm_kernel=enable_layernorm_kernel,
)
for _ in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
self.enable_flash_attn = enable_flash_attn
self.enable_layernorm_kernel = enable_layernorm_kernel
def get_spatial_pos_embed(self):
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
self.input_size[1] // self.patch_size[1],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def unpatchify(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (B, C, T, H, W) tensor of inputs
t: (B,) tensor of diffusion timesteps
y: list of text
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
if self.use_text_encoder:
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
x = x + self.pos_embed_spatial
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")
x = x + self.pos_embed_temporal
x = rearrange(x, "b s t d -> b (t s) d")
else:
x = rearrange(x, "b t s d -> b (t s) d")
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
if self.use_text_encoder:
y = y.squeeze(1).squeeze(1)
condition = t + y
# blocks
for _, block in enumerate(self.blocks):
c = condition
x = auto_grad_checkpoint(block, x, c) # (B, N, D)
# final process
x = self.final_layer(x, condition) # (B, N, num_patches * out_channels)
x = self.unpatchify(x) # (B, out_channels, T, H, W)
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad_:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
# Zero-out text embedding layers:
if self.use_text_encoder:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)