目录
一、说明
二、从头开始实施
三、线性噪声调度器
四、时间嵌入
五、下层DownBlock类块
六、中间midBlock类块
七、UpBlock上层类块
八、UNet 架构
九、训练
十、采样
十一、配置(Default.yaml)
十二、数据集 (MNIST)
keyword: Diffusion Probabilistic Models
一、说明
扩散过程由前向阶段组成,其中图像通过在每个步骤中添加高斯噪声逐渐损坏。经过许多步骤后,图像实际上变得与从正态分布中采样的随机噪声无法区分。这是通过在每个时间步骤 xₜ 应用过渡函数来实现的,其中 β 表示在 t-1 时添加到图像中的预定噪声量,以产生 t 时的图像。
在前面的讨论中,我们确定设置 α=1−β 并计算每个时间步骤中这些 α 值的累积乘积,使我们能够在任何给定步骤 t 直接从原始图像过渡到噪声版本。在反向过程中,模型被训练以近似反向分布。由于正向和反向过程都是高斯的,因此目标是让模型预测反向分布的均值和方差。
通过详细的推导,从最大化观测数据的对数似然性这一目标出发,我们得出需要最小化真实去噪分布(以 x₀ 为条件)与模型预测分布之间的 KL 散度(以特定均值和方差为特征)。方差固定为与目标分布的方差匹配,而均值则以相同形式重写。最小化 KL 散度简化为最小化预测噪声与实际噪声样本之间的平方差。
训练过程包括对图像进行采样、选择时间步长 t,以及添加从正态分布中采样的噪声。然后将 t 处的噪声图像传递给模型。从噪声时间表得出的累积乘积项确定随时间增加的噪声。损失函数是原始噪声样本与模型预测之间的均方误差 (MSE)。
二、从头开始实施
对于图像生成,我们从学习到的反向分布中进行采样,从正态分布中的随机噪声样本 xₜ 开始。使用与 xₜ 和预测噪声相同的公式计算平均值,方差与地面真实去噪分布相匹配。使用重新参数化技巧,我们反复从这个反向分布中采样以生成 x₀。在 x₀ 处,没有添加额外的噪声;相反,平均值直接作为最终输出返回。
为了实现扩散过程,我们需要处理正向和反向阶段的计算。我们将创建一个噪声调度程序来管理这些任务。在正向过程中,给定一个图像、一个噪声样本和一个时间步长 t,调度程序将使用正向方程返回图像的噪声版本。为了优化效率,它将预先计算并存储 α(1−β) 的值以及所有时间步长中 α 的累积乘积。
作者采用了线性噪声调度,其中 β 在 1,000 个时间步骤内从 1×10⁻⁴ 线性缩放到 0.02。调度程序还处理反向过程:给定 xt 和模型预测的噪声,它将通过从反向分布中采样来计算 xₜ₋₁。这涉及使用各自的方程计算均值和方差,并通过重新参数化技巧生成样本。
为了支持这些计算,调度程序还将存储 1-αₜ、1-累积乘积项以及该项的平方根的预先计算的值。
三、线性噪声调度器
import torch
class LinearNoiseScheduler:
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
使用传递给此类的参数初始化所有参数后,我们将定义 β 值从起始范围到结束范围线性增加,确保 βₜ 从 0 进展到最后的时间步骤。接下来,我们将设置正向和反向过程方程所需的所有变量。
def add_noise(self, original, noise, t):
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
# Apply and Return Forward process equation
return (sqrt_alpha_cum_prod.to(original.device) * original
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
该add_noise()
函数表示正向过程。它以原始图像、噪声样本和时间步长 ttt 作为输入。图像和噪声的维度为 b×h×w,而时间步长为大小为 b 的一维张量。对于正向过程,我们计算给定时间步长的累积乘积项的平方根和 1-累积乘积项。这些值被重新整形为维度 b×1×1×1。最后,我们应用正向过程方程来生成噪声图像。
def sample_prev_timestep(self, xt, noise_pred, t):
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
x0 = torch.clamp(x0, -1., 1.)
mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
variance = variance * self.betas.to(xt.device)[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
调度程序类中的下一个函数处理反向过程。它使用噪声图像 xₜ、模型的噪声预测和时间步长 t 作为输入,从学习到的反向分布中生成样本。我们保存原始图像预测 x₀ 以供可视化,它是通过重新排列正向过程方程以使用噪声预测而不是实际噪声来计算 x₀ 获得的。
对于逆向过程中的采样,我们使用逆均值方程计算均值。在 t=0 时,我们只需返回均值。对于其他时间步骤,噪声会添加到均值中,方差与以 x₀ 为条件的地面真实去噪分布的方差相同。最后,我们使用计算出的均值和方差从高斯分布中采样,应用重新参数化技巧来生成结果。
这样就完成了噪声调度程序,它管理添加噪声的正向过程和采样的反向过程。对于扩散模型,我们可以灵活地选择任何架构,只要它满足两个关键要求。第一,输入和输出形状必须相同,第二,必须有一种方法可以整合时间步长信息。
作者图片
无论是在训练期间还是采样期间,时间步长信息始终是可访问的。包含此信息有助于模型更好地预测原始噪声,因为它表明输入图像中有多少是噪声。我们不仅向模型提供图像,还提供相应的时间步长。
对于模型架构,我们将使用 UNet,这也是原作者的选择。为了确保一致性,我们将复制 Hugging Face 的 Diffusers 管道中使用的稳定扩散 UNet 中实现的块、激活、规范化和其他组件的精确规格。
作者图片
时间步长由时间嵌入块处理,该块采用大小为b(批次大小)的时间步长的一维张量,并输出批次中每个时间步长的大小为t_emb_dim的表示。此块首先通过嵌入空间将整数时间步长转换为矢量表示。然后,此嵌入通过中间带有激活函数的两个线性层,产生最终的时间步长表示。对于嵌入空间,作者使用了 Transformers 中常用的正弦位置嵌入方法。在整个架构中,使用的激活函数是 S 形线性单元 (SiLU),但也可以选择其他激活函数。
作者图片
UNet架构遵循简单的编码器-解码器设计。编码器由多个下采样块组成,每个块都会减少输入的空间维度(通常减半),同时增加通道数量。最终下采样块的输出由中间块的几层处理,所有层都以相同的空间分辨率运行。随后,解码器采用上采样块,逐步增加空间维度并减少通道数量,最终匹配原始输入大小。在解码器中,上采样块通过残差跳过连接以相同的分辨率集成相应下采样块的输出。虽然大多数扩散模型都遵循这种通用的 UNet 架构,但它们在各个块内的具体细节和配置上有所不同。
作者图片
大多数变体中的下行块通常由ResNet 块、后跟自注意力块和下采样层组成。每个 ResNet 块都使用一系列操作构建:组归一化、激活层和卷积层。此序列的输出将通过另一组归一化、激活和卷积层。通过将第一个归一化层的输入与第二个卷积层的输出相结合来添加残差连接。这个完整的序列形成ResNet 块,可以将其视为通过残差连接连接的两个卷积块。
在 ResNet 块之后,有一个规范化步骤、一个自注意力层和另一个残差连接。虽然模型通常使用多个 ResNet 层和自注意力层,但为简单起见,我们的实现将只使用每个层的一层。
为了整合时间信息,每个 ResNet 块都包含一个激活层,后面跟着一个线性层,用于处理时间嵌入表示。时间嵌入表示为大小为t_emb_dim的张量,通过此线性层将其投影到与卷积层输出具有相同大小和通道数的张量中。这样就可以通过在空间维度上复制时间步长表示,将时间嵌入添加到卷积层的输出中。
作者图片
另外两个块使用相同的组件,只是略有不同。上块完全相同,只是它首先将输入上采样为两倍空间大小,然后在整个通道维度上集中相同空间分辨率的下块输出。然后我们有相同的 resnet 层和自注意力块。中间块的层始终将输入保持为相同的空间分辨率。hugging face 版本首先有一个 resnet 块,然后是自注意力层和 resnet 层。对于这些 resnet 块中的每一个,我们都有一个时间步长投影层。现有的时间步长表示会经过这些块,然后被添加到 resnet 的第一个卷积层的输出中。
四、时间嵌入
import torch
import torch.nn as nn
def get_time_embedding(time_steps, temb_dim):
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
第一个函数为给定的时间步长get_time_embedding
生成时间嵌入。它受到 Transformer 模型中使用的正弦位置嵌入的启发。time_steps
:时间步长值的张量(形状:[B]
其中B
是批次大小)。每个值代表批次元素的一个离散时间步长。temb_dim
:时间嵌入的维数。这决定了每个时间步长的生成嵌入的大小。
确保这temb_dim
是均匀的,因为正弦嵌入需要将嵌入分成两半,分别表示正弦和余弦分量。无缝扩展以处理任何批量大小或嵌入维度。
五、下层DownBlock类块
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim,
down_sample=True, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for i in range(num_layers)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
4, 2, 1) if self.down_sample else nn.Identity()
def forward(self, x, t_emb):
out = x
for i in range(self.num_layers):
# Resnet block of Unet
resnet_input = out
out = self.resnet_conv_first[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Attention block of Unet
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
out = self.down_sample_conv(out)
return out
DownBlock 类结合了ResNet 块、自注意力块和可选的下采样,并集成了时间嵌入来整合时间步长信息。将卷积层与残差连接相结合,以实现更好的梯度流和更高效的学习。将时间步长表示投影到特征空间中,使模型能够整合时间相关信息。通过对所有空间位置之间的关系进行建模来捕获长距离依赖关系。减少空间维度以专注于更深层中更大规模的特征。
参数:
in_channels
:输入通道数。out_channels
:输出通道数。t_emb_dim
:时间嵌入的维度。down_sample
:布尔值,确定是否在块末尾应用下采样。num_heads
:多头注意力层中的注意力头的数量。num_layers
:此块中的 ResNet + 注意力层的数量。
ResNet块:
resnet_conv_first
:ResNet 块的第一个卷积层。t_emb_layers
:时间嵌入投影层。resnet_conv_second
:ResNet 块的第二个卷积层。residual_input_conv
:用于残差连接的 1x1 卷积。
自注意力模块:
attention_norms
:在注意力机制之前对规范化层进行分组。attentions
:多头注意力层。
下采样:
down_sample_conv
:应用卷积来减少空间维度(如果down_sample=True
)。
Forward Pass 方法定义了如何x
通过块处理输入张量:out
初始化为输入x
。对于每一层,我们都有 ResNet Block 和 Self-Attention Block。
在 ResNet Block 中,我们有第一个 卷积层,它应用 GroupNorm、SiLU 激活和 3x3 卷积,以及一个时间嵌入函数,它将时间嵌入传递t_emb
到线性层(投影到out_channels
),并将此投影时间嵌入添加到out
(在空间维度上广播)。然后我们有第二个卷积和一个残差连接,它将原始输入(resnet_input
)添加到第二个卷积的输出。
在自注意力模块中,我们将空间维度扁平化为一个维度(h * w
)以用于注意力机制。规范化输入并转置以匹配注意力层输入格式。多头注意力in_attn
使用查询、键和值执行自注意力。重塑回转置并重塑回原始空间维度。残差连接和下采样。
六、中间midBlock类块
class MidBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers+1)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers + 1)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers+1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(8, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers+1)
]
)
def forward(self, x, t_emb):
out = x
# First resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i+1](out)
out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i+1](out)
out = out + self.residual_input_conv[i+1](resnet_input)
return out
该类MidBlock
是位于扩散模型中 U-Net 架构中间的模块。它由ResNet 块和自注意力层组成,并集成了时间嵌入来处理时间信息。这是用于去噪扩散等任务的模型的重要组成部分。此外,我们还有:
- 时间嵌入:通过将时间信息(例如,扩散模型中的去噪步骤)投影到特征空间并将其添加到卷积特征中来合并时间信息。
- 层迭代:在注意力和ResNet 块之间交替,按
num_layers
这些组合的顺序处理输入。
七、UpBlock上层类块
class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(8, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down, t_emb):
x = self.up_sample_conv(x)
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
resnet_input = out
out = self.resnet_conv_first[i](out)
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
该类UpBlock
是 U-Net 类架构的解码器阶段的一部分,通常用于扩散模型或其他图像生成/分割任务。它结合了上采样、跳过连接、ResNet 块和自注意力来重建输出图像,同时保留早期编码器阶段的细粒度细节。
- 上采样:通过转置卷积(
ConvTranspose2d
)实现,以增加特征图的空间分辨率。 - 跳过连接:允许解码器重用编码器的详细特征,帮助重建。
- ResNet Block:使用卷积层处理输入,集成时间嵌入,并添加残差连接以实现高效的梯度流。
- 自我注意力:捕获远程空间依赖关系以保留全局上下文。
- 时间嵌入:对时间信息进行编码并将其注入特征图,这对于处理动态数据的模型(如扩散模型)至关重要。
八、UNet 架构
class Unet(nn.Module):
def __init__(self, model_config):
super().__init__()
im_channels = model_config['im_channels']
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.t_emb_dim = model_config['time_emb_dim']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
# Initial projection from sinusoidal time embedding
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim)
)
self.up_sample = list(reversed(self.down_sample))
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
self.downs = nn.ModuleList([])
for i in range(len(self.down_channels)-1):
self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
down_sample=self.down_sample[i], num_layers=self.num_down_layers))
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels)-1):
self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
num_layers=self.num_mid_layers))
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels)-1)):
self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
self.norm_out = nn.GroupNorm(8, 16)
self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
def forward(self, x, t):
# Shapes assuming downblocks are [C1, C2, C3, C4]
# Shapes assuming midblocks are [C4, C4, C3]
# Shapes assuming downsamples are [True, True, False]
# B x C x H x W
out = self.conv_in(x)
# B x C1 x H x W
# t_emb -> B x t_emb_dim
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
down_outs = []
for idx, down in enumerate(self.downs):
down_outs.append(out)
out = down(out, t_emb)
# down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
# out B x C4 x H/4 x W/4
for mid in self.mids:
out = mid(out, t_emb)
# out B x C3 x H/4 x W/4
for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb)
# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
# out B x C x H x W
return out
该类是U-Net 架构Unet
的实现,专为图像处理任务而设计,例如分割或生成,通常用于扩散模型。该网络包括下采样、中级处理和上采样阶段。它利用时间嵌入执行动态任务(例如扩散模型),利用跳过连接保留空间信息,利用 GroupNorm 进行归一化。
作者图片
- 时间嵌入:实现时间动态。
- 跳过连接:通过连接将细粒度的空间细节集成到解码器中。
- 灵活的架构:允许通过
model_config
不同的深度、分辨率和功能丰富度进行定制。 - 规范化和激活:GroupNorm 确保稳定的训练,而 SiLU 激活则改善非线性。
- 输出一致性:确保输出图像保留原始的空间尺寸和通道数。
九、训练
import torch
import yaml
import argparse
import os
import numpy as np
from tqdm import tqdm
from torch.optim import Adam
from dataset.mnist_dataset import MnistDataset
from torch.utils.data import DataLoader
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train(args):
with open(args.config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
print(config)
diffusion_config = config['diffusion_params']
dataset_config = config['dataset_params']
model_config = config['model_params']
train_config = config['train_params']
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
beta_start=diffusion_config['beta_start'],
beta_end=diffusion_config['beta_end'])
# Create the dataset
mnist = MnistDataset('train', im_path=dataset_config['im_path'])
mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4)
# Instantiate the model
model = Unet(model_config).to(device)
model.train()
# Create output directories
if not os.path.exists(train_config['task_name']):
os.mkdir(train_config['task_name'])
# Load checkpoint if found
if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):
print('Loading checkpoint as found one')
model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
train_config['ckpt_name']), map_location=device))
# Specify training parameters
num_epochs = train_config['num_epochs']
optimizer = Adam(model.parameters(), lr=train_config['lr'])
criterion = torch.nn.MSELoss()
# Run training
for epoch_idx in range(num_epochs):
losses = []
for im in tqdm(mnist_loader):
optimizer.zero_grad()
im = im.float().to(device)
# Sample random noise
noise = torch.randn_like(im).to(device)
# Sample timestep
t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)
# Add noise to images according to timestep
noisy_im = scheduler.add_noise(im, noise, t)
noise_pred = model(noisy_im, t)
loss = criterion(noise_pred, noise)
losses.append(loss.item())
loss.backward()
optimizer.step()
print('Finished epoch:{} | Loss : {:.4f}'.format(
epoch_idx + 1,
np.mean(losses),
))
torch.save(model.state_dict(), os.path.join(train_config['task_name'],
train_config['ckpt_name']))
print('Done Training ...')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm training')
parser.add_argument('--config', dest='config_path',
default='config/default.yaml', type=str)
args = parser.parse_args()
train(args)
加载配置:从 YAML 文件读取训练配置(如数据集路径、超参数和模型设置)。
设置组件:
- 初始化噪声调度器,用于在不同的时间步添加噪声。
- 创建一个MNIST 数据集加载器。
- 实例化U-Net模型。
检查点管理:检查现有检查点,如果可用则加载。创建保存检查点和输出所需的目录。
训练循环:每个时期:
- 遍历数据集,根据采样的时间步长向图像添加噪声。
- 使用模型预测噪声并计算损失(预测噪声和实际噪声之间的 MSE)。
- 使用反向传播更新模型参数并保存模型检查点。
优化:使用 Adam 优化器和 MSE 损失函数来训练模型。
完成:打印 epoch 损失并在每个 epoch 结束时保存模型。
十、采样
import torch
import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from tqdm import tqdm
from models.unet_base import Unet
from scheduler.linear_noise_scheduler import LinearNoiseScheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def sample(model, scheduler, train_config, model_config, diffusion_config):
xt = torch.randn((train_config['num_samples'],
model_config['im_channels'],
model_config['im_size'],
model_config['im_size'])).to(device)
for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
# Get prediction of noise
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
# Use scheduler to get x0 and xt-1
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
# Save x0
ims = torch.clamp(xt, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config['num_grid_rows'])
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
os.mkdir(os.path.join(train_config['task_name'], 'samples'))
img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
img.close()
def infer(args):
# Read the config file #
with open(args.config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
print(config)
diffusion_config = config['diffusion_params']
model_config = config['model_params']
train_config = config['train_params']
# Load model with checkpoint
model = Unet(model_config).to(device)
model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
train_config['ckpt_name']), map_location=device))
model.eval()
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
beta_start=diffusion_config['beta_start'],
beta_end=diffusion_config['beta_end'])
with torch.no_grad():
sample(model, scheduler, train_config, model_config, diffusion_config)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
parser.add_argument('--config', dest='config_path',
default='config/default.yaml', type=str)
args = parser.parse_args()
infer(args)
加载配置:从 YAML 文件读取模型、扩散和训练参数。
模型设置:加载训练好的 U-Net 模型检查点。初始化噪声调度程序以指导反向扩散过程。
采样过程:
- 从随机噪声开始,并在指定的时间步内迭代地对其进行去噪。
- 在每个时间步:
- 使用模型预测噪音。
- 使用调度程序计算去噪图像(
x0
)并更新当前噪声图像(xt
)。 - 将中间去噪图像作为 PNG 文件保存在输出目录中。
推理:执行采样过程并保存结果而不改变模型。
十一、配置(Default.yaml)
dataset_params:
im_path: 'data/train/images'
diffusion_params:
num_timesteps : 1000
beta_start : 0.0001
beta_end : 0.02
model_params:
im_channels : 1
im_size : 28
down_channels : [32, 64, 128, 256]
mid_channels : [256, 256, 128]
down_sample : [True, True, False]
time_emb_dim : 128
num_down_layers : 2
num_mid_layers : 2
num_up_layers : 2
num_heads : 4
train_params:
task_name: 'default'
batch_size: 64
num_epochs: 40
num_samples : 100
num_grid_rows : 10
lr: 0.0001
ckpt_name: 'ddpm_ckpt.pth'
该配置文件提供了扩散模型的训练和推理的设置。
数据集参数im_path
:指定训练图像的路径( )。
扩散参数:设置扩散过程的时间步数和噪声参数的范围(beta_start
和beta_end
)。
模型参数:
- 定义模型架构,包括:
- 输入图像通道(
im_channels
)和大小(im_size
)。 - 下采样、中间处理和上采样的通道数。
- 每一级是否发生下采样(
down_sample
)。 - 各种块的嵌入尺寸和层数。
训练参数:
- 指定训练配置,如任务名称、批量大小、时期、学习率和检查点文件名。
- 包括采样设置,例如用于可视化的样本数量和网格行数。
十二、数据集 (MNIST)
import glob
import os
import torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
class MnistDataset(Dataset):
self.split = split
self.im_ext = im_ext
self.images, self.labels = self.load_images(im_path)
def load_images(self, im_path):
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
ims = []
labels = []
for d_name in tqdm(os.listdir(im_path)):
for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
ims.append(fname)
labels.append(int(d_name))
print('Found {} images for split {}'.format(len(ims), self.split))
return ims, labels
def __len__(self):
return len(self.images)
def __getitem__(self, index):
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
return im_tensor
初始化:采用分割名称、图像文件扩展名(im_ext
)和图像路径(im_path
)。调用load_images
以加载图像路径及其相应的标签。
图像加载:load_images
遍历 处的目录结构im_path
,假设子目录已标记(例如,数字类别的0
、1
、...)。收集图像文件路径并根据文件夹名称分配标签。
数据集长度:__len__
返回图像的总数。
数据检索:__getitem__
通过索引检索图像,将其转换为张量,并将像素值缩放到范围 -1,1-1,1-1,1。