扩散模型实战(六):Diffusers DDPM初探

news2024/10/6 16:18:35

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

       之前的五篇文章主要是为了解释扩散模型的基本概念和流程,使读者更容易理解扩散模型的工作原理,但与实际工作中使用的模型差异较大,从本文开始,我们将初步使用DDPM模型的开源实现库Diffusers,在Diffusers库中DDPM模型的实现库是UNet2DModel

UNet2DModel模型实战

UNet2DModel模型比之前介绍的BasicUNet模型有一些改进,具体如下:

  • 退化过程的处理方式不同,UNet2DModel通过调节时间步来调节噪声量,t作为一个额外参数被传入前向过程;
  • 训练目标不同,UNet2DModel旨在预测不带缩放系数的噪声(也就是单位正太分布的噪声)而不是”去噪“的图像;
  • UNet2DModel有更多的采样策略可供选择;

下面我们来看一下UNet2DModel的模型参数以及结构,代码如下:

model = UNet2DModel(    sample_size=28,            # 目标图像的分辨率    in_channels=1,         # 输入图像的通道数,RGB图像的通道数为3    out_channels=1,        # 输出图像的通道数    layers_per_block=2,    # 设置要在每一个UNet块中使用多少个ResNet层    block_out_channels=(32, 64, 64), # 与BasicUNet模型的配置基本相同    down_block_types=(         "DownBlock2D",      # 标准的ResNet下采样模块        "AttnDownBlock2D",  # 带有空域维度self-att的ResNet下采样模块        "AttnDownBlock2D",    ),     up_block_types=(        "AttnUpBlock2D",         "AttnUpBlock2D",    # 带有空域维度self-att的ResNet上采样模块        "UpBlock2D",        # 标准的ResNet上采样模块       ),) # 输出模型结构(看起来虽然冗长,但非常清晰)print(model)

我们继续来查看一下UNet2DModel模型的参数量,代码如下:

sum([p.numel() for p in model.parameters()]) # UNet2DModel模型使用了大约170万个参数,BasicUNet模型则使用了30多万个参数
# 输出1707009

       下面是我们使用UNet2DModel代替BasicUNet模型,重复前面展示的训练以及采样过程(这里t=0,以表明模型是在没有时间步的情况下训练的),完整的代码如下:

#@markdown Trying UNet2DModel instead of BasicUNet:# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = UNet2DModel(    sample_size=28,  # the target image resolution    in_channels=1,  # the number of input channels, 3 for RGB images    out_channels=1,  # the number of output channels    layers_per_block=2,  # how many ResNet layers to use per UNet block    block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example    down_block_types=(        "DownBlock2D",  # a regular ResNet downsampling block        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention        "AttnDownBlock2D",    ),    up_block_types=(        "AttnUpBlock2D",        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention        "UpBlock2D",   # a regular ResNet upsampling block      ),) #<<<net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3)# Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# Plot losses and some samplesfig, axs = plt.subplots(1, 2, figsize=(12, 5))# Lossesaxs[0].plot(losses)axs[0].set_ylim(0, 0.1)axs[0].set_title('Loss over time')# Samplesn_steps = 40x = torch.rand(64, 1, 28, 28).to(device)for i in range(n_steps):  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low  with torch.no_grad():    pred = net(x, 0).sample  mix_factor = 1/(n_steps - i)  x = x*(1-mix_factor) + pred*mix_factoraxs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Generated Samples');
# 输出Finished epoch 0. Average loss for this epoch: 0.020033Finished epoch 1. Average loss for this epoch: 0.013243Finished epoch 2. Average loss for this epoch: 0.011795

可以看出,比BasicUNet网络生成的结果要好一些。

DDPM原理

论文名称:《Denoising Diffusion Probabilistic Models》

论文地址:https://arxiv.org/pdf/2006.11239.pdf

      下面是DDPM论文中的公式,Training步骤其实是退化过程,给原始图像逐渐添加噪声的过程,预测目标是拟合每个时间步的采样噪声。

       还有一点非常重要:我们都知道在前向过程中是不断添加噪声的,其实这个噪声的系数不是固定的,而是与时间t线性增加的(也成为扩散率),这样的好处是在后向过程开始过程先把"明显"的噪声给去除,对应着较大的扩散率;当去到一定程度,逐渐逼近真实真实图像的时候,去噪速率逐渐减慢,开始微调,也就是对应着较小的扩散率。

下面我们使用代码来看一下输入数据与噪声在不同迭代周期的变化:

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${ \sqrt{\bar{\alpha}_t}}$")plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5,  label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")plt.legend(fontsize="x-large");

生成的结果,如下图所示:

       下面我们来看一下,噪声系数不变与DDPM中的噪声方式在MNIST数据集上的加噪效果:

# 可视化:DDPM加噪过程中的不同时间步# 对一批图片加噪,看看效果fig, axs = plt.subplots(3, 1, figsize=(16, 10))xb, yb = next(iter(train_dataloader))xb = xb.to(device)[:8]xb = xb * 2. - 1. # 映射到(-1,1)print('X shape', xb.shape) # 展示干净的原始输入axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().    cpu(), cmap='Greys')axs[0].set_title('Clean X') # 使用调度器加噪timesteps = torch.linspace(0, 999, 8).long().to(device)noise = torch.randn_like(xb) # <<注意是使用randn而不是randnoisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)print('Noisy X shape', noisy_xb.shape) # 展示“带噪”版本(使用或不使用截断函数clipping)axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap='Greys')axs[1].set_title('Noisy X (clipped to (-1, 1))')axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].   detach().cpu(), cmap='Greys')axs[2].set_title('Noisy X');X shape torch.Size([8, 1, 28, 28])Noisy X shape torch.Size([8, 1, 28, 28])

结果如下图所示:

采样补充

       采样在扩散模型中扮演非常重要的角色,我们可以输入纯噪声,然后期待模型能一步输出不带噪声的图像吗?根据前面的所学内容,这显然行不通。那么针对采样会有哪些改进的思路呢?

  • 可以使用模型多预测几次,以通过估计一个更高阶的梯度来更新得到更准确的结果(更高阶的方法和一些离散的ODE处理器);
  • 保留一些历史的预测值来尝试指导当前步的更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/928470.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《Zookeeper》源码分析(二十一)之 客户端启动过程

目录 ZooKeeperMain数据结构初始化启动1. 解析启动参数MyCommandOptions数据结构构造参数 parseOptions() 2. 创建客户端实例3. 处理客户端命令1.解析命令字符串2. 处理命令 ZooKeeperMain 客户端的启动类为ZooKeeperMain 数据结构 commandMap&#xff1a;存放zookeeper支持的…

什么是住宅ip,静态和动态怎么选?

上文我们介绍了数据中心代理&#xff0c;这次我们来介绍下住宅代理ip&#xff0c;住宅代理ip分类两种类型&#xff1a;静态住宅代理和动态住宅代理&#xff0c;他们有什么区别又能用在什么场景呢&#xff1f;我们先从他们是如何运作开始。 一、什么是住宅代理ip isp住宅代理i…

GPIO口输出-点亮LED灯

前言 &#xff08;1&#xff09;本系列是基于STM32的项目笔记&#xff0c;内容涵盖了STM32各种外设的使用&#xff0c;由浅入深。 &#xff08;2&#xff09;小编使用的单片机是STM32F105RCT6&#xff0c;项目笔记基于小编的实际项目&#xff0c;但是博客中的内容适用于各种单片…

大二总结,记录下自己的收获。

一&#xff1a;从大一到大二结束每个学期的学习时间 二&#xff1a;成长历程 三&#xff1a;学习知识 3.1&#xff1a;大一学习知识 3.1&#xff1a;大二学习知识 四&#xff1a;接下来的路 学习时间…

IT运维软件的费用是多少?

正常一套IT运维软件费用一般在5千-50万之间不等&#xff0c;而且分为一次性付费或年付费模式&#xff0c;付费方式导致的价格也不同。 正常情况下IT运维软件的具体价格&#xff0c;是需要根据企业的实际需求来进行综合评估&#xff0c;一般来说&#xff0c;影响具体价格费用有以…

【每日易题】数组下标的逆天用法——你见过把数组存储的值当作数组下标来解题的吗?

君兮_的个人主页 勤时当勉励 岁月不待人 C/C 游戏开发 Hello,米娜桑们&#xff0c;这里是君兮_&#xff0c;在最近是刷题中&#xff0c;遇到了一种非常新奇的数组下标的用法&#xff0c;今天想来给大家分享一下这种神奇的思路和方法&#xff0c;希望能在你遇到类似问题时能通…

《剑指Offer》模块4 栈和队列

栈和队列 1. 用两个栈实现队列 原题链接 补充&#xff1a;copy(a,b) 把a赋值给b class MyQueue { public:/** Initialize your data structure here. */stack<int> stk, cache;MyQueue() {}/** Push element x to the back of queue. */void push(int x) {stk.push(x)…

“分布式”与“集群”初学者的技术总结

一、“分布式”与“集群”的解释&#xff1a; 分布式&#xff1a;把一个囊肿的系统分成无数个单独可运行的功能模块 集群&#xff1a; 把相同的项目复制进行多次部署&#xff08;可以是一台服务器多次部署&#xff0c;例如使用8080部署一个&#xff0c;8081部署一个&#xff0c…

vue拖拽div盒子实现上下拖动互换

vue拖拽div盒子实现上下拖动互换 <div v-for"(item, index) in formList" :key"index" draggable"true"dragstart"handleDragStart($event, item)"dragenter"handleDragEnter($event, item)"dragover.prevent"han…

如何在vscode导入下载的插件安装包

点击vscode插件 --> 点击3个点 --> 选择从VSIX安装 点击更新报 Cannot update while running on a read-only volume. The application is on a read-only volume. Please move the application and try again. If you’re on macOS Sierra or later, you’ll need to m…

清吧全面解析,从此不再困惑

清吧&#xff08;Bar&#xff09;也叫休闲酒吧&#xff0c;是以轻音乐为主、比较安静的酒吧&#xff0c;比较适合和朋友一起谈天说地、喝酒聊天。清吧的装修风格偏向营造氛围&#xff0c;不如其他酒吧炫目。通常清吧这一类的酒吧不提供食品&#xff0c;仅提供酒水和饮料。通常清…

性价比神机!南卡新品OE CC开放式耳机上线,彻底把门焊死了!

开放式耳机,作为今年大热的产品,以其高舒适性在蓝牙耳机市场大放异彩,但是同时用户也反应出存在音质不足、漏音、通话质量差等多项问题&#xff0c;最重要的还是成熟的开放式产品价格偏贵&#xff0c;导致入门门槛相对较高。针对这些问题,开放式音频专业品牌南卡则是以产品力革…

keepalive+haproxy实现高可用

1&#xff0c;两台主机安装keepalived 配置keepalived 安装haproxy make PREFIX/usr/local/haproxy TARGETlinux2628 make install PREFIX/usr/local/haproxy 创建配置文件 配置haproxy vim /etc/haproxy/haproxy.cfg 添加为系统服务 cp /root/haproxy-1.7.2/examples/hapro…

解决抖音semi-ui的Input无法获取到onChange事件

最近在使用semi-ui框架的Input实现一个上传文件功能时遇到了坑&#xff0c;就是无法获取到onChange事件&#xff0c;通过console查看只是拿到了一个文件名。但若是把<Input>换成原生的<input>&#xff0c;就可以正常获取到事件。仔细看了下官方文档&#xff0c;发现…

【Linux】动态库和静态库

动态库和静态库 软链接硬链接硬链接要注意 自定义实现一个静态库(.a)解决、使用方法静态库的内部加载过程 自定义实现一个动态库&#xff08;.so&#xff09;动态库加载过程 静态库和动态库的特点 软链接 命令:ln -s 源文件名 目标文件名 软链接是独立连接文件的&#xff0c;他…

10.Redis数据结构之跳表

sortedSet sortedSet是Redis提供的一个非常特别的数据结构&#xff0c;常用作排行榜等功能&#xff0c;将用户id作为value&#xff0c;关注时间或者分数作为score进行排序。 与其他数据结构相似&#xff0c;zset也有两种不同的实现&#xff0c;分别是zipList和(hashskipList)。…

打造跨境电商新引擎:揭秘跨境电商系统商城软件平台源码

跨境电商系统商城软件平台的意义与需求 跨境电商正成为全球贸易发展的重要驱动力&#xff0c;然而&#xff0c;建立和运营一个成功的跨境电商平台并非易事。在这个过程中&#xff0c;跨境电商系统商城软件平台的作用日益凸显。 跨境电商系统商城软件平台源码是构建一个完整、高…

关于 ts 这一篇文章就够了

你好 文章目录 一、js 和 ts二、TypeScript的特点三、了解 ts , 爱上 ts &#x1f923; 一、js 和 ts 随着近几年前端领域的快速发展&#xff0c;JavaScript 迅速被普及和受广大开发者的喜爱&#xff0c;借助于 JavaScript 本身的强大&#xff0c;也让使用JavaScript开发的人员…

java八股文面试[JVM]——JVM内存结构2

知识来源&#xff1a; 【2023年面试】JVM内存模型如何分配的_哔哩哔哩_bilibili

PRACK消息

概述 PRACK消息是sip协议的扩展&#xff0c;在RFC3262中定义&#xff0c;标准的名称是sip协议中的可靠临时响应。 本文简单介绍标准中对PRACK消息流程的描述&#xff0c;以及fs配置PRACK的方式。 环境 centos&#xff1a;CentOS release 7.0 (Final)或以上版本 freeswitc…