AIGC专栏1——Pytorch搭建DDPM实现图片生成

news2025/2/8 4:48:29

AIGC专栏1——Pytorch搭建DDPM实现图片生成

  • 学习前言
  • 源码下载地址
  • 网络构建
    • 一、什么是Diffusion
      • 1、加噪过程
      • 2、去噪过程
    • 二、DDPM网络的构建(Unet网络的构建)
    • 三、Diffusion的训练思路
  • 利用DDPM生成图片
    • 一、数据集的准备
    • 二、数据集的处理
    • 三、模型训练

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/ddpm-pytorch

喜欢的可以点个star噢。

网络构建

一、什么是Diffusion

在这里插入图片描述
如上图所示。DDPM模型主要分为两个过程:
1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。
2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据 X t X_t Xt的图片生成 X t − 1 X_{t-1} Xt1的噪声,从而实现图片的还原。

1、加噪过程

在这里插入图片描述
Forward加噪过程主要符合如下的公式:
x t = α t x t − 1 + 1 − α t z 1 x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_{1} xt=αt xt1+1αt z1
其中 α t \sqrt{\alpha_t} αt 是预先设定好的超参数,被称为Noise schedule,通常是小于1的值,在论文中 α t \alpha_t αt的值从0.9999到0.998。 ϵ t − 1 ∼ N ( 0 , 1 ) \epsilon_{t-1} \sim N(0, 1) ϵt1N(0,1)是高斯噪声。由公式(1)迭代推导。

x t = a t ( a t − 1 x t − 2 + 1 − α t − 1 z 2 ) + 1 − α t z 1 = a t a t − 1 x t − 2 + ( a t ( 1 − α t − 1 ) z 2 + 1 − α t z 1 ) x_t=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2\right)+\sqrt{1-\alpha_t} z_1=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} z_2+\sqrt{1-\alpha_t} z_1\right) xt=at (at1 xt2+1αt1 z2)+1αt z1=atat1 xt2+(at(1αt1) z2+1αt z1)

其中每次加入的噪声都服从高斯分布 z 1 , z 2 , … ∼ N ( 0 , 1 ) z_1, z_2, \ldots \sim \mathcal{N}(0, 1) z1,z2,N(0,1),两个高斯分布的相加高斯分布满足公式: N ( 0 , σ 1 2 ) + N ( 0 , σ 2 2 ) ∼ N ( 0 , ( σ 1 2 + σ 2 2 ) ) \mathcal{N}\left(0, \sigma_1^2 \right)+\mathcal{N}\left(0, \sigma_2^2 \right) \sim \mathcal{N}\left(0,\left(\sigma_1^2+\sigma_2^2\right) \right) N(0,σ12)+N(0,σ22)N(0,(σ12+σ22)),因此,得到 x t x_t xt的公式为:
x t = a t a t − 1 x t − 2 + 1 − α t α t − 1 z 2 x_t = \sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} z_2 xt=atat1 xt2+1αtαt1 z2
因此不断往里面套,就能发现规律了,其实就是累乘
可以直接得出 x 0 x_0 x0 x t x_t xt的公式:
x t = α t ‾ x 0 + 1 − α t ‾ z t x_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_t xt=αt x0+1αt zt

其中 α t ‾ = ∏ i t α i \overline{\alpha_t}=\prod_i^t \alpha_i αt=itαi,这是随Noise schedule设定好的超参数, z t − 1 ∼ N ( 0 , 1 ) z_{t-1} \sim N(0, 1) zt1N(0,1)也是一个高斯噪声。通过上述两个公式,我们可以不断的将图片进行破坏加噪。

2、去噪过程

在这里插入图片描述
反向过程就是通过估测噪声,多次迭代逐渐将被破坏的 x t x_t xt恢复成 x 0 x_0 x0,在恢复时刻,我们已经知道的是 x t x_t xt,这是图片在 t t t时刻的噪声图。一下子从 x t x_t xt恢复成 x 0 x_0 x0是不可能的,我们只能一步一步的往前推,首先从 x t x_t xt恢复成 x t − 1 x_{t-1} xt1。根据贝叶斯公式,已知 x t x_t xt反推 x t − 1 x_{t-1} xt1
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q\left(x_{t-1} \mid x_t, x_0\right)=q\left(x_t \mid x_{t-1}, x_0\right) \frac{q\left(x_{t-1} \mid x_0\right)}{q\left(x_t \mid x_0\right)} q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)
右边的三个东西都可以从x_0开始推得到:
q ( x t − 1 ∣ x 0 ) = a ˉ t − 1 x 0 + 1 − a ˉ t − 1 z ∼ N ( a ˉ t − 1 x 0 , 1 − a ˉ t − 1 ) q\left(x_{t-1} \mid x_0\right)=\sqrt{\bar{a}_{t-1}} x_0+\sqrt{1-\bar{a}_{t-1}} z \sim \mathcal{N}\left(\sqrt{\bar{a}_{t-1}} x_0, 1-\bar{a}_{t-1}\right) q(xt1x0)=aˉt1 x0+1aˉt1 zN(aˉt1 x0,1aˉt1)
q ( x t ∣ x 0 ) = a ˉ t x 0 + 1 − α ˉ t z ∼ N ( a ˉ t x 0 , 1 − α ˉ t ) q\left(x_t \mid x_0\right) = \sqrt{\bar{a}_t} x_0+\sqrt{1-\bar{\alpha}_t} z \sim \mathcal{N}\left(\sqrt{\bar{a}_t} x_0 , 1-\bar{\alpha}_t\right) q(xtx0)=aˉt x0+1αˉt zN(aˉt x0,1αˉt)
q ( x t ∣ x t − 1 , x 0 ) = a t x t − 1 + 1 − α t z ∼ N ( a t x t − 1 , 1 − α t ) q\left(x_t \mid x_{t-1}, x_0\right)=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} z \sim \mathcal{N}\left(\sqrt{a_t} x_{t-1}, 1-\alpha_t\right) \\ q(xtxt1,x0)=at xt1+1αt zN(at xt1,1αt)
因此,由于右边三个东西均满足正态分布, q ( x t − 1 ∣ x t , x 0 ) q\left(x_{t-1} \mid x_t, x_0\right) q(xt1xt,x0)满足分布如下:
∝ exp ⁡ ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) exp(21(βt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))
把标准正态分布展开后,乘法就相当于加,除法就相当于减,把他们汇总
接下来继续化简,咱们现在要求的是上一时刻的分布
∝ exp ⁡ ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( x t 2 − 2 α t x t x t − 1 + α t x t − 1 2 β t + x t − 1 2 − 2 α ˉ t − 1 x 0 x t − 1 + α ˉ t − 1 x 0 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) \begin{aligned} & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{x_t^2-2 \sqrt{\alpha_t} x_t x_{t-1}+\alpha_t x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} x_0 x_{t-1}+\bar{\alpha}_{t-1} x_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) x_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} x_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} x_0\right) x_{t-1}+C\left(x_t, x_0\right)\right)\right) \end{aligned} exp(21(βt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))=exp(21(βtxt22αt xtxt1+αtxt12+1αˉt1xt122αˉt1 x0xt1+αˉt1x021αˉt(xtαˉt x0)2))=exp(21((βtαt+1αˉt11)xt12(βt2αt xt+1αˉt12αˉt1 x0)xt1+C(xt,x0)))
正态分布满足公式, exp ⁡ ( − ( x − μ ) 2 2 σ 2 ) = exp ⁡ ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \exp \left(-\frac{(x-\mu)^2}{2 \sigma^2}\right)=\exp \left(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)\right) exp(2σ2(xμ)2)=exp(21(σ21x2σ22μx+σ2μ2)),其中 σ \sigma σ就是方差, μ \mu μ就是均值,配方后我们就可以获得均值和方差。

此时的均值为: μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\mu}_t\left(x_t, x_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} x_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} x_0 μ~t(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0。根据之前的公式, x t = α t ‾ x 0 + 1 − α t ‾ z t x_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_t xt=αt x0+1αt zt,我们可以使用 x t x_t xt反向估计 x 0 x_0 x0得到 x 0 x_0 x0满足分布 x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z t ) x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathrm{x}_t-\sqrt{1-\bar{\alpha}_t} z_t\right) x0=αˉt 1(xt1αˉt zt)。最终得到均值为 μ ~ t = 1 a t ( x t − β t 1 − a ˉ t z t ) \tilde{\mu}_t=\frac{1}{\sqrt{a_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{a}_t}} z_t\right) μ~t=at 1(xt1aˉt βtzt) z t z_t zt代表t时刻的噪音是什么。由 z t z_t zt无法直接获得,网络便通过当前时刻的 x t x_t xt经过神经网络计算 z t z_t zt ϵ θ ( x t , t ) \epsilon_\theta\left(x_t, t\right) ϵθ(xt,t)也就是上面提到的 z t z_t zt ϵ θ \epsilon_\theta ϵθ代表神经网络。
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta\left(x_t, t\right)\right)+\sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz
由于加噪过程中的真实噪声 ϵ \epsilon ϵ在复原过程中是无法获得的,因此DDPM的关键就是训练一个由 x t x_t xt t t t估测橾声的模型 ϵ θ ( x t , t ) \epsilon_\theta\left(x_t, t\right) ϵθ(xt,t),其中 θ \theta θ就是模型的训练参数, σ t \sigma_t σt 也是一个高斯噪声 σ t ∼ N ( 0 , 1 ) \sigma_t \sim N(0,1) σtN(0,1),用于表示估测与实际的差距。在DDPM中,使用U-Net作为估测噪声的模型。

本质上,我们就是训练这个Unet模型,该模型输入为 x t x_t xt t t t,输出为 x t x_t xt时刻的高斯噪声。即利用 x t x_t xt t t t预测这一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

二、DDPM网络的构建(Unet网络的构建)

在这里插入图片描述
上图是典型的Unet模型结构,仅仅作为示意图,里面具体的数字同学们无需在意,和本文的学习无关。在本文中,Unet的输入和输出shape相同,通道均为3(一般为RGB三通道),宽高相同。

本质上,DDPM最重要的工作就是训练Unet模型,该模型输入为 x t x_t xt t t t,输出为 x t − 1 x_{t-1} xt1时刻的高斯噪声。即利用 x t x_t xt t t t预测上一时刻的高斯噪声。这样就可以一步一步的再从噪声回到真实图像。

假设我们需要生成一个[64, 64, 3]的图像,在 t t t时刻,我们有一个 x t x_t xt噪声图,该噪声图的的shape也为[64, 64, 3],我们将它和 t t t一起输入到Unet中。Unet的输出为 x t − 1 x_{t-1} xt1时刻的[64, 64, 3]的噪声。

实现代码如下,代码中的特征提取模块为残差结构,方便优化:

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")
    
#------------------------------------------#
#   计算时间步长的位置嵌入。
#   一半为sin,一半为cos。
#------------------------------------------#
class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

#------------------------------------------#
#   下采样层,一个步长为2x2的卷积
#------------------------------------------#
class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
    
    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

#------------------------------------------#
#   上采样层,Upsample+卷积
#------------------------------------------#
class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )
        
    def forward(self, x, time_emb, y):
        return self.upsample(x)

#------------------------------------------#
#   使用Self-Attention注意力机制
#   做一个全局的Self-Attention
#------------------------------------------#
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w  = x.shape
        q, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention   = torch.softmax(dot_products, dim=-1)
        out         = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out         = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x
    
#------------------------------------------#
#   用于特征提取的残差结构
#------------------------------------------#
class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

#------------------------------------------#
#   Unet模型
#------------------------------------------#
class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
        num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SILU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes
        
        # 对时间轴输入的全连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None
    
        # 对输入图片的第一个卷积
        self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
        # 然后利用Downsample降低特征图的高宽
        self.downs      = nn.ModuleList()
        self.ups        = nn.ModuleList()
        
        # channels指的是每一个模块处理后的通道数
        # now_channels是一个中间变量,代表中间的通道数
        channels        = [base_channels]
        now_channels    = base_channels
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)
            
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 可以看作是特征整合,中间的一个特征提取模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout, 
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels
            
            if i != 0:
                self.ups.append(Upsample(now_channels))
        
        assert len(channels) == 0
        
        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
    
    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None
        
        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")
        
        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)
        
        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)
        
        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)
        
        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

三、Diffusion的训练思路

Diffusion的训练思路比较简单,首先随机给每个batch里每张图片都生成一个t,代表我选择这个batch里面第t个时刻的噪声进行拟合。代码如下:

t = torch.randint(0, self.num_timesteps, (b,), device=device)

生成batch_size个噪声,计算施加这个噪声后模型在t个时刻的噪声图片是怎么样的,如下所示:

def perturb_x(self, x, t, noise):
    return (
        extract(self.sqrt_alphas_cumprod, t,  x.shape) * x +
        extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
    )   

def get_losses(self, x, t, y):
    # x, noise [batch_size, 3, 64, 64]
    noise           = torch.randn_like(x)

    perturbed_x     = self.perturb_x(x, t, noise)

之后利用这个噪声图片、t和网络模型计算预测噪声,利用预测噪声和实际噪声进行拟合。

def get_losses(self, x, t, y):
    # x, noise [batch_size, 3, 64, 64]
    noise           = torch.randn_like(x)

    perturbed_x     = self.perturb_x(x, t, noise)
    estimated_noise = self.model(perturbed_x, t, y)

    if self.loss_type == "l1":
        loss = F.l1_loss(estimated_noise, noise)
    elif self.loss_type == "l2":
        loss = F.mse_loss(estimated_noise, noise)
    return loss

利用DDPM生成图片

DDPM的库整体结构如下:
在这里插入图片描述

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
在这里插入图片描述

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在这里插入图片描述

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。
在这里插入图片描述
训练过程中,可在results文件夹内查看训练效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1336213.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS设备信息详解

文章目录 ID 体系iOS设备信息详解IDFA介绍特点IDFA新政前世今生获取方式 IDFV介绍获取方式 UUID介绍特点获取方式 UDID介绍获取方式 OpenUDID介绍 Bundle ID介绍分类其他 IP地址介绍获取方式 MAC地址介绍获取方式正常获取MAC地址获取对应Wi-Fi的MAC地址 系统版本获取方式 设备型…

Java基于TCP网络编程的群聊功能

服务端 import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; import java.util.List;public class Server2 {public static List<Socket> onlineList new ArrayList<>();public static void main(String[] args) throws Except…

在做题中学习:三数之和

15. 三数之和 - 力扣&#xff08;LeetCode&#xff09;15. 三数之和 - 力扣&#xff08;LeetCode&#xff09; 解释&#xff1a;不能重复也就是说不能和前一个三元组的元素完全相同 思路&#xff1a;通过做 两数之和那道题 可以想到&#xff1a; 1.先排序 2.双指针法 3.固定…

分布式核心技术之分布式锁

文章目录 为什么要使用分布锁&#xff1f;分布式锁的三种实现方法基于数据库实现分布式锁基于缓存实现分布式锁基于 ZooKeeper 实现分布式锁知识扩展&#xff1a;如何解决分布式锁的羊群效应问题&#xff1f; 三种实现方式对比 分布式互斥&#xff0c;领悟了其“有你没我&#…

解决 Solidworks2021 报错(-15,10032,0)错误记录

Solidworks2021 报错"-15,10032,0"错误记录 如图所示解决方案步骤1步骤2 个人问题我的没法添加白名单&#xff0c;要是有能解决的大神给个解决方式感激不尽&#xff01;&#xff01; 如图所示 解决方案 步骤1 该问题的解决方式仅对个人有效&#xff0c;不一定通用&…

非对称加密与对称加密的区别是什么?

在数据通信中&#xff0c;加密技术是防止数据被未授权的人访问的关键措施之一。而对称加密和非对称加密是两种最常见的加密技术&#xff0c;它们被广泛应用于数据安全领域&#xff0c;并且可以组合起来以达到更好的加密效果。本文将探讨这两种技术的区别&#xff0c;以及它们在…

C#示例(一):飞行棋游戏

1、先看一下实现效果 输入连个玩家的姓名 两个玩家分别用字母A和字母B表示 按下任意键开始掷骰子、根据骰子走对应的步数… 2、绘制游戏头 /// <summary>/// 画游戏头/// </summary>public static void GameShow(){Console.ForegroundColor ConsoleColor.Blu…

CEEMDAN +组合预测模型(BiLSTM-Attention + ARIMA)

目录 往期精彩内容&#xff1a; 前言 1 风速数据CEEMDAN分解与可视化 1.1 导入数据 1.2 CEEMDAN分解 2 数据集制作与预处理 2.1 划分数据集&#xff0c;按照8&#xff1a;2划分训练集和测试集&#xff0c; 然后再按照前7后4划分分量数据 2.2 设置滑动窗口大小为7&#…

SuperMap iClient3D for WebGL时序影像

文章目录 前言一、加载影像数据二、创建时间条1.这里使用Echarts来创建TimeLine&#xff0c;首先需要引入相关依赖2.初始化Echarts实例 三、设置不同年份影像交替显示四、效果 前言 时序影像可以用于对地球表面的变化进行定量分析和监测。 通过对多时相遥感影像的比较和分析&a…

【开源】基于Vue+SpringBoot的新能源电池回收系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户档案模块2.2 电池品类模块2.3 回收机构模块2.4 电池订单模块2.5 客服咨询模块 三、系统设计3.1 用例设计3.2 业务流程设计3.3 E-R 图设计 四、系统展示五、核心代码5.1 增改电池类型5.2 查询电池品类5.3 查询电池回…

铁山靠之——HarmonyOS组件 - 2.0

HarmonyOS学习第二章 一、HarmonyOS基础组件的使用1.1 组件介绍1.2 Text1.2.1 文本样式1.2.2 设置文本对齐方式1.2.3 设置文本超长显示1.2.4 设置文本装饰线 1.3 Image1.3.1 设置缩放类型1.3.2 加载网络图片 1.4 TextInput1.4.1 设置输入提示文本1.4.2 设置输入类型1.4.3 设置光…

了解基础魔法函数学会封装和继承新建模块和函数使用异常

一、魔法函数 1.1、概念&#xff1a; 魔法函数&#xff08;magic methods&#xff09;是指以双下划线开头和结尾的特殊方法&#xff0c;用于实现对象的特定行为和操作。这些魔法函数可以让我们自定义对象的行为&#xff0c;例如实现对象的比较、算术运算、属性访问等。常见的…

WPS复选框里打对号,显示小太阳或粗黑圆圈的问题解决方法

问题描述 WPS是时下最流行的字处理软件之一&#xff0c;是目前唯一可以和微软office办公套件相抗衡的国产软件。然而&#xff0c;在使用WPS的过程中也会出现一些莫名其妙的错误&#xff0c;如利用WPS打开docx文件时&#xff0c;如果文件包含复选框&#xff0c;经常会出…

博客摘录「 Apollo安装和基本使用」2023年11月27日

一、常见配置中心对比 Spring Cloud Config: https://github.com/spring-cloud/spring-cloud-configApollo: https://github.com/ctripcorp/apolloNacos: https://github.com/alibaba/nacos 对比项目/配置中心 spring cloud config apollo nacos(重点) 开源时间 2014.9 …

blender scripting 编写

blender scripting 编写 一、查看ui按钮对应的代码二、查看或修改对象名称三、案例&#xff1a;渲染多张图片并导出对应的相机参数 一、查看ui按钮对应的代码 二、查看或修改对象名称 三、案例&#xff1a;渲染多张图片并导出对应的相机参数 注&#xff1a;通过ui交互都设置好…

如何在Window系统下搭建Nginx服务器环境并部署前端项目

1.下载并安装Nginx 在nginx官网nginx: download 下载稳定版本至自己想要的目录。 解压后进入目录 2.启动Nginx服务器 启动方式有两种&#xff1a; &#xff08;1&#xff09;直接进入nginx安装目录下&#xff0c;双击nginx.exe运行&#xff0c;此时命令行窗口一闪而过&…

20231222给NanoPC-T4(RK3399)开发板的适配Android11的挖掘机方案并跑通AP6398SV

20231222给NanoPC-T4(RK3399)开发板的适配Android11的挖掘机方案并跑通AP6398SV 2023/12/22 7:54 简略步骤&#xff1a;rootrootrootroot-X99-Turbo:~/3TB$ cat Android11.0.tar.bz2.a* > Android11.0.tar.bz2 rootrootrootroot-X99-Turbo:~/3TB$ tar jxvf Android11.0.tar.…

【MySQL】数据库中为什么使用B+树不用B树

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; 数 据 库 ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 B树的特点和应用场景&#xff1a; B树相对于B树的优势&#xff1a; 结论&#xff1a; 结语 我的其他博客 前言 在数据…

超维空间S2无人机使用说明书——31、使用yolov8进行目标识别

引言&#xff1a;为了提高yolo识别的质量&#xff0c;提高了yolo的版本&#xff0c;改用yolov8进行物体识别&#xff0c;同时系统兼容了低版本的yolo&#xff0c;包括基于C的yolov3和yolov4&#xff0c;以及yolov7。 简介&#xff0c;为了提高识别速度&#xff0c;系统采用了G…

C# WPF上位机开发(子窗口通知父窗口更新进度)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 这两天在编写代码的时候&#xff0c;正好遇到一个棘手的问题&#xff0c;解决之后感觉挺有意义的&#xff0c;所以先用blog记录一下&#xff0c;后…