Training for Stable Diffusion

news2024/9/20 19:38:56

Training for Stable Diffusion

笔记来源:
1.Denoising Diffusion Probabilistic Models
2.最大似然估计(Maximum likelihood estimation)
3.Understanding Maximum Likelihood Estimation
4.How to Solve ‘CUDA out of memory’ in PyTorch

1.1 Introduction

训练过程也就是正向扩散过程(Forward Diffusion Process),即为训练集中每个epoch中的每张照片进行加噪,根据所有加噪照片计算一个概率分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)(续上一篇关于DDPM的博客),至于为什么要计算这个分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0),简要来说是此分布作为了反向扩散过程 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt) 的 ground truth 从而进行MSE,相当于对反向扩散过程进行了一个引导。

1.2 Loss Function

1.2.1 Maximum Likelihood Estimation (MLE)

概率(Probability)与似然(Likelihood)

概率在特定环境下某件事情发生的可能性,也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性,比如抛硬币,抛之前我们不知道最后是哪一面朝上,但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%,这个概率只有在抛硬币之前才是有意义的,抛完硬币后的结果便是确定的;
P ( x ∣ θ ) P(x|\theta) P(xθ) 在已知参数 θ \theta θ的情况下,得到结果 x x x的概率

概率描述的是在一定条件下某个结果发生的可能性,概率越大说明这件事情越可能会发生

似然刚好相反,是在确定的结果下去推测产生这个结果的可能环境(参数),还是抛硬币的例子,假设我们随机抛掷一枚硬币1,000次,结果500次人头朝上,500次数字朝上(实际情况一般不会这么理想,这里只是举个例子),我们很容易判断这是一枚标准的硬币,两面朝上的概率均为50%,这个过程就是我们根据结果来判断这个事情本身的性质(参数),也就是似然。
L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx) 在已知结果 x x x的情况下,得到参数 θ \theta θ的概率

似然描述的是结果已知的情况下,该结果在不同条件下发生的可能性,似然函数的值越大说明该结果在对应的条件下发生的可能性越大

概率和似然在数值上相等, P ( x ∣ θ ) P(x|\theta) P(xθ)= L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx),但意义不同,得知参数 θ \theta θ和结果 x x x的顺序不同。
L ( θ ∣ x ) \mathcal{L}(\theta|x) L(θx)是关于 θ \theta θ的函数, P ( x ∣ θ ) P(x|\theta) P(xθ)是关于 x x x的函数,两者从不同角度描述了同一件事情

似然函数(Likelihood Function)

The likelihood function helps us find the best parameters for our distribution.
L ( θ ∣ x 1 , x 2 , ⋯   , x n ) = f ( x 1 , x 2 , ⋯   , x n ∣ θ ) = ∏ i = 1 n f ( x i ∣ θ ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n)=f(x_1,x_2,\cdots,x_n|\theta)=\prod_{i=1}^{n}f(x_i|\theta) L(θx1,x2,,xn)=f(x1,x2,,xnθ)=i=1nf(xiθ)
where θ \theta θ is the parameter to maximize
x 1 , x 2 , ⋯   , x n x_1,x_2,\cdots,x_n x1,x2,,xn are observations for n n n random variables from a distribution
f f f is the joint density function of our distribution with the parameter θ \theta θ
For example, in the case of a normal distribution, we could have θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)
L ( θ ∣ x 1 , x 2 , ⋯   , x n ) \mathcal{L}(\theta|x_1,x_2,\cdots,x_n) L(θx1,x2,,xn) 不是概率密度函数,这意味着在特定区间上进行积分不会产生该区间上的“概率”。相反,它讨论的是具有特定参数值 θ \theta θ的分布适合我们的数据的可能性
the variance tells about how much the blue intensities in the image vary or deviate from the average blue intensity (0.8).

极大似然估计 (Maximum Likelihood Estimation)

最大似然估计(简称 MLE)是估计分布参数的过程,该过程最大化观测数据属于该分布的可能性。 简而言之,当我们执行 MLE 时,我们试图找到最适合我们数据的分布。分布参数的结果值称为最大似然估计。

1.2.2 Image and Probability Distribution

RGB图片各通道的值范围为:[0, 255]
我们将各通道的通过( R / 255 , G / 255 , B / 255 R/255,G/255,B/255 R/255,G/255,B/255)归一化到范围:[0, 1]
图片单个通道的概率分布(1D Gaussian)
图片两个通道的概率分布(2D Gaussian)

μ = [ μ x 1 , μ x 2 ] = [ μ b l u e , μ g r e e n ] \bf{\mu}=[\mu_{x_1},\mu_{x_2}]=[\mu_{blue},\mu_{green}] μ=[μx1,μx2]=[μblue,μgreen]

Σ = [ σ x 1 2 σ x 1 , x 2 σ x 2 , x 1 σ x 2 2 ] = [ σ b l u e 2 σ b l u e , g r e e n σ g r e e n , b l u e σ g r e e n 2 ] \Sigma=\begin{bmatrix} \sigma_{x_1}^2 & \sigma_{x_1,x_2}\\ \sigma_{x_2,x_1} & \sigma_{x_2}^2 \end{bmatrix}=\begin{bmatrix} \sigma_{blue}^2 & \sigma_{blue,green}\\ \sigma_{green,blue} & \sigma_{green}^2 \end{bmatrix} Σ=[σx12σx2,x1σx1,x2σx22]=[σblue2σgreen,blueσblue,greenσgreen2]

图片三个通道的概率分布(3D Gaussian)


μ = [ μ x , μ y , μ z ] = [ μ r e d , μ g r e e n , μ b l u e ] \bf{\mu}=[\mu_{x},\mu_{y},\mu_{z}]=[\mu_{red},\mu_{green},\mu_{blue}] μ=[μx,μy,μz]=[μred,μgreen,μblue]

Σ = [ σ x 2 σ x y σ x z σ y x σ y 2 σ y z σ z x σ z σ z 2 ] \Sigma=\begin{bmatrix} \sigma_{x}^2 & \sigma_{xy} & \sigma_{xz}\\ \sigma_{yx} & \sigma_{y}^2 & \sigma_{yz}\\ \sigma_{zx} & \sigma_{z} & \sigma_{z}^2\\ \end{bmatrix} Σ= σx2σyxσzxσxyσy2σzσxzσyzσz2
在Stable Diffusion训练过程中我们要给clear image加噪声,则我们需要在三维标准正态分布中进行随机采样,这样采样得到的tensor shape与图片tensor的shape一致
ϵ ∼ N ( 0 , I ) \epsilon \sim N(0,I) ϵN(0,I)

1.2.3 Maximize ELBO (Maximize Evidence Lower Bound)

我们想要收集大量样本数据,使得这些数据的分布尽可能的接近真实分布(已知的所有图片数据的分布)

通过最大化样本概率(极大化似然)使得样本数据的分布尽可能符合真实分布
i i i张样本图片的概率分布 p θ ( x i ) p_{\theta}(x^i) pθ(xi),将数据集中 m m m张照片的分布相乘得到联合概率分布,求该联合分布的极大似然,最终得到一个最优的参数 θ = ( μ , σ ) \theta=(\mu,\sigma) θ=(μ,σ)






目前Stable Diffusion的Unet有三种预测方案:
(1)Unet 直接预测 x 0 x_0 x0,但是效果不好
(2)Unet 预测要去掉的噪声分布(本次训练使用这种方案)

(3)Unet 预测分数

1.3 Training (from DDPM thesis)



batch size, iteration, and epoch

一个数据集由一个epoch组成,一个数据集训练n遍(n个epoch),也就是说一个周期(epoch)包含了数据集的所有数据
一个epoch由多个batch组成,一个batch由多张image组成


完整训练代码

import os.path
import torch
import torch.nn as nn
import torch.optim as optim
from ddpm import DDPMSampler
from diffusion import UNET, Diffusion
import logging
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pipeline import get_time_embedding
from create_dataset import train_loader
import logging

'''
Algorithm Training
1:repeat
2: x_0 ~ q(x_0) 
# sample a batch from a epoch
# for epoch for batch for every image tensor
train_loader
3: t ~ Uniform({1...T})
# sample randomly a t for every image tensor
# t: num_inference_step
# T: num_training_step
t = diffusion.sample_timesteps(images.shape[0]).to(device)
4: epsilon ~ N(0,I) 
# 3d standard normal distribution
# noise tensor shape that sample from this distribution,which is same as image tensor shape
noisy_image_tensor = add_noise(t)
5: Take gradient descent step on 
# nabla_{theta} L2(|| epsilon - epsilon_{theta}(noisy image tensor,t,y)||)
6: until converged
'''

'''
1.Data Preprocessing
(1) Loading and Transforming Data: Data is loaded from the dataset and transformed to a suitable format for training. 
Common transformations include resizing, normalization, and converting to tensors.
(2) Creating Data Loaders: Data loaders are created to efficiently load the data in batches, shuffle the training data, 
and manage parallel processing.
2.Model Initialization
(1) Define the UNet Model: The UNet architecture is defined, which typically consists of an encoder-decoder structure 
with skip connections. The encoder captures context while the decoder enables precise localization.
(2) Move Model to Device: The model is moved to the appropriate device (CPU or GPU) to leverage hardware acceleration.
3.Loss Function and Optimizer
(1) Loss Function: The loss function measures the difference between the predicted output and the true output. 
(2) Optimizer: The optimizer updates the model parameters to minimize the loss. Common optimizers include Adam,SGD,etc.
4.Training Loop
(1) Set Model to Training Mode: The model is set to training mode using model.train().
(2) Iterate Over Data: For each epoch, iterate over batches of data.
    Forward Pass: Pass input data through the model to get predictions.
        A random time step t will be selected for each training sample (image)
        Apply the Gaussian noise (corresponding to t) to each image
        Convert the time steps to embeddings (vector)
    Compute Loss: Calculate the loss using the predictions and ground truth.
    Backward Pass: Perform backpropagation to compute gradients.
    Update Parameters: Use the optimizer to update model parameters based on the gradients.
(3) Monitor Training: Track and print training loss to monitor progress.
5.Validation
After each epoch, validate the model using a separate validation set to ensure the model is not overfitting and 
to monitor its generalization performance.
6.Checkpoint Saving
Save Model Checkpoint: Save the model's state, optimizer state, and any relevant training information after each epoch 
to allow for resuming training if needed.
'''

# A PyTorch random number generator.
generator = torch.Generator(device='cuda')
# Sets the seed for generating random numbers. Returns a torch. Generator object.
generator.manual_seed(42)
# Initialize the DDPMSampler with the random generator
ddpm_sampler = DDPMSampler(generator)


diffusion = Diffusion()


def timesteps_to_time_emb(timesteps):
    time_embeddings = []
    for i, timestep in enumerate(timesteps):
        # (1,320)
        time_emb_320 = get_time_embedding(timestep).to('cuda')
        embedding = diffusion.time_embedding.to('cuda')
        time_embedding = embedding(time_emb_320).squeeze(0)  # Ensure shape is (1280)
        # (1,1280)
        time_embeddings.append(time_embedding)
    return torch.stack(time_embeddings)  # Final shape should be (batch_size, 1280)


print('Start training now !')


def train(args):
    device = args.device  # Get the device to run the training on
    model = UNET().to(device)   # Initialize the model and move it to the device
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)  # set up the optimizer with AdamW
    mse = nn.MSELoss()  # Mean Squared Error loss function
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    len_train = len(train_loader)
    print('Start into the loop !')
    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")  # log the start of the epoch
        progress_bar = tqdm(train_loader)  # progress bar for the dataloader
        optimizer.zero_grad()  # Explicitly zero the gradient buffers
        accumulation_steps = 4
        # Load all data into a batch
        for batch_idx, (images, captions) in enumerate(progress_bar):
            images = images.to(device)  # move images to the device
            # The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE
            # and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloader
            images = torch.squeeze(images, dim=1)
            captions = captions.to(device)  # move caption to the device
            text_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_size
            timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timesteps
            noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the images
            time_embeddings = timesteps_to_time_emb(timesteps)
            # x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)
            # caption (batch_size, seq_len, dim) (bs, 77, 768)
            # t (batch_size, channel) (batch_size, 1280)
            # (bs,320,H/8,W/8)
            with torch.no_grad():
                last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)
            # (bs,4,H/8,W/8)
            final_output = diffusion.final.to(device)
            predicted_noise = final_output(last_decoder_noise).to(device)
            loss = mse(noises, predicted_noise)  # Compute the loss
            loss.backward()  # Backpropagate the loss
            if (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passes
                optimizer.step()  # Now we can do an optimizer step
                optimizer.zero_grad()  # Reset gradients to zero
            progress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss
            # log the loss to TensorBoard
            logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)
        # Save the model checkpoint
        os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))
        torch.save(optimizer.state_dict(),
                   os.path.join("models", args.run_name, f"optim.pt"))  # Save the optimizer state


def launch():
    import argparse  # Import the argparse module for command-line argument parsing
    parser = argparse.ArgumentParser()  # Create an argument parser
    args = parser.parse_args()  # Parse the command-line arguments

    # Set the default values for the arguments
    args.run_name = " Condition_Unet"  # Name for the run, used for logging and saving models
    args.epochs = 40      # Number of epochs to train the model
    args.batch_size = 10  # Batch size for the dataloader
    args.image_size = 256  # Size of the images
    args.device = "cuda"  # Device to run the training on ('cuda' for GPU or 'cpu')
    args.lr = 3e-4  # Learning rate for the optimizer

    train(args)  # Call the train function with the parsed arguments


if __name__ == '__main__':
    launch()  # Call the launch function if this script is run as the main program

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1937068.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FastBee物联网开源项目本地启动调试

一、本地环境准备 (1)Visual Studio Code(启动前端项目) (2)IntelliJ IDEA Community Edition (启动后端项目) (3)Navicat或者DBeaver(用来操…

Godot学习笔记2——GDScript变量与函数

目录 一、代码编写界面 二、变量 三、函数 四、变量的类型 Godot使用的编程语言是GDS,语法上与python有些类似。 一、代码编写界面 在新建的Godot项目中,点击“创建根节点”中的“其他节点”,选择“Node”。 点击场景界面右上角的绿色…

【STM32】按键控制LED光敏传感器控制蜂鸣器(江科大)

一、按键控制LED LED.c #include "stm32f10x.h" // Device header/*** 函 数:LED初始化* 参 数:无* 返 回 值:无*/ void LED_Init(void) {/*开启时钟*/RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA, ENAB…

此扩展在此工作区中被禁用,因为其被定义为在远程扩展主机中运行。

使用VScode打开代码时,无法跳转函数,不提示报错。 安装python时显示, 此扩展在此工作区中被禁用,因为其被定义为在远程扩展主机中运行。 解决方法: CtrlShiftP :键入trust ,工作区&#xff…

空间计算新时代:Vision Pro引领AR/VR/MR市场变革

随着2024年第二季度的结束,空间计算领域的市场动态愈发引人关注。根据国际数据公司(IDC)的最新报告,我们见证了行业格局的重大变化,尤其是苹果Vision Pro的突出表现,以及AR/VR/MR设备市场的整体趋势。以下是…

LabVIEW软件开发的雷区在哪里?

在LabVIEW软件开发中,有几个需要注意的雷区,以避免常见的错误和提高开发效率: 1. 不良的代码结构 雷区:混乱的代码结构和不清晰的程序逻辑。 后果:导致难以维护和调试的代码,增加了错误和故障的风险。 …

无人机侦察:二维机扫雷达探测设备技术详解

二维机扫雷达探测设备采用机械扫描方式,通过天线在水平方向和垂直方向上的转动,实现对目标空域的全方位扫描。雷达发射机发射电磁波信号,遇到目标后产生反射,反射信号被雷达接收机接收并处理,进而得到目标的位置、速度…

搜维尔科技:【研究】动作捕捉加速游戏开发行业的发展

动作捕捉加速游戏开发行业的发展 Sunjata 的故事始于 2004 年,它将席卷乌干达视频游戏行业,然后席卷全世界。但首先,Klan Of The Kings 的小团队需要工具来实现他们的愿景。 漫画家兼非洲民间传说爱好者罗纳德卡伊马 (Ronald Kayima) 在将…

怎样在 PostgreSQL 中进行用户权限的精细管理?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会!📚领书:PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中进行用户权限的精细管理?一、权限管理的重要性二、PostgreSQL 中的权…

[解决方法]Request failed with status code 500错误之一

在写项目时访问后端api时我的axios拦截器进入了错误 然后去浏览器搜索,但是大部分都是因为axios参数或参数格式问题导致的,然而在访问api的编写没有任何问题,后来我反复检查,发现是我写前后端写混了,我把express的 Co…

学习大数据DAY20 Linux环境配置与Linux基本指令

目录 Linux 介绍 Linux 发行版 Linux 和 Windows 比较 Linux 就业方向: 下载 CentOS Linux 目录树 Linux 目录结构 作业 1 常用命令分类 文件目录类 作业 2 vim 编辑文件 作业 3 你问我第 19 天去哪了?第 19 天在汇报第一阶段的知识总结,没什…

深入浅出WebRTC—GCC

GoogCcNetworkController 是 GCC 的控制中心,它由 RtpTransportControllerSend 通过定时器和 TransportFeedback 来驱动。GoogCcNetworkController 不断更新内部各个组件的状态,并协调组件之间相互配合,向外输出目标码率等重要参数&#xff0…

汽车及零部件研发项目管理系统:一汽东机工选择奥博思 PowerProject 提升研发项目管理效率

在汽车行业中,汽车零部件的研发和生产是一个关键的环节。随着汽车市场的不断扩大和消费者需求的不断增加,汽车零部件项目管理的重要性日益凸显。通过有效的项目管理方法及利用先进的数字项目管理系统,可以大幅提高项目的成功率和顺利度&#…

WebRTC QOS方法十三.1(TimestampExtrapolator接收时间预估)

一、背景介绍 虽然我们可通过时间戳的差值和采样率计算出发送端视频帧的发送节奏,但是由于网络延迟、抖动、丢包,仅知道视频发送端的发送节奏是明显不够的。我们还需要评估出视频接收端的视频帧的接收节奏,然后进行适当平滑,保证…

关于 Qt输入法在arm特定的某些weston下出现调用崩溃 的解决方法

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/140423667 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

C#知识|账号管理系统-修改账号按钮功能的实现

哈喽,你好啊,我是雷工! 前边学习了通过选择条件查询账号的功能: 《提交查询按钮事件的编写》 本节继续学习练习C#,今天练习修改账号的功能实现。 以下为学习笔记。 01 实现功能 ①:从查询到的账号中,选择某一账号,然后点击【修改账号】按钮,将选中的信息获取显示到…

攻防世界 re新手模式

Reversing-x64Elf-100 64位ida打开 看if语句,根据i的不同,选择不同的数组,后面的2*i/3选择数组中的某一个元素,我们输入的是a1 直接逆向得到就行 二维字符数组写法:前一个是代表有几个字符串,后一个是每…

《蔚蓝档案》模拟器联动皮肤H5+KOC

《蔚蓝档案》模拟器联动皮肤H5KOC 《蔚蓝档案》自上线以来老师们与MuMu模拟器的共同历程,重温难忘瞬间,回忆游戏历程。蔚蓝档案一周年模拟器联动主题皮肤福利,于7月18日-8月16日,在MuMu模拟器搜索【蔚蓝档案联动】进入活动页面&a…

离散数学,半群性质的证明,群,群的性质,子群

目录 1.半群性质的证明 半群的性质 定理5-3.2证明 定理5-3.3证明 半群的性质 定理5-3.4证明 例子 2.群 群是每个元素都可逆的独异点 例子 有限群,阶数,无限群,平凡群 3.群的性质 群中不可能有零元 群中任一元素逆元…

Paypal个人支付申请及沙箱测试配置

目录 一. 申请paypal账号二. Sanbox 测试配置申请买家Account申请卖家AccountSandbox的Client ID及密钥申请Live的Client ID及密钥申请IPN回调设置 一. 申请paypal账号 浏览器输入https://www.paypal.com, 单击注册按钮 2. 我这里申请个人账户,如果你需要企业账户&…