【生成模型】Diffusion Models:概率扩散模型

news2024/10/6 6:52:30

---

  • 前言
  • 一、Diffusion Model 基本介绍
  • 二、生成模型对比
  • 三、直观理解Diffusion model
  • 四、形式化解析Diffusion model
  • 五、详解 Diffusion Model(数学推导)
    • 1.前向过程(扩散过程)
    • 2.逆扩散过程
    • 3.逆扩散条件概率推导
    • 4.训练损失
  • 六、训练、测试伪代码
    • 1. 训练
    • 2.测试
  • 总结

前言

AI 作画从 18 年的 DeepDream噩梦中惊醒过来,在 2022 年 OpenAI 的 DALL·E 2达到惊人效果,见图:
在这里插入图片描述
AI + 艺术涉及到 Transformer、VAE、ELBO、Diffusion Model 等一系列跟数学相关的知识。Diffusion Models 跟 VAE 一样原理很复杂


一、Diffusion Model 基本介绍

扩散模型(论文:DDPM 即 Denoising Diffusion Probabilistic Model)2020年发表以来关注较少,因为他不像 GAN 那样简单粗暴好理解,但最近爆火以至于ICRL会议相关投稿一半以上,其最先进的两个文本生成图像——OpenAI 的 DALL·E 2 和 Google 的 Imagen,都是基于扩散模型来完成的。

二、生成模型对比

先横向对一下几个重要生成模型 GAN、VAE、Flow-based Models、Diffusion Models。

GAN 由一个生成器(generator)和判别器(discriminator)组成,generator 负责生成逼真数据以 “骗” 过 discriminator,而 discriminator 负责判断一个样本是真实的还是 “造” 出来的。GAN 的训练其实就是两个模型在相互学习,能不能不叫“对抗”,和谐一点。

VAE 同样希望训练一个生成模型 x=g(z),这个模型能够将采样后的概率分布映射到训练集的概率分布,生成隐变量 z ,并且 z 是既含有数据信息又含有噪声,除了还原输入的样本数据以外,还可以用于生成新的数据。
在这里插入图片描述

Diffusion Models 的灵感来自non-equilibrium thermodynamics (非平衡热力学)。理论首先定义扩散步骤的马尔可夫链,以缓慢地将随机噪声添加到数据中,然后学习逆向扩散过程以从噪声中构造所需的数据样本。与 VAE 或流模型不同,扩散模型是通过固定过程学习,并且隐空间 z 具有比较高的维度。

三、直观理解Diffusion model

生成式模型本质上是一组概率分布。如图所示,左边是一个训练数据集,里面所有的数据都是从某个数据 pdata 中独立同分布取出的随机样本。右边就是其生成式模型(概率分布),在这种概率分布中,找出一个分布 pθ 使得它离的 pdata 距离最近。接着在 pθ 上采新的样本,可以获得源源不断的新数据。
在这里插入图片描述
但是往往 pdata 的形式是非常复杂的,而且图像的维度很高,我们很难遍历整个空间,同时我们能观测到的数据样本也有限。

Diffusion作用

我们可以将任意分布,当然也包括我们感兴趣的 pdata ,不断加噪声,使得他最终变成一个纯噪声分布 N(0,I)。怎么理解呢?

从概率分布的角度来看,考虑下图瑞士卷形状的二维联合概率分布 p(x,y) ,扩散过程q非常直观,本来集中有序的样本点,受到噪声的扰动,向外扩散,最终变成一个完全无序的噪声分布。
在这里插入图片描述
从单个图像样来看这个过程,扩散过程q就是不断往图像上加噪声直到图像变成一个纯噪声,逆扩散过程p就是从纯噪声生成一张图像的过程。样本变化:
在这里插入图片描述

四、形式化解析Diffusion model

既然叫生成模型,这意味着 Diffusion Models 用于生成与训练数据相似的数据。从根本上说,Diffusion Models 的工作原理,是通过连续添加高斯噪声来破坏训练数据,然后通过反转这个噪声过程,来学习恢复数据。

测试时,可以使用 Diffusion Models 将随机采样的噪声传入模型中,通过学习去噪过程来生成数据。也就是下面图中所对应的基本原理。

在这里插入图片描述
更具体地说,扩散模型是一种隐变量模型(latent variable model),使用马尔可夫链(Markov Chain, MC)映射到 latent space。通过马尔可夫链,在每一个时间步 t 中逐渐将噪声添加到数据 xi 中以获得后验概率 q(x1:T | x0) ,其中 x1…xT 代表输入的数据同时也是 latent space。也就是说 Diffusion Models 的 latent space与输入数据具有相同维度。

后验概率:在贝叶斯统计中,一个随机事件或者一个不确定事件的后验概率(Posterior probability)是在考虑和给出相关证据或数据后所得到的条件概率。wiki

马尔可夫链为状态空间中经过从一个状态到另一个状态的转换的随机过程。该过程要求具备“无记忆”的性质:下一状态的概率分布只能由当前状态决定,在时间序列中它前面的事件均与之无关

Diffusion Models 分为正向的扩散过程和反向的逆扩散过程。下图为扩散过程,从 到最后的 就是一个马尔可夫链,表示状态空间中经过从一个状态到另一个状态的转换的随机过程。而下标则是 Diffusion Models 对应的图像扩散过程。

在这里插入图片描述
最终,从 x0 输入的真实图像,经过 Diffusion Models 后被渐近变换为纯高斯噪声的图片 xT

模型训练主要集中在逆扩散过程。训练扩散模型的目标是,学习正向的反过程:即训练概率分布 pθ(xt-1 | xt) 。通过沿着马尔可夫链向后遍历,可以重新生成新的数据 x0

Diffusion Models 跟 GAN 或者 VAE 的最大区别在于不是通过一个模型来进行生成的,而是基于马尔可夫链,通过学习噪声来生成数据。
在这里插入图片描述
除了生成高质量图片之外呢,Diffusion Models 另一个好处是训练过程中没有对抗,对于 GAN 网络模型来说,对抗性训练其实是很不好调试的,因为对抗训练过程互相博弈的两个模型,对我们来说是个黑盒子。另外在训练效率方面,扩散模型还具有可扩展性和可并行性,那这里面如何加速训练过程,如何添加更多数学规则和约束,扩展到语音、文本、三维领域就很好玩了,可以出很多新文章。

五、详解 Diffusion Model(数学推导)

上面已经清晰表示了 Diffusion Models 由正向过程(或扩散过程)和反向过程(或逆扩散过程)组成,其中输入数据逐渐被噪声化,然后噪声被转换回源目标分布的样本。 原理即 马尔可夫链 + 条件概率分布核心在于如何使用神经网络模型,来求解马尔可夫过程的概率分布。

1.前向过程(扩散过程)

在这里插入图片描述
在实现和推导过程中要用到的两个重要特性:

特性 1:重参数(reparameterization trick)
重参数技巧在很多工作(gumbel softmax, VAE)中有所引用。如果我们要从某个分布中随机采样 (高斯分布) 一个样本,这个过程是无法反传梯度的。而这个通过高斯噪声采样得到 xt 的过程在 diffusion 中到处都是,因此我们需要通过重参数技巧来使得他可微:
在这里插入图片描述

特性 2:任意时刻的 xt 可以由 x0 和 βt 表示
在这里插入图片描述

2.逆扩散过程

如果说前向过程 (forward) 是加噪的过程,那么逆向过程(reverse) 就是diffusion 的去噪推断过程。

如果我们能够逆转上述过程并从 q(xt-1|xt) 采样,就可以从高斯噪声 xT ~N( 0, I )还原出原图分布 x0 ~q(x) 。在文献7中证明了如果q(xt|xt-1) 满足高斯分布且 βt 足够小, q(xt-1|xt) 仍然是一个高斯分布。然而我们无法简单推断 q(xt-1|xt) ,因此我们使用深度学习模型(参数为 θ,目前主流是 U-Net+attention 的结构)去预测这样的一个逆向的分布 pθ(类似 VAE):

在这里插入图片描述
然而在论文中,作者把条件概率 pθ(xt-1|xt) 的方差直接取了 βt ,而不是上面说的需要网络去估计的 Σθ(xt, t),所以说实际上只有均值需要网络去估计。

正向扩散和逆扩散过程都是马尔可夫,然后正态分布,然后一步一步的条件概率,唯一的区别就是正向扩散里每一个条件概率的高斯分布的均值和方差都是已经确定的(依赖于 βt 和 x0),而逆扩散过程里面的均值和方差是我们网络要学出来。

3.逆扩散条件概率推导

虽然我们无法得到逆转过程的概率分布 q(xt-1|xt),但是如果知道 x0, q(xt-1|xt, x0)就可以直接写出,这个玩意儿大概是这么个形式

在这里插入图片描述

贝叶斯公式:
在这里插入图片描述
带入公式得到:
在这里插入图片描述

在这里插入图片描述
7-1带入了贝叶斯公式2;7-2带入乘法公式1,再整理一下就能得到7-3

单变量正态分布概率密度函数定义为:
在这里插入图片描述,代入得到式 7.4

式 7.5 可整理为 1 2 \frac{1}{2} 21 (ax2+bx+c)的形式,即 1 2 \frac{1}{2} 21a (x+ b 2 a \frac{b}{2a} 2ab)2+C,其均值为- b 2 a \frac{b}{2a} 2ab,方差为 1 a \frac{1}{a} a1,因此稍加整理我们可以得到 (6) 中的方差和均值为:
在这里插入图片描述
根据特性2的公式(2),我们得知在这里插入图片描述,带入上式:在这里插入图片描述

可以看出,在给定 x0 的条件下,后验条件高斯分布的均值只和超参数,xt、εt 有关,方差只与超参数有关。

通过以上的方差和均值,我们就得到了q(xt-1|xt, x0) 的解析形式

4.训练损失

如何训练 Diffusion Models 以求得公式 (3) 中的均值 μθ(xt,t) 和方差 Σθ (xt,t) 呢? 在 VAE 中我们学过极大似然估计的作用:对于真实的训练样本数据已知,要求模型的参数,可以使用极大似然估计。

统计学中,似然函数是一种关于统计模型参数的函数。给定输出x时,关于参数θ的似然函数L(θ|x)(在数值上)等于给定参数θ后变量X的概率:L(θ|x)=P(X=x|θ)。

Diffusion Models 通过极大似然估计,来找到逆扩散过程中马尔可夫链转换的概率分布,这就是 Diffusion Models 的训练目的。即最大化模型预测分布的对数似然,从Loss下降的角度就是最小化负对数似然:
在这里插入图片描述

这个过程很像VAE,即 可以使用变分下界(VLB)来优化负对数似然

KL 散度是一种不对称统计距离度量,用于衡量一个概率分布 P 与另外一个概率分布 Q 的差异程度。连续分布的 KL 散度的数学形式是:
在这里插入图片描述
KL散度的性质:
在这里插入图片描述

由KL散度可知:
在这里插入图片描述

进一步可以写出上式的交叉熵的上界,进一步对其上界进行化简:

在这里插入图片描述
接下来我们对这三种情况进行分类讨论:

首先,由于前向过程 q 没有可学习参数,而 xT 则是纯高斯噪声,因此 LT 可以当做常量忽略。

然后,Lt-1 是KL散度,则可以看做拉近 2 个分布的距离:

  1. 第一个分布 q(xt-1|xT,x0,) 我们已经在上一节推导出其解析形式,这是一个高斯分布,其均值和方差为
    在这里插入图片描述
  2. 第二个分布 pθ(xt-1,xt) 是我们网络期望拟合的目标分布,也是一个高斯分布,均值用网络估计,方差被设置为了一个和 βt 有关的常数。
    在这里插入图片描述
    如果有两个分布 p,q 都是高斯分布,则他们的KL散度为
    在这里插入图片描述
    然后因为这两个分布的方差全是常数,和优化无关,所以其实优化目标就是两个分布均值的二范数
    在这里插入图片描述
    把这个公式,带入到 上一公式中得到:
    在这里插入图片描述
    经过这样一番推导之后就是个 L2 loss。网络的输入是一张和噪声线性组合的图片,然后要估计出来这个噪声:
    在这里插入图片描述

六、训练、测试伪代码

在这里插入图片描述

1. 训练

在这里插入图片描述

2.测试

在这里插入图片描述

总结

  1. Diffusion Model 通过参数化的方式表示为马尔科夫链,这意味着隐变量 x1,…xT 都满足当前时间步 t 只依赖于上一个时间步 t-1,这样对后续计算很有帮助。
  2. 马尔科夫链中的转变概率分布 pθ(xt-1|xt) 服从高斯分布,在正向扩散过程当中高斯分布的参数是直接设定的,而逆向过程中的高斯分布参数是通过学习得到的。
  3. Diffusion Model 网络模型扩展性和鲁棒性比较强,可以选择输入和输出维度相同的网络模型,例如类似于UNet的架构,保持网络模型的输入和输出 Tensor dims 相等。
  4. Diffusion Model 的目的是对输入数据求极大似然函数,实际表现为通过训练来调整模型参数以最小化数据的负对数似然的变分上限
  5. 在概率分布转换过程中,因为通过马尔科夫假设,目标函数第4点中的变分上限都可以转变为利用 KL 散度来计算,因此避免了采用蒙特卡洛采样的方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/30608.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鲲鹏devkit编译调试工具——《sudoku》作业解析

《sudoku》作业解析 本次实验以sudoku项目为例介绍鲲鹏编译调试插件的基本使用方法 本次实验的步骤主要为 获取源码安装鲲鹏编译调试插件服务器配置进行代码同步配置配置测试任务进行编译调试 接下来我们先获取本次实验所需要的源码 获取源码 sudoku项目已经上传到github使…

stata外部命令大全(包含面板门槛、系统GMM、空间计量、Pvar、中介效应等)

1、数据来源:自主整理 2、时间跨度:无 3、区域范围:无 4、指标说明: 该些外部命令包含面板门槛、系统GMM、空间计量、pvar、中介效应等涵盖全部 以下是部分命令截图: 空间计量: 系统GMM(动…

Allure使用手册

一. 简介 Allure是一款支持多语言的测试结果可视化软件,支持Java、Python,搭配Junit、pytest等测试框架食用更香。本文主要讲解搭配Junit4。 二. 下载、安装部署 2.1 下载 百度搜索Allure2!!! 敲重点:…

基于Qlearning强化学习的倒立摆控制系统matlab仿真

目录 1.算法描述 2.仿真效果预览 3.MATLAB部分代码预览 4.完整MATLAB程序 1.算法描述 强化学习通常包括两个实体agent和environment。两个实体的交互如下,在environment的statestst下,agent采取actionatat进而得到rewardrtrt 并进入statest1st1。Q-l…

【头歌实验】五、Python循环结构

文章目录>>>第1关:达依尔的麦子数任务描述案例分析相关知识for循环测试说明参考答案>>>第2关:四级单词查询任务描述案例分析相关知识如何处理文件文件打开文件循环文件关闭遍历文件测试说明第3关:出租车车费计算任务描述案…

Monaco Editor教程(十八):使用api来完成某些键盘操作,格式化,查找,显示右侧菜单等。

背景 在一般的Web IDE中,我们需要将经常用到的一些操作放到顶部操作栏里,类似语雀的文档编辑。 代码编辑器,一般也会放一些查找,格式化,撤销,恢复。有些人喜欢用快捷键来进行这些操作,但由于mo…

Packet Tracer - 配置 OSPF 高级功能

地址分配表 设备 接口 IPv4 地址 子网掩码 默认网关 R1 G0/0 172.16.1.1 255.255.255.0 不适用 S0/0/0 172.16.3.1 255.255.255.252 不适用 S0/0/1 192.168.10.5 255.255.255.252 不适用 R2 G0/0 172.16.2.1 255.255.255.0 不适用 S0/0/0 172.16.3.2 …

论文笔记: 全波形反演的无监督学习: 将 CNN 与偏微分方程做成一个环

摘要: 分享对论文的理解, 原文见 Peng Jin, Xitong Zhang, Yinpeng Chen, Sharon Xiaolei Huang, Zicheng Liu, Youzuo Lin, Unsupervised learning of full-waveform inversion: connecting CNN and partial differential equation in a loop. 论文发表于计算机方面的顶会 ICL…

(续)SSM整合之SSM整合笔记(Spring整合MyBatis)(P179-188)

一 准备工作 1 新建模块ssm com.atguigu.ssm 2 导入依赖 <packaging>war</packaging><properties><spring.version>5.3.1</spring.version> </properties><dependencies><dependency><groupId>org.springframewo…

Linux:进程(二)

文章目录前言一、环境变量1.概念2.常见环境变量3.一个疑问4.通过系统调用获取或设置环境变量二、地址空间1.引入2.分页&进程地址空间1.页表2.写时拷贝3.为什么要有地址空间总结前言 今天我们继续学习进程相关知识。 一、环境变量 1.概念 环境变量(environment variables)…

从理解路由到实现一套Router(路由)

平时在Vue项目中经常用到路由&#xff0c;但是也仅仅处于会用的层面&#xff0c;很多基础知识并不是真正的理解。于是就趁着十一”小长假“查阅了很多资料&#xff0c;总结下路由相关的知识&#xff0c;查缺不漏&#xff0c;加深自己对路由的理解。 路由 在 Web 开发过程中&a…

Redis中最简单的存储类型:String

String类型&#xff0c;也就是字符串类型&#xff0c;是Redis中最简单的存储类型。 其value是字符串&#xff0c;不过根据字符串的格式不同&#xff0c;又可以分为3类&#xff1a; string&#xff1a;普通字符串 int&#xff1a;整数类型&#xff0c;可以做自增、自减操作 f…

CentOS虚拟机装完了,不能粘贴window命令行?不能上网?

文章目录前言关于CentOS安装版本如何实现粘贴命令行CentOS命令行模式下如何联网&#xff1f;结尾前言 最近想系统学习Linux环境下系统运维&#xff0c;所以安装了CentOS 7虚拟机&#xff0c;但是个人笔记本上和工作电脑上无意中下载了不同镜像进行安装&#xff0c;有一台机器无…

Nerf三维重建Pytorch使用Pycharm运行0基础教程

Nerf三维重建Pytorch使用Pycharm运行0基础教程 你好&#xff01; 这里是“出门吃三碗饭”本人&#xff0c;本文章接下来将介绍如何从0运行2020会议Nerf的Pytorch版本&#xff0c;让你自己动手渲染第一个三维模型。视频解说可以关注B站&#xff0c;搜索 出门吃三碗饭 &#xff…

Improving Inductive Link Prediction Using Hyper-Relational Facts

摘要 多年来,知识图(KGs)上的链接预测一直是一个纯粹的转换任务,不允许对看不见的实体进行推理。最近,越来越多的努力被投入到探索半和全归纳场景,使推理能够对不可见的和新兴的实体。然而,所有这些方法都只考虑基于三元组的kg,而它们更丰富的对应,超关系KG(如Wikidata…

OWASP ZAP mac chrome代理配置取消URL强制Https【已解决】

1.OWASP ZAP OWASP Zed攻击代理&#xff08;ZAP&#xff09;是世界上最受欢迎的免费安全审计工具之一&#xff0c;由数百名国际志愿者积极维护。它可以帮助你在开发和测试应用程序时自动查找Web应用程序中的安全漏洞。 也可以说ZAP是一个中间人代理。它能够获取你对Web应用程…

2022亚太赛题浅评

2022年亚太今日已经正式开赛&#xff0c;为了帮助大家更好的选题建模&#xff0c;这里首先对ABC三道题目进行浅要评析&#xff0c;以方便大家更好的择题。同时相关资料也会后续进行补充。预计明日公布各题统计选题人数以及较为完善的资料。今天作为第一天重要的是择好题&#x…

XCTF1-web easyupload

easyupload 题目描述 一名合格的黑客眼中&#xff0c;所有的上传点都是开发者留下的后门 进入场景 是个文件上传的页面&#xff0c;测试上传的文件类型&#xff0c;发现是图片上传点 上传正常图片&#xff0c;会回显文件上传的路径 尝试推测文件上传检测点 测试后缀名php、…

Flutter高仿微信-第30篇-单聊-文本

Flutter高仿微信系列共59篇&#xff0c;从Flutter客户端、Kotlin客户端、Web服务器、数据库表结构、Xmpp即时通讯服务器、视频通话服务器、腾讯云服务器全面讲解。 详情请查看 效果图&#xff1a; 详情请参考Flutter高仿微信-第29篇-单聊 &#xff0c; 这里只是提取文本实现的部…

Linux系统中使用汇编初始化外设方法

大家好&#xff0c;我是ST。 今天主要和大家聊一聊&#xff0c;如何使用汇编语言来实现芯片外设的初始化功能。 ​ 目录 第一步&#xff1a;硬件原理分析 第二&#xff1a;实验程序编写方法 第三&#xff1a;汇编代码具体实现 第四&#xff1a;编译与下载 第五&#xff…