前置数学知识
1、先验概率和后验概率
先验概率:根据以往经验和分析得到的概率,它往往作为“由因求果”问题中的“因”出现,如 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)
后验概率:指在得到“结果”的信息后重新修正的概率,是“执果寻因”问题中的“因", 如 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt)
2、条件概率:设
A
A
A、
B
B
B为任意两个事件,若
P
(
A
)
>
0
P(A)>0
P(A)>0,称在已知事件
A
A
A发生的条件下,事件
B
B
B发生的概率为条件概率,记为
P
(
B
∣
A
)
P(B|A)
P(B∣A)
P
(
B
∣
A
)
=
P
(
A
,
B
)
P
(
A
)
P(B|A)=\frac{P(A,B)} {P(A)}
P(B∣A)=P(A)P(A,B)
3、乘法公式:
P
(
A
,
B
)
=
P
(
B
∣
A
)
P
(
A
)
P(A,B)=P(B|A)P(A)
P(A,B)=P(B∣A)P(A)
4、乘法公式一般形式:
P
(
A
,
B
,
C
)
=
P
(
C
∣
B
,
A
)
P
(
B
,
A
)
=
P
(
C
∣
B
,
A
)
P
(
B
∣
A
)
P
(
A
)
P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A)\\
P(A,B,C)=P(C∣B,A)P(B,A)=P(C∣B,A)P(B∣A)P(A)
5、贝叶斯公式:
P
(
A
∣
B
)
=
P
(
B
∣
A
)
P
(
A
)
P
(
B
)
P(A|B)=\frac{P(B|A)P(A)}{P(B)}
P(A∣B)=P(B)P(B∣A)P(A)
6、多元贝叶斯公式:
P
(
A
∣
B
,
C
)
=
P
(
A
,
B
,
C
)
P
(
B
,
C
)
=
P
(
B
∣
A
,
C
)
P
(
A
,
C
)
P
(
B
,
C
)
=
P
(
B
∣
A
,
C
)
P
(
A
∣
C
)
P
(
C
)
P
(
B
∣
C
)
P
(
C
)
=
P
(
B
∣
A
,
C
)
P
(
A
∣
C
)
)
P
(
B
∣
C
)
P(A|B,C)=\frac{P(A,B,C)}{P(B,C)}=\frac{P(B|A,C)P(A,C)}{P(B,C)}=\frac{P(B|A,C)P(A|C)P(C)}{P(B|C)P(C)}=\frac{P(B|A,C)P(A|C))}{P(B|C)}
P(A∣B,C)=P(B,C)P(A,B,C)=P(B,C)P(B∣A,C)P(A,C)=P(B∣C)P(C)P(B∣A,C)P(A∣C)P(C)=P(B∣C)P(B∣A,C)P(A∣C))
7、正态分布的叠加性:当有两个独立的正态分布变量
N
1
N_{1}
N1和
N
2
N_{2}
N2,它们的均值和方差分别为
μ
1
\mu_{1}
μ1,
μ
2
\mu_{2}
μ2和
σ
1
2
\sigma_{1}^2
σ12,
σ
2
2
\sigma_{2}^2
σ22它们的和为
N
=
a
N
1
+
b
N
2
N=a N_{1}+b N_{2}
N=aN1+bN2的均值和方差可以表示如下:
E
(
N
)
=
E
(
a
N
1
+
b
N
2
)
=
a
μ
1
+
b
μ
2
V
a
r
(
N
)
=
V
a
r
(
a
N
1
+
b
N
2
)
=
a
2
σ
1
2
+
b
2
σ
2
2
E(N)=E(aN_{1}+bN_{2})=a\mu_{1}+b\mu_{2}\\ Var(N)=Var(aN_{1}+bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2
E(N)=E(aN1+bN2)=aμ1+bμ2Var(N)=Var(aN1+bN2)=a2σ12+b2σ22
相减时:
E
(
N
)
=
E
(
a
N
1
−
b
N
2
)
=
a
μ
1
−
b
μ
2
V
a
r
(
N
)
=
V
a
r
(
a
N
1
−
b
N
2
)
=
a
2
σ
1
2
+
b
2
σ
2
2
E(N)=E(aN_{1}-bN_{2})=a\mu_{1}-b\mu_{2}\\ Var(N)=Var(aN_{1}-bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2
E(N)=E(aN1−bN2)=aμ1−bμ2Var(N)=Var(aN1−bN2)=a2σ12+b2σ22
8、重参数化:从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2) 采样等价于从 N ( 0 , 1 ) N(0,1) N(0,1)采样一个 ϵ \epsilon ϵ, ϵ ⋅ σ + μ \epsilon\cdot\sigma+\mu ϵ⋅σ+μ
9、高斯分布的概率密度函数
f
(
x
)
=
1
2
π
σ
e
−
(
x
−
μ
)
2
2
σ
2
f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}
f(x)=2πσ1e−2σ2(x−μ)2
10、高斯分布的KL散度公式
K
L
(
p
∣
q
)
=
l
o
g
σ
2
σ
1
+
σ
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
KL(p|q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}
KL(p∣q)=logσ1σ2+2σ22σ2+(μ1−μ2)2−21
11、二次函数配方
a
x
2
+
b
x
=
a
(
x
+
b
2
a
)
2
+
c
ax^2+bx=a(x+\frac{b}{2a})^2+c
ax2+bx=a(x+2ab)2+c
12、随机变量的期望公式
设
X
X
X是随机变量,
Y
=
g
(
X
)
Y=g(X)
Y=g(X),则:
E
(
Y
)
=
E
[
g
(
X
)
]
=
{
∑
k
=
1
∞
g
(
x
k
)
p
k
∫
−
∞
∞
g
(
x
)
p
(
x
)
d
x
E(Y)=E[g(X)]= \begin{cases} \displaystyle\sum_{k=1}^\infty g(x_k)p_k\\ \displaystyle\int_{-\infty}^{\infty}g(x)p(x)dx \end{cases}
E(Y)=E[g(X)]=⎩
⎨
⎧k=1∑∞g(xk)pk∫−∞∞g(x)p(x)dx
13、KL散度公式
K
L
(
p
(
x
)
∣
q
(
x
)
)
=
E
x
∼
p
(
x
)
[
p
(
x
)
q
(
x
)
]
=
∫
p
(
x
)
p
(
x
)
q
(
x
)
d
x
KL(p(x)|q(x))=E_{x \sim p(x)}[\frac{p(x)}{q(x)}]=\int p(x) \frac{p(x)}{q(x)}dx
KL(p(x)∣q(x))=Ex∼p(x)[q(x)p(x)]=∫p(x)q(x)p(x)dx
介绍DDPM
2020年Berkeley提出DDPM(Denoising Diffusion Probabilistic Models),简称扩散模型,是AIGC的核心算法,在生成图像的真实性和多样性方面均超越了GAN,而且训练过程稳定。缺点是计算成本较高,实时推理比较困难,但也有相关技术在时间和空间维度上降低计算量。
扩散模型包括两个过程:前向扩散过程(前向加噪过程)和反向去噪过程。
前向过程和反向过程都是马尔可夫链,全过程大约需要1000步,其中反向过程用来生成数据,它的推导过程可以描述成:
前向扩散的过程
前向扩散过程是对原始数据逐渐增加高斯噪声,直至变成标准高斯分布的过程。
从原始数据集采样
x
0
∼
q
(
x
0
)
x_0\sim q(x_0)
x0∼q(x0),按照预定义的noise schedule策略添加随机噪声,得到一系列噪声图像
x
1
,
x
2
,
…
,
x
T
x_1,x_2,\dots,x_T
x1,x2,…,xT,用概率表示为:
q
(
x
1
:
T
∣
x
0
)
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
α
t
x
t
−
1
,
β
t
I
)
\begin{aligned} q(x_{1:T}|x_{0})&=\prod_{t=1}^{T}q(x_t|x_{t-1}) \\q(x_{t}|x_{t-1})&=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1},\beta_{t}I)\\ \end{aligned}
q(x1:T∣x0)q(xt∣xt−1)=t=1∏Tq(xt∣xt−1)=N(xt;αtxt−1,βtI)
进行重参数化(前置知识数学知识8),得到
x
t
=
α
t
x
t
−
1
+
β
t
ϵ
t
ϵ
t
∼
N
(
0
,
I
)
α
t
=
1
−
β
t
\begin{aligned} x_{t}&=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \space \space \space \space \epsilon_{t}\sim \mathcal{N}(0,I) \\ \alpha_{t}&=1-\beta_{t} \end{aligned}
xtαt=αtxt−1+βtϵt ϵt∼N(0,I)=1−βt
利用上述公式进行迭代推导
x
t
=
α
t
x
t
−
1
+
β
t
ϵ
t
=
α
t
(
α
t
−
1
x
t
−
2
+
β
t
−
1
ϵ
t
−
1
)
+
β
t
ϵ
t
=
(
α
t
…
α
1
)
x
0
+
(
α
t
…
α
2
)
β
1
ϵ
1
+
(
α
t
…
α
3
)
β
2
ϵ
2
+
⋯
+
α
t
β
t
−
1
ϵ
t
−
1
+
β
t
ϵ
t
\begin{aligned} x_{t}&=\sqrt{\alpha_{t}} x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{\alpha_{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1})+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{(\alpha_{t}\dots\alpha_{1})}x_{0}+\sqrt{(\alpha_{t}\dots\alpha_{2})\beta_{1}}\epsilon_{1}+\sqrt{(\alpha_{t}\dots\alpha_{3})\beta_{2}}\epsilon_{2}+\dots+\sqrt{\alpha_{t}\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \end{aligned}
xt=αtxt−1+βtϵt=αt(αt−1xt−2+βt−1ϵt−1)+βtϵt=(αt…α1)x0+(αt…α2)β1ϵ1+(αt…α3)β2ϵ2+⋯+αtβt−1ϵt−1+βtϵt
设: α t ˉ = α 1 α 2 … α t \bar{\alpha_{t}}=\alpha_{1}\alpha_{2}\dots\alpha_{t} αtˉ=α1α2…αt
根据正态分布的叠加性得到
x
t
=
α
t
ˉ
x
0
+
1
−
α
t
ˉ
ϵ
ϵ
∼
N
(
0
,
I
)
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
t
ˉ
x
0
,
1
−
α
t
ˉ
I
)
x_{t}=\sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon \space \space\space \epsilon\sim \mathcal{N}(0,I)\\ \textcolor{REd}{q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},\sqrt{1-\bar{\alpha_{t}}}I)}
xt=αtˉx0+1−αtˉϵ ϵ∼N(0,I)q(xt∣x0)=N(xt;αtˉx0,1−αtˉI)
这个公式表示任意步骤
t
t
t的噪声图像
x
t
x_t
xt ,都可以通过
x
0
x_0
x0直接加噪得到,后面需要用到。
注:上述前向过程在代码实现时是一步到位的!!!!!
反向去噪过程,神经网络拟合过程
反向去噪过程就是数据生成过程,它首先是从标准高斯分布中采样得到一个噪声样本,再一步步地迭代去噪,最后得到数据分布中的一个样本。
如果知道反向过程的每一步真实的条件分布
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt),那么从一个随机噪声开始,逐步采样就能生成一个真实的样本。但是真实的条件分布利用贝叶斯公式
q
(
x
t
−
1
∣
x
t
)
=
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
)
q
(
x
t
)
q(x_{t-1}|x_{t}) =\frac{q(x_{t}|x_{t-1})q(x_{t-1})}{q(x_{t})}
q(xt−1∣xt)=q(xt)q(xt∣xt−1)q(xt−1)
无法直接求解,原因是其中
q
(
x
t
−
1
)
q(x_{t-1})
q(xt−1) ,
q
(
x
t
)
q(x_{t})
q(xt) 未知,因此无法从
x
t
x_{t}
xt 推导到
x
t
−
1
{x_{t-1}}
xt−1,所以必须通过神经网络**
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(x_{t-1}|x_t)
pθ(xt−1∣xt)来近似。为了简化起见,将反向过程也定义为一个马尔卡夫链,且服从高斯分布**,建模如下:
p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{0:T})=p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t)\\ p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),θ∑(xt,t))
--------------------下面这段讲解与上面有些跳脱,是为损失函数做铺垫------------------------------
虽然真实条件分布
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt)无法直接求解,但是加上已知条件
x
0
x_0
x0的后验分布$q(x_{t-1}|x_{t},x_{0}) $却可以通过贝叶斯公式求解,再结合前向马尔科夫性质可得:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
q(x_{t-1}|x_{t},x_{0}) =\frac{q(x_{t}|x_{t-1},x_{0})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}=\frac{q(x_{t}|x_{t-1})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}
q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1,x0)q(xt−1∣x0)=q(xt∣x0)q(xt∣xt−1)q(xt−1∣x0)
因此可以得到:
q
(
x
t
−
1
∣
x
0
)
=
α
ˉ
t
−
1
x
0
+
1
−
α
ˉ
t
−
1
ϵ
∼
N
(
α
ˉ
t
−
1
x
0
,
(
1
−
α
ˉ
t
−
1
)
I
)
q
(
x
t
∣
x
0
)
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
∼
N
(
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q
(
x
t
∣
x
t
−
1
)
=
α
t
x
t
−
1
+
β
t
ϵ
∼
N
(
α
t
x
t
−
1
,
β
t
I
)
\begin{aligned} q(x_{t-1}|x_{0})&=\sqrt{\bar{\alpha}_{t-1}}x_{0}+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}x_{0},(1-\bar{\alpha}_{t-1})I)\\ q(x_{t}|x_{0})&=\sqrt{\bar{\alpha}_{t}}x_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t}}x_{0},(1-\bar{\alpha}_{t})I)\\ q(x_{t}|x_{t-1})&=\sqrt{\alpha}_{t}x_{t-1}+\beta_{t}\epsilon\sim \mathcal{N}(\sqrt{\alpha}_{t}x_{t-1},\beta_{t}I) \end{aligned}
q(xt−1∣x0)q(xt∣x0)q(xt∣xt−1)=αˉt−1x0+1−αˉt−1ϵ∼N(αˉt−1x0,(1−αˉt−1)I)=αˉtx0+1−αˉtϵ∼N(αˉtx0,(1−αˉt)I)=αtxt−1+βtϵ∼N(αtxt−1,βtI)
所以
q
(
x
t
−
1
∣
x
t
,
x
0
)
∝
e
x
p
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
β
t
)
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
=
e
x
p
(
−
1
2
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
t
ˉ
1
−
α
t
ˉ
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
\begin{aligned} q(x_{t-1}|x_{t},x_{0}) &\propto exp(-\frac{1}{2}(\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_{t}})+\frac{(x_{t-1}-\sqrt{\bar{\alpha}}_{t-1}x_{0})^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_{t}-\sqrt{\bar{\alpha}_{t}}x_{0})^2}{1-\bar{\alpha}_{t}})\\ &=exp(-\frac{1}{2}(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_{t}}}{\beta_{t}}x_{t}+\frac{2\sqrt{\bar{\alpha_{t}}}}{1-\bar{\alpha_{t}}}x_{0})x_{t-1}+C(x_{t},x_{0})) \end{aligned}
q(xt−1∣xt,x0)∝exp(−21(βt(xt−αtxt−1)2)+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2)=exp(−21(βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αtˉ2αtˉx0)xt−1+C(xt,x0))
通过配方就可以得到
β
~
t
=
1
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
β
t
μ
~
t
=
(
α
t
β
t
x
t
+
α
ˉ
t
1
−
α
t
ˉ
x
0
)
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
t
ˉ
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
\widetilde{\beta}_t=1/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t}\\ \widetilde{\mu}_t=(\frac{\sqrt\alpha_{t}}{\beta_{t}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t}}}{1-\bar{\alpha_{t}}}x_{0})/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}x_{0}
β
t=1/(βtαt+1−αˉt−11)=1−αˉt1−αˉt−1βtμ
t=(βtαtxt+1−αtˉαˉtx0)/(βtαt+1−αˉt−11)=1−αtˉαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0
又因为
x
0
=
1
α
ˉ
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
)
x_0= \frac{1}{\sqrt{\bar\alpha_t}}(x_t- \frac{\beta_t}{\sqrt{1-\bar \alpha_t} }\epsilon)\\
x0=αˉt1(xt−1−αˉtβtϵ)
可以得
μ
~
t
=
1
α
t
(
x
t
−
β
t
(
1
−
α
t
)
ϵ
)
\widetilde{\mu}_t=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon)
μ
t=αt1(xt−(1−αt)βtϵ)
----------------------------------------------------------------------------------------------
采样过程(模型训练完后的预测过程)
μ
θ
(
x
t
,
t
)
=
1
α
t
(
x
t
−
β
t
(
1
−
α
t
)
ϵ
θ
(
x
t
,
t
)
)
x
t
−
1
∼
p
θ
(
x
t
−
1
∣
x
t
)
x
t
−
1
=
1
α
t
(
x
t
−
β
t
(
1
−
α
t
)
ϵ
θ
(
x
t
,
t
)
)
+
β
~
t
z
z
∼
N
(
0
,
I
)
\mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))\\ x_{t-1}\sim p_\theta(x_{t-1}|x_t)\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))+\sqrt{\widetilde{\beta}_t}z \space \space\space\space z\sim N(0,I)
μθ(xt,t)=αt1(xt−(1−αt)βtϵθ(xt,t))xt−1∼pθ(xt−1∣xt)xt−1=αt1(xt−(1−αt)βtϵθ(xt,t))+β
tz z∼N(0,I)
这里用z是为了和之前的
ϵ
\epsilon
ϵ区别开
损失函数
https://blog.csdn.net/weixin_45453121/article/details/131223653
Code
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
from torch.optim import Adam
from torch import nn
import math
from torchvision.utils import save_image
def show_images(data, num_samples=20, cols=4):
""" Plots some samples from the dataset """
plt.figure(figsize=(15,15))
for i, img in enumerate(data):
if i == num_samples:
break
plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
plt.imshow(img[0])
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
return torch.linspace(start, end, timesteps)
def get_index_from_list(vals, t, x_shape):
"""
Returns a specific index t of a passed list of values vals
while considering the batch dimension.
"""
batch_size = t.shape[0]
out = vals.gather(-1, t.cpu())
#print("out:",out)
#print("out.shape:",out.shape)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def forward_diffusion_sample(x_0, t, device="cpu"):
"""
Takes an image and a timestep as input and
returns the noisy version of it
"""
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x_0.shape
)
# mean + variance
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
def load_transformed_dataset(IMG_SIZE):
data_transforms = [
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(), # Scales data into [0,1]
transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
]
data_transform = transforms.Compose(data_transforms)
train = torchvision.datasets.MNIST(root="./Data",transform=data_transform,train=True)
test = torchvision.datasets.MNIST(root="./Data", transform=data_transform, train=False)
return torch.utils.data.ConcatDataset([train, test])
def show_tensor_image(image):
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2),
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.),
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
transforms.ToPILImage(),
])
#Take first image of batch
if len(image.shape) == 4:
image = image[0, :, :, :]
plt.imshow(reverse_transforms(image))
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x, t):
#print("ttt:",t.shape)
# First Conv
h = self.bnorm1(self.relu(self.conv1(x)))
# Time embedding
time_emb = self.relu(self.time_mlp(t))
# Extend last 2 dimensions
time_emb = time_emb[(..., ) + (None, ) * 2]
# Add time channel
h = h + time_emb
# Second Conv
h = self.bnorm2(self.relu(self.conv2(h)))
# Down or Upsample
return self.transform(h)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# TODO: Double check the ordering here
return embeddings
class SimpleUnet(nn.Module):
"""
A simplified variant of the Unet architecture.
"""
def __init__(self):
super().__init__()
image_channels =1 #灰度图为1,彩色图为3
down_channels = (64, 128, 256, 512, 1024)
up_channels = (1024, 512, 256, 128, 64)
out_dim = 1 #灰度图为1 ,彩色图为3
time_emb_dim = 32
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# Initial projection
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
# Downsample
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
time_emb_dim) \
for i in range(len(down_channels)-1)])
# Upsample
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
time_emb_dim, up=True) \
for i in range(len(up_channels)-1)])
# Edit: Corrected a bug found by Jakub C (see YouTube comment)
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
def forward(self, x, timestep):
# Embedd time
t = self.time_mlp(timestep)
# Initial conv
x = self.conv0(x)
# Unet
residual_inputs = []
for down in self.downs:
x = down(x, t)
residual_inputs.append(x)
for up in self.ups:
residual_x = residual_inputs.pop()
# Add residual x as additional channels
x = torch.cat((x, residual_x), dim=1)
x = up(x, t)
return self.output(x)
def get_loss(model, x_0, t):
x_noisy, noise = forward_diffusion_sample(x_0, t, device)
noise_pred = model(x_noisy, t)
return F.l1_loss(noise, noise_pred)
@torch.no_grad()
def sample_timestep(x, t):
"""
Calls the model to predict the noise in the image and returns
the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# Call model (current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
# As pointed out by Luis Pereira (see YouTube comment)
# The t's are offset from the t's in the paper
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def sample_plot_image(IMG_SIZE):
# Sample noise
img_size = IMG_SIZE
img = torch.randn((1, 1, img_size, img_size), device=device) #生成第T步的图片
plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)
for i in range(0,T)[::-1]:
t = torch.full((1,), i, device=device, dtype=torch.long)
#print("t:",t)
img = sample_timestep(img, t)
# Edit: This is to maintain the natural range of the distribution
img = torch.clamp(img, -1.0, 1.0)
if i % stepsize == 0:
plt.subplot(1, num_images, int(i/stepsize)+1)
plt.title(str(i))
show_tensor_image(img.detach().cpu())
plt.show()
if __name__ =="__main__":
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
# print(alphas_cumprod.shape)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# print(alphas_cumprod_prev)
# print(alphas_cumprod_prev.shape)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# print(posterior_variance.shape)
IMG_SIZE = 32
BATCH_SIZE = 16
data = load_transformed_dataset(IMG_SIZE)
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 1 # Try more!
for epoch in range(epochs):
for step, batch in enumerate(dataloader): #由于batch 是包含标签的所以取batch[0]
#print(batch[0].shape)
optimizer.zero_grad()
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
loss = get_loss(model, batch[0], t)
loss.backward()
optimizer.step()
if epoch % 1 == 0 and step %5== 0:
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
sample_plot_image(IMG_SIZE)
参考文献
https://zhuanlan.zhihu.com/p/630354327](https://zhuanlan.zhihu.com/p/630354327)
https://blog.csdn.net/weixin_45453121/article/details/131223653
https://www.cnblogs.com/risejl/p/17448442.html
https://zhuanlan.zhihu.com/p/569994589?utm_id=0