从开始实现扩散概率模型 PyTorch 实现

news2024/12/15 9:26:56

目录

一、说明

二、从头开始实施

三、线性噪声调度器

四、时间嵌入

五、下层DownBlock类块

六、中间midBlock类块

七、UpBlock上层类块

八、UNet 架构

九、训练

十、采样

十一、配置(Default.yaml)

十二、数据集 (MNIST)


keyword:  Diffusion Probabilistic Models 

一、说明

        扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。

        在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。

        通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。

训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。

二、从头开始实施

        对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。

        为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。

        作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。

        为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。

三、线性噪声调度器

import torch


class LinearNoiseScheduler:

    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。

  def add_noise(self, original, noise, t):

        original_shape = original.shape
        batch_size = original_shape[0]
        
        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        
        # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
        for _ in range(len(original_shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
        for _ in range(len(original_shape) - 1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
        
        # Apply and Return Forward process equation
        return (sqrt_alpha_cum_prod.to(original.device) * original
                + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

add_noise()函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。

    def sample_prev_timestep(self, xt, noise_pred, t):

        x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
              torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
        x0 = torch.clamp(x0, -1., 1.)
        
        mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
        mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
        
        if t == 0:
            return mean, x0
        else:
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            
            return mean + sigma * z, x0

        调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀​ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。

        对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀​ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。

        这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。

作者图片

        无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。

        对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。

作者图片

        时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。

作者图片

        UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。

作者图片

        大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。

        在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。

        为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。

作者图片

        另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。

四、时间嵌入

import torch
import torch.nn as nn


def get_time_embedding(time_steps, temb_dim):

    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

第一个函数为给定的时间步长get_time_embedding生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。
time_steps:时间步长值的张量(形状:[B]其中B是批次大小)。每个值代表批次元素的一个离散时间步长。
temb_dim:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。

确保这temb_dim是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。

五、下层DownBlock类块

class DownBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
             for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()


    def forward(self, x, t_emb):
        out = x
        for i in range(self.num_layers):
            
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            # Attention block of Unet
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
        out = self.down_sample_conv(out)
        return out

DownBlock 类结合了ResNet 块自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。

参数

  • in_channels:输入通道数。
  • out_channels:输出通道数。
  • t_emb_dim:时间嵌入的维度。
  • down_sample:布尔值,确定是否在块末尾应用下采样。
  • num_heads:多头注意力层中的注意力头的数量。
  • num_layers:此块中的 ResNet + 注意力层的数量。

ResNet块

  • resnet_conv_first:ResNet 块的第一个卷积层。
  • t_emb_layers:时间嵌入投影层。
  • resnet_conv_second:ResNet 块的第二个卷积层。
  • residual_input_conv:用于残差连接的 1x1 卷积。

自注意力模块

  • attention_norms:在注意力机制之前对规范化层进行分组。
  • attentions:多头注意力层。

下采样

  • down_sample_conv:应用卷积来减少空间维度(如果down_sample=True)。

Forward Pass 方法定义了如何x通过块处理输入张量:out初始化为输入x。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。

在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb到线性层(投影到out_channels),并将此投影时间嵌入添加到out(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input)添加到第二个卷积的输出。

在自注意力模块中,我们将空间维度扁平化为一个维度(h * w)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接下采样。

六、中间midBlock类块

class MidBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )
    
    def forward(self, x, t_emb):
        out = x
        
        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i+1](out)
            out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i+1](out)
            out = out + self.residual_input_conv[i+1](resnet_input)
        
        return out

该类MidBlock是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:

  • 时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
  • 层迭代:在注意力ResNet 块之间交替,按num_layers这些组合的顺序处理输入。

七、UpBlock上层类块

class UpBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [
                nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)
            ]
        )
        
        self.attentions = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()
    
    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)
        
        out = x
        for i in range(self.num_layers):
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

        return out

该类UpBlock是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样跳过连接ResNet 块自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。

  • 上采样:通过转置卷积(ConvTranspose2d)实现,以增加特征图的空间分辨率。
  • 跳过连接:允许解码器重用编码器的详细特征,帮助重建。
  • ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
  • 自我注意力:捕获远程空间依赖关系以保留全局上下文。
  • 时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。

八、UNet 架构

class Unet(nn.Module):

    def __init__(self, model_config):
        super().__init__()
        im_channels = model_config['im_channels']
        self.down_channels = model_config['down_channels']
        self.mid_channels = model_config['mid_channels']
        self.t_emb_dim = model_config['time_emb_dim']
        self.down_sample = model_config['down_sample']
        self.num_down_layers = model_config['num_down_layers']
        self.num_mid_layers = model_config['num_mid_layers']
        self.num_up_layers = model_config['num_up_layers']
        
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-2]
        assert len(self.down_sample) == len(self.down_channels) - 1
        
        # Initial projection from sinusoidal time embedding
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
        
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
                                        down_sample=self.down_sample[i], num_layers=self.num_down_layers))
        
        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
                                      num_layers=self.num_mid_layers))
        
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
                                    self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
        
        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        # Shapes assuming downblocks are [C1, C2, C3, C4]
        # Shapes assuming midblocks are [C4, C4, C3]
        # Shapes assuming downsamples are [True, True, False]
        # B x C x H x W
        out = self.conv_in(x)
        # B x C1 x H x W
        
        # t_emb -> B x t_emb_dim
        t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        
        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out, t_emb)
        # down_outs  [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
        # out B x C4 x H/4 x W/4
            
        for mid in self.mids:
            out = mid(out, t_emb)
        # out B x C3 x H/4 x W/4
        
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
            # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        # out B x C x H x W
        return out

该类是U-Net 架构Unet的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样中级处理上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。

作者图片

  • 时间嵌入:实现时间动态。
  • 跳过连接:通过连接将细粒度的空间细节集成到解码器中。
  • 灵活的架构:允许通过model_config不同的深度、分辨率和功能丰富度进行定制。
  • 规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
  • 输出一致性:确保输出图像保留原始的空间尺寸和通道数。

九、训练

import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def train(args):
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    
    diffusion_config = config['diffusion_params']
    dataset_config = config['dataset_params']
    model_config = config['model_params']
    train_config = config['train_params']
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])
    
    # Create the dataset
    mnist = MnistDataset('train', im_path=dataset_config['im_path'])
    mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)
    
    # Instantiate the model
    model = Unet(model_config).to(device)
    model.train()
    
    # Create output directories
    if not os.path.exists(train_config['task_name']):
        os.mkdir(train_config['task_name'])
    
    # Load checkpoint if found
    if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):
        print('Loading checkpoint as found one')
        model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                      train_config['ckpt_name']), map_location=device))
    # Specify training parameters
    num_epochs = train_config['num_epochs']
    optimizer = Adam(model.parameters(), lr=train_config['lr'])
    criterion = torch.nn.MSELoss()
    
    # Run training
    for epoch_idx in range(num_epochs):
        losses = []
        for im in tqdm(mnist_loader):
            optimizer.zero_grad()
            im = im.float().to(device)
            
            # Sample random noise
            noise = torch.randn_like(im).to(device)
            
            # Sample timestep
            t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
            
            # Add noise to images according to timestep
            noisy_im = scheduler.add_noise(im, noise, t)
            noise_pred = model(noisy_im, t)

            loss = criterion(noise_pred, noise)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        print('Finished epoch:{} | Loss : {:.4f}'.format(
            epoch_idx + 1,
            np.mean(losses),
        ))
        torch.save(model.state_dict(), os.path.join(train_config['task_name'],
                                                    train_config['ckpt_name']))
    
    print('Done Training ...')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for ddpm training')
    parser.add_argument('--config', dest='config_path',
                        default='config/default.yaml', type=str)
    args = parser.parse_args()
    train(args)

加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。

设置组件

  • 初始化噪声调度器,用于在不同的时间步添加噪声。
  • 创建一个MNIST 数据集加载器
  • 实例化U-Net模型

检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。

训练循环:每个时期:

  • 遍历数据集,根据采样的时间步长向图像添加噪声。
  • 使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
  • 使用反向传播更新模型参数并保存模型检查点。

优化:使用 Adam 优化器和 MSE 损失函数来训练模型。

完成:打印 epoch 损失并在每个 epoch 结束时保存模型。

十、采样

import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def sample(model, scheduler, train_config, model_config, diffusion_config):

    xt = torch.randn((train_config['num_samples'],
                      model_config['im_channels'],
                      model_config['im_size'],
                      model_config['im_size'])).to(device)
    for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
        # Get prediction of noise
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
        
        # Save x0
        ims = torch.clamp(xt, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        grid = make_grid(ims, nrow=train_config['num_grid_rows'])
        img = torchvision.transforms.ToPILImage()(grid)
        if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
            os.mkdir(os.path.join(train_config['task_name'], 'samples'))
        img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
        img.close()


def infer(args):
    # Read the config file #
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    
    diffusion_config = config['diffusion_params']
    model_config = config['model_params']
    train_config = config['train_params']
    
    # Load model with checkpoint
    model = Unet(model_config).to(device)
    model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                  train_config['ckpt_name']), map_location=device))
    model.eval()
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])
    with torch.no_grad():
        sample(model, scheduler, train_config, model_config, diffusion_config)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
    parser.add_argument('--config', dest='config_path',
                        default='config/default.yaml', type=str)
    args = parser.parse_args()
    infer(args)

加载配置:从 YAML 文件读取模型、扩散和训练参数。

模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。

采样过程

  • 从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
  • 在每个时间步:
  • 使用模型预测噪音。
  • 使用调度程序计算去噪图像(x0)并更新当前噪声图像(xt)。
  • 将中间去噪图像作为 PNG 文件保存在输出目录中。

推理:执行采样过程并保存结果而不改变模型。

十一、配置(Default.yaml)

dataset_params:
  im_path: 'data/train/images'

diffusion_params:
  num_timesteps : 1000
  beta_start : 0.0001
  beta_end : 0.02

model_params:
  im_channels : 1
  im_size : 28
  down_channels : [32, 64, 128, 256]
  mid_channels : [256, 256, 128]
  down_sample : [True, True, False]
  time_emb_dim : 128
  num_down_layers : 2
  num_mid_layers : 2
  num_up_layers : 2
  num_heads : 4

train_params:
  task_name: 'default'
  batch_size: 64
  num_epochs: 40
  num_samples : 100
  num_grid_rows : 10
  lr: 0.0001
  ckpt_name: 'ddpm_ckpt.pth'

该配置文件提供了扩散模型的训练和推理的设置。

数据集参数im_path:指定训练图像的路径( )。

扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_startbeta_end)。

模型参数

  • 定义模型架构,包括:
  • 输入图像通道(im_channels)和大小(im_size)。
  • 下采样、中间处理和上采样的通道数。
  • 每一级是否发生下采样(down_sample)。
  • 各种块的嵌入尺寸和层数。

训练参数

  • 指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
  • 包括采样设置,例如用于可视化的样本数量和网格行数。

十二、数据集 (MNIST)

import glob
import os

import torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset


class MnistDataset(Dataset):
        self.split = split
        self.im_ext = im_ext
        self.images, self.labels = self.load_images(im_path)
    
    def load_images(self, im_path):
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        labels = []
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
                ims.append(fname)
                labels.append(int(d_name))
        print('Found {} images for split {}'.format(len(ims), self.split))
        return ims, labels
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        im = Image.open(self.images[index])
        im_tensor = torchvision.transforms.ToTensor()(im)
        
        # Convert input to -1 to 1 range.
        im_tensor = (2 * im_tensor) - 1
        return im_tensor

初始化:采用分割名称、图像文件扩展名(im_ext)和图像路径(im_path)。调用load_images以加载图像路径及其相应的标签。

图像加载load_images遍历 处的目录结构im_path,假设子目录已标记(例如,数字类别的01、...)。收集图像文件路径并根据文件夹名称分配标签。

数据集长度__len__返回图像的总数。

数据检索__getitem__通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2259857.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

航空航天总线协议分析ARINC429

ARINC429是商用飞机和运输机运用最广泛的总线之一,ARINC是美国航空无线电公司(Aeronautical Radio INC.)的缩写,ARINC429总线协议是美国航空电子工程委员会于1977年7月提出发表并获批准使用,它的规范全称是数字式信息传输系统(Digital Inform…

Unity UGUI图片循环列表插件

效果展示: 下载链接:https://gf.bilibili.com/item/detail/1111843026 概述: LoopListView2 是一个与 UGUI ScrollRect 相同的游戏对象的组件。它可以帮助 UGUI ScrollRect 以高效率和节省内存的方式支持任意数量的项目。 对于具有10,000个…

【经验分享】私有云运维的知识点

最近忙于备考没关注,有次点进某小黄鱼发现首页出现了我的笔记还被人收费了 虽然我也卖了一些资源,但我以交流、交换为主,笔记都是免费给别人看的 由于当时刚刚接触写的并不成熟,为了避免更多人花没必要的钱,所以决定公…

解决MAC装win系统投屏失败问题(AMD显卡)

一、问题描述 电脑接上HDMI线后,电脑上能显示有外部显示器接入,但是外接显示器无投屏画面 二、已测试的方法 1 更改电脑分辨,结果无效 2 删除BootCamp,结果无效 3更新电脑系统,结果无效 4 在设备管理器中&#…

huggingface NLP -Transformers库

1 特点 1.1 易于使用:下载、加载和使用最先进的NLP模型进行推理只需两行代码即可完成。 1.2 灵活:所有型号的核心都是简单的PyTorch nn.Module 或者 TensorFlow tf.kears.Model,可以像它们各自的机器学习(ML)框架中的…

1. 机器学习基本知识(2)——机器学习分类

1.4 机器学习分类 1.4.1 训练监督 1. 监督学习:已对训练数据完成标记 分类:根据数据及其分类信息来进行训练,使模型能够对新的数据进行分类 回归:给出一组特征值来预测目标数值 2. 无监督学习:没有对训练数据进行任…

[C#]使用winform部署ddddocr的onnx模型进行验证码识别文字识别文字检测

【算法介绍】 ddddocr是一个强大的Python OCR(光学字符识别)库,特别适用于验证码识别。它利用深度学习技术,如卷积神经网络(CNN)和循环神经网络(RNN),对图像中的文字进行…

day10 电商系统后台API——接口测试(使用postman)

【没有所谓的运气🍬,只有绝对的努力✊】 目录 实战项目简介: 1、用户管理(8个) 1.1 登录 1.2 获取用户数据列表 1.3 创建用户 1.4 修改用户状态 1.5 根据id查询用户 1.6 修改用户信息 1.7 删除单个用户 1.8 …

云服务器搭建lamp的wordpress

Ubuntu系统搭建过程目录 一、检查环境1.1 检查是否安装Apache1.2 检查是否安装Mysql1.3 检查是否安装PHP 二、安装Apache截图 三、安装Mysql3.1 安全安装配置3.2 修改权限和密码3.3 重启MySQL服务 四、安装PHP4.1 测试截图 五、下载并安装wordpress以及配置5.1 下载并解压移动5…

C#速成(GID+图形编程)

常用类 类说明Brush填充图形形状,画刷GraphicsGDI绘图画面,无法继承Pen定义绘制的对象直线等(颜色,粗细)Font定义文本格式(字体,字号) 常用结构 结构说明Color颜色Point在平面中定义点Rectan…

babeltrace与CTF相关学习笔记-4

babeltrace与CTF相关学习笔记-4 写在前面metadata_string 重头开始定位,操作meta的位置bt_ctf_trace_get_metadata_string stream部分内存的问题 写在前面 正在并行做几件事。 在编译过程中,突然想到,还是再详细研究一下之前的例程。 一是详…

多旋翼无人机 :桨叶设计—跷跷板结构

多旋翼无人机 :桨叶设计——跷跷板结构 前言跷跷板结构 前言 2024年11月,大疆发布了最新的农业无人机T70和T100。其中T70不同于以往的机型,在桨夹处采用了翘翘板结构,大疆将其命名为“挥舞桨叶”。 T70 无人机如下 放大其中螺旋…

低通滤波器,高通滤波器,公式

1 低通滤波器 :输出的是电容的电压 1 低通滤波器可以把低频信号上面的高频信号给滤掉 2 100hz正常通过 3 经过低通滤波器后,波形光滑,绿色波形。一致 4 电容充电速度跟不上输入信号的速度(因为加了电阻,限制了电流&…

如何打造个人知识体系?

第一,每个人的基本情况不同。比如我有一个类别跟「设计」相关,这是自己的个人爱好,但不一定适合其他人。再比如我还有一个类别跟「广告文案」相关,因为里面很多表达可以借用到演讲或写作中,这也不适合所有人。 第二&am…

5G中的ATG Band

Air to Ground Networks for NR是R18 NR引入的。ATG很多部分和NTN类似中的内容类似。比较明显不同的是,NTN的RF内容有TS 38.101-5单独去讲,而ATG则会和地面网络共用某些band,这部分在38.101-1中有描述。 所以会存在ATG与地面网络之间的相邻信…

《自动驾驶技术的深度思考:安全与伦理的挑战》

内容概要 在当今这个自动驾驶技术飞速发展的时代,我们生活的节奏恰似一场疾驰的赛车,然而,赛道上并非总是平坦。在这场技术革命中,安全与伦理问题像是潜伏在阴影中的幽灵,轮番考验着我们的道德底线与法律界限。 随着…

hbuilder 安卓app手机调试中基座如何设置

app端使用基座 手机在线预览功能 1.点击运行 2.点击运行到手机或者模拟器 3.制作自定义调试基座 4.先生成证书【可以看我上一篇文档写的有】,点击打包 5.打包出android自定义调试基座【android_debug.apk】,【就跟app打包一样需要等个几分钟】 6.点击运行到手…

【AIGC】如何高效使用ChatGPT挖掘AI最大潜能?26个Prompt提问秘诀帮你提升300%效率的!

还记得第一次使用ChatGPT时,那种既兴奋又困惑的心情吗?我是从一个对AI一知半解的普通用户,逐步成长为现在的“ChatGPT大神”。这一过程并非一蹴而就,而是通过不断的探索和实践,掌握了一系列高效使用的技巧。今天&#…

汽车免拆诊断案例 | 2014款保时捷卡宴车发动机偶尔无法起动

故障现象 一辆2014款保时捷卡宴车,搭载3.0T 发动机,累计行驶里程约为18万km。车主反映,发动机偶尔无法起动。 故障诊断 接车后试车,发动机起动及运转均正常。用故障检测仪检测,发动机控制单元(DME&#x…

aippt:AI 智能生成 PPT 的开源项目

aippt:AI 智能生成 PPT 的开源项目 在现代办公和学习中,PPT(PowerPoint Presentation)是一种非常重要的展示工具。然而,制作一份高质量的PPT往往需要花费大量的时间和精力。为了解决这一问题,aippt项目应运…