1. Unet for text-to-image
笔记来源:
1.hkproj/pytorch-stable-diffusion
2.understanding u-net a comprehensive tutorial
3.Deep Dive into Self-Attention by Hand
4.Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024
Encoder
The encoder is responsible for capturing high-level features and reducing the spatial dimensions of the input image.
It consists of repeated blocks of convolutional layers followed by max-pooling layers, effectively downsampling the input.
Bottleneck
At the center of the U-Net is a bottleneck layer that captures the most critical features while maintaining spatial information.
Decoder
The decoder is responsible for upsampling the low-resolution feature maps to match the original input size.
It consists of repeated blocks of transposed convolutions (upsampling) followed by concatenation with corresponding feature maps from the contracting path.
1.1 TimeEmbedding
每次给U-Net输入一个t(每个time对应的图片中的噪声程度不同,t越大噪声程度越大,反之,t越小噪声程度越小)
import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention
class TimeEmbedding(nn.Module):
def __init__(self, n_embd):
super().__init__()
# First linear layer to expand embedding size
self.linear_1 = nn.Linear(n_embd, 4 * n_embd) # input,output
# Second linear layer for further processing
self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
def forward(self, x):
# x: (1, 320)
# Expand embedding size: (1, 320) -> (1, 1280)
x = self.linear_1(x)
# Apply SiLU activation function
# (1, 1280) -> (1, 1280)
x = F.silu(x)
# Further processing: (1, 1280) -> (1, 1280)
x = self.linear_2(x)
return x
1.2 ResnetBlock(Resnet+Time_embedding)
ResNetBlocks enable the model to learn richer and more complex feature representations by allowing multiple layers to focus on refining features without the risk of vanishing gradients.
下图来自知乎WeThinkIn
Convolutional Layer: Applies a convolution operation to extract features.
Normalization: Often Batch Normalization or Layer Normalization to stabilize and accelerate training.
Activation Function: Typically SiLU to introduce non-linearity.
Second Convolutional Layer: Another convolution to further process the features.
Normalization and Activation: Additional normalization and activation.
Residual Connection: Adds the input of the block to the output of the block.
import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention
class UNET_ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_time=1280):
super().__init__()
# GN
self.groupnorm_feature = nn.GroupNorm(32, in_channels)
# Conv
self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
# Linear
self.linear_time = nn.Linear(n_time, out_channels)
# 第一次融合结果输入第二次GSC中GN
self.groupnorm_merged = nn.GroupNorm(32, out_channels)
# 第二次GSC中最后Conv
self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
# 若residualblock输入channel和输出channel相同则直接skip,否则做一次conv
if in_channels == out_channels:
self.residual_layer = nn.Identity() # skip connection
else:
self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
def forward(self, feature, time):
# feature: (Batch_Size, In_Channels, Height, Width)
# time: (1, 1280)
##(1)对 latent feature 进行 GSC (GN+SiLU+Conv)
residue = feature
# GN
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
feature = self.groupnorm_feature(feature) #对latent feature进行归一化
# SiLU
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
feature = F.silu(feature) #对latent feature使用激活函数SiLU
# Conv
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
feature = self.conv_feature(feature)
##(2)对 Time Embedding 进行 SiLU+Linear
# (1, 1280) -> (1, 1280)
time = F.silu(time)
# (1, 1280) -> (1, Out_Channels)
time = self.linear_time(time)
## 对(1)(2)进行融合
# Add width and height dimension to time.
# (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
merged = feature + time.unsqueeze(-1).unsqueeze(-1)
## 对(1)(2)融合结果进行 GSC (GN+SiLU+Conv)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
merged = self.groupnorm_merged(merged)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
merged = F.silu(merged)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
merged = self.conv_merged(merged)
## latent feature进行skip connection 与(1)(2)融合后进行 GSC 后的结果 进行融合
# (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
return merged + self.residual_layer(residue)
1.3 SelfAttention
通过SelfAttention机制让模型理解不同位置的像素之间的依赖关系,以更好地理解图像语义
self attention map clearly expresses the outline information of objects
self-attention maps play a crucial role in preserving the geometric and shape details of the source image during the transformation to the target image.
Self-Attention(自注意力机制):自注意力机制的核心是为输入序列中的每一个位置学习一个权重分布,这样模型就能知道在处理当前位置时,哪些位置的信息更为重要。Self-Attention特指在序列内部进行的注意力计算,即序列中的每一个位置都要和其他所有位置进行注意力权重的计算。
下图为笔者个人理解(若有误,请在评论区指正)
Multi-Head Attention(多头注意力机制):为了让模型能够同时关注来自不同位置的信息,Transformer引入了Multi-Head Attention。它的基本思想是将输入序列的表示拆分成多个子空间(头),然后在每个子空间内独立地计算注意力权重,最后将各个子空间的结果拼接起来。这样做的好处是模型可以在不同的表示子空间中捕获到不同的上下文信息。 引用自:Self-Attention 和 Multi-Head Attention 的区别——附最通俗理解!!
下图为笔者个人理解(若有误,请在评论区指正)
下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024
the first component of the horse’s self-attention map clearly expresses the outline information of the horse.
形参
def init(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True)
(1) n_heads defines how many heads we split our embeddings into.
(2) d_embed defines the size of each embedding.
(3) in_proj_bias=True and out_proj_bias=True determine whether biases are added to the input and output projection layers, respectively.
def forward(self, x, causal_mask=False) # x: (Batch_Size, Seq_Len, Dim)
(1) Batch_Size refers to the number of samples processed together in one forward and backward pass of the model.
For instance, if you have 64 images in one batch, then your Batch_Size is 64.
(2) Seq_Len stands for Sequence Length, which, in the context of images, typically refers to the number of patches the image is divided into.
For an image of size H×W(Weight×Width)and patch size P×P,the sequence length (Seq_Len) would be (H×W)/(P×P). For example, an image of size 224×244 divided into 16×16 patches would result in 196 patches.
(3) Dim refers to the dimension of the embeddings or feature vectors for each token (patch).
For example, if each 16x16 patch is embedded into a vector of dimension 768, then Dim is 768.
in_proj() 为三个权重矩阵整合成的一个矩阵
import torch
from torch import nn
from torch.nn import functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# This combines the Wq, Wk and Wv matrices into one matrix
# d_embed: The dimension of the input embeddings.
# 3 * d_embed: The output dimension is three times the input dimension
# to produce the query, key, and value vectors in a single step.
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
# This one represents the Wo matrix
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads # dimension of each head
def forward(self, x, causal_mask=False):
# x: # (Batch_Size, Seq_Len, Dim)
# (Batch_Size, Seq_Len, Dim)
input_shape = x.shape
# (Batch_Size, Seq_Len, Dim)
# Unpack input dimensions
batch_size, sequence_length, d_embed = input_shape
# Shape to split heads and dimensions
# (Batch_Size, Seq_Len, H, Dim / H)
# The 'interim_shape' should be a tuple representing the desired shape, often including batch size,, sequence length, number of heads and head dimension.
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
# Project input tensor into query, key, and value matrices
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
# Splits the resulting tensor into three equal parts
# The '3' specifies that we want to split the tensor into 3 chunks.
# The 'dim=-1' specifies that the split should be done along the last dimension of the tensor.
q, k, v = self.in_proj(x).chunk(3, dim=-1)
# Reshape and transpose to separate heads
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.
# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.
q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)
# Q·K^T
# Compute scaled dot-product attention
# (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = q @ k.transpose(-1, -2) # @表示矩阵乘法
if causal_mask:
# Create a mask for the upper triangle (causal attention)
# Mask where the upper triangle (above the principal diagonal) is 1
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
# Apply the mask by filling the upper triangle with -inf
# Fill the upper triangle with -inf
weight.masked_fill_(mask, -torch.inf)
# (Q·K^T)/sqrt{d_k}
# Divide by d_k (Dim / H).
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight /= math.sqrt(self.d_head)
# softmax(Q·K^T/sqrt{d_k})
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = F.softmax(weight, dim=-1)
# softmax(Q·K^T/sqrt{d_k})·V
# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
output = weight @ v
# Transpose and reshape to combine heads
# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
output = output.transpose(1, 2)
# (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
output = output.reshape(input_shape)
# softmax(Q·K^T/sqrt{d_k})·V·W^O
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
output = self.out_proj(output)
# (Batch_Size, Seq_Len, Dim)
return output
1.4 CrossAttention
Cross Attention是一种多头注意力机制,它可在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列
Stable Diffusion中使用Cross Attention模块有助于在输入文本和生成图片之间建立联系,控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。
The cross-attention map is not only a weight measure of the conditional prompt at the corresponding positions in the generated image but also contains the semantic features of the
conditional token.
The cross-attention map enables the diffusion model to locate/align the tokens of the prompt in the image area.
下图为笔者个人理解(若有误,请在评论区指正)
下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024
下图来自知乎WeThinkIn
class CrossAttention(nn.Module):
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# W^Q
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
# W^K
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
# W^V
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
# W^O
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads #头个数
self.d_head = d_embed // n_heads # 每个头的维度
def forward(self, x, y):
# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
# Matrix C (Seq_Len_KV×Dim_KV)
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
# Q = X·W^Q
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
q = self.q_proj(x)
# K = Y·W^K
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
k = self.k_proj(y)
# V = Y·W^V
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
v = self.v_proj(y)
# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.
# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor. # Transpose the tensor by swapping the dimensions 1 and 2
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
q = q.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
k = k.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
v = v.view(interim_shape).transpose(1, 2)
# Q·K^T
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = q @ k.transpose(-1, -2)
# (Q·K^T)/sqrt{d_k}
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight /= math.sqrt(self.d_head)
# softmax(Q·K^T/sqrt{d_k})
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = F.softmax(weight, dim=-1)
# softmax(Q·K^T/sqrt{d_k})·V
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
output = weight @ v
# Transpose the tensor by swapping the dimensions 1 and 2
# Ensure the tensor is stored in contiguous memory
# This is important because some operations require the tensor to be contiguous in memory
# After the transpose operation, the tensor might not be stored contiguously
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
output = output.transpose(1, 2).contiguous()
# Reshape the tensor 'output' to the shape specified by 'input_shape'
# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = output.view(input_shape)
# softmax(Q·K^T/sqrt{d_k})·V·W^O
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = self.out_proj(output)
# (Batch_Size, Seq_Len_Q, Dim_Q)
return output
1.5 AttentionBlock (SelfAttention+CrossAttention)
下图改编自知乎WeThinkIn
class UNET_AttentionBlock(nn.Module):
def __init__(self, n_head: int, n_embd: int, d_context=768):
super().__init__()
channels = n_head * n_embd
# GN
self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
# Conv
self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
# LN
self.layernorm_1 = nn.LayerNorm(channels)
# SelfAttention
self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
# LN
self.layernorm_2 = nn.LayerNorm(channels)
# CrossAttention
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
# LN
self.layernorm_3 = nn.LayerNorm(channels)
# GeGLU
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
# Conv
self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
def forward(self, x, context):
# x: (Batch_Size, Features, Height, Width)
# context(text_embedding): (Batch_Size, Seq_Len, Dim)
residue_long = x
# GN
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
x = self.groupnorm(x)
# Conv
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
x = self.conv_input(x)
n, c, h, w = x.shape
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
x = x.view((n, c, h * w))
# (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
x = x.transpose(-1, -2)
# Normalization + Self-Attention with skip connection
# Basci Transformer Block
# (Batch_Size, Height * Width, Features)
residue_short = x
## LN_1
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.layernorm_1(x)
## Self Attention
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.attention_1(x)
## Skip Connection
# (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x += residue_short
# (Batch_Size, Height * Width, Features)
residue_short = x
# Normalization + Cross-Attention with skip connection
## LN_2
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.layernorm_2(x)
## Cross Attention
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.attention_2(x, context)
## Skip Connection
# (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x += residue_short
# (Batch_Size, Height * Width, Features)
residue_short = x
## LN_3
# Normalization + FFN with GeGLU and skip connection
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.layernorm_3(x)
## Feed Forward
### GeGLU
# GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
# (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
# Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
x = x * F.gelu(gate)
# (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
x = self.linear_geglu_2(x)
## Skip Connection
# (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x += residue_short
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
x = x.transpose(-1, -2)
# (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
x = x.view((n, c, h, w))
# Conv + Skip Connection
# Final skip connection between initial input and output of the block
# (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
return self.conv_output(x) + residue_long
1.6 Upsample
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
# 使用nearest进行上采样
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
1.7 SwitchSequential
# Define a custom sequential container class named SwitchSequential
# Inherits from nn.Sequential, which is a container module from PyTorch
class SwitchSequential(nn.Sequential):
# Define the forward method, which specifies how the input data flows through the layers
# x: the input tensor
# context: additional context information, possibly used for attention mechanisms
# time: additional time information, possibly used for temporal aspects in certain layers
def forward(self, x, context, time):
for layer in self:
# Check if the current layer is an instance of UNET_AttentionBlock
# Pass the input tensor and context information through the attention block
if isinstance(layer, UNET_AttentionBlock):
x = layer(x, context)
# Check if the current layer is an instance of UNET_ResidualBlock
# Pass the input tensor and time information through the residual block
elif isinstance(layer, UNET_ResidualBlock):
x = layer(x, time)
# For all other types of layers
# Simply pass the input tensor through the layer
else:
x = layer(x)
return x
1.8 Unet
class UNET(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.encoders = nn.ModuleList([
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
# CrossAttentionDownBlock2d_1 (320 channels)
## ResnetBlock+AttentionBlock
# (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## Downsample2D
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_2 (640 channels)
## ResnetBlock+AttentionBlock
# (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
## Downsample2D
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_3 (1280 channels)
## ResnetBlock+AttentionBlock
# (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
# Downsample2D
## ResnetBlock+ResnetBlock
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
## DownBlock2D
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
])
# Bottleneck
self.bottleneck = SwitchSequential(
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_ResidualBlock(1280, 1280),
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_AttentionBlock(8, 160),
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_ResidualBlock(1280, 1280),
)
# Decoder
self.decoders = nn.ModuleList([
# UpBlock2D
## ResnetBlock+ResnetBlock
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# CrossAttentionUpBlock2d_3 (1280 channels)
## upsample
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
# CrossAttentionUpBlock2d_2 (640 channels)
# (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
# CrossAttentionUpBlock2d_1 (1280 channels)
# (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock
# (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
])
def forward(self, x, context, time):
# x: (Batch_Size, 4, Height / 8, Width / 8)
# context: (Batch_Size, Seq_Len, Dim)
# time: (1, 1280)
skip_connections = []
# Encoder
for layers in self.encoders:
x = layers(x, context, time)
skip_connections.append(x) # 将每层lay的输出添加到列表中,便于后续up中进行skip connection
# Bottleneck
x = self.bottleneck(x, context, time)
# Decoder
for layers in self.decoders:
# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
x = torch.cat((x, skip_connections.pop()), dim=1)
x = layers(x, context, time)
return x
1.9 OutputLayer
class UNET_OutputLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.groupnorm = nn.GroupNorm(32, in_channels)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
# x: (Batch_Size, 320, Height / 8, Width / 8)
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
x = self.groupnorm(x)
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
x = F.silu(x)
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
x = self.conv(x)
# (Batch_Size, 4, Height / 8, Width / 8)
return x
1.10 Diffusion
class Diffusion(nn.Module):
def __init__(self):
super().__init__()
self.time_embedding = TimeEmbedding(320)
self.unet = UNET()
self.final = UNET_OutputLayer(320, 4)
def forward(self, latent, context, time):
# latent: (Batch_Size, 4, Height / 8, Width / 8)
# context: (Batch_Size, Seq_Len, Dim)
# time: (1, 320)
# (1, 320) -> (1, 1280)
time = self.time_embedding(time)
# (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
output = self.unet(latent, context, time)
# (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
output = self.final(output)
# (Batch, 4, Height / 8, Width / 8)
return output