【源码解读】扩散模型核心:DDPM专题-结合源码讲解

news2025/1/9 13:23:51

目录

本次训练采用的是cifar数据集,代码和下载好的数据将打包上传在百度网盘

1. 训练

image-20230628172617191

1.1 Uniform({1,…,T})

image-20230628172711622

训练过程, t是随机采样获得的, 这一步是核心之一, 相当于伪代码中的 Step3: t ∼ Uniform ⁡ ( { 1 , … , T } ) t \sim \operatorname{Uniform}(\{1, \ldots, T\}) tUniform({1,,T})

1.2 ϵ ∼ N ( 0 , I ) \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵN(0,I)

代码142行:生成均值为1,方差为0的标准高斯分布噪声

注意一个细节,t的维度是128,表示一个batchsize一起进行加噪

image-20230628174504572

1.3 加噪

主要的函数代码在144行,这里将随机采样的加噪时间t,生成的noise和一个bath的image一起放入perturb函数中。

这里的加噪公式对应论文中的: q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)

采用重参数化技巧后得到如下伪代码表示:

betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
alphas = 1 - betas
# cumprod 相当于为每个时间步 t 计算一个数组 alphas 的前缀乘结果
# https://pytorch.org/docs/stable/generated/torch.cumprod.html
alphas_cum = torch.cumprod(alphas, 0)
alphas_cum_s = torch.sqrt(alphas_cum)
alphas_cum_sm = torch.sqrt(1 - alphas_cum)
 
# 应用重参数化技巧采样得到 xt
noise = torch.randn_like(x_0)
xt = alphas_cum_s[t] * x_0 + alphas_cum_sm[t] * noise

image-20230628182143289

以下为 extract 函数的具体代码:

image-20230628204308134

注意:这里的alpha和beta参数是为了控制扩散程度,即改变高斯噪声的均值和方差,他们是一次性生成的,原始DDPM论文是设置 T=1000, β1=0.0001, βT=0.02,代码中对应如下代码:**Line167&185 ** 这里设置了两种超参数的生成方式

image-20230628204049422

这里也可以参考博客 一文弄懂 Diffusion Model DDPM架构图解

1.4 加噪图片送入UNet预测加入的噪声

这部分的代码核心是时间信息如何加入到UNet中,可以参考代码Lin357&369

image-20230628205704053

这里每一个time[128]信息会被self.time_mlp编码为一个embeeeding,即time_emb[128,512]

image-20230628205928278

可以参考代码的Line260,这里time会被TimeEmbedding层采用和Transformer一致的三角函数位置编码,将常数转变为向量

image-20230628210118739

以下是DDMP的UNet整体代码,关键在于理解这个UNet是如何把时间信息和x融合起来的

image-20230628205234631

DDPM中的Unet架构

image-20230628205301459

DownBlock和UpBlock

可以参考以下的基础block,Line199,ResidualBlock,就是把time_embedding经过一层nn.Linear,x经过一层nn.Conv2d,然后相加即可融合二者信息

image-20230628210734945

1.5 预测的噪声和加入的噪声进行损失计算

参考Line150,这样UNet模型拥有了预测图片中的噪声分布的能力

image-20230628211051218

2. 采样

image-20230628214047097

训练1000步,执行一次采样

image-20230628212929419

采样函数的具体细节,注意,这时候的t就不是随机生成的t了,而是从1000逐步递减下来的

image-20230628211710736

注意,采样函数会过一遍UNet模型,得到UNet预测到的当前时间步的noise,然后用x-noise得到当前时间步的去噪图片,可以参考如下代码Line109

image-20230628213556933

总结:可以看到虽然前面的推导过程很复杂,但是训练过程却很简单:

  • 首先每个迭代就是从数据集中取真实图像 x0,并从均匀分布中采样一个时间步 t
  • 然后从标准高斯分布中采样得到噪声 ε,并根据公式计算得到前向过程的 xt。
  • 接着将 xt 和 t 输入到模型让其输出去拟合预测噪声 ε,并通过梯度下降更新模型,一直循环直到模型收敛。
  • 而采用的深度学习模型是类似 UNet 的结构。

3. 推理

推理过程很简单,给一个随机噪声,使用预训练模型权重,直接过一遍模型参数进行采样,即可得到所需图像。注意,这里label标志着采样的步长。步长越长,去噪效果越好,生成的图片质量越佳。

image-20230628212543247

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/696961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

15天涨粉50万!B站有900万人看过都说“震撼”

“卷”是内容创作者对现在互联网竞争最大的评价,创作者之间复制力极强,导致赛道竞争力大,创作者亟待不断地推动自己找到一个又一个新的内容差异、流量风口。 所以不管是创作者还是品牌,只要是涉及到内容运营的都只有一个目标&…

【C#】反射机制,动态加载类文件

系列文章 【C#】编号生成器(定义单号规则、固定字符、流水号、业务单号) 本文链接:https://blog.csdn.net/youcheng_ge/article/details/129129787 【C#】日期范围生成器(开始日期、结束日期) 本文链接:h…

网络解析----yolov3网络解析

Yolov3是一种针对目标检测任务的神经网络模型,其网络结构主要由三个部分组成:特征提取网络、检测头和非极大值抑制(NMS)模块。特征提取网络采用Darknet-53作为骨干网络,Darknet-53由53个卷积层和5个Max-Pooling层组成&…

死锁的产生

死锁的产生:有一个公共区的玩具,A和B想玩,但是A先得到了玩具,A玩完玩具之后又去干别的事情,但是并没有把玩具还回去,此时B就玩不到了玩具,在无限期的等待。 如下图所示: 线程1把num资…

php压缩一个文件夹,php下载多个图片

$area_id 100;$area_name 一百;shell_exec("cd /www/wwwroot/api/public/images/ && zip -r " . $area_name . ".zip " . $area_id . "/"); 把 100/ 这个文件夹,压缩成 一百.zip 然后得到zip所在的下载url 这个功能&…

百度编辑器(Ueditor)视频上传到阿里云 + 预览不支持FLASH问题解决 + 输入框不展示视频播放页面问题解决

目前需求方提出的问题是以下四个: 1.百度编辑器(Ueditor)视频上传到阿里云 2.解决不支持FLASH问题 3.视频上传后可以预览 4.修改视频封面 看一下原始的功能是什么样的 上传视频: 视频上传完成 上传视频保存的路径&#xff1…

使用vant组件库

参考网址 Vant Weapp - 轻量、可靠的小程序 UI 组件库 1.在小程序中右键打开外部终端窗口 2.npm init -y 生成package.json 如果没有npm指令则需安装node.js 地址:https://nodejs.org/dist/v18.16.1/node-v18.16.1-x64.msi 3.npm i vant/weapp1.3.3 -S --pro…

PHP 论坛系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 论坛系统 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为PHP APACHE,数据库为 mysql5.0,使用php语言开发。 下载链接…

oracle服务器的CPU占用率一直100%排查方式

背景说明 公司开发、测试、演示环境,三个环境的oracle服务器无论服务器是否空闲, CPU的占用率一直是100%, 一直也没有找到问题原因,今天就花了一整天时间研究这个问题。 通过AWR报告查看oracle运行情况 awr报告是oracle 10g下提…

一种人体属性识别的网络结构

0、前言 人体属性识别,是一个典型的多标签分类场景。每个人体有多个标签,如年龄、性别、衣着颜色等,而每个属性又有多种类别,如年龄分儿童青年老人、性别分男女、颜色分红绿青蓝紫... 本文提供了一个网络结构来执行这种任务。 …

【KVM】命令行安装kvm

命令行安装kvm 一、准备镜像文件 mkdir /home/iso cd /home/iso rz ls CentOS-7-x86_64-Minimal-2009.iso二、使用命令行安装虚拟机 virt-install --virt-typekvm --nameKVM_01 --vcpus4 --memory6000 --location/opt/CentOS-7-x86_64-Minimal-2009.iso --disk path/data/kv…

2021年全国硕士研究生入学统一考试管理类专业学位联考数学试题——纯题目版

2021 年 1 月份管综初数真题 一、问题求解(本大题共 5 小题,每小题 3 分,共 45 分)下列每题给出 5 个选项中,只有一个是符合要求的,请在答题卡上将所选择的字母涂黑。 1.某便利店第一天售出50种商品&…

联邦学习中的模型聚合

目录 联邦学习中的模型聚合 1.client-server 算法 2. fully decentralized(完全去中心化)算法 联邦学习中的模型聚合 在联邦学习的情景下引入了多任务学习,其采用的手段是使每个client/task节点的训练数据分布不同,从而使各任务节点学习到不同的模型…

[python] 进度条使用

from tqdm import tqdm# 创建一个示例字典 my_dict {a: 1, b: 2, c: 3}# 使用tqdm遍历字典的键 for key in tqdm(my_dict.keys()):# 在这里编写你的代码# 这部分代码将会在进度条中显示pass# 使用tqdm遍历字典的值 for value in tqdm(my_dict.values()):# 在这里编写你的代码#…

查看 git的 config 配置

git config --list // 查看全部配置信息git config user.name // 查看指定配置信息 查看某一个配置信息 git config --global user.email 参考 如何查看gitconfig配置_笔记大全_设计学院

牛客BM21 旋转数组的最小数字

描述 有一个长度为 n 的非降序数组,比如[1,2,3,4,5],将它进行旋转,即把一个数组最开始的若干个元素搬到数组的末尾,变成一个旋转数组,比如变成了[3,4,5,1,2],或者[4,5,1,2,3]这样的。请问,给定…

IDEA远程DeBug调试

1. 介绍 当我们在开发过程中遇到一些复杂的问题或需要对代码进行调试时,远程调试是一种非常有用的工具。使用 IntelliJ IDEA 进行远程调试可以让你在远程服务器上的应用程序中设置断点、查看变量和执行调试操作。 远程调试的好处如下: 提供更方便的调试…

大众汽车车载娱乐系统曝安全漏洞,可被远程控制

根据GitHub的一份报告,大众汽车Discover Media信息娱乐系统的漏洞是在2023年2月28日发现的。 该漏洞可能会使未打补丁的系统遭到拒绝服务(DoS)攻击。该漏洞起初是由大众汽车的用户发现的,随后大众汽车方面确认了该漏洞&#xff0…

Golang 一个支持错误堆栈, 错误码, 错误链的工具库

介绍 来腾讯之后主要使用go, 在业务开发中需要一个支持错误码对外返回, 堆栈打印等能力的错误工具库, 先开始使用pkg/errors, 但该项目已经只读, 上次更新是几年前, 而且有一些点比如调整堆栈深度等没有支持, 后续根据业务的需要抽取了一个通用库, 且做了一些优化, 详见下方. …

Apikit 自学日记:发起文档测试-RPC

以DUBBO接口为例,进入某个DUBBO协议的API文档详情页,点击文档上方 测试 标签,即可进入 API 测试页,系统会根据API文档的定义的请求报文自动生成测试界面并且填充测试数据。 对RPC/DUBBO接口发起测试 填写请求报文参数值 此测试D…