CMU 10423 Generative AI:lec7、8、9(专题2:一张图理解diffusion model结构、代码实现和效果)

news2024/9/27 21:27:38

本文介绍diffusion model是什么(包括:模型详细的架构图、各模块原理和输入输出、训练算法解读、推理算法解读)、以及全套demo代码和效果。至于为什么要这么设计、以及公式背后的数学原理,过程推导很长很长,可见参考资料。

文章目录

  • 1 Diffusion Model概述
  • 2 一张图看懂diffusion model结构
  • 3 Diffusion Model训练和推理算法解释
    • 3.1 Diffusion Model的训练算法解释
    • 3.2 Diffusion Model的推理算法解释
  • 4 diffusion model 代码实现
    • 数据集
    • **噪声调度与预计算**
        • 2\. 重新参数化公式
        • 3\. 推理阶段公式
        • 4\. 代码实现与预计算
        • **4. 解释代码中的每个步骤**
    • 前向扩散过程(添加噪声)与反向采样过程(去除噪声)
        • 1\. 前向扩散过程:`forward_diffusion_viz`
        • 2\. 反向采样过程:`make_inference`
      • 前向扩散过程可视化效果
    • 开始训练
    • 效果展示

1 Diffusion Model概述

Diffusion Model严格意义上最早源于2015年的《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》,但如下这篇论文才真正将Diffusion Model效果发扬光大,有点类似2013年的alexnet网络和1998年的lenet-5网络感觉。

全称:《Denoising Diffusion Probabilistic Models》
时间:2020年
作者人数:3人,加州伯克利大学
论文地址:https://proceedings.neurips.cc/paper_files/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf
优缺点优点:生成图像的效果非常惊艳,超越VAE、生成式对抗网络等方法,SOTA级别。
缺点:生成图像的速度非常缓慢。

一段话总结DDPM算法:

  1. 目的:训练一个生成模型,以生成符合训练集分布的随机图像。
  2. 核心思想:灵感来自非平衡热力学,扩散模型是通过添加高斯噪声来破坏训练数据(称为正向扩散过程),然后学习如何通过逐步逆转此噪声过程来恢复原始信息(称为反向扩散过程)。
  3. 训练过程:输入一张参入噪声的图像,使用U-Net架构模型,使其学会预测在特定时步 $ t $​ 时对输入图像中添加的噪声。
  4. 生成过程:从一个高斯分布的纯噪声图像开始,输入模型进行噪声预测,然后逐步去除噪声,重复这一过程 T T T 次,以获得一张清晰的全新图像。

在这里插入图片描述

至于为什么这种方式能生效,其背后的数学原理 可见参考资料。

参考资料

  • 李宏毅的《机器学习》——Diffusion Model 原理剖析 (大概1.5小时)

    • 课程网址:https://speech.ee.ntu.edu.tw/~hylee/ml/2023-spring.php
    • B站转载视频:https://www.bilibili.com/video/BV14c411J7f2
  • 《Diffusion Model原理详解及源码解析》

    • https://juejin.cn/post/7210175991837507621
  • 《扩散模型是如何工作的:从零开始的数学原理》

    • https://shao.fun/blog/w/how-diffusion-models-work.html
  • 《What are Diffusion Models?》OpenAI 安全系统团队负责人翁丽莲(Lilian Weng 出品,必是精品)
    a. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

  • 《an-introduction-to-diffusion-models-and-stable-diffusion》

    • https://blog.marvik.ai/2023/11/28/an-introduction-to-diffusion-models-and-stable-diffusion/

2 一张图看懂diffusion model结构

在这里插入图片描述

3 Diffusion Model训练和推理算法解释

3.1 Diffusion Model的训练算法解释

在这里插入图片描述

  1. 第1步:repeat (重复)
  • 这是一个循环,意味着我们会不断重复以下步骤,直到模型收敛。
  1. 第2步:从数据分布​ q ( x 0 ) q(x_0) q(x0)​中采样 x 0 x_0 x0
  • 这里的 x 0 x_0 x0​ 是一张干净的原始图像,表示我们从真实的图像数据分布中采样(即从图像数据集中随机选择一张图片)。
  • q ( x 0 ) q(x_0) q(x0)​ 表示的是真实数据的分布(训练时它也就是个数据集),目标是让模型最终生成的图片和这个分布一致。
  1. 第3步:从均匀分布中采样 ​ t ∼ Uniform ( { 1 , … , T } ) t \sim \text{Uniform}(\{1,\ldots,T\}) tUniform({1,,T})
  • 这里的 t t t​ 是从1到 T T T​ 的均匀分布中采样的一个时间步。这个时间步控制扩散过程的阶段,表示噪声的加噪程度。
  • T T T​ 是扩散过程的总时间步数,表示噪声逐步增加的步骤。
  • 随机采样时间步,这使得模型能够在任何时间步学习逆转扩散过程,从而增强其适应性。
  1. 第4步:从正态分布 ​ ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)​ 中采样噪声 ϵ \epsilon ϵ
  • 采样一个噪声向量 ϵ \epsilon ϵ​,它是从标准正态分布中生成的。这个噪声表示在每个时间步加入的噪声大小。
  1. 第5步:梯度下降更新
  • 这里我们需要根据误差 ϵ − ϵ θ ( ⋅ ) \epsilon - \epsilon_\theta(\cdot) ϵϵθ()​ 进行梯度下降。
  • 其中 ϵ θ ( ⋅ ) \epsilon_\theta(\cdot) ϵθ()​ 是一个神经网络,负责预测在当前时间步 t t t​ 下的噪声。训练的目标是让模型的预测噪声 ϵ θ \epsilon_\theta ϵθ​ 尽可能接近真实的加噪噪声 ϵ \epsilon ϵ​。
  • 更新的目标是最小化预测噪声与真实噪声之间的差异,公式为:

∇ θ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 \nabla_\theta \| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon, t) \|^2 θϵϵθ(αˉt x0+1αˉt ϵ,t)2

其中:

  • α ˉ t \bar{\alpha}_t αˉt​ :是时间步 t t t​ 对应的参数,用来控制在第 t t t​ 步时,图像的真实信息和噪声之间的权重。 α ˉ t \bar{\alpha}_t αˉt​ 是随着时间步 t t t​ 变化的一个量。具体地,噪声调度定义了每个时间步 t t t​ 如何给图像加噪声,常见的噪声调度方法是根据线性或指数衰减来设定一系列 α t \alpha_t αt​,然后通过这些 α t \alpha_t αt​ 累积计算出 α ˉ t \bar{\alpha}_t αˉt​。一般来说,在训练开始时, α ˉ t \bar{\alpha}_t αˉt​ 的值比较大,而随着时间 t t t​ 的增加, α ˉ t \bar{\alpha}_t αˉt​ 的值逐渐减小,这意味着图像中的真实信息减少,而噪声的比例逐渐增加。换句话说, α ˉ t \bar{\alpha}_t αˉt​ 越小,图像中的噪声就越多。到时间步 T T T​ 时,图像几乎完全变成了随机噪声,原始图像信息几乎无法辨认。
  • α ˉ t x 0 \sqrt{\bar{\alpha}_t}x_0 αˉt x0​ :是根据 x 0 x_0 x0​(即干净图像)经过时间步 t t t​ 后加噪的图像。
  • 1 − α ˉ t ϵ \sqrt{1 - \bar{\alpha}_t}\epsilon 1αˉt ϵ​ :是对应噪声在 t t t​ 时间步的成分。
  • ϵ θ \epsilon_\theta ϵθ​ :即模型,需要学会预测噪声。
  1. 第6步:直到收敛(until converged)
  • 当模型收敛时,梯度下降的目标值逐渐变小,意味着模型学会了如何准确地预测噪声,从而可以有效去噪图像。
  1. 备注:当T取不同值时,输入模型的图像含噪幅度示意图如下:
    a. 在这里插入图片描述

训练总结

这个训练过程的核心思想是:通过神经网络 $ \epsilon_\theta $​ 预测扩散过程中给定时间步下的噪声。模型的目标是最小化预测噪声与真实噪声之间的差异,采用均方误差(MSE)损失。

3.2 Diffusion Model的推理算法解释

在这里插入图片描述

  1. x T x_T xT​(左上角的图像)
  • 这个图像最开始是从正态分布 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)​ 中采样的纯噪声图像。推理的任务是逐步从这个纯噪声图像中一步步的去噪(总共经过T步),从而生成一个清晰的图像。
  1. 噪声预测器(Noise Predictor)
  • 即之前训练好的U-net模型。在每个时间步 t t t​,噪声预测器(即神经网络 ϵ θ \epsilon_\theta ϵθ​)基于当前噪声图像 x t x_t xt​ 和时间步 t t t​ 来预测出噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t)​。这是恢复图像过程中的关键一步。
  1. 计算 x t − 1 x_{t-1} xt1
  • 根据 t 步的图像 x t x_{t} xt,使用公式计算出时间步 t − 1 t-1 t1 的图像 x t − 1 x_{t-1} xt1。这一过程包括两个部分:
    • 减去噪声成分 x t x_t xt 中的噪声成分由 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 给出,通过减去该噪声可以得到较少噪声的图像。
    • 添加额外噪声 ​ z z z​:加入正态分布噪声 z z z​。它保证了生成过程中的随机性,使得生成的图像具有多样性;同时,它通过逐步减小噪声的尺度,使生成过程平滑过渡到清晰的图像。如果不加这一噪声项,生成过程会变得确定性,导致生成的图像多样性不足,质量下降,甚至可能出现模式坍塌问题。李宏毅课程中提到,如果不加这个额外噪声,模型几乎完全无法生成有意义的图像,如下图:
    • 在这里插入图片描述
  1. 重复T次,最终生成图像
  • 不断重复这个过程,逐步从时间步 T T T 恢复到时间步 1,噪声逐渐减少,图像逐步清晰。
  • 整个过程表示如下:
  • 在这里插入图片描述
  1. 备注:如果超过T次采样,继续去噪下去,图像将逐渐变得模糊、失去原有结构,变成无意义的噪声图。

4 diffusion model 代码实现

代码出处(某人完成的优达学城AIGC课程练习):

  • https://github.com/amanpreetsingh459/Generative-AI/tree/main/3.%20Computer%20Vision%20and%20Generative%20AI/3.%20Diffusion%20Models
  • 备注:pytorch官网数据集链接已经失效了,需要在kaggle或他人处下载,加载数据集的代码我已经修正了。
/home/ym/AIGC_udacity/exerise/diffusion/stanford_cars/
├── cars_annos.mat
├── cars_test_annos_withlabels.mat
├── car_ims/
    ├── 00001.jpg
    ├── 00002.jpg
    └── ...

数据集

使用斯坦福汽车数据集。该数据集包含 196 类汽车,共 16,185 张图片。在这个练习中,我们不需要任何标签,也不需要测试数据集。我们还将把图像转换为 64x64,以便更快地完成练习:

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import multiprocessing

IMG_SIZE = 64
BATCH_SIZE = 100

class ImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        
        # 获取所有图像文件路径
        self.image_paths = [os.path.join(root, 'car_ims', fname) 
                            for fname in os.listdir(os.path.join(root, 'car_ims')) 
                            if fname.endswith(('.jpg', '.png'))]
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
        
        # 只返回图像
        return image

# 数据预处理
data_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)  # 将数据归一化到 [-1, 1]
])

# 创建数据集。文件夹结构:
# /home/ym/AIGC_udacity/exerise/diffusion/stanford_cars/
# ├── cars_annos.mat
# ├── cars_test_annos_withlabels.mat
# ├── car_ims/
#     ├── 00001.jpg
#     ├── 00002.jpg
#     └── ...
dataset = ImageDataset(root='/home/ym/AIGC_udacity/exerise/diffusion/stanford_cars', 
                       transform=data_transform)

# 创建数据加载器
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    drop_last=False, 
    pin_memory=True, 
    num_workers=multiprocessing.cpu_count(),
    persistent_workers=True
)

数据集可视化如下:

在这里插入图片描述

噪声调度与预计算

在扩散模型的前向过程(Forward Process)中,我们需要根据一个噪声调度策略向数据中添加随机噪声。为了方便模型训练和推理过程中的计算,我们需要预先定义并计算一些常量。

1. 噪声调度的定义

# Define beta schedule
T = 512  # number of diffusion steps
# YOUR CODE HERE
betas = torch.linspace(start=0.0001, end=0.02, steps=T)  # linear schedule

plt.plot(range(T), betas.numpy(), label='Beta Values')
plt.xlabel('Diffusion Step')
plt.ylabel('Beta Value')
_ = plt.title('Beta Schedule over Diffusion Steps')

在这里插入图片描述

  1. 噪声调度的定义
  • 代码中设置 $ T = 512 $​,表示扩散过程共有512个步骤。
  • 使用 torch.linspace(start=0.0001, end=0.02, steps=T) 创建了一个线性调度器,其中噪声参数 $ \beta_t $​ 从 0.0001 均匀递增至 0.02,生成512个数值,代表每个扩散步骤中的噪声幅度。
  1. 解释图像
  • 图像显示, β t \beta_t βt 随扩散步骤呈线性增长,这意味着在扩散模型的每一步,添加到样本中的噪声量逐渐增加。
  • 这种线性调度策略通常用于扩散模型的训练,帮助模型在生成过程中的前期添加较小噪声,后期添加较大噪声,使得模型能够更稳定地学习数据分布。
  1. 备注:后续改进的DDPM使用余弦噪声调度,能获取更好性能。原因:线性调度可能会导致输入图像中的信息快速丢失。因此,这通常会导致突然的扩散过程。相比之下,余弦调度提供了更平滑的退化。因此,允许后续步骤对没有被噪声完全淹没的图像进行操作。
    a. 在这里插入图片描述
2. 重新参数化公式

前向过程中的重新参数化允许我们在任意步骤生成带噪声的图像,而无需按顺序遍历所有的前面步骤。这通过以下公式实现:

α ˉ t = ∏ s = 1 t ( 1 − β s ) \bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s) αˉt=s=1t(1βs)

q ( x t ∣ x 0 ) = N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I) q(xtx0)=N(αˉt x0,(1αˉt)I)

这意味着我们可以直接根据 x 0 x_0 x0 生成任意步骤 x t x_t xt 的带噪声样本,而不必每次从 x 1 , x 2 , . . . , x t − 1 x_1, x_2, ..., x_{t-1} x1,x2,...,xt1 逐步采样。

3. 推理阶段公式

推理阶段的反向过程使用以下公式生成样本 x t − 1 x_{t-1} xt1

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

其中

σ t 2 = ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) β t \sigma_t^2 = \frac{(1 - \bar{\alpha}_{t-1})}{(1 - \bar{\alpha}_t)} \beta_t σt2=(1αˉt)(1αˉt1)βt

这个公式描述了如何利用当前带噪声的样本 x t x_t xt 生成上一步的样本 x t − 1 x_{t-1} xt1

4. 代码实现与预计算

为了高效地进行前向和反向过程,我们在代码中对上述所有常数进行了预先计算:

# 预先计算在封闭形式中用到的各项
alphas = 1. - betas  # 计算每个步骤的 α_t
alphas_cumprod = torch.cumprod(alphas, axis=0)  # 计算累积乘积 \bar{\alpha}_t

# 计算 \bar{\alpha}_{t-1}
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) 

# 计算 \sqrt{\bar{\alpha}_t}
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 

# 计算 1 / \sqrt{\alpha_t}
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) 

# 计算 \sqrt{1 - \bar{\alpha}_t}
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) 

# 计算推理阶段的 σ_t^2
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
4. 解释代码中的每个步骤
  • alphas = 1. - betas:计算每个步骤的 α t \alpha_t αt
  • alphas_cumprod = torch.cumprod(alphas, axis=0):计算 α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t = \prod_{s=1}^t \alpha_s αˉt=s=1tαs
  • alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0):计算 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1 并在序列开始处补1,使其长度与其他参数匹配。
  • sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod):计算 α ˉ t \sqrt{\bar{\alpha}_t} αˉt
  • sqrt_recip_alphas = torch.sqrt(1.0 / alphas):计算 1 α t \frac{1}{\sqrt{\alpha_t}} αt 1
  • sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod):计算 1 − α ˉ t \sqrt{1 - \bar{\alpha}_t} 1αˉt
  • posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod):计算推理阶段 x t − 1 x_{t-1} xt1 所需的方差 σ t 2 \sigma_t^2 σt2

前向扩散过程(添加噪声)与反向采样过程(去除噪声)

  • forward_diffusion_viz(): 实现了输入图像从原始样本到完全随机噪声的前向扩散过程。
  • make_inference(): 实现了扩散模型的逆过程,即将随机噪声逐步还原为原始图像。
@torch.no_grad()
def forward_diffusion_viz(image, device='cpu', num_images=16, dpi=75, interleave=False):
    """
    Generate the forward sequence of noisy images taking the input image to pure noise
    """
    # Visualize only num_images diffusion steps, instead of all of them
    stepsize = int(T/num_images)
    
    imgs = []
    noises = []
    
    for i in range(0, T, stepsize):
        t = torch.full((1,), i, device=device, dtype=torch.long)

        # Forward diffusion process
        bs = image.shape[0]
        noise = torch.randn_like(image, device=device)
        img = (
            sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * image + 
            sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
        )

        imgs.append(torch.clamp(img, -1, 1).squeeze(dim=0))
        noises.append(torch.clamp(noise, -1, 1).squeeze(dim=0))
    
    if interleave:
        imgs = [item for pair in zip(imgs, noises) for item in pair]
        
    fig = display_sequence(imgs, dpi=dpi)
    
    return fig, imgs[-1]


@torch.no_grad()
def make_inference(input_noise, return_all=False):
    """
    Implements the sampling algorithm from the DDPM paper
    """
    
    x = input_noise
    bs = x.shape[0]
    
    imgs = []
    
    # YOUR CODE HERE
    for time_step in range(0, T)[::-1]:
        
        noise = torch.randn_like(x) if time_step > 0 else 0
        
        t = torch.full((bs,), time_step, device=device, dtype=torch.long)
        
        # YOUR CODE HERE
        x = sqrt_recip_alphas[t].view(bs, 1, 1, 1) * (
            x - betas[t].view(bs, 1, 1, 1) * model(x, t) / 
            sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1)
        ) + torch.sqrt(posterior_variance[t].view(bs, 1, 1, 1)) * noise
        
        imgs.append(torch.clamp(x, -1, 1))
    
    if return_all:
        return imgs
    else:
        return imgs[-1]
    
    return x
1. 前向扩散过程:forward_diffusion_viz

该函数实现了将输入图像通过扩散过程逐步变成纯噪声,并在过程中可视化各步骤的噪声生成效果。

代码解析:

  • 装饰器 @torch.no_grad():表示在该函数中不会计算梯度,节省内存并加快计算速度,这是因为我们只是想观察生成过程而不是进行训练。
  • 输入参数
    • image: 输入图像张量,通常是一个单个样本。
    • num_images: 想要可视化的扩散步骤数,默认是16。
    • dpi: 绘图的分辨率。
    • interleave: 是否交错显示生成的噪声图像。

主要步骤

  1. 计算扩散步长stepsize = int(T / num_images) 用于确定可视化过程中要间隔多少步。
  2. 在循环中实现前向扩散过程
  • 对每个步骤 t t t,计算当前噪声 n o i s e noise noise 并利用公式生成带噪声的图像 img

x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵ

  • 将结果添加到 imgs 列表中,noises 列表中存储对应的噪声。

  • 交错选项 interleave:如果 interleave=True,将噪声图像和带噪声图像交替存储在 imgs 列表中。

  • 显示结果:调用 display_sequence(imgs, dpi=dpi) 来显示扩散过程,并返回最终图像。

2. 反向采样过程:make_inference

该函数实现了扩散模型的反向采样过程,将随机噪声转化为生成的图像样本。这个过程基于DDPM论文中的采样算法。

代码解析:

  • 输入参数
    • input_noise: 输入噪声,用于开始反向采样过程。
    • return_all: 控制是否返回所有步骤的结果。

主要步骤

  1. 初始化:将输入噪声赋值给 x
  2. 反向扩散过程
  • T T T 到 0,逐步反向遍历扩散步骤。
  • 在每一步 t t t,根据扩散模型的公式生成上一步的样本 x t − 1 x_{t-1} xt1

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

  • 使用 model(x, t) 预测当前步骤的噪声,使用 torch.sqrt(posterior_variance[t]) 计算加权噪声项。

  • 将生成的结果 x 添加到 imgs 列表中。

  • 返回结果:如果 return_all=True,则返回所有步骤的结果,否则只返回最后一个步骤的图像。

前向扩散过程可视化效果

在这里插入图片描述

开始训练

  • 下述代码实现了一个扩散模型的训练过程,主要包括模型初始化、数据处理、训练循环、损失计算和梯度更新等。
  • 特别注意了混合精度训练的部分,使用了 torch.cuda.amp 中的 autocastGradScaler 来实现自动混合精度训练,从而提高了训练效率并减少显存使用。(消耗10.5G显存,不使用混合精度训练估计需要18G显存)
  • 同时加入了学习率预热(warmup)和余弦退火(Cosine Annealing)的调度策略,以确保模型训练的稳定性和效果。
# 导入自定义的UNet模型
from unet import UNet

# 初始化UNet模型,使用默认的通道倍增数
model = UNet(ch_mults = (1, 2, 1, 1))

# 如果想进行非常长时间的训练,可以选择注释掉上面的模型初始化
# 并启用下面这一行的模型初始化
# model = UNet(ch_mults = (1, 2, 2, 2))

# 计算模型参数数量,并打印
n_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {n_params:,}")

# 设置设备为GPU或CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# 将模型移动到对应设备
model.to(device)

# 将所有需要的参数移动到设备(GPU/CPU)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)
alphas_cumprod_prev = alphas_cumprod_prev.to(device)
sqrt_recip_alphas = sqrt_recip_alphas.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)
posterior_variance = posterior_variance.to(device)
betas = betas.to(device)

# 定义损失函数
criterion = torch.nn.MSELoss()

# 设置训练相关参数
base_lr = 0.0006 # 基础学习率
epochs = 10 # 总的训练轮次
T_max = epochs  # Cosine Annealing 的最大步数
warmup_epochs = 2  # 预热训练轮数

# 如果想要进行非常长时间的训练,可以启用以下设置
# base_lr = 0.0001 # 基础学习率
# epochs = 300 # 总训练轮数
# T_max = epochs  # Cosine Annealing 的最大步数
# warmup_epochs = 10  # 预热训练轮数

# 初始化优化器和学习率调度器
optimizer = Adam(model.parameters(), lr=base_lr)
scheduler = CosineAnnealingLR(
    optimizer, 
    T_max=T_max - warmup_epochs,  # 调度器的最大步数
    eta_min=base_lr / 10  # 学习率最小值
)

# 导入用于混合精度训练的模块
from torch.cuda.amp import autocast, GradScaler

# 初始化GradScaler用于缩放梯度
scaler = GradScaler()

# 生成固定的噪声,用于在训练过程中检查模型的生成效果
fixed_noise = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)

# 设置 EMA 损失平滑因子
alpha = 0.1  # EMA(指数移动平均)平滑因子
ema_loss = None  # 初始化EMA损失

# 开始训练循环
for epoch in range(epochs):
    
    if epoch < warmup_epochs:
        # 线性预热学习率
        lr = base_lr * (epoch + 1) / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        # 预热完成后使用余弦退火学习率
        scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']  # 当前学习率
        
    for batch in tqdm(dataloader):
        
        batch = batch.to(device)  # 将当前batch移动到设备
        bs = batch.shape[0]  # 获取当前batch大小
        
        optimizer.zero_grad()  # 清空优化器的梯度
        
        # 混合精度训练开始
        with autocast():
            # 随机选择t时刻
            t = torch.randint(0, T, (batch.shape[0],), device=device).long()
            
            # 生成目标噪声并添加到图像中
            noise = torch.randn_like(batch, device=device)
            x_noisy = (
                sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * batch + 
                sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
            )
            
            # 通过模型预测噪声
            noise_pred = model(x_noisy, t)
            loss = criterion(noise, noise_pred)  # 计算损失
        
        # 使用 scaler 进行混合精度训练的反向传播
        scaler.scale(loss).backward()
        
        # 更新优化器的权重
        scaler.step(optimizer)
        
        # 更新 scaler 
        scaler.update()
        
        if ema_loss is None:
            # 第一个 batch 初始化 ema_loss
            ema_loss = loss.item()
        else:
            # 计算损失的指数移动平均
            ema_loss = alpha * loss.item() + (1 - alpha) * ema_loss
    
    if epoch == epochs-1:
        with torch.no_grad():
            # 在训练结束时对固定噪声进行推理,查看生成结果
            imgs = make_inference(fixed_noise, return_all=True)
            fig = display_sequence([imgs[0].squeeze(dim=0)] + [x.squeeze(dim=0) for x in imgs[63::64]], nrow=9, dpi=150)
            plt.show(fig)
        
        # 保存结果图像
        os.makedirs("diffusion_output_long", exist_ok=True)
        fig.savefig(f"diffusion_output_long/frame_{epoch:05d}.png")
    
    # 打印当前轮次的损失和学习率
    print(f"epoch {epoch+1}: loss: {ema_loss:.3f}, lr: {current_lr:.6f}")

训练结果:

在这里插入图片描述

效果展示

考虑到这个模型如此之小,我们对它的训练又如此之少,这个结果已经相当不错了。我们已经可以看出,它确实在创建汽车,而且还带有挡风玻璃和车轮,尽管这还只是初步阶段。如果我们训练的时间更长,和/或使用更大的模型(例如,上面注释行中定义的模型有 5,500 万个参数),并让它训练几个小时,我们会得到更好的结果,就像这样:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2171114.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Bug:ThreadPoolTaskScheduler搭配CronTask完成定时任务,关闭scheduler后CronTask任务仍然执行?

【问题】执行下面代码后&#xff0c;关闭ThreadPoolTaskScheduler&#xff0c;CronTask仍然继续执行。 Configuration public class config {Beanpublic String getString() throws InterruptedException {Runnable runnable () -> {try {System.out.println("hello r…

动态规划算法:13.简单多状态 dp 问题_打家劫舍II_C++

目录 题目链接&#xff1a;LCR 090. 打家劫舍 II - 力扣&#xff08;LeetCode&#xff09; 一、题目解析 题目&#xff1a; 解析&#xff1a; 二、算法原理 1、状态表示 2、状态转移方程 状态转移方程推理&#xff1a; 1、i位置状态分析 2、首尾状态分析 3、初始化 d…

Meta震撼发布Llama3.2大规模模型

在2024.9.26的年Meta Connect大会上&#xff0c;Meta正式推出了Llama3.2模型&#xff0c;旨在提升边缘AI和视觉任务的能力。Llama3.2系列包括11亿和90亿参数的中型视觉模型&#xff0c;以及为移动设备优化的1亿和3亿参数的小型模型&#xff0c;并针对高通和联发科的硬件平台进行…

Webpack 介绍

Webpack 介绍 Date: August 29, 2024 全文概要 Webpack概念&#xff1a; Webpack是一个静态的模块化的打包工具&#xff0c;可以为现代的 JavaSript 应用程序进行打包。 1-静态&#xff1a;Webpack可以将代码打包成最终的静态资源 2-模块化&#xff1a;webpack支持各种模块…

教师工作量评估与管理软件

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

Spring异常处理-@ExceptionHandler-@ControllerAdvice-全局异常处理

文章目录 ResponseBodyControllerAdvice最终的异常处理方式 异常的处理分两类 编程式处理&#xff1a;也就是我们的try-catch 声明式处理&#xff1a;使用注解处理 ResponseBody /*** 测试声明式异常处理*/ RestController public class HelloController {//编程式的异常处理&a…

EasyAR自定义相机RTSP视频流(CustomCamera)

EasyAR可以使用视频源作为输入源&#xff0c;官方给出了示例和文档&#xff0c;但是对于大部分Unity开发人员来说看了文档还是一头雾水。 在Android Studio中将custom-camera.jar添加libs中&#xff0c;就可以查看源代码了 分析其源代码&#xff0c;主要是ExternalCameraSampl…

【linux 多进程并发】linux下使用常见命令,来解析进程家族体系脉络

0101 Linux进程 ​专栏内容&#xff1a; postgresql使用入门基础手写数据库toadb并发编程 个人主页&#xff1a;我的主页 管理社区&#xff1a;开源数据库 座右铭&#xff1a;天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物. 文章目录 0101 Li…

ASP.NET Core 打包net8.0框架在Linux CentOS7上部署问题

问题1 libstdc.so.6版本过低。 CentOS7默认安装的gcc版本太低&#xff0c;达不到.net8的启动条件。 /lib64/libstdc.so.6: version GLIBCXX_3.4.20’ not found (required by ./IDT_net) /lib64/libstdc.so.6: version GLIBCXX_3.4.21’ not found (required by ./IDT_net) 解…

恢复丢失的数据:恢复数据库网络解决方案

探索恢复数据库网络的深度对于了解现代企业如何防御其数据不断增长的威胁至关重要。在一个时代&#xff0c;数字证据和取证网络安全在法律和商业领域扮演关键角色&#xff0c;这些网络提供的弹性是不可或缺的。深入研究恢复数据库网络的重要性不仅仅是数据保护&#xff0c;它还…

ubuntu安装mysql 8,mysql密码的修改

目录 1.安装mysql 82.查看当前状态3.手动给数据库设置密码mysql5mysql8 4.直接把数据库验证密码的功能关闭掉 1.安装mysql 8 apt install mysql-server-8.0敲 Y 按回车 table 选ok 2.查看当前状态 service mysql status显示active&#xff08;running&#xff09;证明安装成…

媒界:吉利星瑞百炼成钢,持续引领中国汽车价值向上

秋风送爽绘秋色&#xff0c;出行良辰恰逢时。9月28日至9月29日&#xff0c;2024安行中国汽车安全科技公益巡展迎来尾声&#xff0c;安行中国携手吉利汽车&#xff0c;步履轻盈地踏入苏州星湖天街&#xff0c;共同呈献一场融合环保科技前沿、安全驾驶理念与深厚文化底蕴的48小时…

使用jQuery处理Ajax

使用jQuery处理Ajax HTTP协议 超文本传输协议&#xff08;HTTP&#xff0c;HyperText Transfer Protocol)是互联网上应用最为广泛的一种网络协议 设计HTTP最初的目的是为了提供一种发布和接收HTML页面的方法 所有的WWW文件都必须遵守这个标准 一次HTTP操作称为一个事务&am…

如何使用 CCF Communicator 框架快速开发设备接口

什么是 CCF Communicator Framework&#xff1f; 通信器框架通过封装 CCF 和设备之间的连接&#xff0c;简化了硬件之间的低级消息处理。 举例来说&#xff0c;考虑一下控制软件和硬件设备之间的连接方式。ASCII 串行连接需要使用 TCP 的套接字连接、用于处理设备发送/接收的…

肺癌类器官培养研究概述

前 言 2023年是类器官被《Science》杂志评为年度十大技术的10周年。10年后类器官技术发展迅猛&#xff0c;犹如一颗璀璨的明珠&#xff0c;不断的为生命科学研究揭示新的奥秘&#xff0c;推动生物医学领域不断前行。肺类器官培养条件也在不断完善&#xff0c;在基础和临床研究…

MySQL面试知识汇总

学习链接 创建索引有哪些注意点&#xff1f; 索引应该建在查询频繁的字段&#xff0c;比如where查询、order排序索引的个数应该适量&#xff08;最多64个&#xff09;&#xff0c;索引需要占用空间&#xff0c;更新时也需要维护区分度低的字段&#xff0c;例如性别&#xff0c…

声阔头戴式耳机怎么样?西圣、jBL、声阔头戴式耳机终极pk测评推荐

我们深知&#xff0c;一款优秀的头戴式耳机&#xff0c;不仅仅是音乐的传递者&#xff0c;更是用户情感与个性的延伸。因此&#xff0c;在设计之初&#xff0c;便将极致的佩戴舒适度视为核心追求&#xff0c;通过人体工学的精准设计与优质材料的精心挑选&#xff0c;力求让每一…

Linux 配置与管理 SWAP(虚拟内存)

Linux 配置与管理 SWAP(虚拟内存&#xff09; 一、作用二、创建交换文件&#xff08;以创建一个2GB的交换文件为例&#xff09;1. 创建交换文件2. 设置文件权限2.1. **关于 sudo chmod 600 /root/swapfile 是否一定要执行**2.2. **关于其他用户启动是否没权限用到交换分区** 3.…

大数据电商数仓项目--实战(一)数据准备

第一章 数仓分层 1.1 为什么要分层 1.2 数仓命名规范 1.2.1 表命名 ODS层命名为ods_表名DIM层命名为dim_表名DWD层命名为dwd_表名DWS层命名为dws_表名DWT层命名为dwt_表名ADS层命名为ads_表名临时表命名为tmp_表名 1.2.2 表字段类型 数量类型为bigint金额类型为decimal(16…

猫咪独自在家可以吗?希喂、美的、有哈宠物空气净化器哪款好?

这不是快要国庆了吗&#xff0c;本来计划去旅游的&#xff0c;结果我妈让我假期回家。收拾行李已经很烦了&#xff0c;行李箱旁的猫咪更是让我头疼。我妈因为之前浮毛过敏的事情&#xff0c;禁止我把猫咪再带回家&#xff0c;朋友们也各有计划&#xff0c;甚至连上门喂养都约满…