扩散模型实战(四):从零构建扩散模型

news2024/9/24 13:17:01

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

本文以MNIST数据集为例,从零构建扩散模型,具体会涉及到如下知识点:

  • 退化过程(向数据中添加噪声)
  • 构建一个简单的UNet模型
  • 训练扩散模型
  • 采样过程分析

下面介绍具体的实现过程:

一、环境配置&python包的导入
     

最好有GPU环境,比如公司的GPU集群或者Google Colab,下面是代码实现:

# 安装diffusers库!pip install -q diffusers# 导入所需要的包import torchimport torchvisionfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom diffusers import DDPMScheduler, UNet2DModelfrom matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')
# 输出Using device: cuda

此时会输出运行环境是GPU还是CPU

载MNIST数据集

       MNIST数据集是一个小数据集,存储的是0-9手写数字字体,每张图片都28X28的灰度图片,每个像素的取值范围是[0,1],下面加载该数据集,并展示部分数据:

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)x, y = next(iter(train_dataloader))print('Input shape:', x.shape)print('Labels:', y)plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 输出Input shape: torch.Size([8, 1, 28, 28])Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

扩散模型的退化过程

       所谓退化过程,其实就是对输入数据加入噪声的过程,由于MNIST数据集的像素范围在[0,1],那么我们加入噪声也需要保持在相同的范围,这样我们可以很容易的把输入数据与噪声进行混合,代码如下:

def corrupt(x, amount):  """Corrupt the input `x` by mixing it with noise according to `amount`"""  noise = torch.rand_like(x)  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works  return x*(1-amount) + noise*amount

接下来,我们看一下逐步加噪的效果,代码如下:

# Plotting the input datafig, axs = plt.subplots(2, 1, figsize=(12, 5))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')# Adding noiseamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Plottinf the noised versionaxs[1].set_title('Corrupted data (-- amount increases -->)')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

       从上图可以看出,从左到右加入的噪声逐步增多,当噪声量接近1时,数据看起来像纯粹的随机噪声。

构建一个简单的UNet模型

       UNet模型与自编码器有异曲同工之妙,UNet最初是用于完成医学图像中分割任务的,网络结构如下所示:

代码如下:

class BasicUNet(nn.Module):    """A minimal UNet implementation."""    def __init__(self, in_channels=1, out_channels=1):        super().__init__()        self.down_layers = torch.nn.ModuleList([             nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),            nn.Conv2d(32, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 64, kernel_size=5, padding=2),        ])        self.up_layers = torch.nn.ModuleList([            nn.Conv2d(64, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 32, kernel_size=5, padding=2),            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),         ])        self.act = nn.SiLU() # The activation function        self.downscale = nn.MaxPool2d(2)        self.upscale = nn.Upsample(scale_factor=2)    def forward(self, x):        h = []        for i, l in enumerate(self.down_layers):            x = self.act(l(x)) # Through the layer and the activation function            if i < 2: # For all but the third (final) down layer:              h.append(x) # Storing output for skip connection              x = self.downscale(x) # Downscale ready for the next layer                      for i, l in enumerate(self.up_layers):            if i > 0: # For all except the first up layer              x = self.upscale(x) # Upscale              x += h.pop() # Fetching stored output (skip connection)            x = self.act(l(x)) # Through the layer and the activation function                    return x

我们来检验一下模型输入输出的shape变化是否符合预期,代码如下:

net = BasicUNet()x = torch.rand(8, 1, 28, 28)net(x).shape
# 输出torch.Size([8, 1, 28, 28])

再来看一下模型的参数量,代码如下:

sum([p.numel() for p in net.parameters()])
# 输出309057

至此,已经完成数据加载和UNet模型构建,当然UNet模型的结构可以有不同的设计。

、扩散模型训练

        扩散模型应该学习什么?其实有很多不同的目标,比如学习噪声,我们先以一个简单的例子开始,输入数据为带噪声的MNIST数据,扩散模型应该输出对应的最佳数字预测,因此学习的目标是预测值与真实值的MSE,训练代码如下:

# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = BasicUNet()net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x)        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# View the loss curveplt.plot(losses)plt.ylim(0, 0.1);
# 输出Finished epoch 0. Average loss for this epoch: 0.024689Finished epoch 1. Average loss for this epoch: 0.019226Finished epoch 2. Average loss for this epoch: 0.017939

训练过程的loss曲线如下图所示:

六、扩散模型效果评估

我们选取一部分数据来评估一下模型的预测效果,代码如下:

#@markdown Visualizing model predictions on noisy inputs:# Fetch some datax, y = next(iter(train_dataloader))x = x[:8] # Only using the first 8 for easy plotting# Corrupt with a range of amountsamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Get the model predictionswith torch.no_grad():  preds = net(noised_x.to(device)).detach().cpu()# Plotfig, axs = plt.subplots(3, 1, figsize=(12, 7))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')axs[2].set_title('Network Predictions')axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

从上图可以看出,对于噪声量较低的输入,模型的预测效果是很不错的,当amount=1时,模型的输出接近整个数据集的均值,这正是扩散模型的工作原理。

Note:我们的训练并不太充分,读者可以尝试不同的超参数来优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/890420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux】POSIX信号量和基于环形队列的生产消费者模型

目录 写在前面的话 什么是POSIX信号量 POSIX信号量的使用 基于环形队列的生产消费者模型 写在前面的话 本文章主要先介绍POSIX信号量&#xff0c;以及一些接口的使用&#xff0c;然后再编码设计一个基于环形队列的生产消费者模型来使用这些接口。 讲解POSIX信号量时&#x…

阿里云服务器镜像大全_Linux和Windows操作系统清单

阿里云服务器操作系统大全&#xff0c;阿里云提供的镜像均为正版授权&#xff0c;正版镜像可以在云服务器ECS上运行的应用程序提供安全、稳定的运行环境系统&#xff0c;阿里云服务器以公共镜像为例分享阿里云服务器操作系统大全&#xff0c;包括Alibaba Cloud Linux镜像、Linu…

洁净室气流流型分类、检测及相关法规要求概览

洁净室气流形式可以分为单向流或非单向流两种。如果综合利用两种气流&#xff0c;通常叫做混合气流。单向流可以是垂直的或水平的。两种单向流都以最终过滤的送风和回风入口&#xff0c;它们几乎是相对设置&#xff0c;这样才能使气流的流动形式保持尽可能的呈直线状。确认气流…

java全局捕捉异常并返回统一返回值

1. 首先定义全局异常处理类 import org.springframework.validation.BindException; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.…

vcsa6.7 vsan超融合,虚拟机突然识别不到加密狗

问题现象&#xff1a; vcsa6.7 vsan超融合&#xff0c;虚拟机突然识别不到加密狗&#xff0c;加密狗直通给虚拟机&#xff0c;打开设备管理器&#xff0c;有未知的usb设备&#xff0c;重启虚拟机&#xff0c;卸载tools&#xff0c;重新安装tools&#xff0c;都不行 问题解决…

面向对象编程(OOP):Python中的抽象与封装

文章目录 &#x1f340;引言&#x1f340; 类与对象&#x1f340;封装&#x1f340;继承&#x1f340;多态&#x1f340;面向对象编程的优势&#x1f340;使用面向对象编程的场景&#x1f340;实例化与构造函数&#x1f340; 成员属性和类属性&#x1f340;魔术方法&#x1f34…

opencv基础:几个常用窗口方法

开始说了一些opencv中的一些常用方法。 namedWindow方法 在OpenCV中&#xff0c;namedWindow函数用于创建一个窗口&#xff0c;并给它指定一个名字。这个函数的基本语法如下&#xff1a; import cv2cv2.namedWindow(窗口名称, 标识 )窗口名称&#xff1a;其实窗口名称&…

SqlServer Management工具格式化代码、美化SQL

1、下载simple but powerful SQL formatter 插件下载地址&#xff1a;地址 2、安装 点击next—> finish 3、重新打开SqlServer Management即可看到 SQL Beautifier 点击Format ALL SQL 或者选中sql点击Format Selected SQL即可

Linux的基本权限(文件,目录)

文章目录 前言一、Linux权限的概念二、Linux权限管理 1.文件访问者分类2.文件类型和访问类型3.文件访问权限的相关设置方法三、目录的权限四、权限的总结 前言 Linux下一切皆文件&#xff0c;指令的本质就是可执行文件&#xff0c;直接安装到了系统的某种路径下 一、Linux权限的…

交换实验一

题目 交换机上接口配置 SW1 interface GigabitEthernet0/0/1 port hybrid tagged vlan 2 port hybrid untagged vlan 3 to 6 interface Ethernet0/0/2 port hybrid pvid vlan 3 port hybrid untagged vlan 2 to 6 interface Ethernet0/0/3 port link-type access port d…

【NAS群晖drive异地访问】使用cpolar远程访问内网Synology Drive「内网穿透」

文章目录 前言1.群晖Synology Drive套件的安装1.1 安装Synology Drive套件1.2 设置Synology Drive套件1.3 局域网内电脑测试和使用 2.使用cpolar远程访问内网Synology Drive2.1 Cpolar云端设置2.2 Cpolar本地设置2.3 测试和使用 3. 结语 前言 群晖作为专业的数据存储中心&…

第 3 章 稀疏数组和队列(2)

3.2 队列 3.2.1队列的一个使用场景 银行排队的案例: 3.2.2队列介绍 队列是一个有序列表&#xff0c;可以用数组或是链表来实现。遵循先入先出的原则。即:先存入队列的数据&#xff0c;要先取出。后存入的要后取出示意图:(使用数组模拟队列示意图) 3.2.3数组模拟队列思路…

[RDMA] 高性能异步的消息传递和RPC :Accelio

1. Introduce Accelio是一个高性能异步的可靠消息传递和RPC库&#xff0c;能优化硬件加速。 RDMA和TCP / IP传输被实现&#xff0c;并且其他的传输也能被实现&#xff0c;如共享存储器可以利用这个高效和方便的API的优点。Accelio 是 Mellanox 公司的RDMA中间件&#xff0c;用…

苹果支付的实现

由于app经常需要用到支付功能&#xff0c;然而ios用户&#xff0c;是无法用支付宝、微信进行支付&#xff0c;这时候只能用到苹果支付。苹果支付是苹果公司推出的一种在线支付方式&#xff0c;用户可以通过苹果支付购买应用、内购道具等等。 原理 苹果支付的实现原理是通过在…

VHDL记录

文章目录 使用function名称作为“常量”numeric_std包集中使用乘法的注意项variable的使用对于entity设置属性的方法在entity声明中嵌入function的定义VHDL仿真读写文件File declaration/File handingFile readingFile writing小例子 使用函数 模块中打印出调试信息 使用functi…

(搜索) 剑指 Offer 12. 矩阵中的路径 ——【Leetcode每日一题】

❓剑指 Offer 12. 矩阵中的路径 难度&#xff1a;中等 给定一个 m * n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 单词必须按照字母顺序&#xff0c;通过相邻的单元格内的字母构…

.netcore grpc客户端工厂及依赖注入使用

一、客户端工厂概述 gRPC 与 HttpClientFactory 的集成提供了一种创建 gRPC 客户端的集中方式。可以通过依赖包Grpc.Net.ClientFactory中的AddGrpcClient进行gRPC客户端依赖注入AddGrpcClient函数提供了许多配置项用于处理一些其他事项&#xff1b;例如AOP、重试策略等 二、案…

CS5523规格书|MIPI转EDP方案设计|替代LT8911芯片电路原理|ASL集睿致远CS替代龙讯

ASL芯片&#xff08;集睿致远&#xff09; CS5523是一款MIPI DSI输入&#xff0c;DP/e DP输出转换芯片&#xff0c;可pin to pin替代LT8911龙讯芯片。 MIPI DSI 最多支持 4 个通道&#xff0c;每个通道的最大运行速度为 1.5Gps。对于DP 1.2输出&#xff0c;它支持1.62Gbps和2.…

一文揭露AI聊天机器人到底是怎么实现自助应答的

现在很多的企业都会使用客服系统&#xff0c;主要是想通过它们来解决企业的一些问题和需求。所有就衍生了——AI聊天机器人这个新工具&#xff0c;它是把AI人工智能运用到客户服务当中&#xff0c;让AI来帮助我们完成一些解答客户问题的操作。下面我们就来说一下AI聊天机器人是…

创建和使用分区

创建和使用分区 创建一个名为 /home/curtis/ansible/partition.yml 的 playbook&#xff1a; 该palybook包含一个paly&#xff0c;该paly在balancers主机组的主机上运行&#xff1a; 在设备vdb上创建单个主分区&#xff0c;编号为1&#xff0c;大小为1500MiB 使用ext4文件系统…