扩散模型代码剖析

news2024/11/26 16:54:35

前言

相信大家对扩散模型早有耳闻,其着实大火了一把,效果也确实是好。今天写这篇博客的主要动机就是想真正进入到代码层面去看看其到底是怎么实现的。

其实在看完代码后,会觉得其实现的非常简单,而且也会对原理的理解有一个更好的正反馈。

多说一句,在扩散模型能够生成这么惊艳的图片大背景下,已经有大批研究员悄然开始了研究生成视频的方向,笔者之前也写过一篇,感兴趣的可以穿梭:

https://zhuanlan.zhihu.com/p/570332906

另外其实网上还有很多扩散的代码,大体上核心的地方都一样,笔者在文末也贴出了一些相关的博客,都讲的很好,大家觉得不过瘾的话,可以多看几次。

paper: https://arxiv.org/pdf/2006.11239.pdf

TF版:https://github.com/hojonathanho/diffusion

pytorch版: https://github.com/lucidrains/denoising-diffusion-pytorch

本篇以pytorch版为demo介绍其实现~

原理

在看代码之前,还是得先大体了解一下原理,即了解一下到底我们要实现什么东西?如果大家实在想看细节推导,可以直接看文末的一些资料,这里仅从宏观上帮大家理一下我们要做的事。

alt

首先上图的算法1和2就是具体的train和inference过程,其实要得到这两个公式是有一系列复杂的数学推导的。

但是 我们需要从大的方面来理顺整个逻辑也即我们到底是要做什么:

  • 训练阶段:

扩散模型的原理是先扩散再恢复,而我们的目标也即loss就是在恢复阶段使得最终的图像尽可能的接近原图,假设生成原图的概率是p,那loss用数学公式来写的话就是logp,记住哈,一切是从这个源头出发进行推导,经过一系列复杂的推导就得到了上图左边的最终公式。

其实这里面是有两个过程的,一个是扩散的前向过程,另外一个反向过程,一般博主的讲解基本都是会进行推导,这里说实话有点麻烦,大家可以耐心去看,但是总的来说我们最终就是会得到左边的这个优化loss,这也是整个最后的结论,也是落实到代码层面真正要编写的逻辑。

再看一下最后这个公式,其实 就是对标准正态分布的一个采样, 就是原始图经过t时刻扩散后的样子即 , 就是我们的网络,也即要优化的参数,具体的是一个Unet网络,可以看到其把 和t作为输入,来预测当前时刻的噪声,进而和真正的采样噪声 做loss。

至于公式里面的 其实是个常数,其是一系列 连乘,而 是由 得到的,而 在论文中是0.0001~0.002,在前向扩散中不断增大,所以 越来越小,也即越到后面加噪声的力度越来越大。

alt
  • 推理阶段:

从上面可以看到,训练阶段本身是在训练一个噪声预测模型,具体的就是给定 和t,其就能预测出由 这一过程所叠加的噪声。

有了这个噪声预测模型我们就可以知道任意时刻t的噪声,然后再恢复阶段就不断的减噪即一步一步的去噪直到恢复到原图就行啦。

所以看一下上图右边的公式,其实是一个for循环,在一步一步去噪。

大家在看公式的时候,可能还会注意到 ,它其实是

alt

另外再看一下inference阶段对应的这个最后结论性公式的形式

alt

可以把红色框的看成一个均值, 是方差

我们知道假设有一个变量z服从标准的正态分布 ,如果 , 那x也是个正态分布且服从

所以 服从

这个公式看着复杂,其实 就是模型预测出来的噪声,其他的都是定值,都是 变化过来的值。

所以可以这么理解:inference阶段的每一步 其实都是一个均值+方差的过程,这个均值其实是上一步 减去模型预测的噪声得到的,当然了还需要加上点方差扰动,可以看到当t==0的时候,也就是最后一步的时候,z=0了,就不需要加方差扰动了,因为已经是最终的清晰照片啦。

其实在看最后代码实现的时候会发现,在inference这里,其实是使用了一个推导的中间过程的(具体可以看文末第一个视频),实现均值的时候是用了绿色框的部分,当然了方差还是用了上面讲的。

alt
  • 小结

原理这里我们并没有讲解复杂的推导过程,而是集中精力梳理了一下最终的结论或者说最终落地代码要实现的公式。

总的来说训练阶段就是在训练一个噪声预测器,inference阶段就是在不断的减噪声(训练好的模型预测出来的),且本质上每一步都是个正太分布。

其实要看懂扩散模型复杂的理论部分和代码实现,需要牢牢铭记两个参数一个分布即可,参数是 、分布是标准正态分布。其实 ,所以真正的变量只有一个那就是 ,它是什么呢?它就是个序列,一般有两种即线性和余弦,代码最后用了线性递增序列,代表着随着时间所加噪声的力度越来越大,而具体加的噪声就是标准正态分布。

上面最后这个结论的公式中所有的符号含义都是由 变化出来的或者说是 的一个函数,一切的一切都是由 变化得到的,是一个定值,是一个常数!!!

代码

终于要实际看看怎么实现的了。

核心代码全部在这个py文件下看到

https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

这里比较关键的就是GaussianDiffusion和Unet这两个类,前者可以实际看到loss的实现以及推理阶段的实现,后面就是具体的噪声预测模型也即Unet。

  • train loss

看GaussianDiffusion的forward方法,可以看到其是随机抽取了一个t,即打算扩散多少步,然后调用了self.p_losses了

alt

接着我们看self.p_losses,这个就是核心代码实现的逻辑了,我们仔细看。

首先第673行的noise就是真正的采样噪声 ,深入到default函数就可以看到其就是从标准正态分布随机采样得到的。

677行self.q_sample的返回值x其实就是 。同时可以看到691行把x和t作为输入给模型self.model,模型输出model_out其实就是模型预测的噪声。703就是具体的loss啦,就是把预测当前时刻的噪声和的采真正样噪声进行l1作为loss。

alt

到目前为止都比较清晰,其中self.model就是Unet,下面我们来看看self.q_sample是怎么实现的,657行其实就是 的实现也即 ,其中x_start就是 也即原始图片;self.sqrt_alphas_cumprod就是 ,self.sqrt_one_minus_alphas_cumprod就是 ,是不是和论文中的公式严丝合缝的对上啦! alt

那么我们再看看self.sqrt_alphas_cumprod和self.sqrt_one_minus_alphas_cumprod具体是什么?self.sqrt_alphas_cumprod是458行的alphas_cumprod的开方,其实alphas_cumprod就是 ,从这里也可以实际看到 就是一系列alphas即 的连乘

alt

alphas具体如下,可以看到是由1-betas得到的,这里的betas就是论文中的

alt

其中 具体是在linear_beta_schedule实现的,可以清楚的看到是一个线性递增序列torch.linspace

alt
  • inference

看GaussianDiffusion的sample方法, 本质上是看self.p_sample_loop这个函数

alt

从这里可以清晰的看到,确实是个for函数,且每一步主要调用了 self.p_sample,所以主逻辑是写在 self.p_sample里面的,同时从585行也可以看到 其实就是从标准正态分布里面随机抽样一个作为恢复阶段最初始的图像,然后一步步去噪。

alt

self.p_sample函数如下,从578行一眼就可以看到其实就是个均值+方差的过程(和我们原理一节中讲的一模一样),其中的noise就是个标准的正态分布采样也即论文中的 ,model_mean和(0.5 * model_log_variance).exp()分布代表论文中的 .

同时577行也可以看到当是最后一步后就会把方差置0(和理论部分讲的一样)

alt

接着我们看model_mean和(0.5 * model_log_variance).exp()的实现,可以看到关键都是由self.p_mean_variance这个函数生成的model_mean和model_log_variance,其中self.p_mean_variance里面的核心函数是q_posterior,可以看到model_mean和model_log_variance最后对应的其实是posterior_mean和posterior_log_variance_clipped。

其中posterior_mean的代码在523-353

alt
alt

那么x_start到底是什么呢?我们还是再把开头的原理这张图搬过来,其实就是图中绿色框中的 ,其具体实现是在predict_start_from_noise这个函数,可以看到是和天蓝色框实现一一对应的

alt
alt

接着我们回到代码的q_posterior函数即523-353看,posterior_mean_coef1其实就是上面公式的黄色箭头,posterior_mean_coef2就是上面公式的灰色箭头,具体实现为,可以看到都是一一对应的:

alt

最后看posterior_log_variance_clipped即论文的中 最后可以追踪到posterior_variance,可以看到和论文中的公式是一模一样的。

alt
alt

以上就是均值+方差的具体实现啦

  • 小结

到此我们已经看完了扩散模型真正核心那部分代码的全部实现了,最后应该还有个Unet网络和一些train流程的代码,这个应该不难,感兴趣的小伙伴可以自行看看~

一些解读博客

https://www.bilibili.com/video/av601295714/?vd_source=247c686ab5fac4b46ead87ac455ab963

https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/124641910

https://zhuanlan.zhihu.com/p/572770333

https://blog.csdn.net/sunningzhzh/article/details/125118688

总结

(1)扩散模型的原理不难(一点点加噪再一点点去噪),代码实现也不难,难的是数学推导即理论那部分,推导出了最后的结论性公式,代码直接对着写就可以啦。

(2)我们知道最终的文生图其实是根据文字生成图,也就是说在inference阶段其实是有条件地去噪的,怎么把这个“有条件”加进去是关键,甚至我们这里的文字可以替换成其他的外部信号,这部分逻辑大家可以看Unet网络,后面有时间再写写这块吧。

(3)大家可以把这块代码消化消化,可以不管细节,但是模块的大致逻辑要清楚,起码知道它是在求啥,然后往自己的场景套一套这套代码(扩散模型)试试效果。

关注

欢迎关注,下期再见啦~

知乎,csdn,github,微信公众号

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/90173.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何快速构建企业级数据湖仓?

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 本文整理自火山引擎开发者社区技术大讲堂第四期演讲,主要介绍了数据湖仓开源趋势、火山引擎 EMR 的架构及特点,以及如何基于火山引擎 EMR 构…

Python+Yolov5人脸口罩识别

程序示例精选 PythonYolov5人脸口罩识别 如需安装运行环境或远程调试,见文章底部微信名片,由专业技术人员远程协助! 前言 Yolov5比较Yolov4,Yolov3等其他识别框架,速度快,代码结构简单,识别效率高&#xf…

【王道计算机网络笔记】网络层-网络层协议

文章目录地址解析协议ARP动态主机配置协议DHCP国际控制报文协议ICMPICMP差错报文ICMP询问报文ICMP的应用地址解析协议ARP 由于在实际网络的链路上传送数据帧时,最终必须使用MAC地址 ARP协议:完成主机或路由器IP地址到MAC地址的映射。解决下一跳走哪的问…

Metal每日分享,海报画滤镜效果

本案例的目的是理解如何用Metal实现海报画效果滤镜,主要就是改变颜色级别数量从而获取到新的像素颜色; Demo HarbethDemo地址 实操代码 // 海报画滤镜 let filter C7Posterize.init(colorLevels: 2.3)// 方案1: ImageView.image try? BoxxIO(eleme…

不用虚拟机也能在Windows下使用Linux

想学习热门的Linux系统,可是一开始就需要安装虚拟机软件,这样很容易消耗Linux初学者的热情。比如常用的VMWare虚拟机,虽然步骤并不复杂,但是一开始的搭建和配置过程, 容易劝退一部分新手。我认为学习新的操作系统&…

看了这篇文章后,面试官再也不敢问你非结构化存储的原理了

那么你可能会说,是不是我无限制地增加从库的数量就可以抵抗大量的并发呢? 实际上并不是的。因为随着从库数量增加,从库连接上来的 IO 线程比较多,主库也需要创建同样多的 log dump 线程来处理复制的请求,对于主库资源消…

[附源码]Python计算机毕业设计飞羽羽毛球馆管理系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等…

人工智能课后作业_python实现深度优先遍历搜索(DFS算法)(附源码)

1 深度优先遍历搜索(DFS) 1.1算法介绍1.2实验代码1.3实验结果1.4实验总结 1.1算法介绍 深度优先搜索算法(Depth-First-Search,DFS)是一种用于遍历或搜索树或图的算法。沿着树的深度遍历树的节点,尽可能深的搜索树的分支。当节点…

知识图谱-KGE-第三方库:OpenKE库【清华开源】

GitHub - thunlp/OpenKE: An Open-Source Package for Knowledge Embedding (KE) OpenKE是THUNLP基于TensorFlow、PyTorch开发的用于将知识图谱嵌入到低维连续向量空间进行表示的开源框架。在OpenKE中,我们提供了快速且稳定的各类接口,也实现了诸多经典…

生态流量智能终端机介绍 功能 特点

平升电子生态流量智能终端机是一款集人机交互、视频叠加、4G路由、数据采集、逻辑运算与远程传输功能于一体的多媒体智能终端设备。 此款产品为水电站生态流量监测项目的专用产品,便于监管单位及时掌握水电站的流量下泄情况,以保障河湖生态用水&#xf…

Java序列化_unknown object tag -126

项目场景: 第一次进入获取员工信息的方法时,会先通过序列化数据库的对应员工信息并保存到 Redis 中。 第二次进入获取员工信息的方法时,直接取出 Redis 里序列化后员工信息,进行反序列化后返回。 问题描述 这里是第一次保存成功…

重温经典,推箱子游戏,你能闯到第几关?可自行添加关卡

🎈 作者:Linux猿 🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊! &…

WebDAV之葫芦儿·派盘+厚墨

厚墨 支持WebDAV方式连接葫芦儿派盘。 如果你喜欢看电子书又时常书荒,搜索不到想要的电子书,那就快来试试厚墨阅读APP吧。与你一同搜索极简阅读中的最佳体验。 厚墨是目前网络上非常方便的一款电子阅读软件,采用独家数据采集分析技术,汇合了移动互联网各种资源网站大数据…

【JavaSE成神之路】可变参数

哈喽,我是兔哥呀,今天就让我们继续这个JavaSE成神之路! 这一节啊,咱们要学习的内容是Java的可变参数。 1.什么是可变参数 首先来看下概念。 Java的可变参数指的是在方法中设置不定数量的参数。可变参数使得代码更加简洁&#x…

用cocos creator实现《我的世界》

摘要 《我的世界》是一款非常流行的游戏,不过网上大多都是用unity还原实现的。那么用cocos实现一版,会是怎样的开发体验呢? 使用版本 使用最新的cocos creator 3.6.2版本 目前主要功能 生成地形方块创建与销毁角色移动、碰撞、重力和简单…

Java-MySQL-SQL函数

SQL函数 函数介绍 函数是 SQL 的一个非常强有力的特性,函数能够用于下面的目的: ● 执行数据计算 ● 修改单个数据项 ● 操纵输出进行行分组 ● 格式化显示的日期和数字 ● 转换列数据类型 SQL 函数有输入参数,并且总有一个返回值。 …

【云原生系列CKA备考】Kubernetes架构

目录前言一、Kubernetes架构1.1Master节点1.2 Node节点1.3 Add-ons1.3 Kubeadm二、相关命令2.1 查看组件运行状态2.2 kubeadm容器化组件三、总结前言 ​ OpenStack是管理虚拟机的,底层依靠虚拟化技术;kubernetes是管理容器的,底层也是依靠虚…

juery笔记

文章目录Jquery一、什么是 jQuery二、如何使用 jQuery三、如何选择 jQuery 版本四、jQuery 的运行原理实例方法1、一般通过一个字符串来标识匹配的元素2、支持多个选择器任意组合使用3、jQuery 特有的选择器,当然也可以和其他选择器任意组合使用4、元素筛选&#xf…

基于OpenGL的地形建模技术的研究与实现

毕业论文 基于OpenGL的地形建模技术的研究与实现 诚信声明 本人郑重声明:本设计(论文)及其研究工作是本人在指导教师的指导下独立完成的,在完成设计(论文)时所利用的一切资料均已在参考文献中列出。 本人…