DDPM扩散模型
一、前置知识
1. 条件概率知识
P ( A ∣ B ) = P ( A B ) P ( B ) P(A|B) = \frac{P(AB)}{P(B)} P(A∣B)=P(B)P(AB)
P ( A B C ) = P ( C ∣ B A ) P ( B A ) = P ( C ∣ B A ) P ( B ∣ A ) P ( A ) P(A B C) = P(C|B A)P(BA) = P(C|B A)P(B|A)P(A) P(ABC)=P(C∣BA)P(BA)=P(C∣BA)P(B∣A)P(A)
P ( B C ∣ A ) = P ( B ∣ A ) P ( C ∣ A , B ) P(B C|A) = P(B|A)P(C|A, B) P(BC∣A)=P(B∣A)P(C∣A,B)
P ( C ∣ A , B ) = P ( B C ∣ A ) P ( B ∣ A ) P(C|A, B) = \frac{P(BC| A)}{P(B|A)} P(C∣A,B)=P(B∣A)P(BC∣A)
2. 基于马尔科夫假设的条件概率
如果满足马尔科夫链关系 A − > B − > C A -> B -> C A−>B−>C那么有
P ( A B C ) = P ( C ∣ B A ) P ( B A ) = P ( C ∣ B ) P ( B ∣ A ) P ( A ) P(A B C) = P(C|BA)P(B A) = P(C|B)P(B|A)P(A) P(ABC)=P(C∣BA)P(BA)=P(C∣B)P(B∣A)P(A)
P ( B C ∣ A ) = P ( B ∣ A ) P ( C ∣ B ) P(B C|A) = P(B|A)P(C|B) P(BC∣A)=P(B∣A)P(C∣B)
3. 高斯分布的KL散度公式
对于两个单一变量的高斯分布 P 和 Q 而言,它们的 KL 散度为: K L ( P , Q ) = log σ 1 σ 2 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(P, Q) = \log{\frac{\sigma_1}{\sigma_2}} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2} KL(P,Q)=logσ2σ1+2σ22σ12+(μ1−μ2)2−21
KL 散度,又称为相对熵,描述两个概率分布P和Q的差异和相似性,用 D K L ( P ∣ ∣ Q ) D_{KL}(P||Q) DKL(P∣∣Q)表示
显然,散度越小,说明概率Q与概率P之间越接近,那么估计的概率分布与真实的概率分布也就越接近。
KL 散度的性质:
- 非对称性: D K L ≠ D K L ( Q ∣ ∣ P ) D_{KL} \neq D_{KL}(Q || P) DKL=DKL(Q∣∣P)
- D K L ( P ∣ ∣ Q ) ≥ 0 D_{KL}(P || Q) \geq 0 DKL(P∣∣Q)≥0,仅在 P = Q P = Q P=Q时等于0
4. 参数重整化
如果希望从高斯分布 N ( μ , σ 2 ) N(\mu, \sigma^2) N(μ,σ2)中采样( μ \mu μ:表示均值, σ 2 \sigma^2 σ2:表示方差),可以先从标准分布 N ( 0 , 1 ) N(0, 1) N(0,1)采样处 z z z,再得到 σ × z + μ \sigma \times z + \mu σ×z+μ ,这样做的好处是将随机性转移到了 z z z这个常量上了,而 σ \sigma σ和 μ \mu μ则当做仿射变换网络的一部分。
二、Diffusion Model流程
x 0 x_0 x0是初始数据(一个初始的图片), x T x_T xT是最终的图片(一个纯噪声的图片)。
- x 0 ∼ x T x_0 \sim x_T x0∼xT的过程是一个加噪过程,每次从 q ( x t ∣ x t − 1 ) q(x_t|x_{t - 1}) q(xt∣xt−1)分布中取噪声,然后添加到前一个时间步的图片数据中,这样经过T个时间步,我们就能得到一个纯噪声的图片了。
- x T ∼ x 0 x_T \sim x_0 xT∼x0的过程是一个去噪过程,每次从 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t - 1}|x_t) pθ(xt−1∣xt)分布中取噪声,然后使前一个时间步的图片数据减去该噪声,这样经过T个时间步,我们就能得到原始的图片了。
其中 q ( x t ∣ x t − 1 ) q(x_t|x_{t - 1}) q(xt∣xt−1)是自己设定的一个加噪分布,而 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t - 1}|x_t) pθ(xt−1∣xt)是需要神经网络去学习的一个分布,我们会使用参数分布来去对该分布做估计,由于使用了参数重整化的思想( σ × ϵ + μ \sigma \times \epsilon + \mu σ×ϵ+μ,其中 σ \sigma σ是分布的方差, μ \mu μ表示的是分布的均值, ϵ \epsilon ϵ是从标准正态分布中随机采样的一个值),我们加噪过程是从一个标准正态分布中随机采样一个值,然后再进行参数重整化,依据 μ \mu μ和 σ \sigma σ得到特定分布下的噪声,而去噪过程是利用神经网络学习这个噪声,然后在每个时间步上减去预测出的噪声。
三、加噪过程
给定初始数据分布 x 0 ∼ q ( x ) x_0 \sim q(x) x0∼q(x),可以不断地向分布中添加高斯噪声,该噪声的方差是以固定值 β t \beta_t βt而确定的,均值是以固定值 β t \beta_t βt和当前 t t t时刻的数据 x t x_t xt决定的。这个过程是一个马尔科夫链过程,随着 t t t的不断增大,不断的向数据中添加噪声,最终数据分布 x t x_t xt变成了一个各向独立的高斯分布。
噪声的分布可以表示如下:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t ⋅ x t − 1 , β t I ) q(x_t|x_{t - 1}) = N(x_t; \sqrt{1 - \beta_t}\cdot x_{t - 1}, \beta_tI) q(xt∣xt−1)=N(xt;1−βt⋅xt−1,βtI) 其中 1 − β t ⋅ x t − 1 \sqrt{1 - \beta_t}\cdot x_{t- 1} 1−βt⋅xt−1表示该分布的均值, β t I \beta_t I βtI 表示方差( I I I表示单位矩阵)。
那么有: x t = 1 − β t x t − 1 + β t z t x_t = \sqrt{1 - \beta_t}x_{t - 1} + \sqrt{\beta_t} z_t xt=1−βtxt−1+βtzt,其中 1 − β t \sqrt{1 - \beta_t} 1−βt是一个控制噪声强度的系数, β t \beta_t βt是一个添加噪声比例的系数,论文中说明,当分布越来月接近噪声分布的时候,可以将 β t \beta_t βt变得大一点,这样做可以再初始的时候 β t \beta_t βt很小,那么添加的噪声也就很小,而 1 − β t \sqrt{1 - \beta_t} 1−βt会很大,那么将会保留更多原来数据的特征,再最后的时候 β t \beta_t βt很大,那么添加的噪声也会更大,而 1 − β t \sqrt{1 - \beta_t} 1−βt也就会更大,那么将会去除掉更多原来数据的特征。
设定噪声的方差 β t ∈ ( 0 , 1 ) \beta_t \in (0, 1) βt∈(0,1),并且 β t \beta_t βt 随着 t t t 的增大而增大。
通过上述的分布,我们可以将原始数据图片 x 0 x_0 x0通过 q ( x 1 ∣ x 0 ) q(x_1| x_0) q(x1∣x0)分布进行加噪,从而得到 x 1 x_1 x1,然后再通过 q ( x 2 ∣ x 1 ) q(x_2| x_1) q(x2∣x1)分布进行对 x 1 x_1 x1加噪,从而得到 x 2 x_2 x2,并以此类推,我们可以得到最终纯噪声的高斯分布。
上述过程时比较麻烦的,因为我们需要得到 t − 1 t - 1 t−1时刻的数据分布才能得到 t t t 时刻加噪后的数据分布,其实任意时刻的 q ( x t ) q(x_t) q(xt)数据分布可以直接基于 x 0 x_0 x0和 β t \beta_t βt而计算出来,而不需要一步一步的迭代,其推导过程如下:
正态分布叠加性质:正态分布 X ∼ N ( μ 1 , σ 1 2 ) X \sim N(\mu_1, \sigma_1^2) X∼N(μ1,σ12)和 Y ∼ N ( μ 2 , σ 2 ) Y \sim N(\mu_2, \sigma_2) Y∼N(μ2,σ2)线性叠加后的分布为 Z = a X + b Y Z = aX + bY Z=aX+bY,则叠加后分布的均值为 a μ 1 + b μ 2 a\mu_1 + b\mu_2 aμ1+bμ2,方差为 a 2 σ 1 2 + b 2 σ 2 2 a^2\sigma_1^2 + b^2\sigma_2^2 a2σ12+b2σ22,即 Z ∼ N ( a μ 1 + b μ 2 , a 2 σ 1 2 + b 2 σ 2 2 ) Z \sim N(a\mu_1 + b\mu_2, a^2\sigma_1^2 + b^2\sigma_2^2) Z∼N(aμ1+bμ2,a2σ12+b2σ22)。
推导公式:
有正态分布的叠加性质可知: α t − α t α t − 1 ⋅ z t − 2 + 1 − α t ⋅ z t − 1 \sqrt{\alpha_t - \alpha_t\alpha_{t - 1}} \cdot z_{t - 2} + \sqrt{1 - \alpha_t}\cdot z_{t - 1} αt−αtαt−1⋅zt−2+1−αt⋅zt−1可以参数重整化成只含一个随机变量 z z z构成的 1 − α t ⋅ z t − 1 \sqrt{1 - \alpha_t}\cdot z_{t - 1} 1−αt⋅zt−1可以参数重整化为只含一个随机变量 z z z构成的 1 − α t α t − 1 ⋅ z \sqrt{1 - \alpha_t\alpha_{t - 1}}\cdot z 1−αtαt−1⋅z 的形式,以此类推可以化简为一个最终的结果。
上述公式中的 z t − 1 、 z t − 2 、 ⋯ z_{t -1}、z_{t - 2}、\cdots zt−1、zt−2、⋯都是一个从正态分布中随其采样的数据。
最终可以得到: q ( x t ∣ x 0 ) = N ( x t ; α t ˉ ⋅ x 0 , ( 1 − α t ˉ ) I ) q(x_t | x_0) = N(x_t; \sqrt{\bar{\alpha_{t}}}\cdot x_0, (1 - \bar{\alpha_{t}}) I) q(xt∣x0)=N(xt;αtˉ⋅x0,(1−αtˉ)I) , 此时我们只需要知道初始的数据分布即可直接计算处任意时刻加噪后的数据分布,而不需要一个一个迭代求得。
四、去噪过程
去噪过程是加噪过程的逆过程,是从高斯噪声中恢复原始数据的过程,我们可以假设去噪的噪声也是取自一个高斯分布,我们无法逐步地去直接拟合分布,因此需要构建一个参数分布来去做估计,逆扩散过程仍然是一个马尔科夫链过程。
从 x T x_T xT(纯噪声数据)恢复到初始图片数据 x 0 x_0 x0的公式: p θ ( x 0 ⋯ T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(x_{0\cdots T}) = p(x_T)\prod\limits_{t = 1}^{T}p_\theta(x_{t - 1}| x_t) pθ(x0⋯T)=p(xT)t=1∏Tpθ(xt−1∣xt)
其中 p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , σ θ 2 ( x t , t ) ) p_\theta(x_{t - 1}| x_t) = N(x_{t - 1};\mu_\theta(x_t, t), \sigma_\theta^2(x_t, t)) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σθ2(xt,t)),里面有两个未知的参数,分别是 t t t时刻的 μ θ \mu_\theta μθ和 σ θ \sigma_\theta σθ,这两个参数就是需要神经网络需要拟合的参数。
我们无法直接知道 q ( x t − 1 ∣ x t ) q(x_{t -1}|x_t) q(xt−1∣xt),但是 q ( x t − 1 ∣ x t , x 0 ) q(x_{t - 1}| x_t, x_0) q(xt−1∣xt,x0)分布是可以用 q ( x t ∣ x 0 ) q(x_{t}|x_0) q(xt∣x0)和 p ( x t ∣ x t − 1 ) p(x_t|x_{t - 1}) p(xt∣xt−1)进行表示,也就是说知道了 x t x_t xt和 x 0 x_0 x0,我们是可以计算出 x t − 1 x_{t - 1} xt−1
知识回顾
高斯分布的概率密度函数
f ( x ) = 1 2 π ⋅ σ ⋅ e − ( x − μ ) 2 2 σ 2 f(x) = \frac{1}{\sqrt{2 \pi}\cdot \sigma}\cdot e^{- \frac{(x - \mu)^2}{2 \sigma^2}} f(x)=2π⋅σ1⋅e−2σ2(x−μ)2
其中一个重要的转换: exp ( − ( x − μ ) 2 2 σ 2 ) \exp{(-\frac{(x - \mu)^2}{2\sigma_2})} exp(−2σ2(x−μ)2) = exp ( − 1 2 ( 1 σ 2 x 2 ) − 2 μ σ 2 x + μ 2 σ 2 ) \exp{(-\frac{1}{2}(\frac{1}{\sigma^2}x^2) - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2})} exp(−21(σ21x2)−σ22μx+σ2μ2)二次函数的转换:
a x 2 + b x = a ( x + b 2 a ) 2 + C ax^2 + bx = a(x + \frac{b}{2a})^2 + C ax2+bx=a(x+2ab)2+C
转换后的数据最后 + C +C +C表示数据转换后的一些常数项,其中 b 2 a \frac{b}{2a} 2ab是二次函数的对称轴部分,高斯分布中为均值部分。
我们假设
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t - 1}| x_t, x_0)
q(xt−1∣xt,x0)也是一个高斯分布,并且其分布如下:均值是一个与
x
t
x_t
xt和
x
0
x_0
x0相关的数据,方差是一个与
β
t
\beta_t
βt相关的数据。
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
(
x
t
,
x
0
)
,
β
t
~
I
)
q(x_{t - 1}| x_t, x_0) = N(x_{t - 1};\widetilde{\mu}(x_t, x_0), \widetilde{\beta_t}I)
q(xt−1∣xt,x0)=N(xt−1;μ
(xt,x0),βt
I)
我们可以进行如下推导:
上述公式的具体推导过程如下:
我们已知:
将已知条件带入到公式中可得:
根据高斯分布重要的转换公式,我们可以得到方差: β ~ t = 1 ( α t β t + 1 1 − α t − 1 ˉ ) = 1 − α t − 1 ˉ 1 − α t ˉ ⋅ β t \widetilde\beta_t = \frac{1}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha_{t-1}}}) } = \frac{1 - \bar{\alpha_{t - 1}}}{1 - \bar{\alpha_{t}}} \cdot \beta_t β t=(βtαt+1−αt−1ˉ1)1=1−αtˉ1−αt−1ˉ⋅βt
根据二次函数的转换,我们可以得到均值: μ ~ ( x t , x 0 ) = ( α t β t x t + α t ˉ 1 − α ˉ t x 0 ) / ( α t β t + 1 1 − α t − 1 ˉ ) = α t ( 1 − α t − 1 ˉ ) 1 − α t ˉ x t + α t − 1 ˉ β t 1 − α t ˉ x 0 \widetilde{\mu}(x_t, x_0) = (\frac{\sqrt{\alpha_t}}{\beta_t}x_t + \frac{\sqrt{\bar{\alpha_t}}}{1 - \bar\alpha_t}x_0)/ (\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha_{t - 1}}}) = \frac{\sqrt{\alpha_t(1 - \bar{\alpha_{t - 1}})}}{1 - \bar{\alpha_t}}x_t + \frac{\sqrt{\bar{\alpha_{t - 1}}} \beta_t}{1 - \bar{\alpha_t}}x_0 μ (xt,x0)=(βtαtxt+1−αˉtαtˉx0)/(βtαt+1−αt−1ˉ1)=1−αtˉαt(1−αt−1ˉ)xt+1−αtˉαt−1ˉβtx0
此时可以将
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t - 1}|x_t, x_0)
q(xt−1∣xt,x0)的分布写为:
q
(
x
t
−
1
∣
x
t
,
x
0
)
∼
N
(
x
t
−
1
;
(
1
−
α
t
−
1
ˉ
)
α
t
1
−
α
t
ˉ
x
t
+
β
t
α
t
−
1
ˉ
1
−
α
t
ˉ
x
0
,
1
−
α
t
−
1
ˉ
1
−
α
t
ˉ
β
t
)
q(x_{t - 1}|x_t, x_0) \sim N(x_{t - 1};\frac{(1 - \bar{\alpha_{t - 1}})\sqrt{\alpha_t}}{1 - \bar{\alpha_t}}x_t + \frac{\beta_t \sqrt{\bar{\alpha_{t - 1}}}}{1 - \bar{\alpha_t}}x_0, \frac{1 - \bar{\alpha_{t - 1}}}{1 - \bar{\alpha_t}}\beta_t)
q(xt−1∣xt,x0)∼N(xt−1;1−αtˉ(1−αt−1ˉ)αtxt+1−αtˉβtαt−1ˉx0,1−αtˉ1−αt−1ˉβt)
可以看到,
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t - 1}|x_t, x_0)
q(xt−1∣xt,x0)的分布方差是知道的,我们只需要求得均值即可求出该参数分布,我们继续求解均值。
根据前面的
x
0
x_0
x0与
x
t
x_t
xt之间的关系(
x
t
=
α
t
ˉ
⋅
x
0
+
1
−
α
t
ˉ
⋅
z
t
x_t = \sqrt{\bar{\alpha_t}}\cdot x_0 + \sqrt{1 - \bar{\alpha_t}}\cdot z_t
xt=αtˉ⋅x0+1−αtˉ⋅zt),我们可以知道:
x
0
=
1
α
t
ˉ
(
x
t
−
1
−
α
t
ˉ
⋅
z
t
)
x_0 = \frac{1}{\sqrt{\bar{\alpha_{t}}}}(x_t - \sqrt{1 - \bar{\alpha_t}}\cdot z_t)
x0=αtˉ1(xt−1−αtˉ⋅zt)
将
x
0
x_0
x0的表达式代入到
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t - 1} | x_t, x_0)
q(xt−1∣xt,x0)的分布式中,可以重新给出该分布的均值表达式,也就是说,在给定
x
0
x_0
x0的条件下,后验条件高斯分布的均值计算只与
x
t
x_t
xt和
z
t
z_t
zt有关。
z
t
z_t
zt是从第
t
t
t个时间步的正态分布中采样出来的样本。
μ
t
~
=
α
t
(
1
−
α
t
−
1
ˉ
)
1
−
α
t
ˉ
x
t
+
α
t
−
1
ˉ
β
t
1
−
α
t
ˉ
1
α
t
ˉ
(
x
t
−
1
−
α
t
ˉ
z
t
)
=
1
α
t
(
x
t
−
β
t
1
−
α
t
ˉ
z
t
)
\widetilde{\mu_t} = \frac{\sqrt{\alpha_t}(1-\bar{\alpha_{t - 1}})}{1 - \bar{\alpha_t}}x_t + \frac{\sqrt{\bar{\alpha_{t - 1}}}\beta_t}{1 - \bar{\alpha_t}} \frac{1}{\sqrt{\bar{\alpha_t}}}(x_t - \sqrt{1 - \bar{\alpha_t}}z_t) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha_t}}}z_t)
μt
=1−αtˉαt(1−αt−1ˉ)xt+1−αtˉαt−1ˉβtαtˉ1(xt−1−αtˉzt)=αt1(xt−1−αtˉβtzt)
得到最终的
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t - 1}|x_t, x_0)
q(xt−1∣xt,x0)分布为:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
1
α
t
(
x
t
−
β
t
1
−
α
t
ˉ
z
t
)
,
1
−
α
t
−
1
ˉ
1
−
α
t
ˉ
β
t
)
q(x_{t - 1}|x_t, x_0) = N(x_{t - 1}; \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha_t}}}z_t), \frac{1 - \bar{\alpha_{t - 1}}}{1 - \bar{\alpha_t}}\beta_t)
q(xt−1∣xt,x0)=N(xt−1;αt1(xt−1−αtˉβtzt),1−αtˉ1−αt−1ˉβt)
五、损失函数
我们可以在负对数似然函数的基础上加一个KL散度,于是就构成了负对数似然的上界了,上界越小,负对数似然自然也就越小,那么对数似然就越大了。
等式两边都加上 E q ( x 0 ) E_{q(x_0)} Eq(x0)可以得到得到: E q ( x 0 : T ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] ≥ − E q ( x 0 ) log p θ ( x 0 ) E_{q(x_{0:T})}[\log{\frac{q(x_{1:T} | x_0)}{p_\theta(x_{0:T})}}] \geq -E_{q(x_0)}\log{p_\theta(x_0)} Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]≥−Eq(x0)logpθ(x0)
我们令 L V L B = E q ( x 0 : T ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] ≥ − E q ( x 0 ) log p θ ( x 0 ) L_{VLB} = E_{q(x_{0:T})}[\log{\frac{q(x_{1:T} | x_0)}{p_\theta(x_{0:T})}}] \geq -E_{q(x_0)}\log{p_\theta(x_0)} LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]≥−Eq(x0)logpθ(x0)
现在我们只需要简化交叉熵上界即可,对 L V B L L_{VBL} LVBL进行化简:
已知:
最终化简为: L V L B = E q [ D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x T ) ) ⏟ L T + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ⏟ L t − 1 − log p θ ( x 0 ∣ x 1 ) ⏟ L 0 ] L_{VLB}= \underbrace{E_q[D_{KL}(q(x_T|x_0) || p_\theta(x_T))}_{L_T}+ \sum\limits_{t = 2}^T \underbrace{D_{KL}(q(x_{t-1}|x_t, x_0)|| p_\theta(x_{t - 1}| x_t)}_{L_{t - 1}} -\underbrace{\log{p_\theta(x_0|x_1)}}_{L_0}] LVLB=LT Eq[DKL(q(xT∣x0)∣∣pθ(xT))+t=2∑TLt−1 DKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt)−L0 logpθ(x0∣x1)]
L T L_T LT部分: q ( x T ∣ x 0 ) q(x_T|x_0) q(xT∣x0)是不含参的,可以由 β t \beta_t βt计算出来,是一个完全高斯分布。
L 0 L_0 L0部分:这一部分是 log p θ ( x 0 ∣ x 1 ) \log{p_\theta(x_0|x_1)} logpθ(x0∣x1),也就是说从 x 1 x_1 x1分布推理出 x 0 x_0 x0分布的一个分布,现在我们要使 − log p θ ( x 0 ∣ x 1 ) -\log{p_\theta(x_0|x_1)} −logpθ(x0∣x1)尽可能小,也就是使 log p θ ( x 0 ∣ x 1 ) \log{p_\theta(x_0|x_1)} logpθ(x0∣x1)的值越大,说明要从 x 1 x_1 x1推理出 x 0 x_0 x0的概率更大,也就是 x 0 x_0 x0的分布要与 x 1 x_1 x1的分布越相似,这一步我们在设定参数时就已经考虑了这一情况,即( β t \beta_t βt随时间 t t t而增大)。
L t − 1 L_{t -1} Lt−1部分:是涉及参数的主要部分。
可以知道参数主要存在于 L t − 1 L_{t - 1} Lt−1中,这里论文将 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t - 1} | x_t) pθ(xt−1∣xt)分布的方差设置成了一个与 β \beta β相关的常数,因此可训练的参数只存在于其均值中,对于两个单一变量的高斯分布 p p p和 q q q而言,它们的 KL 散度为: K L ( p , q ) = log σ 1 σ 2 + σ 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(p,q) = \log{\frac{\sigma_1}{\sigma_2}} + \frac{\sigma^2 + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2} KL(p,q)=logσ2σ1+2σ22σ2+(μ1−μ2)2−21,我们将 L t − 1 L_{t - 1} Lt−1展开,并且只取 ( μ 1 − μ 2 ) 2 (\mu_1 - \mu_2)^2 (μ1−μ2)2部分,其它的部分都用一个常数 C C C来表示。
我们将 μ θ \mu_\theta μθ同样写为 μ ~ \widetilde\mu μ 的形式,这样参数从均值 μ \mu μ转移到了变量 ϵ \epsilon ϵ上。
将上述 μ θ ( x t , t ) \mu_\theta(x_t, t) μθ(xt,t)带入到 L t − 1 − C L_{t - 1} - C Lt−1−C中可得:
E
x
0
,
ϵ
[
β
t
2
2
σ
t
2
α
t
(
1
−
α
ˉ
t
)
∣
∣
ϵ
−
ϵ
θ
(
α
t
ˉ
x
0
+
1
−
α
t
ˉ
ϵ
,
t
)
∣
∣
2
]
E_{x_0, \epsilon}[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1 - \bar\alpha_t)}|| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon, t)||^2]
Ex0,ϵ[2σt2αt(1−αˉt)βt2∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]
在论文中,作者声明可以将系数部分完全丢掉,这样训练会更加稳定,质量会更好,因此最终的损失函数可以写为:
L
s
i
m
p
l
e
(
θ
)
=
E
t
,
x
0
,
ϵ
[
∣
∣
ϵ
−
ϵ
θ
(
α
t
ˉ
x
0
+
1
−
α
t
ˉ
ϵ
,
t
)
∣
∣
2
]
L_{simple}(\theta) = E_{t, x_0, \epsilon}[|| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon, t)||^2]
Lsimple(θ)=Et,x0,ϵ[∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]
这样我们只需要将加噪过程中的
ϵ
\epsilon
ϵ与预测过程中的
ϵ
θ
\epsilon_\theta
ϵθ的误差不断减小即可。
六、总结
算法训练与采样流程:
在训练过程中,我们要让模型去学习加噪过程中的每次从正态分布取得的随机噪声 ϵ \epsilon ϵ,我们通过这个噪声可以推理出每一个时刻 t t t的数据分布的均值。
在推理过程中,我们通过模型输出 t t t时刻的随机噪声 ϵ θ \epsilon_\theta ϵθ来计算出 t t t时刻数据分布的均值,然后再通过该均值与方差来进行参数重整化,得到 t − 1 t - 1 t−1时刻的数据分布。
七、附录
DDPM模型结构图
其中resblockattn
模块的具体结构如下:
该模块中Attention
部分不是必要的。