【扩散模型】实战:创建一个类别条件扩散模型

news2024/12/25 15:26:51

创建一个类别条件扩散模型

  • 1. 配置和数据准备
  • 2. 创建一个以类别为条件的UNet模型
  • 3. 训练和采样

本文介绍一种给扩散模型添加额外条件信息的方法。具体地,将在MNIST数据集上训练一个以类别为条件的扩散模型。并且可以在推理阶段指定想要生成的是哪个数字。

1. 配置和数据准备

首先安装diffusers库:

!pip install -q diffusers

导入相关依赖包:
导入依赖包
加载MNIST数据集:

# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(
    root="./mnist/", 
    train=True, 
    download=True, 
    transform=torchvision.transforms.ToTensor()
    )
# 创建数据加载器
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# 查看MNIST数据集中的部分样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

部分样本

2. 创建一个以类别为条件的UNet模型

输入类别这一条件的流程:
(1)创建一个标准的UNet2DModel加入一些额外的输入通道
(2)通过一个嵌入层,把类别标签映射到一个长度为class_emb_size的特征向量上。
(3)把这个信息作为额外通道和原有的输入向量拼接起来。

net_input = torch.cat((x, class_cond), 1)

(4)将net_input(其中包含class_emb_size + 1个通道)输入UNet模型,得到最终的预测结果。

这里,class_emb_size被设置成4,但它其实是可以进行任意修改的,或者把需要学到的nn.Embedding替换成简单地对类别进行one-hot编码,代码如下:

class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()
    # 这个网络层会把数字所属的类别映射到一个长度为class_emb_size的特征向量上
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # self.model是一个不带生成条件的UNet模型,这里,给他添加了额外的输入通道,用于接收条件信息
    self.model = UNet2DModel(
        sample_size = 28, # 所生成图片的尺寸
        in_channels = 1 + class_emb_size, # 加入额外的输入通道
        out_channels = 1, # 输出结果的通道数
        layers_per_block=2, # 残差层个数
        block_out_channels=(32, 64, 64),
        down_block_types=(
            "DownBlock2D", # 常规的ResNet下采样模块
            "AttnDownBlock2D", # 含有spatial self-attention的ResNet下采样模块
            "AttnDownBlock2D", 
        ),
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D", # 含有spatil self-attention的ResNet上采样模块
            "UpBlock2D",  # 上采样模块
        ),
    )
  
  # 此时扩散模型的前向计算就会含有额外的类别标签作为输入了
  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape
    # 类别条件将会以额外通道的形式输入
    class_cond = self.class_emb(class_labels)  # 将类别映射为向量形式,
    # 并扩展成类似于(bs, 4, 28, 28)的张量形式
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # 将原始输入和类别条件信息拼接到一起
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
    # 使用模型进行预测
    return self.model(net_input, t).sample # (bs, 1, 28, 28)

3. 训练和采样

这里使用prediction = unet(x, t, y)在训练时把正确的标签作为第三个输入发送给模型。如果一切正常,模型将会输出与之相匹配的图片。y在这里的范围是0~9.

# 创建一个调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

# 定义数据加载器
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

n_epochs = 10
loss_fn = nn.MSELoss()
net = ClassConditionedUnet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 
losses = []

# 训练开始
for epoch in range(n_epochs):
  for x, y in tqdm(train_dataloader):
    # 获取数据并添加噪声
    x = x.to(device) * 2 - 1 # 数据被归一化到区间(-1, 1)
    y = y.to(device)
    noise = torch.randn_like(x)
    timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
    noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

    # 预测
    pred = net(noisy_x, timesteps, y)  # 注意这里输入了类别信息
    # 计算损失值
    loss = loss_fn(pred, noise) 
    opt.zero_grad()
    loss.backward()
    opt.step()

    losses.append(loss.item())
  
  avg_loss = sum(losses[-100:])/100
  print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')
plt.plot(losses)

Finished epoch 0. Average of the last 100 loss values: 0.053393
Finished epoch 1. Average of the last 100 loss values: 0.047172
Finished epoch 2. Average of the last 100 loss values: 0.045227
Finished epoch 3. Average of the last 100 loss values: 0.043402
Finished epoch 4. Average of the last 100 loss values: 0.041524
Finished epoch 5. Average of the last 100 loss values: 0.040847
Finished epoch 6. Average of the last 100 loss values: 0.040252
Finished epoch 7. Average of the last 100 loss values: 0.040134
Finished epoch 8. Average of the last 100 loss values: 0.038976
Finished epoch 9. Average of the last 100 loss values: 0.039234
损失曲线:
loss
训练结束后,可以通过输入不同的标签作为条件来采样图片:

# 准备一个随机噪声作为起点,并准备想要的图片标签
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)
print(y)
# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
  with torch.no_grad():
    residual = net(x, t, y)
  x = noise_scheduler.step(residual, t, x).prev_sample

# 显示结果
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')

这里,我们的y标签为:

tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
        6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
        9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0'

因此对应生成的图片为:
生长指定标签的图片
至此,已经实现了对输出图片的控制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1195057.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Conda executable is not found 三种问题解决

如果在PyCharm中配置Python解释器时显示“conda executable is not found”错误消息,这意味着PyCharm无法找到您的Conda可执行文件。您可以按照以下步骤解决此问题: 1.方法一 确认Conda已正确安装。请确保您已经正确安装了Anaconda或Miniconda&#xff…

数字化工厂管理系统的三个关键技术是什么

随着科技的飞速发展,数字化工厂管理系统已经成为了现代制造业的重要发展方向。数字化工厂管理系统通过充分运用建模技术、仿真技术和单一数据源技术,实现了产品设计和生产的虚拟化,为制造业带来了前所未有的效率和创新能力。本文将深入探讨这…

Matlab的多项式留数与极点的计算

Matlab的多项式留数与极点的计算 以下面的多项式为例: 运算代码: clc clear closesyms p % 定义多项式 Zp(5*p^571*p^370*p)/(2*p^635*p^4117*p^236); % 提取分子与分母 [I,D]numden(Zp); Idouble(coeffs(I,p,"All"));%分子 Ddouble(coeffs…

轻量日志管理方案-[EFK]

使用FileBeat进行日志文件的数据收集,并发送到ES进行存储,最后Kibana进行查看展示; 这个应该是最简单,轻量的日志收集方案了。 最总方案为:FileBeatESKibana ; 【Kibana过于强大,感觉可以无限扩展】 文章目…

css:文本对齐属性vertical-align实现化学元素上标下标的显示

文档 https://developer.mozilla.org/zh-CN/docs/Web/CSS/vertical-align 语法 vertical-align: <value>;可选值&#xff1a; sub&#xff1a;使元素的基线与父元素的下标基线对齐。 super&#xff1a;使元素的基线与父元素的上标基线对齐。 text-top&#xff1a;使…

24张宇八套卷数一复盘(六)

张八&#xff08;六&#xff09;11/10107选择45填空20高数大题22线代大题12概率大题8 前言 临近考试冲刺阶段&#xff0c;感觉做过的卷子很难再提起精神去复盘&#xff0c;于是在这里进行一下复盘。 主要是对于整体试卷结构的把握&#xff0c;以及考试状态的复盘。 简单的卷子把…

说说React render方法的原理?在什么时候会被触发?

一、原理 首先&#xff0c;render函数在react中有两种形式&#xff1a; 在类组件中&#xff0c;指的是render方法&#xff1a; class Foo extends React.Component { render() { return <h1> Foo </h1>; } } 在函数组件中&#xff0c;指的是函…

虹科分享 | 2023温控生物技术和医药物流前景展望专题报告

2023温控生物技术和医药物流前景展望专题报告 全球供应链正在发生根本性的变化&#xff0c;而制药业对供应链的使用也在不断发展。突破性疗法和个性化药品有望带来崭新的未来&#xff0c;这也改变了我们如今的行医方式。然而&#xff0c;在监管和基础设施方面还面临着许多挑战…

常见排序算法之插入排序类

插入排序&#xff0c;是一种简单直观的排序算法&#xff0c;工作原理是将一个记录插入到已经排好序的有序表中&#xff0c;从而形成一个新的、记录数增1的有序表。在实现过程中&#xff0c;它使用双层循环&#xff0c;外层循环对除了第一个元素之外的所有元素&#xff0c;内层循…

【canvas】在Vue3+ts中实现 canva内的矩形拖动操作。

前言 canvas内的显示内容如何拖动&#xff1f; 这里提供一个 canvas内矩形移动的解决思路。 描述 如何选中canvas里的某部分矩形内容&#xff0c;然后进行拖动&#xff1f; 我的解决思路&#xff1a; **画布搭建。**用一个div将canvas元素包裹&#xff0c;设置宽高&#xf…

漏洞扫描-nuclei-poc编写

0x00 nuclei Nuclei是一款基于YAML语法模板的开发的定制化快速漏洞扫描器。它使用Go语言开发&#xff0c;具有很强的可配置性、可扩展性和易用性。 提供TCP、DNS、HTTP、FILE 等各类协议的扫描&#xff0c;通过强大且灵活的模板&#xff0c;可以使用Nuclei模拟各种安全检查。 …

vue+iView实现下载zip文件导出多个excel表格

1&#xff0c;需求&#xff1a;在vue项目中&#xff0c;实现分月份导出多个Excel表格。 点击导出&#xff0c;下载zip文件&#xff0c;解压出多张表数据。 2&#xff0c;关键代码&#xff1a; <Button class"export button-style button-space" click"ex…

MPC-模型预测控制笔记

线性mpc 凸优化 二次优化问题 1&#xff1a;建立预测模型 2&#xff1a;问题模型 3&#xff1a;求解优化问题 4&#xff1a;得到的优化控制驱动系统 上述方法与qp解一样 硬约束 硬约束 四组约束条件 二次规划求解 matlab代码&#xff1a; 软约束 可以用指数函数 加入…

Python爬虫抓取微博数据及热度预测

首先我们需要安装 requests 和 BeautifulSoup 库&#xff0c;可以使用以下命令进行安装&#xff1a; pip install requests pip install beautifulsoup4然后&#xff0c;我们需要导入 requests 和 BeautifulSoup 库&#xff1a; import requests from bs4 import BeautifulSou…

csv文件导入mysql指定表中

csv文件导入mysql指定表中 mysql数据库准备指定表 准备导入的csv数据如下&#xff1a; sepaLengthsepalWidthpetalLengthpetalWidthlabel5.13.51.40.204.931.40.204.73.21.30.20…………… 准备导入的数据为151行5列的数据&#xff0c;其中第一行为标题行。 因此&#xff0…

什么是Node.js的调试器(debugger)工具?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

演示文稿制作软件 Deckset mac中文版介绍

Deckset mac是一款Mac上的演示文稿制作软件&#xff0c;它可以让你使用Markdown语言快速地创建演示文稿。与传统的演示文稿制作软件相比&#xff0c;Deckset采用了全新的设计理念&#xff0c;旨在让用户更加专注于内容的创作&#xff0c;而不是花费过多的时间在排版和设计上。 …

[100天算法】-颜色分类(day 69)

题目描述 给定一个包含红色、白色和蓝色&#xff0c;一共 n 个元素的数组&#xff0c;原地对它们进行排序&#xff0c;使得相同颜色的元素相邻&#xff0c;并按照红色、白色、蓝色顺序排列。此题中&#xff0c;我们使用整数 0、 1 和 2 分别表示红色、白色和蓝色。注意: 不能使…

LeetCode(4)删除有序数组中的重复项 II【数组/字符串】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 80. 删除有序数组中的重复项 II 1.题目 给你一个有序数组 nums &#xff0c;请你** 原地** 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次 &#xff0c;返回删除后数组的新长度。 不要使用额外的数…

基恩士软件的基本操作(一)

今天就来学习基恩士软件的基础操作&#xff0c;欢迎大家的指正&#xff01;&#xff01;&#xff01; 基本操作 KV STUDIO 基恩士编程软件的名称就KV STUDIO。安装软件地址KV STUDIO的安装与实践 项目的创建 1&#xff0c;双击KV STUDIO. 2&#xff0c;新建项目 单元编辑器…