一文弄懂 Diffusion Model

news2025/2/26 9:31:44

什么是 Diffusion Model

一、前向 Diffusion 过程

Diffusion Model 首先定义了一个前向扩散过程,总共包含T个时间步,如下图所示:

最左边的蓝色圆圈 x0 表示真实自然图像,对应下方的狗子图片。

最右边的蓝色圆圈 xT 则表示纯高斯噪声,对应下方的噪声图片。

最中间的蓝色圆圈 xt 则表示加了噪声的 x0 ,对应下方加了噪声的狗子图片。

箭头下方的 q(xt|xt-1) 则表示一个以前一个状态 xt-1 为均值的高斯分布,xt  从这个高斯分布中采样得到。

所谓前向扩散过程可以理解为一个马尔可夫链[7],即通过逐步对一张真实图片添加高斯噪声直到最终变成纯高斯噪声图片。

(1)利用前一时刻的 xt-1 得到任意时刻的噪声图片 xt重参数化技巧

那么具体是怎么添加噪声呢,公式表示如下:

★★★ 也就是每一时间步的 xt 是从一个,以 1-βt 开根号乘以 xt-1 为均值,βt为方差的高斯分布中采样得到的。其中βt, t ∈ [1, T] 是一系列固定的值,由一个公式生成。

在参考资料 [2] 中设置 T=1000, β1=0.0001, βT=0.02,并通过一句代码生成所有 βt 的值:

# https://pytorch.org/docs/stable/generated/torch.linspace.html
betas = torch.linspace(start=0.0001, end=0.02, steps=1000)

然后在采样得到 xt 的时候并不是直接通过高斯分布 q(xt|xt-1) 采样,而是用了一个重参数化的技巧(详见参考资料[4]第5页)。

★★★ 简单来说就是,如果想要从一个任意的均值 μ 方差 σ^2 的高斯分布中采样得到xt

1)可以首先从一个标准高斯分布(均值0,方差1)中进行采样得到噪声 ε

noise = torch.randn_like(x_0)

2)然后利用 μ + σ·ε 就等价于从任意均值 μ 方差 σ^2 的高斯分布中采样(首先从标准高斯分布中采样得到噪声 ε,接着乘以标准差再加上均值)。公式表示如下:

xt = sqrt(1-betas[t]) * xt-1 + sqrt(betas[t]) * noise

 完整代码:

# https://pytorch.org/docs/stable/generated/torch.randn_like.html
betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
noise = torch.randn_like(x_0)
xt = sqrt(1-betas[t]) * xt-1 + sqrt(betas[t]) * noise

(2)直接从 x0 采样得到中间任意一个时间步的噪声图片 xt

然后前向扩散过程还有个属性,就是可以直接从 x0 采样得到中间任意一个时间步的噪声图片 xt,公式如下:

其中的 αt 表示:

具体怎么推导出来的可以看参考资料[4] 第11页,伪代码表示如下:

betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
alphas = 1 - betas
# cumprod 相当于为每个时间步 t 计算一个数组 alphas 的前缀乘结果
# https://pytorch.org/docs/stable/generated/torch.cumprod.html
alphas_cum = torch.cumprod(alphas, 0)
alphas_cum_s = torch.sqrt(alphas_cum)
alphas_cum_sm = torch.sqrt(1 - alphas_cum)

# 应用重参数化技巧采样得到 xt
noise = torch.randn_like(x_0)
xt = alphas_cum_s[t] * x_0 + alphas_cum_sm[t] * noise

通过上述的讲解,读者应该对 Diffusion Model 的前向扩散过程有比较清晰的理解了。

不过我们的目的不是要做图像生成吗?现在只是从数据集中的真实图片得到一张噪声图,那具体是怎么做图像生成呢?

二、反向 Diffusion 过程

反向扩散过程 q(xt-1|xt, x0) (看粉色箭头)是前向扩散过程 q(xt|xt-1) 的后验概率分布。

和前向过程相反是从最右边的纯高斯噪声图,逐步采样得到真实图像 x0

后验概率 q(xt-1|xt, x0) 的形式可以根据贝叶斯公式推导得到(推导过程详见参考资料[4]第12页):

也是一个高斯分布。

(1)方差:

其方差从公式上看是个常量,所有时间步的方差值都是可以提前计算得到的

betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
alphas = 1 - betas
alphas_cum = torch.cumprod(alphas, 0)
alphas_cum_prev = torch.cat((torch.tensor([1.0]), alphas_cum[:-1]), 0)
posterior_variance = betas * (1 - alphas_cum_prev) / (1 - alphas_cum)

(2)均值:

对于反向扩散过程,在采样生成 xt-1 的时候 xt 是已知的,而其他系数都是可以提前计算得到的常量。

但是现在问题来了,在真正通过反向过程生成图像的时候,x0 我们是不知道的,因为这是待生成的目标图像。

好像变成了鸡生蛋,蛋生鸡的问题,那该怎么办呢?

(3)Diffusion Model 训练目标

当一个概率分布q 求解困难的时候,我们可以换个思路(详见参考资料[5,6])。

通过人为构造一个新的分布 p,然后目标就转为缩小分布 p 和  q 之间差距。通过不断修改 p  的参数去缩小差距,当 p 和 q 足够相似的时候就可以替代 q 了。

然后回到反向 Diffusion 过程,由于后验分布 q(xt-1|xt, x0) 没法直接求解。

那么我们就构造一个高斯分布 p(xt-1|xt)(见绿色箭头),让其方差和后验分布  q(xt-1|xt, x0) 一致:

而其均值则设为:

和 q(xt-1|xt, x0) 的区别在于,x0 改为 xθ(xt, t) 由一个深度学习模型预测得到,模型输入是噪声图像 xt 和时间步 t 。

然后缩小分布  p(xt-1|xt) 和  q(xt-1|xt, x0) 之间差距,变成优化以下目标函数(推导过程详见参考资料[4]第13页):

但是如果让模型直接从 xt 去预测 x0,这个拟合难度太高了,我们再继续换个思路。

前面介绍前向扩散过程提到,xt 可以直接从 x0 得到:

将上面的公式变换一下形式:

代入上面  q(xt-1|xt, x0) 的均值式子中可得(推导过程详见参考资料[4]第15页):

观察上述变换后的式子,发现后验概率 q(xt-1|xt, x0) 的均值只和 xt 和前向扩散时候时间步 t 所加的噪声有关。

所以我们同样对构造的分布 p(xt-1|xt) 的均值做一下修改:

将模型改为去预测在前向时间步 t 所添加的高斯噪声 ε,模型输入是 xt 和 时间步 t

接着优化的目标函数就变为(推导过程详见参考资料[4]第15页):

然后训练过程算法描述如下,最终的目标函数前面的系数都去掉了,因为是常量:

★ 可以看到虽然前面的推导过程很复杂,但是训练过程却很简单:

  1. 首先每个迭代就是从数据集中取真实图像 x0,并从均匀分布中采样一个时间步 t
  2. 然后从标准高斯分布中采样得到噪声 ε,并根据公式计算得到前向过程的 xt
  3. 接着将 xt 和 t 输入到模型让其输出去拟合预测噪声 ε,并通过梯度下降更新模型,一直循环直到模型收敛。
  4. 而采用的深度学习模型是类似 UNet 的结构(详见参考资料[2]附录B)。

训练过程的伪代码如下:

betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
alphas = 1 - betas
alphas_cum = torch.cumprod(alphas, 0)
alphas_cum_s = torch.sqrt(alphas_cum)
alphas_cum_sm = torch.sqrt(1 - alphas_cum)

def diffusion_loss(model, x0, t, noise):
    # 根据公式计算 xt
    xt = alphas_cum_s[t] * x0 + alphas_cum_sm[t] * noise
    # 模型预测噪声
    predicted_noise = model(xt, t)
    # 计算Loss
    return mse_loss(predicted_noise, noise)

for i in len(data_loader):
    # 从数据集读取一个 batch 的真实图片
    x0 = next(data_loader)
    # 采样时间步
    t = torch.randint(0, 1000, (batch_size,))
    # 生成高斯噪声
    noise = torch.randn_like(x_0)
    loss = diffusion_loss(model, x0, t, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

三、Diffusion Model 生成图像过程

模型训练好之后,在真实的推理阶段就必须从时间步 T 开始往前逐步生成图片,算法描述如下:

步骤:

  1. 一开始先生成一个从标准高斯分布生成噪声,
  2. 然后每个时间步 t,将上一步生成的图片 xt 输入模型模型预测出噪声。
  3. 接着从标准高斯分布中采样一个噪声,根据重参数化技巧,后验概率的均值和方差公式,计算得到 xt-1,直到时间步 1 为止。

四、改进 Diffusion Model

文章 [3] 中对 Diffusion Model 提出了一些改进点。

(1)对方差 βt  的改进

前面提到 βt 的生成是将一个给定范围均匀的分成 T 份,然后每个时间步对应其中的某个点:

betas = torch.linspace(start=0.0001, end=0.02, steps=1000)

然后文章 [3] 通过实验观察发现,采用这种方式生成方差 βt 会导致一个问题,就是做前向扩散的时候到靠后的时间步噪声加的太多了。

这样导致的结果就是在前向过程靠后的时间步,在反向生成采样的时候并没有产生太大的贡献,即使跳过也不会对生成结果有多大的影响。

接着论文[3] 中就提出了新的 βt 生成策略,和原策略在前向扩散的对比如下图所示:

第一行就是原本的生成策略,可以看到还没到最后的时间步就已经变成纯高斯噪声了,而第二行改进的策略,添加噪声的速度慢一些,看起来也更合理。

实验结果表明,针对 imagenet 数据集 64x64 的图片,原始的策略在做反向扩散的时候,即使跳过开头的 20% 的时间步,都不会对指标有很大的影响。

然后看下新提出的策略公式:

其中 s 设置为 0.008同时限制 βt最大值为 0.999,伪代码如下:

T = 1000
s = 8e-3
ts = torch.arange(T + 1, dtype=torch.float64) / T + s
alphas = ts / (1 + s) * math.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = betas.clamp(max=0.999)

(2)对生成过程时间步数的改进

原本模型训练的时候是假定在 T个时间步下训练的,在生成图像的时候,也必须从 T 开始遍历到 1 。而论文 [3] 中提出了一种不需要重新训练就可以减少生成步数的方法,从而显著提升生成的速度。

这个方法简单描述就是,原来是 T 个时间步现在设置一个更小的时间步数 S ,将 S 时间序列中的每一个时间步 s 和 T时间序列中的步数 t 对应起来,伪代码如下:

T = 1000
S = 100
start_idx = 0
all_steps = []
frac_stride = (T - 1) / (S - 1)
cur_idx = 0.0
s_timesteps = []
for _ in range(S):
    s_timesteps.append(start_idx + round(cur_idx))
    cur_idx += frac_stride

接着计算新的 β ,St 就是上面计算得到的 s_timesteps

伪代码如下:

alphas = 1 - betas
alphas_cum = torch.cumprod(alphas, 0)
last_alpha_cum = 1.0
new_betas = []
# 遍历原来的 alpha 前缀乘序列
for i, alpha_cum in enumerate(alphas_cum):
    # 当原序列 T 的索引 i 在新序列 S 中时,计算新的 beta
    if i in s_timesteps:
        new_betas.append(1 - alpha_cum / last_alpha_cum)
        last_alpha_cum = alpha_cum

简单看下实验结果:

关注画蓝线的红色和绿色实线,可以看到采样步数从 1000 缩小到 100 指标也没有降多少。

参考资料

  • [1] https://www.assemblyai.com/blog/diffusion-models-for-machine-learning-introduction/

  • [2] https://arxiv.org/pdf/2006.11239.pdf

  • [3] https://arxiv.org/pdf/2102.09672.pdf

  • [4] https://arxiv.org/pdf/2208.11970.pdf

  • [5] https://www.zhihu.com/question/41765860/answer/1149453776

  • [6] https://www.zhihu.com/question/41765860/answer/331070683

  • [7] https://zh.wikipedia.org/wiki/%E9%A9%AC%E5%B0%94%E5%8F%AF%E5%A4%AB%E9%93%BE

  • [8] https://github.com/rosinality/denoising-diffusion-pytorch

  • [9] https://github.com/openai/improved-diffusion

一文弄懂 Diffusion Model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/39443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tomcat安装及配置和常见的问题(2022最新详解、图文教程)

Tomcat的配置安装1. 关于WEB服务器软件2. 配置Tomcat的服务器第一步:配置Java的运行环境第二步:Tomcat的安装第三步:启动Tomcat3. 问题一:解决Tomcat服务器在DOS命令窗口中的乱码问题(控制台乱码)4. 测试To…

问题盘点|使用 Prometheus 监控 Kafka,我们该关注哪些指标

Kafka 作为当前广泛使用的中间件产品,承担了重要/核心业务数据流转,其稳定运行关乎整个业务系统可用性。本文旨在分享阿里云 Prometheus 在阿里云 Kafka 和自建 Kafka 的监控实践。01Kafka 简介Aliware01Kafka 是什么?Kafka 是分布式、高吞吐…

算法选修(J.琴和可莉)(为选修画上句号)

可莉又去池塘炸鱼啦!琴团长决定亲自捉拿可莉将其关禁闭。琴团长不断地追,可莉不断地跑。 琴团长和可莉的行动路线可以看做是一个有n个节点的无根树,初始时琴团长在A点,可莉在B点,她们互相知道对方的位置。 琴团长想尽…

P8869 [传智杯 #5 初赛] A-莲子的软件工程学

import java.util.Scanner;public class Main {public static void main(String[] args) {Scanner sc new Scanner(System.in);long a sc.nextLong();long b sc.nextLong();System.out.println(Math.abs(a)*(b>0?1:-1));}} 题目背景 在宇宙射线的轰击下,莲子…

Day13--搜索建议-自动获取焦点与防抖处理

1.定义如下的 UI 结构: 我的操作: 第一次尝试:【出现轮廓】 官方文档: 1》在search.vue中: 效果图:【还是和博主的搜索框有区别的】 第二次尝试:【加上圆角】 官方文档: 第三次尝试…

58、ElasticSearch DSL Bucket聚合

1、聚合的分类 2、DSL实现Bucket聚合 # 集合, 1、bucket terms GET /hotel/_search { "size": 0, "aggs": { "brandAgg": { "terms": { "field": "brand", "size": 20 …

10.前端笔记-CSS-盒子模型-border和padding

页面布局的三大核心: 盒子模型浮动定位 1、盒子模型 1.1 盒子模型组成 盒子模型本质还是一个盒子,包括边框border、外边距margin、内边距padding和实际内容content 1.1.1 边框border 组成 组成:颜色border-color、边框宽度border-wid…

信息论与编码:线性分组码与性能参数

文章目录1.1 线性分组码(n,k)定义1.2 信道编码性能参数1.3基本线性分组码a.奇偶监督码b.恒比码c.汉明码1.4 差错控制类型对信道编码的要求1.5信道编码主要涉及的数学知识:有限域运算、矩阵运算1.1 线性分组码(n,k)定义 线性分组码是由 (n, k) 形式表示。编码器将一…

WEB安全技能树-安全漏洞类型-命令执行漏洞

题目类型 环境:CentOSApachePHPMySQL 题目:ping主机 考点分析 1.过滤 ; && || 等多条命令连接符; 2.过滤cat more less等文件读取命令; 解题思路 第一步 ping 127.0.0.1 看看命令是否能够正确执行 linux如果不指定-…

【Java第35期】:Bean的生命周期

作者:有只小猪飞走啦 博客地址:https://blog.csdn.net/m0_62262008?typeblog 内容:1,这篇博客要分析Bean生命周期有几个阶段? 2,每个阶段的效果是什么? 3,PostConstruct 和 PreDestroy 各自的效果是什…

如果线性变换可以模仿

🍿*★,*:.☆欢迎您/$:*.★* 🍿 正文 如何模仿一个 行为 假设这个行为是线性变换 A 通过权重w 变换为 B 假设可以通过 如下方式 模仿 A变换到B 线性变换 让 C 变换 D首先 计算A C 的距离 dx 计算 B D 的距离 dy假设 w 是通过等差求解权重的方…

(附源码)计算机毕业设计Java搬家预约系统

项目运行 环境配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis Maven Vue 等等组成,B/…

Elon Musk 与开发者分享他的第一份代码评审

Elon Musk 比以往任何时候都更致力于 Twitter 2.0 的成功,与开发者分享他的第一份代码评审。 原文 https://ssaurel.medium.com/more-committed-than-ever-to-making-twitter-2-0-succeed-elon-musk-shares-his-first-code-review-a565e8df5e2f 前言 Elon Musk 也是…

第8讲:Python中列表的概念与基本使用

文章目录1.列表的概念1.什么是列表1.2.列表中元素的索引概念2.列表的简单定义3.获取列表中某个元素的索引3.1.如何获取列表中某个元素的索引3.2.各种场景获取列表中元素的索引4.使用运算符in检查列表中是否存在指定元素1.列表的概念 1.什么是列表 Python中的列表其实就是其他…

骨传导蓝牙耳机哪个品牌好,骨传导蓝牙耳机品牌推荐

在选择骨传导耳机时还不知道选择什么品牌好?下面小编就给大家推荐几款做的不错的骨传导耳机,大家要注意,在选择骨传导耳机时,还是要选择一些较大的骨传导品牌,这样无论是耳机体验还是售后服务都有保证。 1、南卡Runne…

【机器学习入门项目10例】(八):贝叶斯网络-拼写检查器

🌠 『精品学习专栏导航帖』 🐳最适合入门的100个深度学习实战项目🐳🐙【PyTorch深度学习项目实战100例目录】项目详解 + 数据集 + 完整源码🐙🐶【机器学习入门项目10例目录】项目详解 + 数据集 + 完整源码🐶🦜【机器学习项目实战10例目录】项目详解 + 数据集 +

pagination分页插件的getResult明明有数据,但是getTotal方法为0

最近把之前毕设的SSM项目改成SpringBoot项目时遇到了明明后端数据库查询到了数据,但是page的getTotal方法却是0的bug 解决办法: 先导入需要的依赖,这里注意ssm项目的依赖和SpringBoot的依赖是不一样的,这个只要导入极少启动依赖…

视频 | 扩增子文库拆分和16S序列合并

点击阅读原文跳转完整教案。基因组中的趣事(二)- 最长的基因2.7 million,最短的基因只有8 nt却能编码基因组中的趣事(一):这个基因编码98种转录本1 Linux初探,打开新世界的大门1.1 Linux系统简介…

Strimzi Kafka Bridge(桥接)实战之二:生产和发送消息

欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos 本篇概览 本文是《Strimzi Kafka Bridge(桥接)实战之》系列的第二篇,咱们直奔bridge的重点:常用接口,用实际操作体验如何用…

27. Ubuntu 20.04 开机自动挂载文件/etc/fstab

自动挂载文件/etc/fstab1.fstab2. 参数含义3.开机自动挂载3.1 查看要挂载的磁盘UUID3.2 向fstab文件中添加不同于热插拔的设备,对于硬盘可能需要长期挂载在系统下,所以如果每次开机都去手动mount是非常痛苦的,当然Ubuntu系统的GNOME桌面自带的…