Training for Stable Diffusion
笔记来源:
1.Denoising Diffusion Probabilistic Models
2.最大似然估计(Maximum likelihood estimation)
3.Understanding Maximum Likelihood Estimation
4.How to Solve ‘CUDA out of memory’ in PyTorch
1.1 Introduction
训练过程也就是正向扩散过程(Forward Diffusion Process),即为训练集中每个epoch中的每张照片进行加噪,根据所有加噪照片计算一个概率分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)(续上一篇关于DDPM的博客),至于为什么要计算这个分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0),简要来说是此分布作为了反向扩散过程 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt) 的 ground truth 从而进行MSE,相当于对反向扩散过程进行了一个引导。
1.2 Loss Function
1.2.1 Maximum Likelihood Estimation (MLE)
概率(Probability)与似然(Likelihood)
概率是在特定环境下某件事情发生的可能性,也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性,比如抛硬币,抛之前我们不知道最后是哪一面朝上,但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%,这个概率只有在抛硬币之前才是有意义的,抛完硬币后的结果便是确定的;
P ( x ∣ θ ) P(x|\theta) P(x∣θ) 在已知参数 θ \theta θ的情况下,得到结果 x x x的概率
概率描述的是在一定条件下某个结果发生的可能性,概率越大说明这件事情越可能会发生
似然刚好相反,是在确定的结果下去推测产生这个结果的可能环境(参数),还是抛硬币的例子,假设我们随机抛掷一枚硬币1,000次,结果500次人头朝上,500次数字朝上(实际情况一般不会这么理想,这里只是举个例子),我们很容易判断这是一枚标准的硬币,两面朝上的概率均为50%,这个过程就是我们根据结果来判断这个事情本身的性质(参数),也就是似然。
L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θ∣x) 在已知结果 x x x的情况下,得到参数 θ \theta θ的概率
似然描述的是结果已知的情况下,该结果在不同条件下发生的可能性,似然函数的值越大说明该结果在对应的条件下发生的可能性越大
概率和似然在数值上相等, P ( x ∣ θ ) P(x|\theta) P(x∣θ)= L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θ∣x),但意义不同,得知参数 θ \theta θ和结果 x x x的顺序不同。
L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θ∣x)是关于 θ \theta θ的函数, P ( x ∣ θ ) P(x|\theta) P(x∣θ)是关于 x x x的函数,两者从不同角度描述了同一件事情
似然函数(Likelihood Function)
The likelihood function helps us find the best parameters for our distribution.
L ( θ ∣ x 1 , x 2 , ⋯ , x n ) = f ( x 1 , x 2 , ⋯ , x n ∣ θ ) = ∏ i = 1 n f ( x i ∣ θ ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n)=f(x_1,x_2,\cdots,x_n|\theta)=\prod_{i=1}^{n}f(x_i|\theta) L(θ∣x1,x2,⋯,xn)=f(x1,x2,⋯,xn∣θ)=i=1∏nf(xi∣θ)
where θ \theta θ is the parameter to maximize
x 1 , x 2 , ⋯ , x n x_1,x_2,\cdots,x_n x1,x2,⋯,xn are observations for n n n random variables from a distribution
f f f is the joint density function of our distribution with the parameter θ \theta θ
For example, in the case of a normal distribution, we could have θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)
L ( θ ∣ x 1 , x 2 , ⋯ , x n ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n) L(θ∣x1,x2,⋯,xn) 不是概率密度函数,这意味着在特定区间上进行积分不会产生该区间上的“概率”。相反,它讨论的是具有特定参数值 θ \theta θ的分布适合我们的数据的可能性。
the variance tells about how much the blue intensities in the image vary or deviate from the average blue intensity (0.8).
极大似然估计 (Maximum Likelihood Estimation)
最大似然估计(简称 MLE)是估计分布参数的过程,该过程最大化观测数据属于该分布的可能性。 简而言之,当我们执行 MLE 时,我们试图找到最适合我们数据的分布。分布参数的结果值称为最大似然估计。
1.2.2 Image and Probability Distribution
RGB图片各通道的值范围为:[0, 255]
我们将各通道的通过(
R
/
255
,
G
/
255
,
B
/
255
R/255,G/255,B/255
R/255,G/255,B/255)归一化到范围:[0, 1]
图片单个通道的概率分布(1D Gaussian)
图片两个通道的概率分布(2D Gaussian)
μ
=
[
μ
x
1
,
μ
x
2
]
=
[
μ
b
l
u
e
,
μ
g
r
e
e
n
]
\bf{\mu}=[\mu_{x_1},\mu_{x_2}]=[\mu_{blue},\mu_{green}]
μ=[μx1,μx2]=[μblue,μgreen]
Σ = [ σ x 1 2 σ x 1 , x 2 σ x 2 , x 1 σ x 2 2 ] = [ σ b l u e 2 σ b l u e , g r e e n σ g r e e n , b l u e σ g r e e n 2 ] \Sigma=\begin{bmatrix} \sigma_{x_1}^2 & \sigma_{x_1,x_2}\\ \sigma_{x_2,x_1} & \sigma_{x_2}^2 \end{bmatrix}=\begin{bmatrix} \sigma_{blue}^2 & \sigma_{blue,green}\\ \sigma_{green,blue} & \sigma_{green}^2 \end{bmatrix} Σ=[σx12σx2,x1σx1,x2σx22]=[σblue2σgreen,blueσblue,greenσgreen2]
图片三个通道的概率分布(3D Gaussian)
μ
=
[
μ
x
,
μ
y
,
μ
z
]
=
[
μ
r
e
d
,
μ
g
r
e
e
n
,
μ
b
l
u
e
]
\bf{\mu}=[\mu_{x},\mu_{y},\mu_{z}]=[\mu_{red},\mu_{green},\mu_{blue}]
μ=[μx,μy,μz]=[μred,μgreen,μblue]
Σ
=
[
σ
x
2
σ
x
y
σ
x
z
σ
y
x
σ
y
2
σ
y
z
σ
z
x
σ
z
σ
z
2
]
\Sigma=\begin{bmatrix} \sigma_{x}^2 & \sigma_{xy} & \sigma_{xz}\\ \sigma_{yx} & \sigma_{y}^2 & \sigma_{yz}\\ \sigma_{zx} & \sigma_{z} & \sigma_{z}^2\\ \end{bmatrix}
Σ=
σx2σyxσzxσxyσy2σzσxzσyzσz2
在Stable Diffusion训练过程中我们要给clear image加噪声,则我们需要在三维标准正态分布中进行随机采样,这样采样得到的tensor shape与图片tensor的shape一致
ϵ
∼
N
(
0
,
I
)
\epsilon \sim N(0,I)
ϵ∼N(0,I)
1.2.3 Maximize ELBO (Maximize Evidence Lower Bound)
我们想要收集大量样本数据,使得这些数据的分布尽可能的接近真实分布(已知的所有图片数据的分布)
通过最大化样本概率(极大化似然)使得样本数据的分布尽可能符合真实分布
第
i
i
i张样本图片的概率分布
p
θ
(
x
i
)
p_{\theta}(x^i)
pθ(xi),将数据集中
m
m
m张照片的分布相乘得到联合概率分布,求该联合分布的极大似然,最终得到一个最优的参数
θ
=
(
μ
,
σ
)
\theta=(\mu,\sigma)
θ=(μ,σ)
目前Stable Diffusion的Unet有三种预测方案:
(1)Unet 直接预测
x
0
x_0
x0,但是效果不好
(2)Unet 预测要去掉的噪声分布(本次训练使用这种方案)
(3)Unet 预测分数
1.3 Training (from DDPM thesis)
batch size, iteration, and epoch
一个数据集由一个epoch组成,一个数据集训练n遍(n个epoch),也就是说一个周期(epoch)包含了数据集的所有数据
一个epoch由多个batch组成,一个batch由多张image组成
完整训练代码
import os.path
import torch
import torch.nn as nn
import torch.optim as optim
from ddpm import DDPMSampler
from diffusion import UNET, Diffusion
import logging
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pipeline import get_time_embedding
from create_dataset import train_loader
import logging
'''
Algorithm Training
1:repeat
2: x_0 ~ q(x_0)
# sample a batch from a epoch
# for epoch for batch for every image tensor
train_loader
3: t ~ Uniform({1...T})
# sample randomly a t for every image tensor
# t: num_inference_step
# T: num_training_step
t = diffusion.sample_timesteps(images.shape[0]).to(device)
4: epsilon ~ N(0,I)
# 3d standard normal distribution
# noise tensor shape that sample from this distribution,which is same as image tensor shape
noisy_image_tensor = add_noise(t)
5: Take gradient descent step on
# nabla_{theta} L2(|| epsilon - epsilon_{theta}(noisy image tensor,t,y)||)
6: until converged
'''
'''
1.Data Preprocessing
(1) Loading and Transforming Data: Data is loaded from the dataset and transformed to a suitable format for training.
Common transformations include resizing, normalization, and converting to tensors.
(2) Creating Data Loaders: Data loaders are created to efficiently load the data in batches, shuffle the training data,
and manage parallel processing.
2.Model Initialization
(1) Define the UNet Model: The UNet architecture is defined, which typically consists of an encoder-decoder structure
with skip connections. The encoder captures context while the decoder enables precise localization.
(2) Move Model to Device: The model is moved to the appropriate device (CPU or GPU) to leverage hardware acceleration.
3.Loss Function and Optimizer
(1) Loss Function: The loss function measures the difference between the predicted output and the true output.
(2) Optimizer: The optimizer updates the model parameters to minimize the loss. Common optimizers include Adam,SGD,etc.
4.Training Loop
(1) Set Model to Training Mode: The model is set to training mode using model.train().
(2) Iterate Over Data: For each epoch, iterate over batches of data.
Forward Pass: Pass input data through the model to get predictions.
A random time step t will be selected for each training sample (image)
Apply the Gaussian noise (corresponding to t) to each image
Convert the time steps to embeddings (vector)
Compute Loss: Calculate the loss using the predictions and ground truth.
Backward Pass: Perform backpropagation to compute gradients.
Update Parameters: Use the optimizer to update model parameters based on the gradients.
(3) Monitor Training: Track and print training loss to monitor progress.
5.Validation
After each epoch, validate the model using a separate validation set to ensure the model is not overfitting and
to monitor its generalization performance.
6.Checkpoint Saving
Save Model Checkpoint: Save the model's state, optimizer state, and any relevant training information after each epoch
to allow for resuming training if needed.
'''
# A PyTorch random number generator.
generator = torch.Generator(device='cuda')
# Sets the seed for generating random numbers. Returns a torch. Generator object.
generator.manual_seed(42)
# Initialize the DDPMSampler with the random generator
ddpm_sampler = DDPMSampler(generator)
diffusion = Diffusion()
def timesteps_to_time_emb(timesteps):
time_embeddings = []
for i, timestep in enumerate(timesteps):
# (1,320)
time_emb_320 = get_time_embedding(timestep).to('cuda')
embedding = diffusion.time_embedding.to('cuda')
time_embedding = embedding(time_emb_320).squeeze(0) # Ensure shape is (1280)
# (1,1280)
time_embeddings.append(time_embedding)
return torch.stack(time_embeddings) # Final shape should be (batch_size, 1280)
print('Start training now !')
def train(args):
device = args.device # Get the device to run the training on
model = UNET().to(device) # Initialize the model and move it to the device
model.train()
optimizer = optim.AdamW(model.parameters(), lr=args.lr) # set up the optimizer with AdamW
mse = nn.MSELoss() # Mean Squared Error loss function
logger = SummaryWriter(os.path.join("runs", args.run_name))
len_train = len(train_loader)
print('Start into the loop !')
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:") # log the start of the epoch
progress_bar = tqdm(train_loader) # progress bar for the dataloader
optimizer.zero_grad() # Explicitly zero the gradient buffers
accumulation_steps = 4
# Load all data into a batch
for batch_idx, (images, captions) in enumerate(progress_bar):
images = images.to(device) # move images to the device
# The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE
# and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloader
images = torch.squeeze(images, dim=1)
captions = captions.to(device) # move caption to the device
text_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_size
timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device) # Sample random timesteps
noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps) # Add noise to the images
time_embeddings = timesteps_to_time_emb(timesteps)
# x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)
# caption (batch_size, seq_len, dim) (bs, 77, 768)
# t (batch_size, channel) (batch_size, 1280)
# (bs,320,H/8,W/8)
with torch.no_grad():
last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)
# (bs,4,H/8,W/8)
final_output = diffusion.final.to(device)
predicted_noise = final_output(last_decoder_noise).to(device)
loss = mse(noises, predicted_noise) # Compute the loss
loss.backward() # Backpropagate the loss
if (batch_idx + 1) % accumulation_steps == 0: # Wait for several backward passes
optimizer.step() # Now we can do an optimizer step
optimizer.zero_grad() # Reset gradients to zero
progress_bar.set_postfix(MSE=loss.item()) # Update the progress bar with the loss
# log the loss to TensorBoard
logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)
# Save the model checkpoint
os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))
torch.save(optimizer.state_dict(),
os.path.join("models", args.run_name, f"optim.pt")) # Save the optimizer state
def launch():
import argparse # Import the argparse module for command-line argument parsing
parser = argparse.ArgumentParser() # Create an argument parser
args = parser.parse_args() # Parse the command-line arguments
# Set the default values for the arguments
args.run_name = " Condition_Unet" # Name for the run, used for logging and saving models
args.epochs = 40 # Number of epochs to train the model
args.batch_size = 10 # Batch size for the dataloader
args.image_size = 256 # Size of the images
args.device = "cuda" # Device to run the training on ('cuda' for GPU or 'cpu')
args.lr = 3e-4 # Learning rate for the optimizer
train(args) # Call the train function with the parsed arguments
if __name__ == '__main__':
launch() # Call the launch function if this script is run as the main program