深度学习——(生成模型)DDPM

news2024/11/20 7:17:53

前置数学知识

1、先验概率和后验概率

先验概率:根据以往经验和分析得到的概率,它往往作为“由因求果”问题中的“因”出现,如 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)

后验概率:指在得到“结果”的信息后重新修正的概率,是“执果寻因”问题中的“因", 如 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)

2、条件概率:设 A A A B B B为任意两个事件,若 P ( A ) > 0 P(A)>0 P(A)>0,称在已知事件 A A A发生的条件下,事件 B B B发生的概率为条件概率,记为 P ( B ∣ A ) P(B|A) P(BA)
P ( B ∣ A ) = P ( A , B ) P ( A ) P(B|A)=\frac{P(A,B)} {P(A)} P(BA)=P(A)P(A,B)

3、乘法公式:
P ( A , B ) = P ( B ∣ A ) P ( A ) P(A,B)=P(B|A)P(A) P(A,B)=P(BA)P(A)

4、乘法公式一般形式:
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A)\\ P(A,B,C)=P(CB,A)P(B,A)=P(CB,A)P(BA)P(A)

5、贝叶斯公式:
P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A|B)=\frac{P(B|A)P(A)}{P(B)} P(AB)=P(B)P(BA)P(A)
6、多元贝叶斯公式:
P ( A ∣ B , C ) = P ( A , B , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A , C ) P ( B , C ) = P ( B ∣ A , C ) P ( A ∣ C ) P ( C ) P ( B ∣ C ) P ( C ) = P ( B ∣ A , C ) P ( A ∣ C ) ) P ( B ∣ C ) P(A|B,C)=\frac{P(A,B,C)}{P(B,C)}=\frac{P(B|A,C)P(A,C)}{P(B,C)}=\frac{P(B|A,C)P(A|C)P(C)}{P(B|C)P(C)}=\frac{P(B|A,C)P(A|C))}{P(B|C)} P(AB,C)=P(B,C)P(A,B,C)=P(B,C)P(BA,C)P(A,C)=P(BC)P(C)P(BA,C)P(AC)P(C)=P(BC)P(BA,C)P(AC))

7、正态分布的叠加性:当有两个独立的正态分布变量 N 1 N_{1} N1 N 2 N_{2} N2,它们的均值和方差分别为 μ 1 \mu_{1} μ1, μ 2 \mu_{2} μ2 σ 1 2 \sigma_{1}^2 σ12, σ 2 2 \sigma_{2}^2 σ22它们的和为 N = a N 1 + b N 2 N=a N_{1}+b N_{2} N=aN1+bN2的均值和方差可以表示如下:
E ( N ) = E ( a N 1 + b N 2 ) = a μ 1 + b μ 2 V a r ( N ) = V a r ( a N 1 + b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}+bN_{2})=a\mu_{1}+b\mu_{2}\\ Var(N)=Var(aN_{1}+bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1+bN2)=aμ1+bμ2Var(N)=Var(aN1+bN2)=a2σ12+b2σ22
相减时:
E ( N ) = E ( a N 1 − b N 2 ) = a μ 1 − b μ 2 V a r ( N ) = V a r ( a N 1 − b N 2 ) = a 2 σ 1 2 + b 2 σ 2 2 E(N)=E(aN_{1}-bN_{2})=a\mu_{1}-b\mu_{2}\\ Var(N)=Var(aN_{1}-bN_{2})=a^2\sigma_{1}^2+b^2\sigma_{2}^2 E(N)=E(aN1bN2)=aμ1bμ2Var(N)=Var(aN1bN2)=a2σ12+b2σ22

8、重参数化:从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2) 采样等价于从 N ( 0 , 1 ) N(0,1) N(0,1)采样一个 ϵ \epsilon ϵ, ϵ ⋅ σ + μ \epsilon\cdot\sigma+\mu ϵσ+μ

9、高斯分布的概率密度函数
f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=2π σ1e2σ2(xμ)2
10、高斯分布的KL散度公式
K L ( p ∣ q ) = l o g σ 2 σ 1 + σ 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(p|q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(pq)=logσ1σ2+2σ22σ2+(μ1μ2)221
11、二次函数配方
a x 2 + b x = a ( x + b 2 a ) 2 + c ax^2+bx=a(x+\frac{b}{2a})^2+c ax2+bx=a(x+2ab)2+c
12、随机变量的期望公式
X X X是随机变量, Y = g ( X ) Y=g(X) Y=g(X),则:
E ( Y ) = E [ g ( X ) ] = { ∑ k = 1 ∞ g ( x k ) p k ∫ − ∞ ∞ g ( x ) p ( x ) d x E(Y)=E[g(X)]= \begin{cases} \displaystyle\sum_{k=1}^\infty g(x_k)p_k\\ \displaystyle\int_{-\infty}^{\infty}g(x)p(x)dx \end{cases} E(Y)=E[g(X)]= k=1g(xk)pkg(x)p(x)dx

13、KL散度公式
K L ( p ( x ) ∣ q ( x ) ) = E x ∼ p ( x ) [ p ( x ) q ( x ) ] = ∫ p ( x ) p ( x ) q ( x ) d x KL(p(x)|q(x))=E_{x \sim p(x)}[\frac{p(x)}{q(x)}]=\int p(x) \frac{p(x)}{q(x)}dx KL(p(x)q(x))=Exp(x)[q(x)p(x)]=p(x)q(x)p(x)dx

介绍DDPM

2020年Berkeley提出DDPM(Denoising Diffusion Probabilistic Models),简称扩散模型,是AIGC的核心算法,在生成图像的真实性和多样性方面均超越了GAN,而且训练过程稳定。缺点是计算成本较高,实时推理比较困难,但也有相关技术在时间和空间维度上降低计算量。

扩散模型包括两个过程:前向扩散过程(前向加噪过程)反向去噪过程

img

前向过程和反向过程都是马尔可夫链,全过程大约需要1000步,其中反向过程用来生成数据,它的推导过程可以描述成:

img

前向扩散的过程

前向扩散过程是对原始数据逐渐增加高斯噪声,直至变成标准高斯分布的过程。

img

从原始数据集采样 x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0),按照预定义的noise schedule策略添加随机噪声,得到一系列噪声图像 x 1 , x 2 , … , x T x_1,x_2,\dots,x_T x1,x2,,xT,用概率表示为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I ) \begin{aligned} q(x_{1:T}|x_{0})&=\prod_{t=1}^{T}q(x_t|x_{t-1}) \\q(x_{t}|x_{t-1})&=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1},\beta_{t}I)\\ \end{aligned} q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;αt xt1,βtI)
进行重参数化(前置知识数学知识8),得到
x t = α t x t − 1 + β t ϵ t      ϵ t ∼ N ( 0 , I ) α t = 1 − β t \begin{aligned} x_{t}&=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \space \space \space \space \epsilon_{t}\sim \mathcal{N}(0,I) \\ \alpha_{t}&=1-\beta_{t} \end{aligned} xtαt=αt xt1+βt ϵt    ϵtN(0,I)=1βt

利用上述公式进行迭代推导
x t = α t x t − 1 + β t ϵ t = α t ( α t − 1 x t − 2 + β t − 1 ϵ t − 1 ) + β t ϵ t = ( α t … α 1 ) x 0 + ( α t … α 2 ) β 1 ϵ 1 + ( α t … α 3 ) β 2 ϵ 2 + ⋯ + α t β t − 1 ϵ t − 1 + β t ϵ t \begin{aligned} x_{t}&=\sqrt{\alpha_{t}} x_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{\alpha_{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1})+\sqrt{\beta_{t}}\epsilon_{t}\\ &=\sqrt{(\alpha_{t}\dots\alpha_{1})}x_{0}+\sqrt{(\alpha_{t}\dots\alpha_{2})\beta_{1}}\epsilon_{1}+\sqrt{(\alpha_{t}\dots\alpha_{3})\beta_{2}}\epsilon_{2}+\dots+\sqrt{\alpha_{t}\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \end{aligned} xt=αt xt1+βt ϵt=αt (αt1 xt2+βt1 ϵt1)+βt ϵt=(αtα1) x0+(αtα2)β1 ϵ1+(αtα3)β2 ϵ2++αtβt1 ϵt1+βt ϵt

设: α t ˉ = α 1 α 2 … α t \bar{\alpha_{t}}=\alpha_{1}\alpha_{2}\dots\alpha_{t} αtˉ=α1α2αt

根据正态分布的叠加性得到

x t = α t ˉ x 0 + 1 − α t ˉ ϵ     ϵ ∼ N ( 0 , I ) q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , 1 − α t ˉ I ) x_{t}=\sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon \space \space\space \epsilon\sim \mathcal{N}(0,I)\\ \textcolor{REd}{q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},\sqrt{1-\bar{\alpha_{t}}}I)} xt=αtˉ x0+1αtˉ ϵ   ϵN(0,I)q(xtx0)=N(xt;αtˉ x0,1αtˉ I)
这个公式表示任意步骤 t t t的噪声图像 x t x_t xt ,都可以通过 x 0 x_0 x0直接加噪得到,后面需要用到。

注:上述前向过程在代码实现时是一步到位的!!!!!

反向去噪过程,神经网络拟合过程

反向去噪过程就是数据生成过程,它首先是从标准高斯分布中采样得到一个噪声样本,再一步步地迭代去噪,最后得到数据分布中的一个样本。

img

如果知道反向过程的每一步真实的条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt),那么从一个随机噪声开始,逐步采样就能生成一个真实的样本。但是真实的条件分布利用贝叶斯公式
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_{t}) =\frac{q(x_{t}|x_{t-1})q(x_{t-1})}{q(x_{t})} q(xt1xt)=q(xt)q(xtxt1)q(xt1)
无法直接求解,原因是其中 q ( x t − 1 ) q(x_{t-1}) q(xt1) , q ( x t ) q(x_{t}) q(xt) 未知,因此无法从 x t x_{t} xt 推导到 x t − 1 {x_{t-1}} xt1,所以必须通过神经网络** p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt1xt)来近似。为了简化起见,将反向过程也定义为一个马尔卡夫链,且服从高斯分布**,建模如下:

p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{0:T})=p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t)\\ p_\theta(x_{t-1}|x_t)=N(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))

--------------------下面这段讲解与上面有些跳脱,是为损失函数做铺垫------------------------------

虽然真实条件分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)无法直接求解,但是加上已知条件 x 0 x_0 x0的后验分布$q(x_{t-1}|x_{t},x_{0}) $却可以通过贝叶斯公式求解,再结合前向马尔科夫性质可得
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_{t},x_{0}) =\frac{q(x_{t}|x_{t-1},x_{0})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})}=\frac{q(x_{t}|x_{t-1})q(x_{t-1}|x_{0})}{q(x_{t}|x_{0})} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)=q(xtx0)q(xtxt1)q(xt1x0)
因此可以得到:
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ ∼ N ( α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t ϵ ∼ N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q ( x t ∣ x t − 1 ) = α t x t − 1 + β t ϵ ∼ N ( α t x t − 1 , β t I ) \begin{aligned} q(x_{t-1}|x_{0})&=\sqrt{\bar{\alpha}_{t-1}}x_{0}+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}x_{0},(1-\bar{\alpha}_{t-1})I)\\ q(x_{t}|x_{0})&=\sqrt{\bar{\alpha}_{t}}x_{0}+\sqrt{1-\bar{\alpha}_{t}}\epsilon\sim \mathcal{N}(\sqrt{\bar{\alpha}_{t}}x_{0},(1-\bar{\alpha}_{t})I)\\ q(x_{t}|x_{t-1})&=\sqrt{\alpha}_{t}x_{t-1}+\beta_{t}\epsilon\sim \mathcal{N}(\sqrt{\alpha}_{t}x_{t-1},\beta_{t}I) \end{aligned} q(xt1x0)q(xtx0)q(xtxt1)=αˉt1 x0+1αˉt1 ϵN(αˉt1 x0,(1αˉt1)I)=αˉt x0+1αˉt ϵN(αˉt x0,(1αˉt)I)=α txt1+βtϵN(α txt1,βtI)
所以
q ( x t − 1 ∣ x t , x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t ) + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) = e x p ( − 1 2 ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α t ˉ 1 − α t ˉ x 0 ) x t − 1 + C ( x t , x 0 ) ) \begin{aligned} q(x_{t-1}|x_{t},x_{0}) &\propto exp(-\frac{1}{2}(\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{\beta_{t}})+\frac{(x_{t-1}-\sqrt{\bar{\alpha}}_{t-1}x_{0})^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_{t}-\sqrt{\bar{\alpha}_{t}}x_{0})^2}{1-\bar{\alpha}_{t}})\\ &=exp(-\frac{1}{2}(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_{t}}}{\beta_{t}}x_{t}+\frac{2\sqrt{\bar{\alpha_{t}}}}{1-\bar{\alpha_{t}}}x_{0})x_{t-1}+C(x_{t},x_{0})) \end{aligned} q(xt1xt,x0)exp(21(βt(xtαt xt1)2)+1αˉt1(xt1αˉ t1x0)21αˉt(xtαˉt x0)2)=exp(21(βtαt+1αˉt11)xt12(βt2αt xt+1αtˉ2αtˉ x0)xt1+C(xt,x0))

通过配方就可以得到
β ~ t = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 − α ˉ t − 1 1 − α ˉ t β t μ ~ t = ( α t β t x t + α ˉ t 1 − α t ˉ x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = α t ( 1 − α ˉ t − 1 ) 1 − α t ˉ x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \widetilde{\beta}_t=1/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_{t}\\ \widetilde{\mu}_t=(\frac{\sqrt\alpha_{t}}{\beta_{t}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t}}}{1-\bar{\alpha_{t}}}x_{0})/(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}})=\frac{\sqrt{\alpha_{t}}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_{t}}{1-\bar{\alpha}_{t}}x_{0} β t=1/(βtαt+1αˉt11)=1αˉt1αˉt1βtμ t=(βtα txt+1αtˉαˉt x0)/(βtαt+1αˉt11)=1αtˉαt (1αˉt1)xt+1αˉtαˉt1 βtx0

又因为
x 0 = 1 α ˉ t ( x t − β t 1 − α ˉ t ϵ ) x_0= \frac{1}{\sqrt{\bar\alpha_t}}(x_t- \frac{\beta_t}{\sqrt{1-\bar \alpha_t} }\epsilon)\\ x0=αˉt 1(xt1αˉt βtϵ)
可以得
μ ~ t = 1 α t ( x t − β t ( 1 − α t ) ϵ ) \widetilde{\mu}_t=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon) μ t=αt 1(xt(1αt) βtϵ)

----------------------------------------------------------------------------------------------

采样过程(模型训练完后的预测过程)

μ θ ( x t , t ) = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) x t − 1 ∼ p θ ( x t − 1 ∣ x t ) x t − 1 = 1 α t ( x t − β t ( 1 − α t ) ϵ θ ( x t , t ) ) + β ~ t z      z ∼ N ( 0 , I ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))\\ x_{t-1}\sim p_\theta(x_{t-1}|x_t)\\ x_{t-1}=\frac{1}{\sqrt{\alpha_t} }(x_t-\frac{\beta_t}{\sqrt{(1-\alpha_t)}}\epsilon_\theta(x_t,t))+\sqrt{\widetilde{\beta}_t}z \space \space\space\space z\sim N(0,I) μθ(xt,t)=αt 1(xt(1αt) βtϵθ(xt,t))xt1pθ(xt1xt)xt1=αt 1(xt(1αt) βtϵθ(xt,t))+β t z    zN(0,I)
这里用z是为了和之前的 ϵ \epsilon ϵ区别开

损失函数

https://blog.csdn.net/weixin_45453121/article/details/131223653

Code

import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
from torch.optim import Adam
from torch import nn
import math
from torchvision.utils import save_image


def show_images(data, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15))
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(img[0])


def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    #print("out:",out)
    #print("out.shape:",out.shape)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)






def load_transformed_dataset(IMG_SIZE):
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.MNIST(root="./Data",transform=data_transform,train=True)
    test = torchvision.datasets.MNIST(root="./Data", transform=data_transform, train=False)

    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    #Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))





class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t):
        #print("ttt:",t.shape)
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels =1   #灰度图为1,彩色图为3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1   #灰度图为1 ,彩色图为3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)




def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)



@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )

    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image(IMG_SIZE):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 1, img_size, img_size), device=device)   #生成第T步的图片
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        #print("t:",t)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            plt.title(str(i))
            show_tensor_image(img.detach().cpu())
    plt.show()


if __name__ =="__main__":

    # Define beta schedule
    T = 300
    betas = linear_beta_schedule(timesteps=T)

    # Pre-calculate different terms for closed form
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    # print(alphas_cumprod.shape)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    # print(alphas_cumprod_prev)
    # print(alphas_cumprod_prev.shape)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    # print(posterior_variance.shape)


    IMG_SIZE = 32
    BATCH_SIZE = 16

    data = load_transformed_dataset(IMG_SIZE)
    dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    model = SimpleUnet()
    print("Num params: ", sum(p.numel() for p in model.parameters()))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    epochs = 1 # Try more!

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):  #由于batch 是包含标签的所以取batch[0]
            #print(batch[0].shape)
            optimizer.zero_grad()

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch[0], t)
            loss.backward()
            optimizer.step()

            if epoch % 1 == 0 and step %5== 0:
                print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                sample_plot_image(IMG_SIZE)

参考文献

https://zhuanlan.zhihu.com/p/630354327](https://zhuanlan.zhihu.com/p/630354327)

https://blog.csdn.net/weixin_45453121/article/details/131223653

https://www.cnblogs.com/risejl/p/17448442.html

https://zhuanlan.zhihu.com/p/569994589?utm_id=0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1225632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据库】数据库连接池导致系统吞吐量上不去-复盘

在实际的开发中,我们会使用数据库连接池,但是如果不能很好的理解其中的含义,那么就可以出现生产事故。 HikariPool-1 - Connection is not available, request timed out after 30001ms.当系统的调用量上去,就出现大量这样的连接…

OpenGL 坐标投影与反投影(Qt)

文章目录 一、简介1.1投影1.2反投影二、应用代码三、实现效果参考资料一、简介 在学习OpenGL一段时间之后,我们都会了解坐标的转换过程,如下图所示: 1.1投影 正如图中所述,OpenGL将一个3D坐标投影到一个2D空间主要有以下几个步骤,这也是我们比较熟知的几个步骤: 现实局部…

Android SdkManager简介

关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、商业变现、人工智能等,希望大家多多支持。 目录 一、导读二、概览三、 安装使用3.1 安装3.2 使用3.3 选项…

05-Spring Boot工程中简化开发的方式Lombok和dev-tools

简化开发的方式Lombok和dev-tools Lombok常用注解 Lombok用标签方式代替构造器、getter/setter、toString()等重复代码, 在程序编译的时候自动生成这些代码 注解名功能NoArgsConstructor生成无参构造方法AllArgsConstructor生产含所有属性的有参构造方法,如果不希望含所有属…

注解【元数据,自定义注解等概念详解】(超简单的好吧)

注解的理解与使用 注解的释义元数据的含义基础阶段常见的注解注解的作用(包括但不限于)教你读懂注解内部代码内容五种元注解尝试解读简单注解我当时的疑惑点 自定义注解自定义注解举例 注解的原理总结 注解的释义 我们都知道注释是拿来给程序员看的&…

asp.net学生成绩评估系统VS开发sqlserver数据库web结构c#编程计算机网页项目

一、源码特点 asp.net 学生成绩评估系统 是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 系统运行视频连接:https://www.bilibili.com/video/BV1Wz4y1A7CG/ 二、功能介绍 本系统使用Microsof…

创建谷歌账号 绕过手机验证(2023.11亲测有效)

如何成功注册谷歌账号:一个详细实用指南 写在最前面谷歌注册全流程环境配置切换至全英文环境 开通foxmail.com邮箱在英文环境下注册验证邮箱注册过程中的注意事项完成!总结 写在最前面 在这个数字化迅速发展的时代,谷歌账号几乎成为了我们日…

【计算机网络笔记】IPv6简介

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

二分查找算法合集

二分查找也称折半查找(Binary Search),它是一种效率较高的查找方法。但是,折半查找要求线性表必须采用顺序存储结构,而且表中元素按关键字有序排列。 时间复杂度 O(logn) 自己写二分算法 左闭右开 左开右闭C算法&a…

机器学习笔记 - Ocr识别中的文本检测EAST网络概述

一、文本检测 文本检测简单来说就是找到图像中可以出现文本的区域。例如,请参见下图,其中在检测到的文本周围绘制了绿色边框。 在进行文本检测时,你可能会遇到两种情况 具有结构化文本的图像:这是指具有干净/均匀背景和常规字体的图像。文本大多密集,行结构正确,…

Linux shell编程学习笔记27:tputs

除了stty命令,我们还可以使用tput命令来更改终端的参数和功能。 1 tput 命令的功能 tput 命令的主要功能有:移动更改光标、更改文本显示属性(如颜色、下划线、粗体),清除屏幕特定区域等。 2 tput 命令格式 tput [选…

使用 Python进行量化交易:前向验证分析

运行环境:Google Colab 1. 利用 yfinance 下载数据 import yfinance as yfticker AAPL df yf.download(ticker) df下载苹果的股票数据 df df.loc[2018-01-01:].copy()dfdf[change_tomorrow] df[Adj Close].pct_change(-1) df.change_tomorrow df.change_tom…

C++二分查找算法:查找和最小的 K 对数字

相关专题 二分查找相关题目 题目 给定两个以 非递减顺序排列 的整数数组 nums1 和 nums2 , 以及一个整数 k 。 定义一对值 (u,v),其中第一个元素来自 nums1,第二个元素来自 nums2 。 请找到和最小的 k 个数对 (u1,v1), (u2,v2) … (uk,vk) 。 示例 1:…

【C语言的秘密】密探—深究C语言中多组输入的秘密!

场景引入: 你是否在刷题过程中,经常遇到以下场景呢? 场景一: 场景二: 从这些题上都能看见输入描述中提出了一条多组输入,那啥是多组输入?如何实现它呢? 多组输入:在输入…

Centos(Linux)服务器安装Dotnet8 及 常见问题解决

1. 下载dotnet8 sdk 下载 .NET 8.0 SDK (v8.0.100) - Linux x64 Binaries 拿到 dotnet-sdk-8.0.100-linux-x64.tar.gz 文件 2. 把文件上传到 /usr/local/software 目录 mkdir -p /usr/local/software/dotnet8 把文件拷贝过去 mv dotnet-sdk-8.0.100-linux-x64.tar.gz /usr/loc…

【Linux】 uptime命令使用

uptime 正常运行时间提供以下信息的单行显示。当前时间、系统运行的时间、当前登录的用户数量以及过去1、5和15分钟的系统平均负载。 语法 uptimeuptime命令 -Linux手册页 作者 由Larry Greenfield编写和迈克尔K约翰逊编写。 命令选项及作用 执行令 man uptime 执行命令结…

【linux】进行间通信——共享内存+消息队列+信号量

共享内存消息队列信号量 1.共享内存1.1共享内存的原理1.2共享内存的概念1.3接口的认识1.4实操comm.hppservice.cc (写)clint.cc (读) 1.5共享内存的总结1.6共享内存的内核结构 2.消息队列2.1原理2.2接口 3.信号量3.1信号量是什么3…

goland 远程调试 remote debug

1、远程服务器装好go环境,并设置国内源 linux go安装 参考: 如何在 Debian / Ubuntu 上安装 Go 开发环境 - 知乎 设置国内源 go env -w GOPROXYhttps://goproxy.cn,direct 2、远程服务器安装dlv git clone https://github.com/derekparker/delve.gi…

移动端表格分页uni-app

使用uni-app提供的uni-table表格 网址&#xff1a;https://uniapp.dcloud.net.cn/component/uniui/uni-table.html#%E4%BB%8B%E7%BB%8D <uni-table ref"table" :loading"loading" border stripe type"selection" emptyText"暂无更多数据…

C++ STL之string初始

我最近开了几个专栏&#xff0c;诚信互三&#xff01; > |||《算法专栏》&#xff1a;&#xff1a;刷题教程来自网站《代码随想录》。||| > |||《C专栏》&#xff1a;&#xff1a;记录我学习C的经历&#xff0c;看完你一定会有收获。||| > |||《Linux专栏》&#xff1…