【论文阅读】Improved Denoising Diffusion Probabilistic Models

news2024/11/19 0:47:47

Improved Denoising Diffusion Probabilistic Models

文章目录

引用: Nichol A Q, Dhariwal P. Improved denoising diffusion probabilistic models[C]//International conference on machine learning. PMLR, 2021: 8162-8171.

论文链接: https://arxiv.org/abs/2102.09672

代码链接: https://github.com/openai/improved-diffusion

概述

去噪扩散概率模型 (DDPM) 是一类生成模型,最近已被证明可以产生出色的样本。实验表明,通过一些简单的修改,DDPM还可以在保持高样品质量的同时实现竞争性的对数似然。为了更紧密地优化变分下界 (VLB),我们使用简单的重新参数化和混合学习目标来学习逆向过程方差,该目标将 VLB 与 Ho 等人[1]的简化目标相结合,允许采样前向传递减少一个数量级,样本质量差异可以忽略不计,这对于这些模型的实际部署非常重要。使用混合目标,模型获得了比直接优化对数似然获得的对数似然更好的对数似然,并发现后一个目标在训练过程中具有更多的梯度噪声。与混合目标相比,一个简单的重要性采样技术可以减少这种噪声,并能够获得更好的对数似然。此外,论文还使用精确度和召回率来比较 DDPM 和 GAN 对目标分布的覆盖程度。最后,我们表明,这些模型的样本质量和可能性可以随着模型容量和训练计算而平滑扩展,使其易于扩展。

Improving the Log-likelihood

虽然Ho等人[1]发现DDPM可以根据FID[2]和Inception Score[3]生成高保真样本,但他们无法通过这些模型实现竞争对数可能性。对数似然是生成建模中广泛使用的指标,人们普遍认为,优化对数似然会迫使生成模型捕获数据分布的所有模式。此外,最近的工作[4]表明,对数似然的微小改进可以对样本质量和学习的特征表示产生巨大影响。因此,重要的是要探讨为什么 DDPM 似乎在这个指标上表现不佳,因为这可能表明一个根本性的缺点,例如模式覆盖率差。

为了研究不同修改的影响,在ImageNet 64×64和CIFAR-10数据集上训练具有固定超参数的固定模型架构。虽然 CIFAR-10 在此类模型中的应用更多,但论文选择研究 ImageNet 64 × 64,因为它在多样性和分辨率之间提供了良好的权衡,能够快速训练模型而不必担心过度拟合。此外,ImageNet 64×64 已在生成建模的背景下进行了广泛研究,能够将 DDPM 直接与许多其他生成模型进行比较。

Ho等人[1]的设置(在设置 σ t 2 = β t σ^2_t = β_t σt2=βt T = 1000 T = 1000 T=1000 的同时优化 L s i m p l e L_{simple} Lsimple )在 200K 训练迭代后,在 ImageNet 64 × 64 64 × 64 64×64 上实现了 3.99 3.99 3.99 b i t s / d i m bits/dim bits/dim) 的对数似然。论文在早期的实验中发现,可以通过将 T T T 1000 1000 1000 增加到 4000 4000 4000 来提高对数似然;通过此更改,对数似然提高到 3.77 3.77 3.77

Learning ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)

在这里插入图片描述

Ho等人[1]将 ∑ θ ( x t , t ) = σ t 2 I \sum_{\theta}(x_{t}, t) = \sigma_{t}^{2}I θ(xt,t)=σt2I,其中 σ t σ_t σt 不是学习的。奇怪的是,他们发现将 σ t 2 σ^2_t σt2 固定到 β t β_t βt 产生的样品质量与将其固定到 β ~ t \tilde { \beta } _ { t } β~t 大致相同。考虑到 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 代表两个相反的极端,有理由问为什么这种选择不会影响样本。图 1 给出了一个线索,**它表明 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 几乎相等(除了接近 t = 0 t = 0 t=0),即模型正在处理难以察觉的细节。此外,随着扩散步骤数量的增加, β t β_t βt和β ̃t似乎在更多的扩散过程中彼此靠近。这表明,在无限扩散步骤的极限下, σ t σ_t σt的选择对样品质量可能完全无关紧要。换句话说,当添加更多的扩散步骤时,模型平均值 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t) ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)更能决定分布。虽然上述论点表明,为了样本质量,固定 σ t σ_t σt 是一个合理的选择,但它并没有说明对数似然性。事实上,图2显示,扩散过程的前几步对变分下限的贡献最大。因此,似乎可以通过使用更好的 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t) 选择来提高对数似然。为了实现这一目标,必须学习 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而不会遇到 Ho 等人遇到的不稳定性。

由于图 1 显示 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)的理想范围非常小,因此神经网络很难直接预测 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),即使在对数域中也是如此。相反,我们发现最好将方差参数化为在log域 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t之间的插值。具体而言,模型输出一个向量 v v v,每个维度包含一个分量,将此输出转换为方差,如下所示:

∑ θ ( x t , t ) = e x p ( v log ⁡ β t + ( 1 − v ) log ⁡ β ~ t ) \sum _ { \theta } ( x _ { t } , t ) = e x p ( v \log \beta _ { t } + ( 1 - v ) \log \tilde { \beta } _ { t } ) θ(xt,t)=exp(vlogβt+(1v)logβ~t)

没有对 v v v 施加任何约束,理论上允许模型预测插值范围之外的方差。由于 Lsimple 不依赖于 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),因此定义了一个新的混合目标:

L h y b r i d = L s i m p l e + λ L v l b L _ { h y b r i d } = L _ { s i m p l e } + \lambda L _ { v l b } Lhybrid=Lsimple+λLvlb

对于实验,设置 λ = 0.001 λ = 0.001 λ=0.001 以防止 L v l b L_{vlb} Lvlb 压倒 L s i m p l e L_{simple} Lsimple。按照同样的推理思路,还对 L v l b L_{vlb} Lvlb项的 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)输出应用了停止梯度。这样,$L_{vlb} $可以引导 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而 L s i m p l e L_{simple} Lsimple 仍然是影响 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)的主要来源。

在这里插入图片描述

Improving the Noise Schedule

虽然Ho等人中使用的线性噪声调度对于高分辨率图像效果良好,但对于分辨率为64×64和32×32的图像来说,它是次优的。特别地,前向噪声处理的末尾噪声太大,因此对样本质量没有太大贡献。这可以在图3中直观地看到。这种影响的结果在图4中进行了研究,当跳过高达20%的反向扩散过程时,用线性时间表训练的模型不会变得更糟(通过FID测量)。为了解决这个问题,根据 α t ˉ \bar { \alpha _ { t } } αtˉ构建了一个不同的噪声表:

α t ˉ = f ( t ) f ( 0 ) , f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) 2 \bar { \alpha _ { t } } = \frac { f ( t ) } { f ( 0 ) } , f ( t ) = \cos \left( \frac { t / T + s } { 1 + s } \cdot \frac { \pi } { 2 } \right) ^ { 2 } αtˉ=f(0)f(t),f(t)=cos(1+st/T+s2π)2

β t = 1 − α ‾ t α ‾ t − 1 \beta _ { t } = 1 - \frac { \overline { \alpha } _ { t } } { \overline { \alpha } _ { t - 1 } } βt=1αt1αt

在实践中,将 β t \beta_t βt 裁剪为不大于 0.999,以防止在扩散过程结束时接近 $t = T $的奇点。

在这里插入图片描述

提出的余弦时间表被设计为在过程中具有 α t ˉ \bar { \alpha _ { t } } αtˉ的线性下降,同时在$ t = 0 $和 t = T t = T t=T 的极端附近变化很小,以防止噪声水平的突然变化。图 5 显示了两个计划的 α α α进展情况。可以看到,线性时间表以更快的速度趋向于零,破坏信息的速度比必要的要快得多。使用较小的偏移量 s s s 来防止 β t β_t βt 在$ t = 0 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。 ∗ ∗ 特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。**特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。特别是,选择了 s ,使得 ,使得 ,使得\sqrt { \beta _ { 0 } }$略小于像素箱大小 1 / 127.5 1/127.5 1/127.5,因此 s = 0.008 s = 0.008 s=0.008。我们特别选择使用 c o s 2 cos^2 cos2,因为它是一个具有我们正在寻找的形状的通用数学函数。这种选择是任意的,我们预计许多其他具有类似形状的函数也可以使用。**

Reducing Gradient Noise

在这里插入图片描述

在这里插入图片描述

我们希望通过直接优化 L v l b L_{vlb} Lvlb 而不是优化 L h y b r i d L_{hybrid} Lhybrid 来实现最佳的对数似然。然而, L v l b L_{vlb} Lvlb在实践中实际上很难优化,至少在多样化的 ImageNet 64×64 数据集上是这样。图 6 显示了 $L_{vlb} $和 L h y b r i d L{hybrid} Lhybrid 的学习曲线。两条曲线都是嘈杂的,但在训练时间相同的情况下,混合目标显然在训练集上实现了更好的对数似然。通过评估使用两个目标训练的模型的梯度噪声标度证实了 L v l b L_{vlb} Lvlb 的梯度比 L h y b r i d L_{hybrid} Lhybrid 的梯度大得多,如图7所示。因此,我们寻找一种方法来减少 L v l b L_{vlb} Lvlb 的方差,以便直接优化对数似然性。注意到 L v l b L_{vlb} Lvlb的不同项具有很大差异的幅度(图 2),假设采样$ t $在 $L_{vlb} $中均匀地产生不必要的噪声。为了解决这个问题,采用了重要性抽样:

L v l b = E t ∼ p t [ L t p t ] , w h e r e p t ∝ E [ L t 2 ] a n d ∑ p t = 1 L_{vlb} = E_{ t \sim p_{t} } \left[ \frac { L_{t} } { p_{t} } \right] , where p_{t} \propto \sqrt { E \left[ L_{t} ^ {2} \right] } and \sum p_{t} = 1 Lvlb=Etpt[ptLt],whereptE[Lt2] andpt=1

由于 E [ L t 2 ] E \left[ L _ { t } ^ { 2 } \right] E[Lt2] 事先是未知的,并且可能在整个训练过程中发生变化,因此我们维护每个损失项的前 10 个值的历史记录,并在训练期间动态更新。在训练开始时,均匀地采样 t t t,直到为每个 $t ∈ [0, T −1] $抽取 10 个样本。有了这个重要性抽样目标,就能够通过优化 L v l b L_{vlb} Lvlb 来实现最佳的对数似然。如图 6 所示,即 L v l b L_{vlb} Lvlb(重采样)曲线。该图还显示,重要性采样物镜的噪声比原始的均匀采样要小得多。可以发现,在直接优化噪声较小的L_{{hybrid}时,重要性采样技术没有帮助。

Improving Sampling Speed

在这里插入图片描述

为了减少从 T T T K K K 的采样步骤数,使用$ K$ 个介于 1 1 1 T T T(含)之间的均匀分布的实数,然后将每个结果数字四舍五入到最接近的整数。在图 8 中,评估了使用 4000 扩散步骤,使用 25、50、100、200、400、1000 和 4000 个采样步骤训练的$ L_{hybrid}$ 模型和 L s i m p l e L_{simple} Lsimple 模型的 FID。.我们既针对训练有素的检查点,也针对培训中途的检查点。对于 CIFAR-10,使用了 200K 和 500K 的训练迭代,对于 ImageNet 64,使用了 500K 和 1500K 的训练迭代。可以发现,当使用较少的采样步骤时,具有固定sigmas的 L s i m p l e L_{simple} Lsimple 模型在样本质量方面受到的影响要大得多,而学习sigmas的 L h y b r i d L_{hybrid} Lhybrid模型保持了较高的样本质量。使用此模型,100 个采样步骤足以为完全训练的模型实现近乎最佳的 FID。

Scaling Model Size

在这里插入图片描述

为了衡量性能如何通过训练计算进行扩展,我们在 ImageNet 64 × 64 上训练了四个不同的模型,并使用 L h y b r i d L_{hybrid} Lhybrid 目标。为了改变模型容量,在所有层上应用深度乘法器,使得第一层有 64、96、128 或 192 个通道。请注意,之前的实验在第一层中使用了 128 个通道。由于每一层的深度都会影响初始权重的规模,因此将每个模型的Adam学习率按 1 / c h a n n e l m u l t i p l i e r 1 / \sqrt{channel multiplier} 1/channelmultiplier 缩放,因此128通道模型的学习率为0.0001。图 10 显示了 FID 和 NLL 相对于理论训练计算的改进情况。FID 曲线在对数-对数图上看起来近似线性,表明 FID 根据幂律(绘制为黑色虚线)进行缩放。NLL曲线不能完全拟合幂律,这表明验证NLL的扩展方式不如FID。这可能是由多种因素引起的,例如 1) 这种类型的扩散模型出乎意料的高不可约损失,或 2) 模型过度拟合到训练分布。还注意到,这些模型通常无法实现最佳对数似然,因为它们是使用 L h y b r i d L_{hybrid} Lhybrid 而不是直接使用 L v l b L_{vlb} Lvlb 进行训练的,以保持良好的对数似然性和样本质量。

实验

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考文献

[1] Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020.
[2] Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in Neural Information Processing Systems 30 (NIPS 2017), 2017.
[3] Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans, 2016.
[4] Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., and Amodei, D. Scaling laws for neural language models, 2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1521458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字多空策略(实盘+回测+数据)

数量技术宅团队在CSDN学院推出了量化投资系列课程 欢迎有兴趣系统学习量化投资的同学,点击下方链接报名: 量化投资速成营(入门课程) Python股票量化投资 Python期货量化投资 Python数字货币量化投资 C语言CTP期货交易系统开…

2023年总结:聊一聊我这10年做过的副业

以下是我的2023年终总结,我从公众号同步到CSDN,大家可以看看我这10年的副业经验,希望对大家有所帮助。 今天是2023年最后一天,今年是不平凡的一年,也是变动最大的一年,大家也知道,嘟嘟最近离职了…

深度学习模型部署(十)模型部署配套工具二

上篇blog讲了trtexec和onnx_graphsurgeon两个工具,一个用于将onnx转化为trt模型,另一个用于对onnx模型进行修改。这篇blog讲polygraphy和nsight systems,前者用于进行模型优化以及结果验证,后者用于性能分析。 polygraph polygra…

Python程序设计基础——代码习题

1 __name__属性 import demodef main():if __name__ __main__:print(这个程序被直接运行。)elif __name__demo:print(这个程序作为模块被使用。) main()3.3 编写程序,生成包含1000个0~100之间的随机整数,并统计每个元素出现的次数。 import randomx[r…

WebGIS之实现查询地区天气并让地区高亮

一.预览>> 二.思路>> 根据搜索框的内容来进行页面视角的切换,对应的地区高亮,右边有关天气的地方实时更新,并且因为代码体量非常小,并没有选择在框架下完成。直接一个html文件搞定了,但实际上还是有一些坑…

C#集合和数据结构,随笔记录

C#集合和数据结构 System.Collections命名空间包含接口和类,这些接口和类定义各种对象(如列表/链表、位数组、哈希表、队列和堆栈)的集合 System.Collections.Generic命名空间: 所有集合都直接或间接基于ICollection接口 列表类集…

Vue3+TS+Vite 找不到模块“@/components/xxx/xxx”或其相应的类型声明

引入vue文件时文件是存在的,引入路径也是对的,报找不到模块,有一些解决方案是在tsconfig.json里面做一些配置,大家可以自行百度(不知道是不是我百度的不对,我的没有解决)还有一种是在项目根目录…

免密ssh密钥登录Linux该如何设置

我们在使用ssh客户端远程连接Linux服务器时,为了考虑安全方面的因素,通常使用密钥的方式来登录。密钥分为公钥和私钥,这两把密钥可以互为加解密。公钥是公开的,私钥是由个人自己持有,并且必须妥善保管和注意保密。 Li…

安装python、pycharm,打好基础,准备飞起

python安装使用 安装python安装包 以下为自定义安装python安装包,无特殊要求可直接进行安装。 勾选Add Python 3.6 to PATH, 然后点击 Customize installation,进行自定义安装。 所有的都勾上,然后点击Next。 可选择自己需要…

perl 用 XML::DOM 解析 Freeplane.mm文件,生成测试用例.csv文件

Perl 官网 www.cpan.org 从 https://strawberryperl.com/ 下载网速太慢了 建议从 https://download.csdn.net/download/qq_36286161/87892419 下载 strawberry-perl-5.32.1.1-64bit.zip 约105MB 解压后安装.msi,装完后有520MB,建议安装在D:盘。 运行 …

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:ListItemGroup)

该组件用来展示列表item分组,宽度默认充满List组件,必须配合List组件来使用。 说明: 该组件从API Version 9开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。该组件的父组件只能是List。 使用说明 当List…

拿捏指针(二)

个人主页:秋邱博客 所属栏目:C语言 (感谢您的光临,您的光临蓬荜生辉) 目录 前言 数组与指针 数组名的理解 指针数组与数组指针 指针数组 数组指针 数组传参 一维数组传参的本质 二维数组传参的本质 二维数组…

Spring源码流程图

1.IOC源码 流程图地址:https://www.processon.com/view/link/626ce8dc0e3e742d46229977 2.AOP源码 流程图地址:https://www.processon.com/view/link/627134571efad45d06d6a1de 3.事务源码 流程图地址:https://www.processon.com/view/li…

Android中compile,implementation和api的区别,以及gradle-wrapper的详解

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂,风趣幽默",感觉非常有意思,忍不住分享一下给大家。 👉点击跳转到教程 前言: compile,implementation和api的区别和其作用 compile和api会进行传递…

AI赋能写作:AI大模型高效写作一本通

❤️作者主页:小虚竹 ❤️作者简介:大家好,我是小虚竹。2022年度博客之星评选TOP 10🏆,Java领域优质创作者🏆,CSDN博客专家🏆,华为云享专家🏆,掘金年度人气作…

15.7k stars一个实用型OCR,支持80多种语言

一个实用型 OCR,支持 80 多种语言和所有流行的书写脚本,包括:拉丁文、中文、阿拉伯文、梵文、西里尔文等。 特点 支持本地或云/API部署 准确度提高到 99% 以上 完全可定制,支持 80 多种语言 支持表格识别 二维码/条码提取识别 GitHub数据 15.7k s…

PS学习-放大图片保持清晰

快捷键冲突所以有的不能截屏 500就是原图的5倍 还很清晰

如何本地搭建hMailServer邮件服务

文章目录 前言1. 安装hMailServer2. 设置hMailServer3. 客户端安装添加账号4. 测试发送邮件5. 安装cpolar6. 创建公网地址7. 测试远程发送邮件8. 固定连接公网地址9. 测试固定远程地址发送邮件 前言 hMailServer 是一个邮件服务器,通过它我们可以搭建自己的邮件服务,通过cpola…

获取远程管理软件保存的凭据

点击星标,即时接收最新推文 本文选自《内网安全攻防:红队之路》 扫描二维码五折购书 内网敏感数据的发现 内网的核心敏感数据,不仅包括数据库、电子邮件,还包括个人数据及组织的业务数据、技术数据等。可以说,价值较高…

C语言数据结构基础笔记——树、二叉树简介

1.树 树是一种 非线性 的数据结构,它是由 n ( n>0 )个有限结点组成一个具有层次关系的集合。 把它叫做树是因 为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。 (图片来源于网络)…