扩散模型 DDPM 核心代码梳理

news2024/11/17 21:27:11

参考内容:

大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
AIGC 基础,从VAE到DDPM 原理、代码详解
全网最简单的扩散模型DDPM教程
The Annotated Diffusion Model
LaTeX公式编辑器

备注: 具体公式的推导请查看参考链接,本文只记录核心步骤的几个核心公式。

什么是扩散模型?

与Normalizing Flows、GAN或VAEs等生成模型一样,它们都将噪声从一些简单分布转换为数据样本。这也是使用神经网络学习从纯噪声开始逐渐去噪进行内容生成的过程。扩散模型主要包括以下两个过程:

  • 前向加噪: 前向加噪过程是一个固定的、预定义的过程,通过逐步的往一张真实图像上添加高斯噪声,最终得到一个完全的高斯噪声图像
  • 反向去噪: 反向去噪过程通过训练学习一个神经网络模型,模型的输入是一张带有噪声的图像,模型的输出是预测得到的噪声,逐步减去预测的噪声,最终得到一张真实的图像
    在这里插入图片描述

加噪、去噪、训练、推理阶段相关的数学公式

  • 前向加噪

在前向加噪过程中,逐步的往真实图片上添加高斯噪声,每一步添加高斯噪声的公式表示如下:
x t = 1 − β t x t − 1 + β t ϵ t \begin{equation}x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}\epsilon_{t}\end{equation} xt=1βt xt1+βt ϵt
其中, 0 < β 1 < β 2 < ⋯ < β T < 1 0 < \beta_{1} < \beta_{2} < \dots < \beta_{T} < 1 0<β1<β2<<βT<1 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵN(0,1) β t \beta_{t} βt的取值可以想神经网络的学习率衰减那样,使用线性的、余弦变化的。由于正态分布的均值和方差具有可加性,从[1, T]时刻逐步添加噪声的过程可以通过一步得到:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ \begin{equation}x_{t} = \sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}\epsilon\end{equation} xt=αtˉ x0+1αtˉ ϵ
其中, α t = 1 − β t \alpha_{t} = 1 - \beta_{t} αt=1βt α t ˉ = α t α t − 1 … α 1 \bar{\alpha_{t}} = \alpha_{t}\alpha_{t-1}\dots\alpha_{1} αtˉ=αtαt1α1

  • 模型训练

在模型训练阶段,对于一个真实的图像数据,随机生成[1, T]之前的整数,表示往真实图片数据中添加噪声的次数,然后将添加噪声后的图片输入到神经网络模型中,预测添加的噪声,基于神经网络预测的噪声和真实添加的噪声,计算损失:
L o s s = ∣ ∣ ϵ − ϵ θ ( α t ˉ x 0 + 1 − α t ˉ ∗ ϵ , t ) ∣ ∣ 2 \begin{equation}Loss = ||\epsilon -\epsilon_{\theta}(\sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}*\epsilon, t)||^{2}\end{equation} Loss=∣∣ϵϵθ(αtˉ x0+1αtˉ ϵ,t)2
其中, ϵ \epsilon ϵ表示在前向加噪过程中,使用公式(2)往真实图片中添加的随机噪声, ϵ θ \epsilon_{\theta} ϵθ表示一个神经网络模型,输入一个带有噪声的图像,以及对应添加噪声的时间步数,输出预测的噪声, x 0 x_{0} x0表示原始的真实图像, t t t表示时间步数。
在这里插入图片描述

  • 反向去噪

在反向去噪过程中,使用神经网络预测输出一个和输入图像一样大小的噪声数据,从输入图像中减去噪声数据,实现去噪。
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
其中, ϵ θ \epsilon _{\theta} ϵθ是一个神经网络模型, ϵ θ ( x t , t ) \epsilon _{\theta }(x_{t},t) ϵθ(xt,t)是神经网络模型预测输出的噪声, β t ˉ = 1 − α t ˉ \bar{\beta_{t}} = 1 - \bar{\alpha_{t}} βtˉ=1αtˉ

  • 模型推理

在模型推理阶段,也就是模型训练完之后进行图像的生成阶段,设置好迭代生成的时间步数 t t t,通过一个随机噪声 x t x_{t} xt,不断执行下面的步骤,直到公式(5)中的 t = 1 t = 1 t=1,实现图像的生成:
x t − 1 = 1 α t ( x t − β t β t ˉ ∗ ϵ θ ( x t , t ) ) + δ t ∗ z \begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation} xt1=αt 1(xtβtˉ βtϵθ(xt,t))+δtz
x t = x t − 1 \begin{equation}x_{t} = x_{t-1}\end{equation} xt=xt1
t = t − 1 \begin{equation}t = t-1\end{equation} t=t1

当公式(5)中的 t = 1 t = 1 t=1时,也就是最后一轮去噪,不加 δ t ∗ z \delta_{t}*z δtz,最后得到的 x 0 x_{0} x0就是生成的图像内容。
在这里插入图片描述

UNet网络结构

UNet神经网络在特定的时间步 t t t 接收噪声图像并返回预测的噪声。预测的噪声是一个与输入图像具有相同的大小/分辨率的张量。从技术上讲,网络输入和输出相同形状的张量。在DDPM中采用UNet架构的神经网络,UNet网络中主要包括以下部分:
在这里插入图片描述

  • 下采样:使用卷积 + 池化的方式实现图像分辨率的下采样
  • 上采样:使用转置卷积或者线性插值的方式,提升特征图的分辨率
  • Short-cut连接:将下采样和上采样得到的分辨率相同额特征图在通道维度上进行融合,有利于捕捉细粒度的图像特征
  • 注意力机制:使用注意力机制计算特征图上每个位置之间的注意力关系
  • time-embedding:由于DDPM是逐步生成图像的,所以需要一个特征能够标记当前执行到哪个时间步了

DDPM核心代码解释

  1. 基础代码:构造 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数
  • 使用不同的策略构建 β \beta β 序列
def linear_beta_schedule(timesteps):
    """
        在0.0001到0.02之间,均匀采样timesteps个数值,构造成beta序列
    """
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
  • 根据生成的 β \beta β 序列,生成 α , α ˉ , β ˉ \alpha,\bar{\alpha},\bar{\beta} α,αˉ,βˉ等, α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数的序列长度等于最大的迭代步长timesteps
timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
  • 备注
    • betas对应 β \beta β
    • alphas对应 α = 1 − β \alpha = 1 - \beta α=1β
    • alphas_cumprod对应 α ˉ \bar{\alpha} αˉ
    • sqrt_recip_alphas对应 1 α \frac{1}{\sqrt{\alpha}} α 1
    • sqrt_alphas_cumprod对应 1 α ˉ \frac{1}{\sqrt{\bar{\alpha}}} αˉ 1
    • sqrt_one_minus_alphas_cumprod对应 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1αˉ
  • 在训练阶段对于batch中的每个样本,加噪的迭代次数是从[0, T]中进行随机采样的,所以训练阶段每个样本的加噪次数 t ∈ [ 0 , T ] t \in [0, T] t[0,T] 是不同的,使用gather函数获取到每个样本的t对应的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,对应的代码如下:
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
  1. 前向加噪:根据上一步计算得到的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,将一张真实图像 x 0 x_{0} x0 使用公式(2)进行多次加噪,得到加噪后的图像,对应代码如下:
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    # x_start就是前面讲的最原始图像 x_0,根据 t 获取到对应的alpha,beta等参数
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    # 使用公式(2)对图像进行前向加噪
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
  1. UNet模型:将加噪后的样本以及每个样本对应的加噪次数 t 输入到UNet网络模型中,UNet模型预测输出加入的噪声,将UNet的输出结果与加入到图像中的噪声使用公式(3)计算损失,训练UNet网络模型。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)
    # x_start就是前面讲的最原始图像 x_0,这一步就是往 x_0 中加入t次的噪声
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 将加入噪声的图像以及对应的时间步数 t 输入到UNet模型
    predicted_noise = denoise_model(x_noisy, t)

    # 将UNet预测的结果与加入的噪声计算损失
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss
  1. 模型推理:当训练完UNet之后,在模型推理也就是图像生成阶段执行反向去噪过程。首先生成一张纯噪声的图像,初始时间步设置为timesteps,将噪声图像和时间步数值 t 输入到UNet模型中,预测得到输出结果,然后使用公式(4)计算得到经过去噪之后 t-1时间步的输出,如此迭代,直到 t=0为止。
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)

def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs


def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

注意事项:

  • torch.randn生成符合标准正态分布的数据,torch.rand生成符合0-1之间均匀分布的数据
  • UNet有利于细粒度的图像生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/984743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BC类电池取代TOPCon?隆基绿能看得很“远”

隆基绿能在业绩会上喊话&#xff0c;BC类电池在未来会逐步取代TOPCON电池&#xff0c;成为晶硅电池的绝对主流&#xff0c;并指出N型TOPCon电池只是短期过渡技术路线。 “All in ”BC类电池的隆基绿能&#xff0c;重新定义技术路线的野望藏不住。 目前来看&#xff0c;相比TO…

数据接口工程对接BI可视化大屏(一)

文章目录 第1章 案例概述1.1 案例目标1.2 BI最终效果1.2.1 PC端显示效果1.2.2 移动端显示效果 后记 第1章 案例概述 1.1 案例目标 此项目以常见的手机零售BI场景为例&#xff0c;介绍如何编写数据接口工程对接BI可视化大屏。 如何从当前常见的主流大数据场景中为后台程序推送…

入行测试一年半的心得体会

成为xx一员测试已经有1年半了&#xff0c;一直没有真正坐下来花些时间将自己的思路理清一下。刚好近期公司落地了OKR&#xff0c;给自己制定了OKR之后思路终于开始清晰起来&#xff0c;朦朦胧胧地开始看清了远方的路&#xff0c;麻着胆子分析一下自己&#xff0c;毕竟摸黑走路的…

2023年9月CSPM-3国标项目管理中级认证报名,找弘博创新

CSPM-3中级项目管理专业人员评价&#xff0c;是中国标准化协会&#xff08;全国项目管理标准化技术委员会秘书处&#xff09;&#xff0c;面向社会开展项目管理专业人员能力的等级证书。旨在构建多层次从业人员培养培训体系&#xff0c;建立健全人才职业能力评价和激励机制的要…

【效率提升】手把手教你如何使用免费的 Amazon Code Whisperer 提升开发效率堪比 GitHub Copilot 平替

说明 GitHub copilot 虽然很强&#xff0c;但是一个月10美金的费用拿来吃个小火锅他不香吗&#xff1f;而身为云计算博主将向你推荐一款可以平替 GitHub copilot 并且免费的支持多种编程语言的 AI 编程助手 Amazon Code Whisperer。 亚马逊云科技开发者社区为开发者们提供全球…

基于文本提示的图像目标检测与分割实践

近年来&#xff0c;计算机视觉取得了显着的进步&#xff0c;特别是在图像分割和目标检测任务方面。 最近值得注意的突破之一是分段任意模型&#xff08;SAM&#xff09;&#xff0c;这是一种多功能深度学习模型&#xff0c;旨在有效地从图像和输入提示中预测对象掩模。 通过利用…

Web安全——Web安全漏洞与利用上篇(仅供学习)

SQL注入 一、SQL 注入漏洞1、与 mysql 注入的相关知识2、SQL 注入原理3、判断是否存在注入回显是指页面有数据信息返回id 1 and 114、三种 sql 注释符5、注入流程6、SQL 注入分类7、接受请求类型区分8、注入数据类型的区分9、SQL 注入常规利用思路&#xff1a;10、手工注入常规…

ansible的安装和简单的块使用

目录 一、概述 二、安装 1、选择源 2、安装ansible 3、模块查看 三、实验 1、拓扑​编辑 2、设置组、ping模块 3、hostname模块 4、file模块 ​编辑 5、stat模块 6、copy模块&#xff08;本地拷贝到远程&#xff09; 7、fetch模块与copy模块类似&#xff0c;但作用…

YOLOv5改进算法之添加CA注意力机制模块

目录 1.CA注意力机制 2.YOLOv5添加注意力机制 送书活动 1.CA注意力机制 CA&#xff08;Coordinate Attention&#xff09;注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息&#xff0c;以便模型可以更好地…

火热的低代码,是时候系统的来学一学了!

一、前言 低代码诞生至今&#xff0c;大家各抒己见&#xff0c;也不乏有针锋相对的意思。古时的治国之术有百家争鸣&#xff0c;如今的低代码也有“诸子论道”&#xff0c;这本质上是一件有助于推动低代码发展的事情。 业内的朋友们一定知道&#xff0c;关于低代码的热点不止发…

数字内容风控行业首本白皮书正式发布,打造长效安全的数字内容生态

数字内容包含文本、图片、视频等多种形式&#xff0c;起源于计算机问世&#xff0c;并随着互联网、智能手机快速发展&#xff0c;如今&#xff0c;数字内容已经成为个人及企业建立形象、传播价值的必要途径。 2022年起&#xff0c;随着ChatGPT的火爆出圈&#xff0c;AI大模型强…

Kotlin+MVVM 构建todo App 应用

作者&#xff1a;易科 项目介绍 使用KotlinMVVM实现的todo app&#xff0c;功能界面参考微软的Todo软件&#xff08;只实现了核心功能&#xff0c;部分功能未实现&#xff09;。 功能模块介绍 项目模块&#xff1a;添加/删除项目&#xff0c;项目负责管理todo任务任务模块&a…

执行上下文-通俗易懂版

(1) js引擎执行代码时候/前&#xff0c;在堆内存创建一个全局对象&#xff0c;该对象 所有的作用域&#xff08;scope&#xff09;都可以访问&#xff0c;里面会包含Date、Array、String、Number、setTimeout、setInterval等等&#xff0c;其中还有一个window属性指向自己 (2…

C++数组类的自实现,使其可以保存学生成绩,并进行降序排列

类的封装 #ifndef ARRAY_H #define ARRAY_Hclass DoubArray { private:int m_length;double* m_pointer;public:DoubArray(int len);DoubArray(const DoubArray& obj);int length();bool get(int index, double& value);bool set(int index, double value);void sort(…

尚硅谷大数据项目《在线教育之离线数仓》笔记007

视频地址&#xff1a;尚硅谷大数据项目《在线教育之离线数仓》_哔哩哔哩_bilibili 目录 第12章 报表数据导出 P112 01、创建数据表 02、修改datax的jar包 03、ads_traffic_stats_by_source.json文件 P113 P114 P115 P116 P117 P118 P119 P120 P121 P122【122_在…

Hadoop:HDFS--分布式文件存储系统

目录 HDFS的基础架构 VMware虚拟机部署HDFS集群 HDFS集群启停命令 HDFS Shell操作 hadoop 命令体系&#xff1a; 创建文件夹 -mkdir 查看目录内容 -ls 上传文件到hdfs -put 查看HDFS文件内容 -cat 下载HDFS文件 -get 复制HDFS文件 -cp 追加数据到HDFS文件中 -appendTo…

第 3 章 栈和队列(汉诺塔问题递归解法)

1. 背景说明 假设有 3 个分别命名为 X、Y 和 Z 的塔座&#xff0c;在塔座 X 上插有 n 个直径大小各不相同、依小到大编号为 1, 2&#xff0c;…&#xff0c;n 的圆盘。 现要求将 X 轴上的 n 个圆盘移至塔座 Z 上并仍按同样顺序叠排&#xff0c;圆盘移动时必须遵循下列规则&…

伦敦金的走势高低的规律

伦敦金市场是一个流动性很强的市场&#xff0c;其价格走势会在诸多因素的影响下&#xff0c;出现反复的上下波动&#xff0c;如果投资者能够在这些高低走势中找到一定的规律&#xff0c;在相对有利的时机入场和离场&#xff0c;就能够通过不断的交易&#xff0c;累积大量的财富…

浏览器渲染原理及流程

浏览器主要组成与浏览器线程 浏览器组件 浏览器大体上由以下几个组件组成&#xff0c;各个浏览器可能有一点不同。 界面控件 – 包括地址栏&#xff0c;前进后退&#xff0c;书签菜单等窗口上除了网页显示区域以外的部分浏览器引擎 – 查询与操作渲染引擎的接口渲染引擎 – …

记录vite下使用require报错和解决办法

前情提要 我们现在项目用的是vite4react18开发的项目、但是最近公司有个睿智的人让我把webpack中的bpmn组件迁移过来、结果就出现问题啦&#xff1a;因为webpack是commonjs规范、但是vite不是、好像是es吧、可想而知各种报错 废话不多说啦 直接上代码&#xff1a; 注释是之前c…