生成模型:扩散模型(DDPM, DDIM, 条件生成)

news2025/1/30 8:13:31

扩散模型的理论较为复杂,论文公式与开源代码都难以理解。现有的教程大多侧重推导公式。为此,本文通过精简代码(约300行),更多以代码运行角度讲解扩散模型。

本代码包括扩散模型的主流技术复现:

1.DDPM (Denoising Diffusion Probabilistic Models,去噪扩散概率模型,SDE-包含训练与推理)

2.DDIM (Denoising Diffusion Implicit Models,ODE-加速推理)

3.Classifier_free(以标签为条件的控制生成)

1. 训练-加噪过程

1.1 参数设置

  • timesteps是加噪次数,默认为1000次,具体某一次视为一个时间步 t {t} t

步数越多过程越稳定,Mnist实验中100次以上能保证效果。但t在训练中是随机的,并没有时序性,即值为[0,timesteps]内的随机正整数

注:本文time_steps = 300

  • α \alpha α是接近但小于1的递减时序序列,有timesteps个, α t \alpha_t αt α \alpha α序列第t个元素的值

本文 α \alpha α的取值范围是:[0.9997, 0.97]

  • β \beta β是接近但大于0的递增时序序列,有 β = 1 − α \beta = 1- \alpha β=1α. 同理, β t \beta_t βt β \beta β序列第t个元素的值

范围是:[0.0003, 0.03]

  • α ˉ \bar{\alpha} αˉ α \alpha α的累积值,是一个1到0的递减序列(序列元素值不包含边界0和1).

假设alphas = [0.9, 0.8, 0.7] (即timesteps = 3), 则 α ˉ \bar{\alpha} αˉ = [0.9, 0.9 * 0.8, 0.9 * 0.8 * 0.7] = [0.9, 0.72, 0.504]

α ˉ t \bar{\alpha}_t αˉt则是第t步的序列值

  • 1 − α ˉ \sqrt{1-\bar{\alpha}} 1αˉ ,与 α ˉ \bar{\alpha} αˉ相反,是一个0到1的递增序列(序列元素同样不包含边界0与1),

区别于 β ˉ \bar{\beta} βˉ,即 β \beta β的累积值,它是一个接近0的递减序列(代码中没有用到)

  • x 0 x_0 x0是数据集样本,即训练集的图像

  • ϵ \epsilon ϵ是服从标准正态分布的高维张量(tensor)样本, ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0,1) ϵN(0,1),即随机噪声

1.2 加噪方程

顺序加噪过程是一个Markov Chain如下:

x t = α t x t − 1 + 1 − α t ϵ x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1 - \alpha_t}\epsilon xt=αt xt1+1αt ϵ

上式表示从 x 0 x_0 x0 x t x_t xt是一个加噪的高斯过程,即对图像进行timesteps次加噪,通过 α ˉ t \bar{\alpha}_t αˉt 1 − α ˉ t 1-\bar{\alpha}_t 1αˉt控制图像和噪声的比例,

最终随着t的增长, α t \alpha_t αt的值变小,图像在信号的比例降低,而噪声的比例增大, 最终图像变为全高斯噪声。

但是,这个训练过程每次需要迭代timesteps次,训练效率太低(类似RNN), 这里可以简化,概率表述如下:

q ( x t ∣ x 0 ) ∼ N ( α ˉ t   x 0 , ( 1 − α ˉ t )   I ) q(x_t \mid x_0) \sim \mathcal{N}(\sqrt{\bar{\alpha}_t} \, x_0, (1 - \bar{\alpha}_t) \, \mathbf{I}) q(xtx0)N(αˉt x0,(1αˉt)I)

即只需要 x 0 x_0 x0即可完成训练,其中每个sample都可以是随机的t,即有下述公式:

x t = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon xt=αˉt x0+1αˉt ϵ

该公式描述了训练过程,随机取batch_size个t,每个t加到上述公式即可得到前向过程的 x t x_t xt, 并与模型的输出做MSE即可。

1.3 加噪过程

1.3.1 数据预处理

  • 训练模型每次送入batch_size个图片,每个图片分一个随机的正数t,t的值域为[0, timesteps],

因此,根据timesteps和batch_size可算出t值。计算代码如下:

t = torch.randint(0, timesteps, (batch_size,), device=device).long() 
  • 根据timesteps, 算出 β \beta β, α \alpha α,进一步算出 α ˉ \bar{\alpha} αˉ (alphas_cumprod), 其shape = [timesteps].
def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = 0.0003 * scale # 该值过小,去燥不充分
    beta_end = 0.03 * scale # 该值过小,生成胡乱条纹
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
    
betas = linear_beta_schedule(timesteps)

tensor([0.0010, 0.0013, 0.0017, 0.0020, 0.0023, 0.0027, 0.0030, 0.0033, 0.0036,
        0.0040, 0.0043, 0.0046, 0.0050, 0.0053, 0.0056, 0.0060, 0.0063, 0.0066,
        0.0070, 0.0073, 0.0076, 0.0080, 0.0083, 0.0086, 0.0089, 0.0093, 0.0096,
        ...
        0.0934, 0.0937, 0.0940, 0.0944, 0.0947, 0.0950, 0.0954, 0.0957, 0.0960,
        0.0964, 0.0967, 0.0970, 0.0974, 0.0977, 0.0980, 0.0983, 0.0987, 0.0990,
        0.0993, 0.0997, 0.1000], dtype=torch.float64)

alphas = 1. - betas

tensor([0.9990, 0.9987, 0.9983, 0.9980, 0.9977, 0.9973, 0.9970, 0.9967, 0.9964,
        0.9960, 0.9957, 0.9954, 0.9950, 0.9947, 0.9944, 0.9940, 0.9937, 0.9934,
        0.9930, 0.9927, 0.9924, 0.9920, 0.9917, 0.9914, 0.9911, 0.9907, 0.9904,
         ...
        0.9066, 0.9063, 0.9060, 0.9056, 0.9053, 0.9050, 0.9046, 0.9043, 0.9040,
        0.9036, 0.9033, 0.9030, 0.9026, 0.9023, 0.9020, 0.9017, 0.9013, 0.9010,
        0.9007, 0.9003, 0.9000], dtype=torch.float64)

alphas_cumprod = torch.cumprod(self.alphas, axis=0)

tensor([9.9900e-01, 9.9767e-01, 9.9601e-01, 9.9403e-01, 9.9172e-01, 9.8908e-01,
        9.8613e-01, 9.8286e-01, 9.7927e-01, 9.7537e-01, 9.7117e-01, 9.6666e-01,
        9.6185e-01, 9.5675e-01, 9.5136e-01, 9.4568e-01, 9.3973e-01, 9.3350e-01,
         ...
        8.8150e-07, 7.9802e-07, 7.2218e-07, 6.5331e-07, 5.9079e-07, 5.3406e-07,
        4.8260e-07, 4.3594e-07, 3.9364e-07, 3.5532e-07, 3.2061e-07, 2.8919e-07,
        2.6075e-07, 2.3502e-07, 2.1175e-07, 1.9072e-07, 1.7171e-07, 1.5454e-07],
       dtype=torch.float64)
  • 得到公式系数: α ˉ \sqrt{\bar{\alpha}} αˉ (sqrt_alphas_cumprod), 1 − α ˉ \sqrt{1-\bar{\alpha}} 1αˉ (sqrt_one_minus_alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
tensor([9.9900e-01, 9.9767e-01, 9.9601e-01, 9.9403e-01, 9.9172e-01, 9.8908e-01,
        9.8613e-01, 9.8286e-01, 9.7927e-01, 9.7537e-01, 9.7117e-01, 9.6666e-01,
        9.6185e-01, 9.5675e-01, 9.5136e-01, 9.4568e-01, 9.3973e-01, 9.3350e-01,
         ...
        8.8150e-07, 7.9802e-07, 7.2218e-07, 6.5331e-07, 5.9079e-07, 5.3406e-07,
        4.8260e-07, 4.3594e-07, 3.9364e-07, 3.5532e-07, 3.2061e-07, 2.8919e-07,
        2.6075e-07, 2.3502e-07, 2.1175e-07, 1.9072e-07, 1.7171e-07, 1.5454e-07],
       dtype=torch.float64)

sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_one_minus_alphas_cumprod
tensor([0.0316, 0.0483, 0.0632, 0.0773, 0.0910, 0.1045, 0.1178, 0.1309, 0.1440,
        0.1569, 0.1698, 0.1826, 0.1953, 0.2080, 0.2205, 0.2331, 0.2455, 0.2579,
        0.2702, 0.2824, 0.2946, 0.3067, 0.3187, 0.3306, 0.3424, 0.3542, 0.3658,
        ...
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000], dtype=torch.float64)

  • 计算t步的 α ˉ t \sqrt{\bar{\alpha}_t} αˉt (sqrt_alphas_cumprod_t) 和 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1αˉt (sqrt_one_minus_alphas_cumprod_t)

由于一个batch_size中的图片t的值是随机的,且需要取每个t在 α ˉ \bar{\alpha} αˉ的值, 类似key-value索引,因此构造extract函数

该函数取每个样本的t(当作key)在 α ˉ \sqrt{\bar{\alpha}} αˉ 对应的值(value),并reshape:

def _extract(a: torch.FloatTensor, t: torch.LongTensor, x_shape):
        # get the param of given timestep t
        batch_size = t.shape[0]
        out = a.to(t.device).gather(0, t).float()
        out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
        return out

这里:

gather()函数实现t到噪声强度值的key-value索引: sqrt_alphas_cumprod.gather(0,t)

得到的强度值reshape为 [batch_size, channels, 1 , 1], 方便加到图像中。

最后通过q_sample函数算出 α ˉ t \sqrt{\bar{\alpha}_t} αˉt (sqrt_alphas_cumprod_t) 和 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1αˉt (sqrt_one_minus_alphas_cumprod_t)

函数代码如下:

def q_sample(x_start: torch.FloatTensor, t: torch.LongTensor, noise=None): 
        # 前向加噪过程:forward diffusion (using the nice property): q(x_t | x_0)
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

举例:batch_size = 5, 随机得到一组t, 可从[0, timesteps]内随机采样得到 = [ 75, 112, 268, 207, 90], 索引+reshape后的值为:

sqrt_alphas_cumprod_t = 

tensor([[[[0.5979]]],


        [[[0.3268]]],


        [[[0.0018]]],


        [[[0.0234]]],


        [[[0.4814]]]])

以上每个数值对应一个样本的t,t.shape = [5,1,1,1]; 假设图像x.shape = [5, 3, 32, 32], 即一个batch_size有5张32x32的3通道RGB彩色图像

而x+t会触发广播机制:即一个样本(图像)的每个元素(像素点)都会加上对应的t,[1,1,1] broadcast to [3,32,32]

1.3.2 训练

  • 构造x_T, 即一个与x_0相同shape的高斯噪声 ϵ \epsilon ϵ

noise = torch.randn_like(x_start) # random noise ~ N(0, 1)

  • 通过q_sample函数算出batch_size个样本从 x 0 x_0 x0 x t x_t xt的高斯噪声估计

就是前面提到的公式:

x t = α ˉ t ⋅ x 0 + 1 − α ˉ t ⋅ ϵ x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon xt=αˉt x0+1αˉt ϵ

  • 两者做MSE

像u-net模型送入 x t x_t xt, 与对应的t,输出得到估计噪声 ϵ θ \epsilon_\theta ϵθ,

ϵ \epsilon ϵ ϵ θ \epsilon_\theta ϵθ做MSE, 完整代码如下:

def train_losses(model, x_start: torch.FloatTensor, t: torch.LongTensor):
    noise = torch.randn_like(x_start)  # random noise ~ N(0, 1)
    x_noisy = self.q_sample(x_start, t, noise=noise)  # x_t ~ q(x_t | x_0)
    predicted_noise = model(x_noisy, t)  # predict noise from noisy image
    loss = F.mse_loss(noise, predicted_noise)
    return loss

2. DDPM推理 - 去噪过程

去噪过程即图像生成过程,将相同shape的高斯噪声去燥成为图。

与训练不同,这里的t是逆序的, 且DDPM的timesteps不能跳步,需要执行timesteps次,即由 x T x_T xT x 0 x_0 x0

2.1 参数设置

  • α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1 (alphas_cumprod_prev),时间步 t − 1 t-1 t1 的累积 α \alpha α 值,这里首位元素填充1,保证长度和 α ˉ t \bar{\alpha}_{t} αˉt一致。

用于计算后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_t, x_0) q(xt1xt,x0) 的均值( μ t − 1 \mu_{t-1} μt1)和方差( σ t − 1 2 \sigma_{t-1}^2 σt12)

具体实现代码:

alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
  • μ t − 1 \mu_{t-1} μt1 (posterior_mean)

计算公式如下:

μ t − 1 = α ˉ t − 1 β t 1 − α ˉ t x 0 + ( 1 − α ˉ t − 1 ) α t 1 − α ˉ t x t \mu_{t-1} = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0 + \frac{(1 - \bar{\alpha}_{t-1})\sqrt{\alpha_t}}{1 - \bar{\alpha}_t} x_t μt1=1αˉtαˉt1 βtx0+1αˉt(1αˉt1)αt xt

其中 α ˉ t − 1 β t 1 − α ˉ t \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} 1αˉtαˉt1 βt 记为posterior_mean_coef1 ( x 0 的系数 x_0的系数 x0的系数), ( 1 − α ˉ t − 1 ) α t 1 − α ˉ t ) \frac{(1 - \bar{\alpha}_{t-1})\sqrt{\alpha_t}}{1 - \bar{\alpha}_t}) 1αˉt(1αˉt1)αt ) 记为posterior_mean_coef2 ( x t 的系数 x_t的系数 xt的系数

这里的 x 0 x_0 x0是当前时刻t算出的估计值,不是数据集样本,计算两个coef的代码如下:

posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
  • σ t − 1 2 \sigma_{t-1}^2 σt12 (posterior_variance)

σ t − 1 2 = β t ( 1 − α ˉ t − 1 ) 1 − α ˉ t \sigma_{t-1}^2 = \frac{\beta_t (1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} σt12=1αˉtβt(1αˉt1)

posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

这里 σ t − 1 2 \sigma_{t-1}^2 σt12非常小,直接存储可能造成下溢,导致无法计算梯度,因此计算 log ⁡ σ t − 1 2 \log{\sigma^2_{t-1}} logσt12

posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
  • x 0 x_0 x0 是在t时刻图像的估计值,最后时刻t的去噪图像才是最终结果(img)

x 0 = x t α ˉ t − 1 − α ˉ t α ˉ t ⋅ ϵ θ x_0 = \frac{x_t}{\sqrt{\bar{\alpha}_t}} - \frac{\sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \cdot \epsilon_\theta x0=αˉt xtαˉt 1αˉt ϵθ

代码是:

pre_x_0 = _extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * model(x_t, t) # pred_noise = model(x_t, t)

其中 ϵ θ \epsilon_\theta ϵθ为模型的输出,即t步时预测的噪声,如果将上式的 x 0 x_0 x0带入开始的 μ t − 1 \mu_{t-1} μt1公式,可以得到 x t x_t xt x t − 1 x_{t-1} xt1完整公式:

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) + σ t ⋅ z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right) + \sigma_t \cdot z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

简化为:

x t − 1 = μ t − 1 + σ t − 1 ⋅ ϵ x_{t-1} = \mu_{t-1} + \sigma_{t-1} \cdot \epsilon xt1=μt1+σt1ϵ

这里的 ϵ \epsilon ϵ是随机噪声

2.2 去噪方程

在代码中通过对 x t x_t xt去燥,得到 x t − 1 x_{t-1} xt1的表达式为:

$x_{t-1} = \mu_{t-1} + mask \cdot e{\frac{1}{2}\log{\sigma_{t-1}2}} \cdot \epsilon $

其中:

  • 为有效训练, 用 e 1 2 log ⁡ σ t − 1 2 e^{\frac{1}{2}\log{\sigma_{t-1}^2}} e21logσt12 替换 σ t − 1 \sigma_{t-1} σt1 (两者理论值相等),

目的是避免直接用 σ t − 1 \sigma_{t-1} σt1造成数值过小而产生下溢(nan),无法计算梯度。

  • mask是一个掩码矩阵,shape = [batch_size, 1, 1, 1],等t=0时,其元素的全部值为0,t不为0时,元素值为1

目的是t=0时不加噪,即 x 0 = μ 0 x_{0} = \mu_0 x0=μ0

mask代码是:

nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) # no noise when t == 0

2.3 去噪过程

  • 通过噪声累积值 α t ˉ \bar{\alpha_t} αtˉ填充得到 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1, 算出方差 σ t − 1 2 \sigma^2_{t-1} σt12, 并转为 log ⁡ σ t − 1 2 \log{\sigma^2_{t-1}} logσt12

假设初始噪音为 x_t.shape = (batch_size, channels, image_size, image_size),

得到batch_size个 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1 log ⁡ σ t − 1 2 \log{\sigma^2_{t-1}} logσt12代码如下:

def q_posterior_mean_variance(self, x_start: torch.FloatTensor, x_t: torch.FloatTensor, t: torch.LongTensor):
    # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
    posterior_mean = (self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)
    posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
    posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
    return posterior_mean, posterior_variance, posterior_log_variance_clipped
  • 通过初始噪音 x t x_t xt α t ˉ \bar{\alpha_t} αtˉ, 以及模型输出 ϵ θ \epsilon_\theta ϵθ, 估计出 x 0 x_0 x0, 这里 x 0 x_0 x0需要裁剪到合理值域{-1,1}

这里调用了上一步的q_posterior_mean_variance函数,因为要先算 x 0 x_0 x0 (x_start)才能得到 μ t − 1 \mu_{t-1} μt1 (model_mean)

def p_mean_variance(self, model, x_t: torch.FloatTensor, t: torch.LongTensor):
    # compute x_0 from x_t and pred noise: the reverse of `q_sample`, 估计值,包含部分残留噪声
    pre_x_0 = self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * model(x_t, t) # pred_noise = model(x_t, t)
    pre_x_0 = torch.clamp(pre_x_0, min=-1., max=1.) # clip_denoised
    model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(pre_x_0, x_t, t) ## compute predicted mean and variance of p(x_{t-1} | x_t), predict noise using model
    return model_mean, posterior_variance, posterior_log_variance
  • 通过 x t x_t xt x 0 x_0 x0,以及参数 α t ˉ , β t ˉ , α ˉ t − 1 \bar{\alpha_t},\bar{\beta_t},\bar{\alpha}_{t-1} αtˉ,βtˉ,αˉt1组成的系数, 算出均值 μ t − 1 \mu_{t-1} μt1

这个放在了第一个函数里:

posterior_mean = (self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t)

  • 通过公式计算t-1时刻的 x t − 1 x_{t-1} xt1, 并重复timesteps次到 x 0 x_0 x0

单次代码为:

def p_sample(self, model, x_t: torch.FloatTensor, t: torch.LongTensor):
    # denoise_step: sample x_{t-1} from x_t and pred_noise, predict mean and variance
    model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t)
    noise = torch.randn_like(x_t)
    nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) # no noise when t == 0
    pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise # # compute x_{t-1}
    return pred_img

从T时刻,即构造高斯噪声 x T x_T xT开始,执行timesteps次的循环为 (注意这里时逆向reverse,即从T到0):

def sample(self, model: nn.Module, image_size, batch_size=8, channels=3):
    shape = (batch_size, channels, image_size, image_size) # denoise: reverse diffusion
    device = next(model.parameters()).device
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)  # x_T ~ N(0, 1)
    imgs = []
    for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
        t = torch.full((batch_size,), i, device=device, dtype=torch.long)
        img = self.p_sample(model, img, t)
        imgs.append(img.cpu().numpy())
    return imgs

3.DDIM

总的来说,DDIM是DDPM的跳步采样,简化了采样公式,可以加快采样速度。

为简化代码,这里省略了论文公式中DDPM+DDIM混合采样的参数(即去掉了公式中DDPM的可选随机项),仅保留DDIM采样参数。

3.1 去噪方程

DDPM每一步都需要一个随机噪声 ϵ \epsilon ϵ,即SDE随机采样,引入随机噪声的好处是增加多样性,但timesteps次数多(通常要100-1000步)采样慢。

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) + σ t ⋅ ϵ = α ˉ t − 1 ⋅ x 0 + 1 − α ˉ t − 1 ⋅ x t − α ˉ t ⋅ x 0 1 − α ˉ t + σ t ⋅ ϵ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right) + \sigma_t \cdot \epsilon = \sqrt{\bar{\alpha}_{t-1}} \cdot x_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} \cdot x_0}{\sqrt{1 - \bar{\alpha}_t}} + \sigma_t \cdot \epsilon xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtϵ=αˉt1 x0+1αˉt1 1αˉt xtαˉt x0+σtϵ

DDIM每一步不需要引入噪声,是一种确定性的ODE采样,类似跳步采样的方法,可以自定义步数,通常10-50步就能生成效果。

x t − 1 = α ˉ t − 1 ⋅ x 0 + 1 − α ˉ t − 1 ⋅ x t − α ˉ t ⋅ x 0 1 − α ˉ t x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot x_0 + \sqrt{1 - \bar{\alpha}_{t-1}} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} \cdot x_0}{\sqrt{1 - \bar{\alpha}_t}} xt1=αˉt1 x0+1αˉt1 1αˉt xtαˉt x0

其中 x 0 x_0 x0是通过模型输出的t时刻 ϵ θ \epsilon_\theta ϵθ计算得到的估计值

3.2 去噪过程

3.2.1 采样参数

  • x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I), 随机噪声
shape = (batch_size, channels, image_size, image_size)
x_T = torch.randn(shape, device=self.betas.device) # start from pure noise
  • 时间间隔c, 用于从原ddpm的timesteps中抽取等距的跳步数索引(index)
c = ddpm_timesteps / ddim_timesteps
  • 根据c,抽取间隔序列 T T T (ddim_timestep_seq) 与 T p r e T_{pre} Tpre (ddim_timestep_prev_seq) 序列。

其中 T p r e T_{pre} Tpre序列是移除 T T T的最后一个元素后,在序列首位补上第一个元素(值为0)

ddim_timestep_seq = torch.tensor(list(range(0, self.timesteps, c))) + 1 # one from first scale to data during sampling
ddim_timestep_prev_seq = torch.cat((torch.tensor([0]), ddim_timestep_seq[:-1])) # previous sequence

3.2.2 单次去噪

  • T T T T p r e T_{pre} Tpre 索引去噪循环的第 t 与 t-1步
t = torch.full((batch_size,), ddim_timestep_seq[i], device=x_T.device, dtype=torch.long)
next_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=x_T.device, dtype=torch.long)
  • 将根据 α ˉ \bar{\alpha} αˉ α t ˉ \bar{\alpha_t} αtˉ, α t − 1 ˉ \bar{\alpha_{t-1}} αt1ˉ
alpha_cumprod_t = self._extract(self.alphas_cumprod, t, x_T.shape) #1. get current and previous alpha_cumprod
alpha_cumprod_t_prev = self._extract(self.alphas_cumprod, next_t, x_T.shape)
  • 输出U-Net模型的预测噪声 ϵ θ \epsilon_\theta ϵθ, 第一次是输入 T T T x T x_T xT
pred_noise = model(x_t, t)
  • 根据t时刻的 ϵ θ \epsilon_\theta ϵθ估计 x 0 x_0 x0

公式与DDPM一致:

x 0 = x t α ˉ t − 1 − α ˉ t α ˉ t ⋅ ϵ θ x_0 = \frac{x_t}{\sqrt{\bar{\alpha}_t}} - \frac{\sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \cdot \epsilon_\theta x0=αˉt xtαˉt 1αˉt ϵθ

代码:

pred_x0 = (xs[-1] - torch.sqrt(1 - alpha_cumprod_t) * pred_noise) / torch.sqrt(alpha_cumprod_t)
pred_x0 = torch.clamp(pred_x0, min=-1., max=1.) # 3. get the predicted x_0, 预测 x_0
  • 根据 x 0 x_0 x0, x t x_t xt α t ˉ \bar{\alpha_t} αtˉ, α t − 1 ˉ \bar{\alpha_{t-1}} αt1ˉ 计算出去燥值 x t − 1 x_{t-1} xt1
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev) * pred_noise # 5. compute "direction pointing to x_t" of formula (12)
x_t_pre = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt # 6. compute x_{t-1} of formula (12)

3.3 代码

循环单次去燥代码即可完成DDIM去噪过程,

完整的去噪代码如下:

#DDIM Inference/Reverse
def ddim_sample(self, model, image_size, ddim_timesteps=100, batch_size=8, channels=3):
    shape = (batch_size, channels, image_size, image_size)
    x_T = torch.randn(shape, device=self.betas.device) # start from pure noise
    xs = [x_T]
    c = self.timesteps // ddim_timesteps # make ddim timestep sequence
    ddim_timestep_seq = torch.tensor(list(range(0, self.timesteps, c))) + 1 # one from first scale to data during sampling
    ddim_timestep_prev_seq = torch.cat((torch.tensor([0]), ddim_timestep_seq[:-1])) # previous sequence

    for i in tqdm(reversed(range(0,ddim_steps)), desc='ddpm sampling loop time step', total=ddim_steps):
        t = torch.full((batch_size,), ddim_timestep_seq[i], device=x_T.device, dtype=torch.long)
        next_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=x_T.device, dtype=torch.long)

        alpha_cumprod_t = self._extract(self.alphas_cumprod, t, x_T.shape) #1. get current and previous alpha_cumprod
        alpha_cumprod_t_prev = self._extract(self.alphas_cumprod, next_t, x_T.shape)

        pred_noise = model(xs[-1], t) # 2. predict noise using model, 模型预测噪声
        pred_x0 = (xs[-1] - torch.sqrt(1 - alpha_cumprod_t) * pred_noise) / torch.sqrt(alpha_cumprod_t)
        pred_x0 = torch.clamp(pred_x0, min=-1., max=1.) # 3. get the predicted x_0, 预测 x_0
        pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev) * pred_noise # 5. compute "direction pointing to x_t" of formula (12)
        x_t_pre = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt # 6. compute x_{t-1} of formula (12)
        xs.append(x_t_pre) 
        # omit 4. compute variance: "sigma_t(η)" -> see formula (16) / σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
    return xs

4. 模型结构

多数开源代码的U-Net结构较为复杂,包含较多的Attention, ResNet等。

本文设计一个简化的U-Net结构,仅保留必要部分:

4.1 下采样块 (Downsample)

4 blocks: 每个block由 1个 conv layer 构成,每两个conv layer 完成一次下采样

这里有两种conv layer,一种保持特征shape,一种下采样,具体代码如下:

class Upsample(nn.Module):
    def __init__(self, channels, num_groups=32):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.num_groups = num_groups

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest") #        # 上采样
        x = self.conv(x) #        # 卷积 + GroupNorm
        return x  # 激活函数
        
down_block1 = nn.Conv2d(io_channels, model_channels, kernel_size=3, padding=1) # down blocks
down_block2 = Downsample(model_channels)

4.2 中间块 (Middle)

  • block

这里仅用1个block,包含2个conv layer,并且用resnet结构:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.shortcut = (nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity())

    def forward(self, x):
        h = F.relu(F.group_norm(self.conv1(x), num_groups=32)) # 第一层卷积 + GroupNorm + 激活
        h = F.relu(F.group_norm(self.conv2(h), num_groups=32))
        return h + self.shortcut(x)  # 残差连接

middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle block
  • 注入时序信息

U-Net在输入 x x x时,需要对特征每一个元素广播对应的时序 t t t的值, 否则无法实现生成效果

多数开源代码是将t注入到每一个layer。这里为了简化模型,仅将 t t t注入到中间层(middle layer)部分,

具体是将t嵌入三角函数

4.3 时序注入模块(time_embedding)

4.3.1 代码

具体输入t和dim,输出embedding,代码如下:

def timestep_embedding(t, dim, max_period=10000):
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim // 2, dtype=torch.float32) / (dim // 2)).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    return embedding

输入:t.shape = [batch_size], dim = layer_channels,

输出: embedding.shape = [batch_size, layer_channels]

4.3.2 计算过程

假定输入

  • timesteps = torch.tensor([0, 1, 2, 3])
  • dim = 8
1. 计算 half

half = dim 2 = 8 2 = 4 \text{half} = \frac{\text{dim}}{2} = \frac{8}{2} = 4 half=2dim=28=4

2. 计算频率 freqs

生成频率值:

freqs [ i ] = e − log ⁡ ( max_period ) ⋅ i half \text{freqs}[i] = e^{-\log(\text{max\_period}) \cdot \frac{i}{\text{half}}} freqs[i]=elog(max_period)halfi

对于 half=4max_period=10000

freqs = [ e − log ⁡ ( 10000 ) ⋅ 0 / 4 ,   e − log ⁡ ( 10000 ) ⋅ 1 / 4 ,   e − log ⁡ ( 10000 ) ⋅ 2 / 4 ,   e − log ⁡ ( 10000 ) ⋅ 3 / 4 ] \text{freqs} = \left[ e^{-\log(10000)\cdot 0/4 }, \ e^{-\log(10000) \cdot 1/4}, \ e^{-\log(10000) \cdot 2/4}, \ e^{-\log(10000) \cdot 3/4} \right] freqs=[elog(10000)0/4, elog(10000)1/4, elog(10000)2/4, elog(10000)3/4]

结果:

freqs = [ 1.0 , 0.1 , 0.01 , 0.001 ] \text{freqs} = [1.0, 0.1, 0.01, 0.001] freqs=[1.0,0.1,0.01,0.001]


3. 计算 args

timestepsfreqs 生成输入参数 args

args [ i , j ] = timesteps [ i ] ⋅ freqs [ j ] \text{args}[i, j] = \text{timesteps}[i] \cdot \text{freqs}[j] args[i,j]=timesteps[i]freqs[j]

args = torch.tensor([0, 1, 2, 3])[:, None] * torch.tensor([1.0, 0.1, 0.01, 0.001])[None, :]

结果:

args = [ 0.00 0.00 0.00 0.000 1.00 0.10 0.01 0.001 2.00 0.20 0.02 0.002 3.00 0.30 0.03 0.003 ] \text{args} = \begin{bmatrix} 0.00 & 0.00 & 0.00 & 0.000 \\ 1.00 & 0.10 & 0.01 & 0.001 \\ 2.00 & 0.20 & 0.02 & 0.002 \\ 3.00 & 0.30 & 0.03 & 0.003 \end{bmatrix} args= 0.001.002.003.000.000.100.200.300.000.010.020.030.0000.0010.0020.003


4. 计算正弦和余弦嵌入

分别对 args 应用正弦和余弦函数:

  • 余弦部分

    cos_part = cos ⁡ ( args ) \text{cos\_part} = \cos(\text{args}) cos_part=cos(args)

  • 正弦部分

    sin_part = sin ⁡ ( args ) \text{sin\_part} = \sin(\text{args}) sin_part=sin(args)

具体计算:

余弦部分

cos_part = [ cos ⁡ ( 0.00 ) cos ⁡ ( 0.00 ) cos ⁡ ( 0.00 ) cos ⁡ ( 0.000 ) cos ⁡ ( 1.00 ) cos ⁡ ( 0.10 ) cos ⁡ ( 0.01 ) cos ⁡ ( 0.001 ) cos ⁡ ( 2.00 ) cos ⁡ ( 0.20 ) cos ⁡ ( 0.02 ) cos ⁡ ( 0.002 ) cos ⁡ ( 3.00 ) cos ⁡ ( 0.30 ) cos ⁡ ( 0.03 ) cos ⁡ ( 0.003 ) ] = [ 1.0000 1.0000 1.0000 1.0000 0.5403 0.9950 0.9999 1.0000 − 0.4161 0.9801 0.9998 1.0000 − 0.9899 0.9553 0.9996 1.0000 ] \text{cos\_part} = \begin{bmatrix} \cos(0.00) & \cos(0.00) & \cos(0.00) & \cos(0.000) \\ \cos(1.00) & \cos(0.10) & \cos(0.01) & \cos(0.001) \\ \cos(2.00) & \cos(0.20) & \cos(0.02) & \cos(0.002) \\ \cos(3.00) & \cos(0.30) & \cos(0.03) & \cos(0.003) \end{bmatrix}= \begin{bmatrix} 1.0000 & 1.0000 & 1.0000 & 1.0000 \\ 0.5403 & 0.9950 & 0.9999 & 1.0000 \\ -0.4161 & 0.9801 & 0.9998 & 1.0000 \\ -0.9899 & 0.9553 & 0.9996 & 1.0000 \end{bmatrix} cos_part= cos(0.00)cos(1.00)cos(2.00)cos(3.00)cos(0.00)cos(0.10)cos(0.20)cos(0.30)cos(0.00)cos(0.01)cos(0.02)cos(0.03)cos(0.000)cos(0.001)cos(0.002)cos(0.003) = 1.00000.54030.41610.98991.00000.99500.98010.95531.00000.99990.99980.99961.00001.00001.00001.0000

正弦部分

sin_part = [ sin ⁡ ( 0.00 ) sin ⁡ ( 0.00 ) sin ⁡ ( 0.00 ) sin ⁡ ( 0.000 ) sin ⁡ ( 1.00 ) sin ⁡ ( 0.10 ) sin ⁡ ( 0.01 ) sin ⁡ ( 0.001 ) sin ⁡ ( 2.00 ) sin ⁡ ( 0.20 ) sin ⁡ ( 0.02 ) sin ⁡ ( 0.002 ) sin ⁡ ( 3.00 ) sin ⁡ ( 0.30 ) sin ⁡ ( 0.03 ) sin ⁡ ( 0.003 ) ] = [ 0.0000 0.0000 0.0000 0.0000 0.8415 0.0998 0.0100 0.0010 0.9093 0.1987 0.0200 0.0020 0.1411 0.2955 0.0300 0.0030 ] \text{sin\_part} = \begin{bmatrix} \sin(0.00) & \sin(0.00) & \sin(0.00) & \sin(0.000) \\ \sin(1.00) & \sin(0.10) & \sin(0.01) & \sin(0.001) \\ \sin(2.00) & \sin(0.20) & \sin(0.02) & \sin(0.002) \\ \sin(3.00) & \sin(0.30) & \sin(0.03) & \sin(0.003) \end{bmatrix}= \begin{bmatrix} 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.8415 & 0.0998 & 0.0100 & 0.0010 \\ 0.9093 & 0.1987 & 0.0200 & 0.0020 \\ 0.1411 & 0.2955 & 0.0300 & 0.0030 \end{bmatrix} sin_part= sin(0.00)sin(1.00)sin(2.00)sin(3.00)sin(0.00)sin(0.10)sin(0.20)sin(0.30)sin(0.00)sin(0.01)sin(0.02)sin(0.03)sin(0.000)sin(0.001)sin(0.002)sin(0.003) = 0.00000.84150.90930.14110.00000.09980.19870.29550.00000.01000.02000.03000.00000.00100.00200.0030


5. 拼接结果

cos_partsin_part 沿最后一个维度拼接:

embedding = [ cos_part , sin_part ] \text{embedding} = [\text{cos\_part}, \text{sin\_part}] embedding=[cos_part,sin_part]

6. 输出

embedding = [ 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.5403 0.9950 0.9999 1.0000 0.8415 0.0998 0.0100 0.0010 − 0.4161 0.9801 0.9998 1.0000 0.9093 0.1987 0.0200 0.0020 − 0.9899 0.9553 0.9996 1.0000 0.1411 0.2955 0.0300 0.0030 ] \text{embedding} = \begin{bmatrix} 1.0000 & 1.0000 & 1.0000 & 1.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ 0.5403 & 0.9950 & 0.9999 & 1.0000 & 0.8415 & 0.0998 & 0.0100 & 0.0010 \\ -0.4161 & 0.9801 & 0.9998 & 1.0000 & 0.9093 & 0.1987 & 0.0200 & 0.0020 \\ -0.9899 & 0.9553 & 0.9996 & 1.0000 & 0.1411 & 0.2955 & 0.0300 & 0.0030 \end{bmatrix} embedding= 1.00000.54030.41610.98991.00000.99500.98010.95531.00000.99990.99980.99961.00001.00001.00001.00000.00000.84150.90930.14110.00000.09980.19870.29550.00000.01000.02000.03000.00000.00100.00200.0030

总结layer 的time_embedding操作:

通过三角函数,输出t时刻下,不同尺度的两类三角函数值: α cos ⁡ ( t ) + α sin ⁡ ( t ) \alpha\cos(t)+\alpha\sin(t) αcos(t)+αsin(t), α ∈ [ 1 , 0 , 10 , 001 , . . . ] \alpha \in [1, 0,1 0,001,...] α[1,0,10,001,...] (尺度数量由layer channels决定)

4.3.3 时序嵌入

将时序t的嵌入特征t_embedding广播到x的特征值,

由于middle层的特征通道数是原通道的2倍,因此这里用一个fc (noise_embedding) 将t_embedding特征映射为2倍,具体代码如下:

noise_embedding = nn.Linear(model_channels, model_channels*2) # noise block
middle = middle_block(x) # Middle block
noise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))
middle = middle + noise_t[:, :, None, None]

4.4. 下采样块及总结

下采样块和上采样块类似,整型的U-Net结构总结如下:

  • 实现U-Net核心部分,即下采样块链接下采样块的等尺寸特征
  • 仅1个中间特征块 (middle block),且仅在该block处使用time_embedding 和 resnet
  • 未使用attention结构

完整代码如下:

class Upsample(nn.Module):
    def __init__(self, channels, num_groups=32):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.num_groups = num_groups

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest") #        # 上采样
        x = self.conv(x) #        # 卷积 + GroupNorm
        return x  # 激活函数

class Downsample(nn.Module):
    def __init__(self, channels, num_groups=32):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
        self.num_groups = num_groups

    def forward(self, x):
        x = self.conv(x) #卷积 + GroupNorm
        return x # 激活函数

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.shortcut = (nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity())

    def forward(self, x):
        h = F.relu(F.group_norm(self.conv1(x), num_groups=32)) # 第一层卷积 + GroupNorm + 激活
        h = F.relu(F.group_norm(self.conv2(h), num_groups=32))
        return h + self.shortcut(x)  # 残差连接

def timestep_embedding(t, dim, max_period=10000):
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim // 2, dtype=torch.float32) / (dim // 2)).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    return embedding

class UNetModel(nn.Module):
    def __init__(self, io_channels=3, model_channels=128):
        super().__init__()
        self.model_channels = model_channels

        self.down_block1 = nn.Conv2d(io_channels, model_channels, kernel_size=3, padding=1) # down blocks
        self.down_block2  =  Downsample(model_channels)
        self.down_block3 = nn.Conv2d(model_channels, model_channels*2, kernel_size=3, padding=1)
        self.down_block4  =  Downsample(model_channels*2)

        self.middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle block
        self.noise_embedding = nn.Linear(model_channels, model_channels*2) # noise block

        self.up_block1 = Upsample(model_channels*2) #  # up blocks
        self.up_block2 =  nn.Conv2d(model_channels*2, model_channels, kernel_size=3, padding=1)
        self.up_block3 = Upsample(model_channels)
        self.up_block4 =  nn.Conv2d(model_channels, io_channels, kernel_size=3, padding=1)
        
    def forward(self, x, t):
        x1 = F.relu(F.group_norm(self.down_block1(x), num_groups=32)) #Encode-Downsampling 
        x2 = F.relu(F.group_norm(self.down_block2(x1), num_groups=32))
        x3 = F.relu(F.group_norm(self.down_block3(x2), num_groups=32))
        x4 = F.relu(F.group_norm(self.down_block4(x3), num_groups=32))

        middle = self.middle_block(x4) # Middle block
        noise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))
        middle = middle + noise_t[:, :, None, None]

        x5 = F.relu(F.group_norm(self.up_block1(middle + x4), num_groups=32)) # Decode-Upsampling
        x6 = F.relu(F.group_norm(self.up_block2(x5 + x3 ), num_groups=32)) # 
        x7 = F.relu(F.group_norm(self.up_block3(x6 + x2), num_groups=32)) #  
        out = self.up_block4(x7 + x1) #  

        return out

5. 条件生成 (Classifier-Free)

条件生成在生成对抗网络GAN中就有多种实现方式, 包括CGAN,ACGAN/PCGAN,InfoGAN等,通过语义标签生成对应内容,训练都需要将标签 label 作为输入。

大体可以分为三种方法:

  • 强监督:模型输出标签的预判,通过添加多标签分类损失训练
  • 弱监督:仅添加标签作为特征输入模型,嵌入图像特征(相加),但不改变训练过程(不额外输出、不添加损失函数)
  • 无监督:通过输入假标签(pseudo label), 这种标签是按某一特征规律自动生成的,且通常需要对应的损失函数(分类、聚类、对比学习等)

Classifier-Free是用的弱监督方法,即仅将标签作为特征嵌入进图像 x x x即可,这里用了最简单的方式,类似t的嵌入, 直接将label_embedding注入到U-Net的middle层,
再广播到每一个元素即可, 即:

x = x + t_embedding + label_embedding

由于label是离散序列,这里用embedding layer 而非 fc, 具体代码如下:

class UNetModel(nn.Module):
        ...
        self.down_block4  =  Downsample(model_channels*2)
        
        self.middle_block = ResidualBlock(model_channels*2, model_channels*2) # middle block
        self.noise_embedding = nn.Linear(model_channels, model_channels*2) # noise block
        self.class_emb = nn.Embedding(class_num, model_channels*2)

        self.up_block1 = Upsample(model_channels*2) #  # up blocks
        ...
        
    def forward(self, x, t, label=None):
        ...
        x4 = F.relu(F.group_norm(self.down_block4(x3), num_groups=32))

        middle = self.middle_block(x4) # Middle block
        noise_t = F.relu(self.noise_embedding(timestep_embedding(t,self.model_channels)))
        c_emb = F.relu(self.class_emb(label))
        middle = middle + noise_t[:, :, None, None] + c_emb[:, :, None, None]

        x5 = F.relu(F.group_norm(self.up_block1(middle + x4), num_groups=32)) # Decode-Upsampling
        ...
        out = self.up_block4(x7 + x1) #  

        return out

代码中的c_emb就是图像对应的label嵌入变量

6. 实验结果

本文测试了 MNIST, Fashion-MNIST, Cifar-10三个数据集

6.1 参数设置

  • MNIST, Fashion-MNIST

    • epoch = 200
    • timesteps =300
  • Cifar-10

    • epoch = 500
    • timesteps = 1000
    • 训练Cifar-10增大了模型层数(下采样、中间块、上采样各加2 cnn layers)

6.1 DDPM

  • MNIST
Image 1 Image 2
  • Fashion-MNIST
Image 1 Image 2
  • Cifar-10

由于Cifar-10数据集本身类别的特征差异较大(有飞机、青蛙、骑车。。),且图像质量不清晰(32x32),因此合成Cifar-10在GAN中一直是挑战,这也是扩散模型的亮点。

由于本文的模型较小,尽管这里训练Cifar-10增大了模型层数,但效果还是不佳,450epoch的结果如图所示:

Image 1 Image 2

6.2 DDIM

这里DDIM仅设置为50步, 实验发现最少10步可以就有效果,但效果不如DDPM。

Fashion-MNIST结果如下:

Image 1 Image 2

6.3 Classifier-Free

以下是增加Classifier-Free的改动代码后,输入label的条件生成结果(DDIM-10步去噪):

Image 2 Image 1

7.参考文献

7.1 本文代码

  • DDPM(不到300行):

https://github.com/disanda/GM/blob/main/DDPM-DDIM-ClassifierFree/ddpm.py

  • 预训练模型(300 timesteps的 MNIST 以及 FashionMNIST):

https://github.com/disanda/GM/tree/main/DDPM-DDIM-ClassifierFree/pre-trained-models

  • 条件生成以及Cifar网络加层

https://github.com/disanda/GM/tree/main/DDPM-DDIM-ClassifierFree

7.2 参考代码

  • https://github.com/ermongroup/ddim/blob/main/functions/denoising.py
  • https://github.com/LinXueyuanStdio/PyTorch-DDPM
  • https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm
  • https://github.com/BastianChen/ddpm-demo-pytorch
  • https://github.com/tatakai1/classifier_free_ddim/blob/main/Classifier_Free_DDIM_Mnist.ipynb

7.3 知乎讲解

  • https://zhuanlan.zhihu.com/p/666552214
  • https://zhuanlan.zhihu.com/p/656757576

7.4 原论文

  • https://arxiv.org/abs/2006.11239, Denoising Diffusion Probabilistic Models, 2020
  • https://arxiv.org/abs/2102.09672, Improved Denoising Diffusion Probabilistic Models, 2021
  • https://arxiv.org/abs/2010.02502, Denoising Diffusion Implicit Models, 2022
  • https://arxiv.org/abs/2207.12598, Classifier-Free Diffusion Guidance

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2286341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【hot100】刷题记录(7)-除自身数组以外的乘积

题目描述: 给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#x…

鸢尾花书01---基本介绍和Jupyterlab的上手

文章目录 1.致谢和推荐2.py和.ipynb区别3.Jupyterlab的上手3.1入口3.2页面展示3.3相关键介绍3.4代码的运行3.5重命名3.6latex和markdown说明 1.致谢和推荐 这个系列是关于一套书籍,结合了python和数学,机器学习等等相关的理论,总结的7本书籍…

可扩展架构:如何打造一个善变的柔性系统?

系统的构成:模块 + 关系 我们天天和系统打交道,但你有没想过系统到底是什么?在我看来,系统内部是有明确结构 的,它可以简化表达为: 系统 = 模块 + 关系 在这里,模块是系统的基本组成部分,它泛指子系统、应用、服务或功能模块。关系指模块 之间的依赖关系,简单…

C++并发:C++内存模型和原子操作

C11引入了新的线程感知内存模型。内存模型精确定义了基础构建单元应当如何被运转。 1 内存模型基础 内存模型牵涉两个方面:基本结构和并发。 基本结构关系到整个程序在内存中的布局。 1.1 对象和内存区域 C的数据包括: 内建基本类型:int&…

宝塔mysql数据库容量限制_宝塔数据库mysql-bin.000001占用磁盘空间过大

磁盘空间占用过多,排查后发现网站/www/wwwroot只占用7G,/www/server占用却高达8G,再深入排查发现/www/server/data目录下的mysql-bin.000001和mysql-bin.000002两个日志文件占去了1.5G空间。 百度后学到以下知识,做个记录。 mysql…

2859.计算K置位下标对应元素的和

示例 1:输入:nums [5,10,1,5,2], k 1 输出:13 解释:下标的二进制表示是: 0 0002 1 0012 2 0102 3 0112 4 1002 下标 1、2 和 4 在其二进制表示中都存在 k 1 个置位。 因此,答案为 nums[1] nums[…

8. 网络编程

网络的基本概念 TCP/IP协议概述 OSI和TCP/IP模型 socket(套接字) 创建socket 字节序 字节序转换函数 通用地址结构 因特网地址结构 IPV4地址族和字符地址间的转换(点分十进制->网络字节序) 填写IPV4地址族结构案例 掌握TCP协议网络基础编程 相关函数 …

关于opencv环境搭建问题:由于找不到opencv_worldXXX.dll,无法执行代码,重新安装程序可能会解决此问题

方法一:利用复制黏贴方法 打开opencv文件夹目录找到\opencv\build\x64\vc15\bin 复制该目录下所有文件,找到C:\Windows\System32文件夹(注意一定是C盘)黏贴至该文件夹重新打开VS。 方法二:直接配置环境 打开opencv文…

Git Bash 配置 zsh

博客食用更佳 博客链接 安装 zsh 安装 Zsh 安装 Oh-my-zsh github仓库 sh -c "$(curl -fsSL https://install.ohmyz.sh/)"让 zsh 成为 git bash 默认终端 vi ~/.bashrc写入: if [ -t 1 ]; thenexec zsh fisource ~/.bashrc再重启即可。 更换主题 …

DeepSeek-R1 本地部署模型流程

DeepSeek-R1 本地部署模型流程 ***************************************************** 环境准备 操作系统:Windows11 内存:32GB RAM 存储:预留 300GB 可用空间 显存: 16G 网络: 100M带宽 ********************************************…

C++ unordered_map和unordered_set的使用,哈希表的实现

文章目录 unordered_map,unorder_set和map ,set的差异哈希表的实现概念直接定址法哈希冲突哈希冲突举个例子 负载因子将关键字转为整数哈希函数除法散列法/除留余数法 哈希冲突的解决方法开放定址法线性探测二次探测 开放定址法代码实现 哈希表的代码 un…

C#通过3E帧SLMP/MC协议读写三菱FX5U/Q系列PLC数据案例

C#通过3E帧SLMP/MC协议读写三菱FX5U/Q系列PLC数据案例,仅做数据读写报文测试。附带自己整理的SLMP/MC通讯协议表。 SLMP以太网读写PLC数据20191206/.vs/WindowsFormsApp7/v15/.suo , 73216 SLMP以太网读写PLC数据20191206/SLMP与MC协议3E帧通讯协议表.xlsx , 10382…

Unity|小游戏复刻|见缝插针1(C#)

准备 创建Scenes场景,Scripts脚本,Prefabs预制体文件夹 修改背景颜色 选中Main Camera 找到背景 选择颜色,一种白中透黄的颜色 创建小球 将文件夹里的Circle拖入层级里 选中Circle,位置为左右居中,偏上&…

数据结构的队列

一.队列 1.队列(Queue)的概念就是先进先出。 2.队列的用法,红色框和绿色框为两组,offer为插入元素,poll为删除元素,peek为查看元素红色的也是一样的。 3.LinkedList实现了Deque的接口,Deque又…

HTML-新浪新闻-实现标题-排版

标题排版 图片标签&#xff1a;<img> src&#xff1a;指定图片的url&#xff08;绝对路径/相对路径&#xff09; width&#xff1a;图片的宽度&#xff08;像素/相对于父元素的百分比&#xff09; heigth&#xff1a;图片的高度&#xff08;像素/相对于父元素的百分比&a…

C语言二级题解:查找字母以及其他字符个数、数字字符串转双精度值、二维数组上下三角区域数据对调

目录 一、程序填空题 --- 查找字母以及其他字符个数 题目 分析 二、程序修改 --- 数字字符串转双精度值 题目 分析 小数位字符串转数字 三、程序设计 --- 二维数组上下三角区域数据对调 题目 分析 前言 本文来讲解&#xff1a; 查找字母以及其他字符个数、数字字符串…

VPR概述、资源

SOTA网站&#xff1a; Visual Place Recognition | Papers With Code VPR&#xff08;Visual Place Recognition&#xff09; 是计算机视觉领域的一项关键任务&#xff0c;旨在通过图像匹配和分析来识别场景或位置。它的目标是根据视觉信息判断某个场景是否与数据库中的场景匹…

Electron学习笔记,安装环境(1)

1、支持win7的Electron 的版本是18&#xff0c;这里node.js用的是14版本&#xff08;node-v14.21.3-x86.msi&#xff09;云盘有安装包 Electron 18.x (截至2023年仍在维护中): Chromium: 96 Node.js: 14.17.0 2、安装node环境&#xff0c;node-v14.21.3-x86.msi双击运行选择安…

58.界面参数传递给Command C#例子 WPF例子

界面参数的传递&#xff0c;界面参数是如何从前台传送到后台的。 param 参数是从界面传递到命令的。这个过程通常涉及以下几个步骤&#xff1a; 数据绑定&#xff1a;界面元素&#xff08;如按钮&#xff09;的 Command 属性绑定到视图模型中的 RelayCommand 实例。同时&#x…

Git图形化工具【lazygit】

简要介绍一下偶然发现的Git图形化工具——「lazygit」 概述 Lazygit 是一个用 Go 语言编写的 Git 命令行界面&#xff08;TUI&#xff09;工具&#xff0c;它让 Git 操作变得更加直观和高效。 Github地址&#xff1a;https://github.com/jesseduffield/lazygit 主要特点 主要…