【AI理论学习】手把手利用PyTorch实现扩散模型DDPM

news2024/11/25 21:44:31

手把手利用PyTorch实现扩散模型DDPM

  • DDPM代码实现
    • 神经网络
      • 定义辅助函数
      • 位置嵌入
      • ResNet block
      • 注意力模块
      • 分组归一化
      • Conditional U-Net
    • 定义前向扩散过程
    • 定义PyTorch数据集+DataLoader
    • 采样
    • 训练模型
    • 采样
    • 后续阅读
  • 参考链接

上一篇博文已经手把手推导了扩散模型DDPM,本文利用PyTorch在Google Colab notebook中实现扩散模型。

DDPM代码实现

注意,扩散模型有几种观点。在这里,我们采用discrete-time(潜变量模型)的观点,但请务必查看其他观点。

神经网络

神经网络需要在特定的时间步接收噪声图像,并返回预测的噪声。需要注意的是,预测的噪声是一个与输入图像具有相同大小/分辨率的张量。因此从技术上讲,网络接收和输出具有相同形状的张量。在这种情况下,可以使用什么类型的神经网络呢?

在这里通常使用的方法与自编码器(Autoencoder)非常相似,你可能还记得它出现在典型的“intro to deep learning门”教程中。Autoencoders在encoder和decoder之间具有一个所谓的“bottleneck”层。编码器首先将图像编码为较小的隐藏表示,称为“bottleneck”,然后解码器将该隐表示解码回实际图像。这迫使网络在瓶颈层中仅保留最重要的信息。

在体系结构方面,DDPM 的作者采用了一个 U-Net 结构,该结构由(Ronneberger等人,2015)引入,当时在医学图像分割领域取得了最先进的结果。这个网络,像任何自编码器一样,由一个位于中间的瓶颈层组成,确保网络仅学习最重要的信息。重要的是,它在编码器和解码器之间引入了残差连接,大大改善了梯度流动(灵感来自于 He等人,2015年的 ResNet)。
U-net
如图所示,U-Net 模型首先对输入进行下采样(即在空间分辨率方面使输入变小),然后进行上采样

接下来,我们逐步实现这个网络。

!pip install -q -U einops datasets matplotlib tqdm

导入相关依赖库

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

定义辅助函数

首先,定义一些在实现神经网络时将使用的辅助函数和类。重要的是,定义了一个残差模块,它将输入简单地添加到特定函数的输出中(换句话说,将残差连接添加到特定函数中)。

def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

我们还为上采样和下采样操作定义了别名。

def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # 不再有阶梯卷积或池
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

位置嵌入

由于神经网络的参数在不同时间(噪声水平)之间共享,作者采用了受 Transformer(Vaswani et al., 2017)启发的正弦位置嵌入(sinusoidal position embeddings)来编码 t t t。这使得神经网络可以“know”它正在处理批次中的每个图像的特定时间步(噪声水平)。

SinusoidalPositionEmbeddings 模块接受形状为(batch_size,1)的张量作为输入(即批次中多个带噪声图像的噪声水平),并将其转换为形状为(batch_size,dim)的张量,其中 dim 是位置嵌入的维度。然后将其添加到每个残差块中,我们将在后面看到。

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

总之就是将 t t t 编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step。

ResNet block

接下来,定义 U-Net 模型的核心构建块。DDPM 的作者使用了 Wide ResNet block(Zagoruyko et al., 2016),但 Phil Wang 将标准卷积层替换为“weight standardized”版本,这与分组归一化的结合效果更好(有关详细信息,请参见(Kolesnikov et al., 2019))。

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.strides,
            self.padding,
            self.dilation,
            self.groups,
        )

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

注意力模块

现在,定义注意力模块,这是 DDPM 的作者在卷积块之间添加的。注意力是著名的 Transformer 架构(Vaswani et al., 2017)的构建块,在人工智能的各个领域,从自然语言处理和视觉到蛋白质折叠都取得了巨大的成功。Phil Wang 使用了两种注意力的变体:一种是常规的多头自注意力(multi-head self-attention)(就像在 Transformer 中使用的那样),另一种是线性注意力变体(linear attention variant)(Shen et al., 2018),其时间和内存要求与序列长度线性缩放,而不是常规注意力的二次缩放。

关于注意力机制的详细解释,请参阅 Jay Allamar 的精彩博客文章。

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q.softmax(dim=2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

分组归一化

DDPM 的作者在 U-Net 的卷积/注意力层之间交错使用了分组归一化group normalization)(Wu et al., 2018)。在下面,定义了一个 PreNorm 类,该类将在注意力层之前应用分组归一化,正如我们将在后面看到的。值得注意的是,关于在 Transformer 中是在注意力之前还是之后应用归一化一直存在争议。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

Conditional U-Net

现在,我们已经定义了所有构建块(position embeddings,ResNet blocks,attention和group normalization),现在该定义整个神经网络了。回想一下网络 ϵ θ ( x t , t ) \boldsymbol{\epsilon}_\theta\left(\mathbf{x_t}, t\right) ϵθ(xt,t)的工作是获取一批有噪声的图像及其各自的噪声水平,并输出添加到输入的噪声。更正式地说:

  • 网络采集一批形状为(batch_size, num_channels, height, width)的噪声图像和一批形状为 (batch_size, 1)的噪声水平作为输入,并返回一个形状为 (batch_size, num_channels, height, width)的张量

网络构建如下:

  • 首先,在一批有噪声的图像上应用卷积层,并计算噪声水平的位置嵌入(position embeddings)
  • 然后,执行一系列的下采样阶段(downsampling stages)。每个下采样阶段由2个ResNet blocks+ groupnorm+ attention+residual connection+a downsample operation组成
  • 在网络的中间,再次应用ResNet block,与attention交错
  • 接下来,执行一系列上采样阶段(upsampling stages)。每个上采样阶段由2个ResNet blocks + groupnorm +attention + residual connection + an upsample operation组成
  • 最后,在一个卷积层后面应用一个ResNet block。

最终,神经网络就像乐高积木一样层层堆叠(但了解它们是如何工作的很重要)。

class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, self_condition=False,
                 resnet_block_groups=4):
        super().__init__()
        # determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)  #  changed to 1 and 0 from 7,3

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )
        self.out_dim = default(out_dim, channels)
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
        return self.final_conv(x)

定义前向扩散过程

forward diffusion process在 T T T个时间步内逐渐将噪声从真实分布添加到图像中,这是根据variance schedule发生的。最初的DDPM作者采用了linear schedule

我们将前向过程的方差设置为线性增加的常数 from β 1 = 1 0 − 4 \beta_1=10^{-4} β1=104 to β T = 0.02 \beta_T=0.02 βT=0.02.

然而,在(Nichol et al.,2021)中表明,使用cosine schedule可以获得更好的结果。下面,我们定义 T T T个时间步的不同的schedule(我们稍后会选择一个):

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

首先,使用T=300个时间步的linear schedule,并从 β t \beta_t βt中定义我们需要的变量,例如,方差的累积乘积KaTeX parse error: Undefined control sequence: \bat at position 1: \̲b̲a̲t̲{\alpha}_t。下面的每个变量都只是一维张量,存储从 t t t T T T的数值。注意,我们还定义了一个extract函数,它允许我们按照 t t t提取一个批次的索引。

timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

我们将用猫图像说明如何在扩散过程的每个时间步中添加噪声:

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL image of shape HWC
image

cat
将噪声添加到Pytorch张量,而不是Pillow Images中。首先定义能够将PIL图像转换为Pytorch张量(可以在上面添加噪声)的图像转换,反之亦然。

这些转换非常简单:我们首先通过除以255(其结果能在[0,1]范围),然后确保它们在[-1,1]范围。DDPM 文中提到:

我们假设图像数据由在集合{ 0 ,1 、. . . , 255}中的整数组成,然后线性缩放到[−1, 1]。 这确保了神经网络逆向过程能够从标准正态先验 p ( x T ) p(x_T) p(xT)开始且一致缩放的输入上运行。"

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(),  # turn into torch Tensor of shape CHW, divide by 255
    Lambda(lambda t: (t * 2) - 1),

])

x_start = transform(image).unsqueeze(0)
x_start.shape

输出结果:
torch.Size([1, 3, 128, 128])

另外,还定义了反向变换(reverse transform),它接收一个PyTorch张量,其中包含[-1,1],并将它们重新转换回PIL图像:

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])
reverse_transform(x_start.squeeze())

cat
现在,可以像论文中定义前向扩散过程:

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

在特定的时间步中进行测试:

def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)

  # turn back into PIL image
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image
# take time step
t = torch.tensor([40])

get_noisy_image(x_start, t)

noise cat
可视化不同时间步的结果:

import matplotlib.pyplot as plt

# use seed for reproducability
torch.manual_seed(0)

# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

不同时间步结果可视化
在给定模型的情况下定义损失函数:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

dnoise_model就是上面定义的U-Net。在真实噪声和预测噪声之间使用Huber损失

定义PyTorch数据集+DataLoader

这里定义一个常规的PyTorch数据集。该数据集仅由真实数据集(如Fashion、MNIST、CIFAR-10或ImageNet)的图像组成,线性缩放至 [ − 1 , 1 ] [-1,1] [1,1]
每个图像都被调整为相同的大小,同时是随机水平翻转的。从论文中:

我们在 CIFAR10 的训练过程中使用了随机水平翻转;我们尝试了有翻转和没有翻转的训练,发现翻转可以稍微提高样本质量。

在这里,使用Datasets库轻松地从hub加载 Fashion MNIST 数据集。该数据集由已经具有相同分辨率的图像组成,即 28x28。

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset('fashion_mnist')
image_size = 28
channels = 1
batch_size = 128

接下来,定义一个函数,将在整个数据集上即时应用它。为此使用该with_transform功能。该函数只是应用了一些基本的图像预处理:随机水平翻转、重新缩放并最终使它们在[-1,1]范围。

from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations(e.g. using torchvision)
transform = Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) -1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())  # dict_keys(['pixel_values'])

采样

由于将在训练期间从模型中采样(来跟踪进度),因此定义了下面的代码。采样方法总结如下:
Sampling
从扩散模型生成新图像是通过逆扩散过程来实现的:从 T T T开始,从高斯分布中采样纯噪声,然后使用神经网络逐渐去噪(使用它学到的条件概率),直到在时间步 t = 0 t=0 t=0结束。如上所示,可以得到一个稍微降噪的图像 x t − 1 x_{t-1} xt1通过使用我们的噪声预测器插入均值的重新参数化。注意,方差是提前知道的。

理想情况下,最终会得到一张看起来像是来自真实数据分布的图像。下面的代码实现了这一点。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

训练模型

接下来,以常规 PyTorch 方式训练模型。我们还定义了一些逻辑来定期保存生成的图像,使用上面定义的sample定义的方法。

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

下面,定义模型,并将其移动到GPU,还定义了一个标准优化器(Adam)。

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

开始训练:

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

训练过程:

Loss: 0.5570111274719238
Loss: 0.06583500653505325
Loss: 0.06006840616464615
Loss: 0.051015421748161316
Loss: 0.0394190177321434
Loss: 0.04075610265135765
Loss: 0.039987701922655106
Loss: 0.03415030241012573
Loss: 0.030019590631127357
Loss: 0.036297883838415146
Loss: 0.037256866693496704
Loss: 0.03864285722374916
Loss: 0.03298967331647873
Loss: 0.03331328555941582
Loss: 0.027535393834114075
Loss: 0.03803558647632599
Loss: 0.03721949830651283
Loss: 0.03478413075208664
Loss: 0.03918925300240517
Loss: 0.03608154132962227
Loss: 0.027622627094388008
Loss: 0.02948344498872757
Loss: 0.029868196696043015
Loss: 0.03154699504375458
Loss: 0.029723389074206352
Loss: 0.039195798337459564
Loss: 0.032130151987075806
Loss: 0.031276602298021317
Loss: 0.03440115600824356
Loss: 0.030476151034235954

采样

要从模型中采样,可以使用上面定义的采样函数:

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

采样结果
看起来模型能够生成一件漂亮的 T 恤!请记住,用来训练的数据集的分辨率非常低(28x28)。还可以创建去噪过程的 gif 图像:

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

gif图像

后续阅读

注意,DDPM 论文表明,扩散模型是(无)条件图像生成的一个promising的方向。从DDPM提出到现在已经(极大地)得到了改进,尤其是在文本条件图像生成方面。下面,列出了一些截至到2022年6月7日之前的重要的(但远非详尽的)后续工作:

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021): 发现学习条件分布的方差(除均值外)有助于提高性能
  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021): 引入了级联扩散,它包含多个扩散模型的pipeline,可生成分辨率不断提高的图像,用于高保真图像合成
  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021): 表明扩散模型可以通过改进 U-Net 架构以及引入分类器引导,达到比SOTA生成式模型更好的效果。
  • Classifier-Free Diffusion Guidance (Ho et al., 2021): 表明不需要使用分类器来指导扩散模型,只需要使用单个神经网络联合训练条件扩散模型和无条件扩散模型
  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022): 使用先验将文字说明转换为 CLIP 图像嵌入,然后使用扩散模型将其解码为图像
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022): 表明将大型预训练语言模型(例如 T5)与级联扩散相结合非常适用于文本到图像合成

参考链接

  1. The Annotated Diffusion Model
  2. 带你深入理解扩散模型DDPM
  3. 扩散模型全新课程:扩散模型从0到1实现!
  4. Denoising Diffusion Probabilitistic Models
  5. 《Diffusion Models Beat GANs on Image Synthesis》阅读笔记
  6. How Diffusion Models Work
  7. DDPM交叉熵损失函数推导
  8. DDPM(Denoising Diffusion Probabilistic Models)扩散模型简述
  9. What are Diffusion Models?
  10. 由浅入深了解Diffusion Model
  11. 什么是Diffusion模型?
  12. Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
  13. Denoising Diffusion Probabilistic Model, in Pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/867892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1. 消息队列

消息队列 1.1. MQ 的相关概念 1.1.1. 什么是 MQ MQ(message queue),从字面意思上看,本质是个队列,FIFO 先入先出,只不过队列中存放的内容是message 而已,还是一种跨进程的通信机制,用于上下游传递消息。在…

性能测试—Jmeter工具

文章目录 性能测试1. 术语介绍2. 方法3. 应用场景4. 工具(Jmeter)4.1 介绍4.2 元件和组件4.2.2 元件4.2.1 组件 4.3 作用域4.4 参数化4.5 执行脚本 性能测试 1. 术语介绍 响应时间(Response time):对请求作出响应所需要的时间。 在互联网上对…

Tik Tok娱乐+电商MCN怎么做?

在美国外的热门市场中,TikTok 主要做的区域市场包括中东、拉美、欧洲和东亚,而这里面适合做电商的其实并不多。 欧洲、东亚都属于成熟市场,且 TikTok 本身在欧洲面临 DSA 法案更严格的审查,与在英国相同,欧洲各市场消…

【二分+贪心】CF1622 C

Problem - 1622C - Codeforces 题意: 思路: 首先,观察样例可知,肯定是把原本的最小值减到某个值,然后再复制几次 复制的时候肯定是从大到小复制 那把最小值减到哪个值是不确定的,考虑枚举这个值&#x…

并发编程--------JUC集合

并发集合 一、ConcurrentHashMap 1.1 存储结构 ConcurrentHashMap是线程安全的HashMap ConcurrentHashMap在JDK1.8中是以CASsynchronized实现的线程安全 CAS:在没有hash冲突时(Node要放在数组上时) synchronized:在出现hash…

STM32CubeMX之freeRTOS事件组

当有多个判断,才会执行的时候,就会有事件组 事件组就是24个标志位的组合,或操作或者与操作就可以操作其 例如发射导弹 需要很多人都同意才能发送 V1版本无法自动添加事件组 这里手动创建事件组 这里是基本的使用 置1操作 这里进行事件组的…

初创体验版彩虹知识付费商城源码 V3.4支持二级分类,多级分销,秒杀,砍价,团购,首页继续浏览

最新彩虹知识付费商城初创体验版,支持二级分类,多级分销,秒杀,砍价,团购,首页继续浏览,分站个人虚拟余额自定义,最新批量对接,批量下载图片,批量替换标题&…

工厂物流管理:提升生产效率的关键驱动力

工厂物流管理在现代制造业中扮演着至关重要的角色。它涉及到物料的采购、生产过程中的物料运输和仓储管理,以及最终产品的分发。 1. 定义和重要性: 工厂物流管理是指通过合理规划、组织和控制物流活动,确保物料和产品在生产过程中的高效流动。它的目标是…

【Java多线程学习7】Java线程池技术

线程池技术 一、什么是线程池 线程池顾名思义是管理一组线程的池子。当有任务要处理时,直接从线程池中获取线程来处理,处理完之后线程不会立即销毁,而是等待下一个任务。 二、为什么要使用线程池? 线程池的作用? 1、降低资源…

【大数据】Flink 详解(二):核心篇 Ⅲ

Flink 详解(二):核心篇 Ⅲ 29、Flink 通过什么实现可靠的容错机制? Flink 使用 轻量级分布式快照,设计检查点(checkpoint)实现可靠容错。 30、什么是 Checkpoin 检查点? Checkpoint …

四化智造MES(API)与金蝶云星空对接集成派工作业打通生产订单新增

四化智造MES(API)与金蝶云星空对接集成派工作业打通生产订单新增 对接系统:四化智造MES(API) MES建立统一平台上通过物料防错防错、流程防错、生产统计、异常处理、信息采集和全流程追溯等精益生产和精细化管理&#x…

mybatis如何生成和执行动态sql

文章目录 1. 相关代码2. SQL 语句解析全流程2.1 涉及到的重要类2.2 解析标签2.2.1 \<include>2.2.2 \<selectKey>2.2.3 处理 SQL 语句 3. 获取真正执行的sql 1. 相关代码 package com.boge.mapper;import com.boge.pojo.User;import java.util.List;public interf…

情感资源日记Resource Diary

什么是 Resource Diary &#xff1f; Resource Diary 是自托管的情感资源日记&#xff0c;允许用户跟踪特定任务的情感反应以进行趋势识别。 按照作者的设计&#xff0c;对每个任务&#xff0c;通过 /- 号来表达任务完成后的情绪&#xff0c;等级从 0 到 5 &#xff0c;其中&am…

大语言模型之三 InstructGPT训练过程

大语言模型 GPT历史文章中简介的大语言模型的的发展史&#xff0c;并且简要介绍了大语言模型的训练过程&#xff0c;本篇文章详细阐述训练的细节和相关的算法。 2020年后全球互联网大厂、AI创业公司研发了不少AI超大模型&#xff08;百亿甚至千亿参数&#xff09;&#xff0c;…

2023年中国政务云行业发展概况及发展趋势分析:政务云由基础设施建设向云服务运营转变[图]

政务云是指运用云计算技术&#xff0c;统筹利用已有的机房、计算、存储、网络、安全、应用支撑、信息资源等&#xff0c;发挥云计算虚拟化、高可靠性、高通用性、高可扩展性及快速、按需、弹性服务等特征&#xff0c;为政府行业提供基础设施、支撑软件、应用系统、信息资源、运…

QGIS3.28的二次开发六:VS不借助QT插件创建UI界面

上一篇博客我们说了在VS中如何使用QT插件来创建UI界面&#xff0c;但是我们二次开发QGIS的第一篇博客就说了&#xff0c;最好使用OSGeo4W中自动下载的QT进行QGIS二次开发&#xff0c;这样兼容性是最好的&#xff0c;那么该如何在VS中不使用外部安装的QT以及QT的VS插件情况下进行…

使用 LangChain 构建 LLM 应用详细教程(附python代码演练)

介绍 欢迎来到语言处理的未来&#xff01;在一个语言是连接人与技术的桥梁的世界中&#xff0c;自然语言处理&#xff08;NLP&#xff09;的进步为我们带来了令人难以置信的机会。其中一个重要的进步是革命性的语言模型&#xff0c;即大型语言模型&#xff08;LLM&#xff09;&…

图解java.util.concurrent并发包源码系列——Condition条件等待队列深入详解

图解java.util.concurrent并发包源码系列——Condition条件等待队列深入详解 Condition的作用Condition的原理Condition源码Condition的定义和Condition对象的获取await方法addConditionWaiter方法unlinkCancelledWaiters方法 fullyRelease方法isOnSyncQueue方法checkInterrupt…

一台阿里云服务器怎么部署多个网站?以CentOS系统为例

本文阿里云百科介绍如何在CentOS 7系统的ECS实例上使用Nginx搭建多个Web站点。本教程适用于熟悉Linux操作系统&#xff0c;希望合理利用资源、统一管理站点以提高运维效率的用户。比如&#xff0c;您可以在一台云服务器上配置多个不同分类的博客平台或者搭建多个Web站点实现复杂…

为新手和非技术人员提供扩展Web网站提供一个升级指南

本指南总结了扩展的基本原则&#xff0c;从一台服务器扩展到能够服务数百万用户的Web应用程序。它面向在技术领域工作的新手和非开发人员。因此&#xff0c;如果您刚刚部署了您的多云平台VPN设置&#xff0c;那么本文并不适合您。 话不多说&#xff0c;那就让我们开始吧&#x…