参考内容:
大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
AIGC 基础,从VAE到DDPM 原理、代码详解
全网最简单的扩散模型DDPM教程
The Annotated Diffusion Model
LaTeX公式编辑器
备注: 具体公式的推导请查看参考链接,本文只记录核心步骤的几个核心公式。
什么是扩散模型?
与Normalizing Flows、GAN或VAEs等生成模型一样,它们都将噪声从一些简单分布转换为数据样本。这也是使用神经网络学习从纯噪声开始逐渐去噪进行内容生成的过程。扩散模型主要包括以下两个过程:
- 前向加噪: 前向加噪过程是一个固定的、预定义的过程,通过逐步的往一张真实图像上添加高斯噪声,最终得到一个完全的高斯噪声图像
- 反向去噪: 反向去噪过程通过训练学习一个神经网络模型,模型的输入是一张带有噪声的图像,模型的输出是预测得到的噪声,逐步减去预测的噪声,最终得到一张真实的图像
加噪、去噪、训练、推理阶段相关的数学公式
- 前向加噪
在前向加噪过程中,逐步的往真实图片上添加高斯噪声,每一步添加高斯噪声的公式表示如下:
x
t
=
1
−
β
t
x
t
−
1
+
β
t
ϵ
t
\begin{equation}x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}\epsilon_{t}\end{equation}
xt=1−βtxt−1+βtϵt
其中,
0
<
β
1
<
β
2
<
⋯
<
β
T
<
1
0 < \beta_{1} < \beta_{2} < \dots < \beta_{T} < 1
0<β1<β2<⋯<βT<1,
ϵ
∼
N
(
0
,
1
)
\epsilon \sim N(0,1)
ϵ∼N(0,1),
β
t
\beta_{t}
βt的取值可以想神经网络的学习率衰减那样,使用线性的、余弦变化的。由于正态分布的均值和方差具有可加性,从[1, T]时刻逐步添加噪声的过程可以通过一步得到:
x
t
=
α
t
ˉ
x
0
+
1
−
α
t
ˉ
ϵ
\begin{equation}x_{t} = \sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}\epsilon\end{equation}
xt=αtˉx0+1−αtˉϵ
其中,
α
t
=
1
−
β
t
\alpha_{t} = 1 - \beta_{t}
αt=1−βt,
α
t
ˉ
=
α
t
α
t
−
1
…
α
1
\bar{\alpha_{t}} = \alpha_{t}\alpha_{t-1}\dots\alpha_{1}
αtˉ=αtαt−1…α1。
- 模型训练
在模型训练阶段,对于一个真实的图像数据,随机生成[1, T]之前的整数,表示往真实图片数据中添加噪声的次数,然后将添加噪声后的图片输入到神经网络模型中,预测添加的噪声,基于神经网络预测的噪声和真实添加的噪声,计算损失:
L
o
s
s
=
∣
∣
ϵ
−
ϵ
θ
(
α
t
ˉ
x
0
+
1
−
α
t
ˉ
∗
ϵ
,
t
)
∣
∣
2
\begin{equation}Loss = ||\epsilon -\epsilon_{\theta}(\sqrt{\bar{\alpha_{t}}}x_{0} + \sqrt{1 - \bar{\alpha_{t}}}*\epsilon, t)||^{2}\end{equation}
Loss=∣∣ϵ−ϵθ(αtˉx0+1−αtˉ∗ϵ,t)∣∣2
其中,
ϵ
\epsilon
ϵ表示在前向加噪过程中,使用公式(2)往真实图片中添加的随机噪声,
ϵ
θ
\epsilon_{\theta}
ϵθ表示一个神经网络模型,输入一个带有噪声的图像,以及对应添加噪声的时间步数,输出预测的噪声,
x
0
x_{0}
x0表示原始的真实图像,
t
t
t表示时间步数。
- 反向去噪
在反向去噪过程中,使用神经网络预测输出一个和输入图像一样大小的噪声数据,从输入图像中减去噪声数据,实现去噪。
x
t
−
1
=
1
α
t
(
x
t
−
β
t
β
t
ˉ
∗
ϵ
θ
(
x
t
,
t
)
)
+
δ
t
∗
z
\begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation}
xt−1=αt1(xt−βtˉβt∗ϵθ(xt,t))+δt∗z
其中,
ϵ
θ
\epsilon _{\theta}
ϵθ是一个神经网络模型,
ϵ
θ
(
x
t
,
t
)
\epsilon _{\theta }(x_{t},t)
ϵθ(xt,t)是神经网络模型预测输出的噪声,
β
t
ˉ
=
1
−
α
t
ˉ
\bar{\beta_{t}} = 1 - \bar{\alpha_{t}}
βtˉ=1−αtˉ。
- 模型推理
在模型推理阶段,也就是模型训练完之后进行图像的生成阶段,设置好迭代生成的时间步数
t
t
t,通过一个随机噪声
x
t
x_{t}
xt,不断执行下面的步骤,直到公式(5)中的
t
=
1
t = 1
t=1,实现图像的生成:
x
t
−
1
=
1
α
t
(
x
t
−
β
t
β
t
ˉ
∗
ϵ
θ
(
x
t
,
t
)
)
+
δ
t
∗
z
\begin{equation}x_{t-1} = \frac{1}{\sqrt{\alpha_{t}}}(x_{t} - \frac{\beta_{t}}{\sqrt{\bar{\beta_{t}}}}*\epsilon _{\theta }(x_{t},t)) + \delta_{t}*z\end{equation}
xt−1=αt1(xt−βtˉβt∗ϵθ(xt,t))+δt∗z
x
t
=
x
t
−
1
\begin{equation}x_{t} = x_{t-1}\end{equation}
xt=xt−1
t
=
t
−
1
\begin{equation}t = t-1\end{equation}
t=t−1
当公式(5)中的
t
=
1
t = 1
t=1时,也就是最后一轮去噪,不加
δ
t
∗
z
\delta_{t}*z
δt∗z,最后得到的
x
0
x_{0}
x0就是生成的图像内容。
UNet网络结构
UNet神经网络在特定的时间步
t
t
t 接收噪声图像并返回预测的噪声。预测的噪声是一个与输入图像具有相同的大小/分辨率的张量。从技术上讲,网络输入和输出相同形状的张量。在DDPM中采用UNet架构的神经网络,UNet网络中主要包括以下部分:
- 下采样:使用卷积 + 池化的方式实现图像分辨率的下采样
- 上采样:使用转置卷积或者线性插值的方式,提升特征图的分辨率
- Short-cut连接:将下采样和上采样得到的分辨率相同额特征图在通道维度上进行融合,有利于捕捉细粒度的图像特征
- 注意力机制:使用注意力机制计算特征图上每个位置之间的注意力关系
- time-embedding:由于DDPM是逐步生成图像的,所以需要一个特征能够标记当前执行到哪个时间步了
DDPM核心代码解释
- 基础代码:构造 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数
- 使用不同的策略构建 β \beta β 序列
def linear_beta_schedule(timesteps):
"""
在0.0001到0.02之间,均匀采样timesteps个数值,构造成beta序列
"""
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
- 根据生成的 β \beta β 序列,生成 α , α ˉ , β ˉ \alpha,\bar{\alpha},\bar{\beta} α,αˉ,βˉ等, α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数的序列长度等于最大的迭代步长timesteps
timesteps = 300
# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
- 备注:
- betas对应 β \beta β
- alphas对应 α = 1 − β \alpha = 1 - \beta α=1−β
- alphas_cumprod对应 α ˉ \bar{\alpha} αˉ
- sqrt_recip_alphas对应 1 α \frac{1}{\sqrt{\alpha}} α1
- sqrt_alphas_cumprod对应 1 α ˉ \frac{1}{\sqrt{\bar{\alpha}}} αˉ1
- sqrt_one_minus_alphas_cumprod对应 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1−αˉ
- 在训练阶段对于batch中的每个样本,加噪的迭代次数是从[0, T]中进行随机采样的,所以训练阶段每个样本的加噪次数 t ∈ [ 0 , T ] t \in [0, T] t∈[0,T] 是不同的,使用gather函数获取到每个样本的t对应的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,对应的代码如下:
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
- 前向加噪:根据上一步计算得到的 α , β , α ˉ , β ˉ \alpha,\beta,\bar{\alpha},\bar{\beta} α,β,αˉ,βˉ等参数,将一张真实图像 x 0 x_{0} x0 使用公式(2)进行多次加噪,得到加噪后的图像,对应代码如下:
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
# x_start就是前面讲的最原始图像 x_0,根据 t 获取到对应的alpha,beta等参数
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
# 使用公式(2)对图像进行前向加噪
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
- UNet模型:将加噪后的样本以及每个样本对应的加噪次数 t 输入到UNet网络模型中,UNet模型预测输出加入的噪声,将UNet的输出结果与加入到图像中的噪声使用公式(3)计算损失,训练UNet网络模型。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
# x_start就是前面讲的最原始图像 x_0,这一步就是往 x_0 中加入t次的噪声
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 将加入噪声的图像以及对应的时间步数 t 输入到UNet模型
predicted_noise = denoise_model(x_noisy, t)
# 将UNet预测的结果与加入的噪声计算损失
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
- 模型推理:当训练完UNet之后,在模型推理也就是图像生成阶段执行反向去噪过程。首先生成一张纯噪声的图像,初始时间步设置为timesteps,将噪声图像和时间步数值 t 输入到UNet模型中,预测得到输出结果,然后使用公式(4)计算得到经过去噪之后 t-1时间步的输出,如此迭代,直到 t=0为止。
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# Algorithm 2 line 4:
return model_mean + torch.sqrt(posterior_variance_t) * noise
# Algorithm 2 (including returning all images)
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
注意事项:
- torch.randn生成符合标准正态分布的数据,torch.rand生成符合0-1之间均匀分布的数据
- UNet有利于细粒度的图像生成