【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)

news2025/1/18 11:54:32

  上一篇Diffusion实战是确确实实一步一步走的公式,这回采用一个更方便的库:diffusers,来实现Diffusion模型训练。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


0、所需安装

pip install diffusers  # diffusers库
pip install datasets  

1、数据集下载

  下载地址:蝴蝶数据集
  下载好后的文件夹中包括以下文件,放在当前目录下就可以了。

在这里插入图片描述
加载数据集,并对一批数据进行可视化:

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')

image_size = 32
batch_size = 64

# 数据增强
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

dataset.set_transform(transform)

# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出可视化结果:

在这里插入图片描述


2、加噪调度器

  即DDPM论文中需要预定义的 β t {\beta_t } βt ,可使用DDPMScheduler类来定义,其中num_train_timesteps参数为时间步 t {t} t

from diffusers import DDPMScheduler

# βt值
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

根据定义的 β t {\beta_t } βt ,可视化 α ˉ t {\sqrt {{{\bar \alpha }_t}}} αˉt 1 − α ˉ t {\sqrt {1 - {{\bar \alpha }_t}}} 1αˉt

在这里插入图片描述

  通过设置beta_start、beta_end和beta_schedule三个参数来控制噪声调度器的超参数 β t {\beta_t } βt

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)

在这里插入图片描述

  beta_schedule可以通过一个函数映射来为模型推理的每一步生成一个 β t {\beta_t } βt值。

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

在这里插入图片描述

x t = α ˉ t x 0 + 1 − α ˉ t ε {{x_t} = \sqrt {{{\bar \alpha }_t}} {x_0} + \sqrt {1 - {{\bar \alpha }_t}} \varepsilon } xt=αˉt x0+1αˉt ε 加噪前向过程可视化:

timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

输出为:

在这里插入图片描述


3、扩散模型定义

  diffusers库中模型的定义也非常简洁:

# 创建模型
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

model.to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同

4、扩散模型训练

  定义优化器,和传统模型一样的训练写法:

# 定义噪声调度器
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        
        # 为图像添加随机噪声
        noise = torch.randn(clean_images.shape).to(clean_images.device)  # eps
        bs = clean_images.shape[0]

        # 为每一张图像随机选择一个时间步
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()  

        # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # 获得模型预测结果
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # 计算损失, 损失回传
        loss = F.mse_loss(noise_pred, noise)  
        loss.backward(loss)
        losses.append(loss.item())

        # 更新模型参数
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

30个epoch训练过程如下所示:

在这里插入图片描述

可用以下代码查看损失曲线:

# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()

损失曲线可视化:

在这里插入图片描述


5、图像生成

  (1)通过建立pipeline生成图像:

# 图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)

pipeline_output = image_pipe()
plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()

# 保存pipeline
image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

生成的蝴蝶图像如下:

在这里插入图片描述

生成的my_pipeline文件夹如下:

在这里插入图片描述

  (2)通过随机采样循环生成图像:

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # 获得模型预测结果
    with torch.no_grad():
        residual = model(sample, t).sample

    # 根据预测结果更新图像
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)

8张生成图像如下:
在这里插入图片描述


6、代码汇总

import torch
import torchvision
from datasets import load_dataset
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image


def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

# --------------------------------------------------------------------------------
# 1、数据集加载与可视化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 数据加载
dataset = load_dataset("./smithsonian_butterflies_subset", split='train')

image_size = 32
batch_size = 64

# 数据增强
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

dataset.set_transform(transform)

# 数据装载
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 抽取一批数据可视化
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 2、噪声调度器
from diffusers import DDPMScheduler

# 加噪声的系数βt
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

plt.figure(dpi=300)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 加噪声可视化
timesteps = torch.linspace(0, 999, 8).long().to(device)  # 随机采样时间步
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)  # 加噪
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 3、创建模型
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

model.to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape  # 验证输出与输出尺寸相同
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 4、扩散模型训练
# 定义噪声调度器
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        
        # 为图像添加随机噪声
        noise = torch.randn(clean_images.shape).to(clean_images.device)  # eps
        bs = clean_images.shape[0]

        # 为每一张图像随机选择一个时间步
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()  

        # 根据时间步,向清晰的图像中加噪声, 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # 获得模型预测结果
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # 计算损失, 损失回传
        loss = F.mse_loss(noise_pred, noise)  
        loss.backward(loss)
        losses.append(loss.item())

        # 更新模型参数
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 损失曲线可视化
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))  # 对数坐标
plt.show()
# --------------------------------------------------------------------------------

# --------------------------------------------------------------------------------
# 5、图像生成
# 方法一:建立一个pipeline, 打包模型和噪声调度器
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)

pipeline_output = image_pipe()

plt.figure()
plt.imshow(pipeline_output.images[0])
plt.axis('off')
plt.show()

image_pipe.save_pretrained("my_pipeline")  # 在当前目录下保存了一个 my_pipeline 的文件夹

# 方法二:模型调用, 写采样循环 
# 随机初始化8张图像:
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # 获得模型预测结果
    with torch.no_grad():
        residual = model(sample, t).sample

    # 根据预测结果更新图像
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)

grid_im = show_images(sample).resize((8 * 64, 64), resample=Image.NEAREST)
plt.figure(dpi=300)
plt.imshow(grid_im)
plt.axis('off')
plt.show()
# --------------------------------------------------------------------------------

  参考资料:扩散模型从原理到实践. 人民邮电出版社. 李忻玮, 苏步升等.

  diffusers确实很方便使用,有点子PyCaret的感觉了~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1628048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux学习】​​学习Linux的准备工作和Linux的基本指令

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

并并并并·病查坤

P1、什么是并查集 引用自百度百科: 并查集,在一些有N个元素的集合应用问题中,我们通常是在开始时让每个元素构成一个单元素的集合,然后按一定顺序将属于同一组的元素所在的集合合并,其间要反复查找一个元素在哪个集合…

【数据标注】使用LabelImg标注YOLO格式的数据(案例演示)

文章目录 LabelImg介绍LabelImg安装LabelImg界面标注常用的快捷键标注前的一些设置案例演示检查YOLO标签中的标注信息是否正确参考文章 LabelImg介绍 LabelImg是目标检测数据标注工具,可以标注两种格式: VOC标签格式,标注的标签存储在xml文…

insightface 环境配置

首先创建续集环境: conda create -n insightface3 python3.8 然后打开此虚拟环境:conda activate insightface3 然后安装: pip install insightface 再安装:pip install onnxruntime-gpu 就可以了

零基础俄语培训哪家好,柯桥俄语培训

1、Мощный дух спасает расслабленное тело. 强大的精神可以拯救孱弱的肉体。 2、Единственное правило в жизни, по которому нужно жить — оставаться человеком в лю…

物联网的基本功能及五大核心技术——青创智通

工业物联网解决方案-工业IOT-青创智通 物联网基本功能 物联网的最基本功能特征是提供“无处不在的连接和在线服务”,其具备十大基本功能。 (1)在线监测:这是物联网最基本的功能,物联网业务一般以集中监测为主、控制为…

qml和c++结合使用

目录 文章简介1. 创建qml工程2. 创建一个类和qml文件,修改main函数3. 函数说明:4. qml 文件间的调用5. 界面布局6. 代码举例 文章简介 初学qml用来记录qml的学习过程,方便后面归纳总结整理。 1. 创建qml工程 如下图,我使用的是…

本地Mysql开启远程访问(图文)

目录 1. 问题所示2. 原理分析3. 解决方法 1. 问题所示 事因是访问同事的数据库时,出现无法访问 出现1130 - Host ‘IT07’ is not allowed to connect to this MySQL server截图如下: 2. 原理分析 如果账号密码地址都正常的情况下,这是没开…

SQLite尽如此轻量

众所周知,SQLite是个轻量级数据库,适用于中小型服务应用等,在我真正使用的时候才发现,它虽然轻量,但不知道它却如此轻量。 下载 官网: SQLite Download Page 安装 1、将下载好的两个压缩包同时解压到一个…

基于springboot的高校宣讲会管理系统

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式 🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 &…

前端工程化Vue使用Node.js永久设置国内高速npm镜像源

前端工程化Vue使用Node.js永久设置国内高速npm镜像源 接续上篇错误收录,此篇通过简单配置永久设置国内高速npm镜像源方法 1.更换新版镜像 清空npm缓存 npm cache clean --force修改回原版镜像源或直接删除配置过的镜像源 npm config set registry https://registr…

pve(Proxmox VE)安装i225v网卡驱动

配置pve源 备份原来的源 mv /etc/apt/sources.list /etc/apt/sources.list.bak打开文件 vi /etc/apt/sources.list将以下内容粘贴进去 deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmwaredeb https://mirrors.tuna.tsing…

2024景区五一游园会活动策划方案

2024景区营地五一发酵面包市集主题游园会(面包狂想曲主题)活动策划方案-59P 方案页码:59页 文件格式:pptx 方案简介: 面包派对 如约而至 你有见过面包发酵的过程吗? 看着面团从手掌大小 不断膨胀到一倍、两倍、…

CLIP论文笔记:Learning Transferable Visual Models From Natural Language Supervision

导语 会议:ICML 2021链接:https://proceedings.mlr.press/v139/radford21a/radford21a.pdf 当前的计算机视觉系统通常只能识别预先设定的对象类别,这限制了它们的广泛应用。为了突破这一局限,本文探索了一种新的学习方法&#x…

校车车载4G视频智能监控系统方案

一、项目背景 随着社会的快速发展,校车安全问题日益受到人们的关注。为了提高校车运营的安全性,保障学生的生命安全,我们提出了一套校车车载4G视频智能监控系统方案。该系统能够实时监控校车内部和外部环境,及时发现并处理潜在的…

全志ARM-网络链接

命令扫描周围的WIFI热点 nmcli dev wifi 命令接入网络 nmcli dev wifi connect (WiFi名,不要有空格)password (WiFi密码) 查看IP地址 ip addr show wlan0或ifconfig 出现successfully就连接成功了

qt5core.dll怎么下载,qt5core.dll丢失能否修复?

qt5core.dll的丢失真是让人头疼。这个Visual C Redistributable for Visual Studio 2015的运行时库被许多程序和游戏所依赖,一旦缺失了qt5core.dll,就会面临无法打开程序或游戏,甚至系统崩溃等一系列问题。 qt5core.dll的消失会带来以下麻烦 …

沉浸式推理乐趣:体验线上剧本杀小程序的魅力

在这个信息爆炸的时代,人们的娱乐方式也在不断地推陈出新。其中,线上剧本杀小程序以其独特的沉浸式推理乐趣,成为了许多人的新宠。它不仅让我们在闲暇之余享受到了推理的快乐,更让我们在虚拟的世界里感受到了人性的复杂与多彩。 线…

可视化大屏对浅色系果断说“不”

可视化大屏拒绝浅色系的原因有以下几点: 对比度不足:浅色系在大屏上可能会导致对比度不足,使得数据和图表的展示不清晰,影响信息的传达和理解。可读性差:浅色系的文字和图表在大屏上可能会显得模糊不清,使…