U-Net for text-to-image

news2025/1/22 20:19:23

1. Unet for text-to-image

笔记来源:
1.hkproj/pytorch-stable-diffusion
2.understanding u-net a comprehensive tutorial
3.Deep Dive into Self-Attention by Hand
4.Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024


Encoder

The encoder is responsible for capturing high-level features and reducing the spatial dimensions of the input image.
It consists of repeated blocks of convolutional layers followed by max-pooling layers, effectively downsampling the input.

Bottleneck

At the center of the U-Net is a bottleneck layer that captures the most critical features while maintaining spatial information.

Decoder

The decoder is responsible for upsampling the low-resolution feature maps to match the original input size.
It consists of repeated blocks of transposed convolutions (upsampling) followed by concatenation with corresponding feature maps from the contracting path.

1.1 TimeEmbedding

每次给U-Net输入一个t(每个time对应的图片中的噪声程度不同,t越大噪声程度越大,反之,t越小噪声程度越小)

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention

class TimeEmbedding(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        # First linear layer to expand embedding size
        self.linear_1 = nn.Linear(n_embd, 4 * n_embd) # input,output
        # Second linear layer for further processing
        self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)

    def forward(self, x):
        # x: (1, 320)
		# Expand embedding size: (1, 320) -> (1, 1280)
        x = self.linear_1(x)
        # Apply SiLU activation function
        # (1, 1280) -> (1, 1280)
        x = F.silu(x) 
        # Further processing: (1, 1280) -> (1, 1280)
        x = self.linear_2(x)
        return x

1.2 ResnetBlock(Resnet+Time_embedding)

ResNetBlocks enable the model to learn richer and more complex feature representations by allowing multiple layers to focus on refining features without the risk of vanishing gradients.

下图来自知乎WeThinkIn

Convolutional Layer: Applies a convolution operation to extract features.
Normalization: Often Batch Normalization or Layer Normalization to stabilize and accelerate training.
Activation Function: Typically SiLU to introduce non-linearity.
Second Convolutional Layer: Another convolution to further process the features.
Normalization and Activation: Additional normalization and activation.
Residual Connection: Adds the input of the block to the output of the block.

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention

class UNET_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_time=1280):
        super().__init__()
        # GN
        self.groupnorm_feature = nn.GroupNorm(32, in_channels)
        # Conv
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        # Linear
        self.linear_time = nn.Linear(n_time, out_channels)
		# 第一次融合结果输入第二次GSC中GN
        self.groupnorm_merged = nn.GroupNorm(32, out_channels)
        # 第二次GSC中最后Conv
        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
		# 若residualblock输入channel和输出channel相同则直接skip,否则做一次conv
        if in_channels == out_channels:
            self.residual_layer = nn.Identity() # skip connection
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, feature, time):
        # feature: (Batch_Size, In_Channels, Height, Width)
        # time: (1, 1280)
        
##(1)对 latent feature 进行 GSC (GN+SiLU+Conv)
        residue = feature
        # GN
        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
        feature = self.groupnorm_feature(feature) #对latent feature进行归一化
        # SiLU
        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
        feature = F.silu(feature) #对latent feature使用激活函数SiLU
        # Conv
        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        feature = self.conv_feature(feature)
        
##(2)对 Time Embedding 进行 SiLU+Linear
        # (1, 1280) -> (1, 1280)
        time = F.silu(time)
        # (1, 1280) -> (1, Out_Channels)
        time = self.linear_time(time)
## 对(1)(2)进行融合
        # Add width and height dimension to time. 
        # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
        merged = feature + time.unsqueeze(-1).unsqueeze(-1)
## 对(1)(2)融合结果进行 GSC (GN+SiLU+Conv)        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        merged = self.groupnorm_merged(merged)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        merged = F.silu(merged)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        merged = self.conv_merged(merged)
## latent feature进行skip connection 与(1)(2)融合后进行 GSC 后的结果 进行融合
        # (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        return merged + self.residual_layer(residue)

1.3 SelfAttention

通过SelfAttention机制让模型理解不同位置的像素之间的依赖关系,以更好地理解图像语义

self attention map clearly expresses the outline information of objects
self-attention maps play a crucial role in preserving the geometric and shape details of the source image during the transformation to the target image.


Self-Attention(自注意力机制):自注意力机制的核心是为输入序列中的每一个位置学习一个权重分布,这样模型就能知道在处理当前位置时,哪些位置的信息更为重要。Self-Attention特指在序列内部进行的注意力计算,即序列中的每一个位置都要和其他所有位置进行注意力权重的计算。

下图为笔者个人理解(若有误,请在评论区指正)

Multi-Head Attention(多头注意力机制):为了让模型能够同时关注来自不同位置的信息,Transformer引入了Multi-Head Attention。它的基本思想是将输入序列的表示拆分成多个子空间(头),然后在每个子空间内独立地计算注意力权重,最后将各个子空间的结果拼接起来。这样做的好处是模型可以在不同的表示子空间中捕获到不同的上下文信息。 引用自:Self-Attention 和 Multi-Head Attention 的区别——附最通俗理解!!

下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024

the first component of the horse’s self-attention map clearly expresses the outline information of the horse.



形参

def init(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True)
(1) n_heads defines how many heads we split our embeddings into.
(2) d_embed defines the size of each embedding.
(3) in_proj_bias=True and out_proj_bias=True determine whether biases are added to the input and output projection layers, respectively.

def forward(self, x, causal_mask=False) # x: (Batch_Size, Seq_Len, Dim)
(1) Batch_Size refers to the number of samples processed together in one forward and backward pass of the model.
For instance, if you have 64 images in one batch, then your Batch_Size is 64.
(2) Seq_Len stands for Sequence Length, which, in the context of images, typically refers to the number of patches the image is divided into.
For an image of size H×W(Weight×Width)and patch size P×P,the sequence length (Seq_Len) would be (H×W)/(P×P). For example, an image of size 224×244 divided into 16×16 patches would result in 196 patches.
(3) Dim refers to the dimension of the embeddings or feature vectors for each token (patch).
For example, if each 16x16 patch is embedded into a vector of dimension 768, then Dim is 768.

in_proj() 为三个权重矩阵整合成的一个矩阵

import torch
from torch import nn
from torch.nn import functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        # This combines the Wq, Wk and Wv matrices into one matrix
        # d_embed: The dimension of the input embeddings.
        # 3 * d_embed: The output dimension is three times the input dimension 
        # to produce the query, key, and value vectors in a single step.
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        # This one represents the Wo matrix
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads # dimension of each head

    def forward(self, x, causal_mask=False):
        # x: # (Batch_Size, Seq_Len, Dim)

        # (Batch_Size, Seq_Len, Dim)
        input_shape = x.shape 
        
        # (Batch_Size, Seq_Len, Dim)
        # Unpack input dimensions
        batch_size, sequence_length, d_embed = input_shape
         
		# Shape to split heads and dimensions
        # (Batch_Size, Seq_Len, H, Dim / H)
        # The 'interim_shape' should be a tuple representing the desired shape, often including batch size,, sequence length, number of heads and head dimension.
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) 
        
		# Project input tensor into query, key, and value matrices
        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
        # Splits the resulting tensor into three equal parts
        # The '3' specifies that we want to split the tensor into 3 chunks.
        # The 'dim=-1' specifies that the split should be done along the last dimension of the tensor.
        q, k, v = self.in_proj(x).chunk(3, dim=-1)
        
        # Reshape and transpose to separate heads
        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
        # q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.
        # .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)
# Q·K^T	
		# Compute scaled dot-product attention
        # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight = q @ k.transpose(-1, -2) # @表示矩阵乘法
        
        if causal_mask:
        	# Create a mask for the upper triangle (causal attention)
            # Mask where the upper triangle (above the principal diagonal) is 1
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1) 
            # Apply the mask by filling the upper triangle with -inf
            # Fill the upper triangle with -inf
            weight.masked_fill_(mask, -torch.inf) 
# (Q·K^T)/sqrt{d_k}        
        # Divide by d_k (Dim / H). 
        # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight /= math.sqrt(self.d_head) 
# softmax(Q·K^T/sqrt{d_k})
        # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight = F.softmax(weight, dim=-1) 
# softmax(Q·K^T/sqrt{d_k})·V
        # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
        output = weight @ v
        
		# Transpose and reshape to combine heads
        # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
        output = output.transpose(1, 2) 

        # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
        output = output.reshape(input_shape) 
        
# softmax(Q·K^T/sqrt{d_k})·V·W^O
        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
        output = self.out_proj(output) 
        
        # (Batch_Size, Seq_Len, Dim)
        return output

1.4 CrossAttention

Cross Attention是一种多头注意力机制,它可在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列
Stable Diffusion中使用Cross Attention模块有助于在输入文本和生成图片之间建立联系,控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。

The cross-attention map is not only a weight measure of the conditional prompt at the corresponding positions in the generated image but also contains the semantic features of the
conditional token.
The cross-attention map enables the diffusion model to locate/align the tokens of the prompt in the image area.


下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024



下图来自知乎WeThinkIn

class CrossAttention(nn.Module):
    def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        # W^Q
        self.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        # W^K
        self.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        # W^V
        self.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        # W^O
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads #头个数
        self.d_head = d_embed // n_heads # 每个头的维度 
    
    def forward(self, x, y):
        # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
        # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
		# Matrix C (Seq_Len_KV×Dim_KV)
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
        interim_shape = (batch_size, -1, self.n_heads, self.d_head)
# Q = X·W^Q
        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        q = self.q_proj(x)
# K = Y·W^K
        # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
        k = self.k_proj(y)
# V = Y·W^V
        # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
        v = self.v_proj(y)
        
		# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.
        # .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.	# Transpose the tensor by swapping the dimensions 1 and 2
        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
        q = q.view(interim_shape).transpose(1, 2) 
        # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
        k = k.view(interim_shape).transpose(1, 2) 
        # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
        v = v.view(interim_shape).transpose(1, 2) 
# Q·K^T        
        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight = q @ k.transpose(-1, -2)
# (Q·K^T)/sqrt{d_k}          
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight /= math.sqrt(self.d_head)
# softmax(Q·K^T/sqrt{d_k})        
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight = F.softmax(weight, dim=-1)
# softmax(Q·K^T/sqrt{d_k})·V        
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
        output = weight @ v
        
        # Transpose the tensor by swapping the dimensions 1 and 2
        # Ensure the tensor is stored in contiguous memory
        # This is important because some operations require the tensor to be contiguous in memory
        # After the transpose operation, the tensor might not be stored contiguously
        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
        output = output.transpose(1, 2).contiguous()
        
        # Reshape the tensor 'output' to the shape specified by 'input_shape'
        # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        output = output.view(input_shape)
# softmax(Q·K^T/sqrt{d_k})·V·W^O        
        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        output = self.out_proj(output)

        # (Batch_Size, Seq_Len_Q, Dim_Q)
        return output

1.5 AttentionBlock (SelfAttention+CrossAttention)

下图改编自知乎WeThinkIn

class UNET_AttentionBlock(nn.Module):
    def __init__(self, n_head: int, n_embd: int, d_context=768):
        super().__init__()
        channels = n_head * n_embd
        # GN
        self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
        # Conv
        self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
		# LN
        self.layernorm_1 = nn.LayerNorm(channels)
        # SelfAttention
        self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
        # LN
        self.layernorm_2 = nn.LayerNorm(channels)
        # CrossAttention
        self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
        # LN
        self.layernorm_3 = nn.LayerNorm(channels)
        # GeGLU
        self.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)
        self.linear_geglu_2 = nn.Linear(4 * channels, channels)
		# Conv
        self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
    
    def forward(self, x, context):
        # x: (Batch_Size, Features, Height, Width)
        # context(text_embedding): (Batch_Size, Seq_Len, Dim) 
		
        residue_long = x
# GN
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
        x = self.groupnorm(x)
# Conv        
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
        x = self.conv_input(x)
        
        n, c, h, w = x.shape
        
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
        x = x.view((n, c, h * w))
        
        # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
        x = x.transpose(-1, -2)
        
        # Normalization + Self-Attention with skip connection
# Basci Transformer Block
        # (Batch_Size, Height * Width, Features)
        residue_short = x
## LN_1    
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.layernorm_1(x)
## Self Attention       
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.attention_1(x)
## Skip Connection       
        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x += residue_short
        
        # (Batch_Size, Height * Width, Features)
        residue_short = x

        # Normalization + Cross-Attention with skip connection
## LN_2        
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.layernorm_2(x)
## Cross Attention        
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.attention_2(x, context)
## Skip Connection        
        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x += residue_short
        
        # (Batch_Size, Height * Width, Features)
        residue_short = x
## LN_3  
        # Normalization + FFN with GeGLU and skip connection      
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.layernorm_3(x)
## Feed Forward
### GeGLU        
        # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
        # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
        x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) 
        
        # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
        x = x * F.gelu(gate)
        
        # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
        x = self.linear_geglu_2(x)
## Skip Connection        
        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x += residue_short
        
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
        x = x.transpose(-1, -2)
        
        # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
        x = x.view((n, c, h, w))
# Conv + Skip Connection
        # Final skip connection between initial input and output of the block
        # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
        return self.conv_output(x) + residue_long

1.6 Upsample

class Upsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
        # 使用nearest进行上采样
        x = F.interpolate(x, scale_factor=2, mode='nearest') 
        return self.conv(x)

1.7 SwitchSequential

# Define a custom sequential container class named SwitchSequential
# Inherits from nn.Sequential, which is a container module from PyTorch
class SwitchSequential(nn.Sequential):
# Define the forward method, which specifies how the input data flows through the layers
# x: the input tensor
# context: additional context information, possibly used for attention mechanisms
# time: additional time information, possibly used for temporal aspects in certain layers
    def forward(self, x, context, time):
        for layer in self:
        	# Check if the current layer is an instance of UNET_AttentionBlock
        	# Pass the input tensor and context information through the attention block
            if isinstance(layer, UNET_AttentionBlock):
                x = layer(x, context)
            # Check if the current layer is an instance of UNET_ResidualBlock
            # Pass the input tensor and time information through the residual block
            elif isinstance(layer, UNET_ResidualBlock):
                x = layer(x, time)
            # For all other types of layers
            # Simply pass the input tensor through the layer
            else:
                x = layer(x)
        return x

1.8 Unet

class UNET(nn.Module):
    def __init__(self):
        super().__init__()
# Encoder
        self.encoders = nn.ModuleList([
            # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
# CrossAttentionDownBlock2d_1 (320 channels)
## ResnetBlock+AttentionBlock           
            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## Downsample2D             
            # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
            SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_2 (640 channels)  
## ResnetBlock+AttentionBlock          
            # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
            SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock          
            # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
            SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
## Downsample2D             
            # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
            SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_3 (1280 channels)
## ResnetBlock+AttentionBlock           
            # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
            SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
            SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
# Downsample2D
## ResnetBlock+ResnetBlock            
            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
            SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
## DownBlock2D            
            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            SwitchSequential(UNET_ResidualBlock(1280, 1280)),
            
            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            SwitchSequential(UNET_ResidualBlock(1280, 1280)),
        ])
# Bottleneck
        self.bottleneck = SwitchSequential(
            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            UNET_ResidualBlock(1280, 1280), 
            
            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            UNET_AttentionBlock(8, 160), 
            
            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            UNET_ResidualBlock(1280, 1280), 
        )
# Decoder        
        self.decoders = nn.ModuleList([
# UpBlock2D
## ResnetBlock+ResnetBlock 
            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            SwitchSequential(UNET_ResidualBlock(2560, 1280)),
            
            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
            SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# CrossAttentionUpBlock2d_3 (1280 channels)
## upsample            
            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32) 
            SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
## ResnetBlock+AttentionBlock           
            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
            SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
# CrossAttentionUpBlock2d_2 (640 channels)            
            # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
            SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
            SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
# CrossAttentionUpBlock2d_1 (1280 channels)           
            # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            
            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
        ])

    def forward(self, x, context, time):
        # x: (Batch_Size, 4, Height / 8, Width / 8)
        # context: (Batch_Size, Seq_Len, Dim) 
        # time: (1, 1280)

        skip_connections = []
# Encoder
        for layers in self.encoders:
            x = layers(x, context, time)
            skip_connections.append(x) # 将每层lay的输出添加到列表中,便于后续up中进行skip connection
# Bottleneck
        x = self.bottleneck(x, context, time)
# Decoder
        for layers in self.decoders:
            # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
            x = torch.cat((x, skip_connections.pop()), dim=1) 
            x = layers(x, context, time)
        
        return x

1.9 OutputLayer

class UNET_OutputLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # x: (Batch_Size, 320, Height / 8, Width / 8)

        # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
        x = self.groupnorm(x)
        
        # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
        x = F.silu(x)
        
        # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
        x = self.conv(x)
        
        # (Batch_Size, 4, Height / 8, Width / 8) 
        return x

1.10 Diffusion

class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)
        self.unet = UNET()
        self.final = UNET_OutputLayer(320, 4)
    
    def forward(self, latent, context, time):
        # latent: (Batch_Size, 4, Height / 8, Width / 8)
        # context: (Batch_Size, Seq_Len, Dim)
        # time: (1, 320)

        # (1, 320) -> (1, 1280)
        time = self.time_embedding(time)
        
        # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
        output = self.unet(latent, context, time)
        
        # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
        output = self.final(output)
        
        # (Batch, 4, Height / 8, Width / 8)
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1873038.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 英文单词联想(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 &#x1f…

使用uniapp编写微信小程序

使用uniapp编写微信小程序 文章目录 使用uniapp编写微信小程序前言一、项目搭建1.1 创建项目方式1.1.1 HBuilderX工具创建1.1.2 命令行下载1.1.3 直接Gitee下载 1.2 项目文件解构1.2.1 安装依赖1.2.2 项目启动1.2.3 文件结构释义 1.2 引入uni-ui介绍 二、拓展2.1 uni-app使用uc…

CVPR 2024最佳论文分享:生成图像动力学

CVPR 2024最佳论文分享:生成图像动力学 CVPR(Conference on Computer Vision and Pattern Recognition)是计算机视觉领域最有影响力的会议之一,主要方向包括图像和视频处理、目标检测与识别、三维视觉等。近期,CVPR 2…

盘点7款适合团队使用的知识库工具

作为一名技术爱好者和企业管理者,我深知知识库工具在日常工作中的重要性。 无论是个人笔记管理还是企业知识共享,知识库工具都能极大地提升我们的工作效率和信息管理水平。 根据麦肯锡全球研究院报告显示,使用知识库工具可以帮助个人或者企…

JavaWeb-day28_HTML

今日内容 零、 复习昨日 一、HTML 零、 复习昨日 一、Web开发 前端三大件 HTML ,页面展现CSS , 样式JS (JavaScript) , 动起来 二、HTML 2.1 HTML概念 ​ 网页,是网站中的一个页面,通常是网页是构成网站的基本元素,是承载各种网站应用的平台…

普乐蛙景区9d电影体验馆商场影院娱乐设备旋转飞行影院

今天与大家聊聊VR娱乐新潮流,我们普乐蛙的新品——旋转飞行影院!裸眼7D环幕影院,话不多说上产品!我们通过亲身体验来给大家讲讲这款高性价比新品的亮点。 想象一下走上电动伸缩梯,坐进动感舱,舱门缓缓合上&…

RuoYi_Cloud本地搭建

目录 1.先进入若依官网下载源码 2.在git链接在idea本地打开 3.建立数据库 (1)创建一个ruoyi_cloud数据库,设定好账号密码 (2)建表 4.配置nacos (1)nacos官网下载2.0.x以上的版本 &#…

Java常量、变量、成员内部类

文章目录 1.常量2.变量3.成员内部类4.变动 1.常量 实例常量:只用final修饰,是某个具体类的实例 静态常量:finalstatic修饰,属于类,所有实例共享同一个类常量 2.变量 实例变量(成员变量):定义在类内部但在…

上海App开发测试需要注意的内容

在上海app开发中,测试发挥着至关重要的作用。及时、专业的对app进行测试,能够快速发现app存在的漏洞与问题,从而及时进行修正,确保app的顺利上线与发布。那么,在上海app开发测试的过程中,需要注意哪些内容呢…

1.驱动程序框架

驱动是用来控制和操作硬件的软件。 在linux下,一切皆文件。当我们write一个文件时,内核通过文件的file_operations结构体(include/linux/fs.h)来找到对应的驱动函数,最终调用的是存储介质(ssd,硬盘等)驱动提供的write函数(这中间…

米联客FDMA驱动OV5640摄像头—基于野火Zynq7020开发板

使用米联客的ddr3缓存方案 FDMA驱动OV5640摄像头在RGB888屏幕上显示。 总体BLOCK DESIGN框架图 RTC框架图 FDMA设置 FDMA控制器设置 帧选择IP设置 IP核封装及代码在工程文件中 参考 FDMA3.1数据缓存方案全网最细讲解,自创升级版,提供3套视频和音频缓存…

python案例-自动识别图片数字并进行填充,小键盘数字键练习工具轻松达到最高评级!ddddocr+pyauotgui

🌈所属专栏:【python】✨作者主页: Mr.Zwq✔️个人简介:一个正在努力学技术的Python领域创作者,擅长爬虫,逆向,全栈方向,专注基础和实战分享,欢迎咨询!您的点赞、关注、收藏、评论,是对我最大的激励和支持!!!🤩🥰😍 目录 前言 测试工具界面 代码完成思…

VMware Workstation环境下DNS的安装配置,并使用ubuntu来测试

需求说明: 某企业信息中心计划使用IP地址17216.11.0用于虚拟网络测试,注册域名为xyz.net.cn.并将172.16.11.2作为主域名的服务器(DNS服务器)的IP地址,将172.16.11.3分配给虚拟网络测试的DHCP服务器,将172.16.11.4分配给虚拟网络测试的web服务器,将172.16.11.5分配给FTP服务器…

python水仙花数 青少年编程电子学会python编程等级考试三级真题解析2022年3月

python水仙花数 2022年3月 python编程等级考试级编程题 一、题目要求 1、编程实现 明明请你帮忙寻找100-999之间的所有"水仙花数”,并统计个数。"水仙花数"是指一个三位数各位数字的立方和等于该数本身,例如:1531*1*15*5*53*3*3。要求输出结果如下所示: 153…

工业路由器与家用路由器的区别

在现代网络环境中,路由器扮演着至关重要的角色。无论是在家庭网络还是在工业网络,选择合适的路由器都至关重要。本文将从多个角度,对工业路由器与家用路由器进行详细比较,帮助您更好地理解二者的区别。 1、安全性 工业路由器&…

Spring学习02-[Spring容器核心技术IOC学习]

Spring容器核心技术IOC学习 什么是bean?如何配置bean?Component方式bean配合配置类的方式import导入方式实现ImportSelector类的方式-批量注册bean实现ImportBeanDefinitionRegistrar的方式 实例化bean推断构造函数使用实例工厂方法实例化----Bean的方式 使用工厂Bean。实例化…

你的编程小助手:Kimi!!【送源码】

从OpenAI发布AI大模型到现在已经快2年时间,中间随着新模型的不断出现,也让大家认识到了AI的强大之处,现在AI已经渗透到我们生活,工作的方方面面。 这期间国产大模型也在努力发展,不断完善,甚至一些大模型在…

【unity笔记】五、UI面板TextMeshPro 添加中文字体

Unity 中 TextMeshPro不支持中文字体,下面为解决方法: 准备字体文件,从Windows系统文件的Fonts文件夹里拖一个.ttf文件(C盘 > Windows > Fonts ) 准备字库文件,新建一个文本文件,命名为“字库”&…

【计算机毕业设计】079基于微信小程序网上商城设计

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

自研网关架构设计

网关项目 1. 了解网关网关横向对比为什么自研网关 2. 架构设计技术栈技术要点异步化设计使用缓存缓冲合理使用串行化吞吐量为王合适的工作线程 架构图 1. 了解网关 概念 访问数据、业务逻辑或功能的 “前门”负责处理接受和处理调用过程中的所有任务 类型 RESTful APl 使用…