【Diffusion实战】训练一个类别引导diffusion模型(Pytorch代码详解)

news2025/1/9 20:30:15

  又学习了一种方法,类别引导diffusion模型,使用mnist数据集,记录一下它的用法吧。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
  【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)
  【Diffusion实战】引导一个diffusion模型根据文字生成图像(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


1、数据集装载

  使用mnist数据集来训练类别引导diffusion模型,因为其比较简单清晰:

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, 
                                     transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()

  看一看我们朴素的样本:
在这里插入图片描述


2、创建条件扩散模型

  创建了一个名为ClassConditionedUnet的条件扩散模型,定义了一个可学习的嵌入层,用以将数字类别映射到特征向量上,将类别嵌入与原始输入拼接之后,送入常规的UNet网络即可。

  知识传送:【python函数】torch.nn.Embedding函数用法图解

class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()
    
    # 嵌入层将数字类别映射到特征向量上
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # 一个常规的UNet网络
    self.model = UNet2DModel(
        sample_size=28,           # 图像尺寸
        in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成
        out_channels=1,           # 输出通道
        layers_per_block=2,       # 残差连接层数目
        block_out_channels=(32, 64, 64), 
        down_block_types=( 
            "DownBlock2D",        # a regular ResNet downsampling block
            "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ), 
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
          ),
    )

  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape  # [8, 1, 28, 28] 
    
    # 类别条件以额外通道的形式输入
    class_cond = self.class_emb(class_labels)  # [8, 4]
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]
    
    # 拼接原始输入与类别条件映射
    net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)

    # 模型预测
    return self.model(net_input, t).sample  # (8, 1, 28, 28)

noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同

3、模型训练

  训练过程就跟之前的一样啦~

# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

losses = []
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):
        
        # 获取数据并添加噪声
        x = x.to(device) * 2 - 1  # 归一化到[-1, 1]
        y = y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        # 前向加噪
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # 获得模型预测结果
        pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签

        # 损失计算
        loss = loss_fn(pred, noise) 

        # 损失回传, 参数更新
        opt.zero_grad()
        loss.backward()
        opt.step()

        # 损失保存
        losses.append(loss.item())

    # 输出损失
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()

  输出损失曲线为:

在这里插入图片描述


4、模型推理

  进行采样循环,用类别标签引导图像生成:

x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签

# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # 模型预测结果
    with torch.no_grad():
        residual = net(x, t, y)

    # 根据预测噪声和时间步更新图像
    x = noise_scheduler.step(residual, t, x).prev_sample

# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')

  类别引导效果如下,效果还是挺好的哩:

在这里插入图片描述


5、代码汇总

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# -----------------------------------------------------------------------------
# 1、数据集装载
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, 
                                     transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# 2、创建条件扩散模型
class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()
    
    # 嵌入层将数字类别映射到特征向量上
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # 一个常规的UNet网络
    self.model = UNet2DModel(
        sample_size=28,           # 图像尺寸
        in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成
        out_channels=1,           # 输出通道
        layers_per_block=2,       # 残差连接层数目
        block_out_channels=(32, 64, 64), 
        down_block_types=( 
            "DownBlock2D",        # a regular ResNet downsampling block
            "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ), 
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
          ),
    )

  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape  # [8, 1, 28, 28] 
    
    # 类别条件以额外通道的形式输入
    class_cond = self.class_emb(class_labels)  # [8, 4]
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]
    
    # 拼接原始输入与类别条件映射
    net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)

    # 模型预测
    return self.model(net_input, t).sample  # (8, 1, 28, 28)

noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# 3、模型训练
# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

losses = []
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):
        
        # 获取数据并添加噪声
        x = x.to(device) * 2 - 1  # 归一化到[-1, 1]
        y = y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        # 前向加噪
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # 获得模型预测结果
        pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签

        # 损失计算
        loss = loss_fn(pred, noise) 

        # 损失回传, 参数更新
        opt.zero_grad()
        loss.backward()
        opt.step()

        # 损失保存
        losses.append(loss.item())

    # 输出损失
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# 4、模型推理
x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签

# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # 模型预测结果
    with torch.no_grad():
        residual = net(x, t, y)

    # 根据预测噪声和时间步更新图像
    x = noise_scheduler.step(residual, t, x).prev_sample

# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')
# -----------------------------------------------------------------------------

  diffusion的修炼境界又提升了一级~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1649420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何处理多模态数据噪声不均衡动态?天大等最新《低质量数据的多模态融合》综述

多模态融合致力于整合来自多种模态的信息,目的是实现更准确的预测。在包括自动驾驶和医疗诊断等广泛的场景中,多模态融合已取得显著进展。然而,在低质量数据环境下,多模态融合的可靠性大部分仍未被探索。本文综述了开放多模态融合…

Linux专栏03:使用Xshell远程连接云服务器

博客主页:Duck Bro 博客主页系列专栏:Linux专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 使用Xshell远程连接云服务器 编号:03 文章目录 使用Xsh…

鸿蒙OpenHarmony实战开发-MiniCanvas

介绍 基于OpenHarmony的Cavas组件封装了一版极简操作的MiniCanvas,屏蔽了原有Canvas内部复杂的调用流程,支持一个API就可以实现相应的绘制能力,该库还在继续完善中,也欢迎PR。 使用说明 1.添加MiniCanvas依赖 在项目entry目录…

Spring Boot3.x集成Disruptor4.0

Disruptor介绍 Disruptor是一个高性能内存队列,研发的初衷是解决内存队列的延迟问题(在性能测试中发现竟然与I/O操作处于同样的数量级)。基于Disruptor开发的系统单线程能支撑每秒600万订单,2010年在QCon演讲后,获得了业界关注。2011年&…

C++手写协程项目(协程实现线程结构体、线程调度器定义,线程挂起函数、线程切换函数、线程恢复函数、线程结束函数、线程结束判断函数,模块测试)

协程结构体定义 之前我们使用linux下协程函数实现了线程切换,使用的是ucontext_t结构体,和基于这个结构体的四个函数。现在我们要用这些工具来实现我们自己的一个线程结构体,并实现线程调度和线程切换、挂起。 首先我们来实现以下线程结构体…

【iOS】——浅析CALayer

文章目录 一、CALayer介绍二、UIview与CALayer1.区别2.联系 三、CALayer的使用1.初始化方法2.常用属性 四.CALayer坐标系1.position属性和anchorPoint属性2.position和anchorPoint的关系3.position、anchorPoint和frame的关系 五、CALayerDelegate六、CALayer绘图机制1.绘图流程…

官方教程来啦!上手体验YashanDB主备部署、同步延迟和自动切换能力

在上一篇深度干货 | 如何兼顾性能与可靠性?一文解析YashanDB主备高可用技术中,我们深入探讨了YashanDB高可用的架构设计原理和关键技术,本文将聚焦于实践操作,快速体验YashanDB的主备高可用能力。 概要 YashanDB提供了不同部署形…

C++程序设计教案

文章目录: 一:软件安装环境 第一种:vc2012 第二种:Dev-C 第三种:小熊猫C 二:语法基础 1.相关 1.1 注释 1.2 换行符 1.3 规范 1.4 关键字 1.5 ASCll码表 1.6 转义字符 2.基本框架 2.1 第一种&…

如果insightface/instantID安装失败怎么办(关于InsightFaceLoader_Zho节点的报错)

可能性有很多,但是今天帮朋友解决问题的时候又收集了一种新的思路。 首先,可以先按照这篇文章里边提到的方法去安装: 【全网最详细】ComfyUI下,Insightface安装指南-聚梦小课堂_insightface如何安装-CSDN博客 其次,…

解决Python中的 `ModuleNotFoundError: No module named ‘fcmeans‘` 错误

ModuleNotFoundError: No module named fcmeans 解决Python中的 ModuleNotFoundError: No module named fcmeans 错误如何解决这个错误fcmeans 库简介应用实例 解决Python中的 ModuleNotFoundError: No module named fcmeans 错误 在进行数据科学或机器学习项目时,…

Linux内核之获取文件系统超级块:sget用法实例(六十八)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

大眼橙C1 Air投影仪:千元预算内的明智之选

在科技日新月异的今天,投影仪已经不再是会议室或教室的专属,而是越来越多地走入了寻常百姓家。家庭影院的概念越来越流行,尤其在都市人之间逐渐成为一股风尚。市场上投影仪非常多,如何选到一台合适的投影仪也成为困扰广大用户的一…

了解TMS运输管理系统,实现物流高效运转

TMS运输管理系统(Transportation Management System)是一种集成物流和信息技术的解决方案,通过优化运输流程、实时跟踪货物信息和自动化管理操作,提高物流效率,降低运营成本,实现高效运输。 TMS运输管理系…

软件设计师-重点的构造型设计模式

一、桥接模式(Bridge): 意图: 将抽象部分与其实现部分分离,使它们都可以独立地变化。 结构: 适用性: 不希望在抽象和它的实现部分之间有一个固定的绑定关系。例如,这种情况可能是…

探索大模型能力--prompt工程

1 prompt工程是什么 1.1 什么是Prompt? LLM大语言模型终究也只是一个工具,我们不可能每个人都去训一个大模型,但是我们可以思考如何利用好大模型,让他提升我们的工作效率。就像计算器工具一样,要你算10的10倍&#x…

【计算机网络】计算机网络的性能指标

计算机网络的性能指标被用来从不同方面度量计算机网络的性能。常用的八个计算机网络性能指标:速率、带宽、吞吐量、时延、时延带宽积、往返时间、利用率、丢包率。 一.速率 (1) 数据量 比特(bit,记为小写b)是计算机中数据量的基…

JavaWEB 框架安全:Spring 漏洞序列.(CVE-2022-22965)

什么叫 Spring 框架. Spring 框架是一个用于构建企业级应用程序的开源框架。它提供了一种全面的编程和配置模型,可以简化应用程序的开发过程。Spring 框架的核心特性包括依赖注入(Dependency Injection)、面向切面编程(Aspect-Or…

c++ 线程交叉场景试验

1.需求 1.处理一个列表的数据,要求按照列表的数据处理10个数据 2.可以使用多线程处理,但是针对每个线程,1~10的处理顺序不能变。 3.每个数据的处理必须原子,即只有一个线程可以针对某个数据进行处理,但是10个数据是可…

2024年CSC公派联合培养博士项目申报即将开始~

一、选派计划 联合培养博士研究生面向全国各博士学位授予单位选拔。 联合培养博士研究生的留学期限、资助期限为6-24个月。留学期限应根据拟留学单位学制、外方录取通知(或正式邀请信)中列明的留学时间确定。个人申报的资助期限应不超过留学期限&#…

79、贪心-跳跃游戏II

思路: 首先理解题意:从首位置跳最少多少次到达末尾。 第一种:使用递归,将所有跳转路径都获取到进行求出最小值。 第二种:使用动态规划,下一次最优取决上一次的最优解 第三针:贪心&#xff…