原理+代码:Diffusion Model 直观理解

news2024/11/20 4:47:18

原理部分

直观理解

数学形式

网络如何训练?

训练一个怎样的网络?

代码部分

Network helpers

Positional embeddings

ResNet/ConvNeXT block

Attention module

Conditional U-Net

定义前向扩散过程

用一个实例来说明前向加噪过程

损失函数

定义数据集 PyTorch Dataset 和 DataLoader

采样

模型训练

相关文章

Stable Diffusion 代码解读

初始化图像:扩散模型中被忽视的控制机制

本文基于The Annotated Diffusion Model

原理部分

扩散模型:和其他生成模型一样,实现从噪声(采样自简单的分布)到数据样本的转换。

扩散模型的两个步骤:

  • 一个固定的(预先定义好的)前向扩散过程 � :逐步向图片增加噪声直到最终得到一张纯噪声。
  • 一个学习得到的去噪声过程 �� (a learned reverse denoising diffusion process): 训练一个神经网络去逐渐地从一张纯噪声中消除噪声,直到得到一张真正的图片。

前向与后向的步数由下标 � 定义,并且有预先定义好的总步数 � (DDPM原文中为1000)。

�=0 时为从数据集中采样得到的一张真实图片, �=� 时近似为一张纯粹的噪声。

直观理解

为了看懂扩散模型查了很多资料,但是要么就是大量的数学公式,一行行公式推完了还是不知道它想干啥。要么就是高视角,上来就和能量模型,VAE放一块儿对比说共同点和不同点,看完还是云里雾里。然而事实上下面几句话就能把扩散模型说明白了

扩散模型的目的是什么?

学习从纯噪声生成图片的方法

扩散模型是怎么做的?

训练一个U-Net,接受一系列加了噪声的图片,学习预测所加的噪声

前向过程在干啥?

逐步向真实图片添加噪声最终得到一个纯噪声

对于训练集中的每张图片,都能生成一系列的噪声程度不同的加噪图片

在训练时,这些 【不同程度的噪声图片 + 生成它们所用的噪声】 是实际的训练样本

反向过程在干啥?

训练好模型后,采样、生成图片

数学形式

我们需要的,是一个可供神经网络学习的损失函数! (a tractable loss function which our neural network needs to optimize)

�(�0) 是真实数据分布(也就是真实的大量图片),从这个分布中采样即可得到一张真实图片。( �0∼�(�0) )

我们定义前向扩散过程为 �(��|��−1) ,也就是每一个step向图片添加噪声的过程。

我们预先定义好一系列的参数(schedule): 0<�1<�2<...<��<1

那么下一步的分布就是 �(��|��−1)=�(��;1−����−1,��I)

下一步的(加了噪声之后的图片所属于的)分布的均值,是基于上一张图片的均值轻微偏移后得到的

下一步的(加了噪声之后的图片所属于的)分布的方差,是预先定义好的

这里可以应用一个高斯分布的性质:

  • 从这样一个分布 �(��|��−1)中采样得到样本 ��
  • 先从正太分布 �(0,1) 中采样一个 � ,并计算 1−����−1+��� 的结果

这两个过程是等效的。

注意: 不同 � 的 �� 是不同的,类似于训练中学习率衰减的预先定义方法, � 的衰减方法也多种多样(Linear, cosine等等)

因此,如果我们合理的设置schedule,那么从 �0 到 ��就完成了从真实图片到纯噪声的转换过程。

那么问题的核心就是如何得到 �(��|��−1) 的逆过程 �(��−1|��) 。而真实的 �(��−1|��) 是不可能求出来的,所以我们使用神经网络去拟合这一分布。我们使用一个具有参数 � 的神经网络去计算 ��(��−1|��) 。

总结一下:我们需要一个神经网络去表示反向过程的条件概率分布。我们假设反向的条件概率分布也是高斯分布,且高斯分布实际上只有两个参数:均值 �� 和方差 Σ� ,那么神经网络需要计算的实际上是

��(��−1|��)=�(��−1;��(��,�),Σ�(��,�))

这里的 � 是加噪(或者去噪)的步数,同时也是噪声的程度。

注意,神经网络需要去学习均值和方差,但是,DDPM的作者决定保持方差固定,并且让神经网络只去学习均值。DDPM原文中的描述是:

First, we set Σ�(��,�)=��2�  to untrained time dependent constants. Experimentally, both ��2=��  and ��2=�~� ​ (see paper) had similar results.

(在后面的improved diffusion models 中,方差也是由学习得到的。)

总之,我们定义这么一个过程:给一张图片逐步加噪声直到变成纯粹的噪声,然后对噪声进行去噪得到真实的图片。所谓的扩散模型就是让神经网络学习这个去除噪声的方法。
所谓的加噪声,就是基于稍微干净的图片计算一个(多维)高斯分布(每个像素点都有一个高斯分布,且均值就是这个像素点的值,方差是预先定义的 � ),然后从这个多维分布中抽样一个数据出来,这个数据就是加噪之后的结果。显然,如果方差非常非常小,那么每个抽样得到的像素点就和原本的像素点的值非常接近,也就是加了一个非常非常小的噪声。如果方差比较大,那么抽样结果就会和原本的结果差距较大。
去噪声也是同理,我们基于稍微噪声的图片 �� 计算一个条件分布,我们希望从这个分布中抽样得到的是相比于 �� 更加接近真实图片的稍微干净的图片。我们假设这样的条件分布是存在的,并且也是个高斯分布,那么我们只需要知道均值和方差就可以了。问题是这个均值和方差是无法直接计算的,所以用神经网络去学习近似这样一个高斯分布。

网络如何训练?

(详细数学推导可以参考其他文章,这里不贴了)

我们最终要训练的实际上是一个噪声预测器

神经网络输出的噪声是 ��(��,�) ,而真实的噪声 � 取自于正态分布 �(0,1)

要优化的距离就是 ||�−��(��,�)||2=||�−��(�¯��0+(1−�¯�)�,�)||2

也就是说:

  • 我们接受一个随机的样本 �0 ,这一样本来自一个未知的,且可能非常复杂的真实数据分布 �(�0)
  • 我们随机从 1 到 � 采样一个 �
  • 我们从高斯分布采样一些噪声并且施加在输入上
  • 网络从被影响过后的噪声图片学习其被施加了的噪声

训练一个怎样的网络?

网络的输入是一张有噪声的图片,输出是预测的噪声(与图片相同的shape)

U-Net:输入与输出有相同的shape

代码部分

Network helpers

定义几个基本的Function,

def exists(x):
    return x is not None

# 有val时返回val,val为None时返回d
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 残差模块,将输入加到输出上
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# 上采样(反卷积)
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

# 下采样
def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

Positional embeddings

类似于Transformer的positional embedding,为了让网络知道当前处理的是一系列去噪过程中的哪一个step,我们需要将步数 � 也编码并传入网络之中。DDPM采用正弦位置编码(Sinusoidal Positional Embeddings)

这一方法的输入是shape为 (batch_size, 1) 的 tensor,也就是batch中每一个sample所处的 � ,并将这个tensor转换为shape为 (batch_size, dim) 的 tensor。这个tensor会被加到每一个残差模块中。

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

总之就是将 � 编码为embedding,和原本的输入一起进入网络,让网络“知道”当前的输入属于哪个step。

ResNet/ConvNeXT block

接下来就是正式的U-Net的实现。ConvNeXT 似乎可以(并不一定)取得比ResNet更好的效果。

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    """Deep Residual Learning for Image Recognition"""
    
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, "b c -> b c 1 1") + h

        h = self.block2(h)
        return h + self.res_conv(x)
    
class ConvNextBlock(nn.Module):
    """A ConvNet for the 2020s"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        Get an email address at self.net. It's ad-free, reliable email that's based on your own name | self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")

        h = Get an email address at self.net. It's ad-free, reliable email that's based on your own name | self.net(h)
        return h + self.res_conv(x)

Attention module

两种attention模块,一个是常规的 multi-head self-attention,一个是 linear attention variant

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

Group normalization

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

Conditional U-Net

现在,我们已经定义了所有的组件,接下来就是定义完整的网络了。

网络需要的输入:噪声图片的batch+这些图片各自的 �

输出:预测每个图片上所添加的噪声

写得更正式一点的话:

  • Input:a batch of noisy images of shape ( batch_size, num_channels, h, w ) and a batch of steps of shape ( batch_size, 1 )
  • output: a tensor of shape ( batch_size, num_channels, h, w )

具体的网络结构:

  1. 首先,输入通过一个卷积层,同时计算step � 所对应得embedding
  2. 通过一系列的下采样stage,每个stage都包含:2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + downsample operation
  3. 在网络中间,应用一个带attention的ResNet或者ConvNeXT
  4. 通过一系列的上采样stage,每个stage都包含:2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + upsample operation
  5. 最终,通过一个ResNet/ConvNeXT blocl和一个卷积层。
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        x = self.init_conv(x)
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

定义前向扩散过程

DDPM中使用linear schedule定义 � 。后续的研究指出使用cosine schedule可能会有更好的效果。

接下来是一些简单的对于 � schedule 的定义,从当中选一个使用即可。

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

DDPM中用的是第二种linear,我们首先也尝试使用这一种。

将 � 设置为200,并将每个 � 下的各种参数提前计算好

timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

用一个实例来说明前向加噪过程

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

这里对图片先进行一些简单的处理

这里定义一个transform,接受一张图片( [ 0, 255] ),输出一个tensor( [ -1, 1] )

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # turn into Numpy array of shape HWC, divide by 255
    Lambda(lambda t: (t * 2) - 1),
])

x_start = transform(image).unsqueeze(0)
x_start.shape  # 输出的结果是 torch.Size([1, 3, 128, 128])

同时也定义出逆变换

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

验证一下

reverse_transform(x_start.squeeze())

准备齐全,接下来就可以定义正向扩散过程了

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)
  # turn back into PIL image
  noisy_image = reverse_transform(x_noisy.squeeze())
  return noisy_image

定义采样多个 � 并且展示图片的方法

import matplotlib.pyplot as plt

# use seed for reproducability
torch.manual_seed(0)

# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

跑一下看看效果

plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

损失函数

注意理解这里的顺序

  1. 先采样噪声
  2. 用这个噪声去加噪图片
  3. 根据加噪了的图片去预测第一步中采样的噪声
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    # 先采样噪声
    if noise is None:
        noise = torch.randn_like(x_start)
    
    # 用采样得到的噪声去加噪图片
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)
    
    # 根据加噪了的图片去预测采样的噪声
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

定义数据集 PyTorch Dataset 和 DataLoader

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128


from torchvision import transforms
from torch.utils.data import DataLoader

transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())

总之,我们使用现有的数据集构造了一个简单的 DataLoader,每个batch由128张 normalize 过的 image 组成。

采样

采样过程发生在反向去噪时。对于一张纯噪声,扩散模型一步步地去除噪声最终得到真实图片,采样事实上就是定义的去除噪声这一行为。 观察上图中第四行, �−1 步的图片是由 � 步的图片减去一个噪声得到的,只不过这个噪声是由 � 网络拟合出来,并且 rescale 过的而已。 这里要注意第四行式子的最后一项,采样时每一步也都会加上一个从正态分布采样的纯噪声。

理想情况下,最终我们会得到一张看起来像是从真实数据分布中采样得到的图片。

我们将上述过程写成代码,这里的代码相比于原论文的代码略有简化但是效果接近。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

模型训练

 from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

接下来实例化模型

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

开始训练!

from torchvision.utils import save_image

epochs = 5

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

Inference

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

或者也可以生成动图

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

动图封面

编辑于 2024-01-12 09:03・IP 属地日本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1514292.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SSH框架的电子商城的设计

目录 摘要 2 Abstract 3 第一章 前言 4 1.1 课题研究意义 4 1.2 国外研究现状 4 方案一&#xff1a; 4 方案二&#xff1a; 4 方案三&#xff1a; 5 1.3 课题研究内容 5 &#xff08;1&#xff09;商品浏览模块 5 &#xff08;2&#xff09;订单管理模块 5 &#xff08;3&…

基于 llvm 3.4 的C++重构工具

还未测试&#xff0c;存个档&#xff0c;未完待续 1,源码 Makefile LLVM_CONFIG?llvm-configifndef VERBOSE QUIET: endifSRC_DIR?$(PWD) LDFLAGS$(shell $(LLVM_CONFIG) --ldflags) COMMON_FLAGS-Wall -Wextra CXXFLAGS$(COMMON_FLAGS) $(shell $(LLVM_CONFIG) --cxxflags…

【机器学习300问】36、什么是集成学习?

一、什么是集成学习&#xff1f; &#xff08;1&#xff09;它的出现是为了解决什么问题&#xff1f; 提高准确性&#xff1a;单个模型可能对某些数据敏感或者有概念偏见&#xff0c;而集成多个模型可以提高预测的准确性。让模型变稳定&#xff1a;一些模型&#xff0c;如决策…

【JavaScript】数据类型转换 ① ( 隐式转换 和 显式转换 | 常用的 数据类型转换 | 转为 字符串类型 方法 )

文章目录 一、 JavaScript 数据类型转换1、数据类型转换2、隐式转换 和 显式转换3、常用的 数据类型转换4、转为 字符串类型 方法 一、 JavaScript 数据类型转换 1、数据类型转换 在 网页端 使用 HTML 表单 和 浏览器输入框 prompt 函数 , 接收的数据 是 字符串类型 变量 , 该…

[密码学]OpenSSL实践篇

背景 最近在写Android abl阶段fastboot工具&#xff0c;需要我在Android代码中实现一些鉴权加解密相关的fastboot命令&#xff0c;里面用到了OpenSSL。我们先来实践一下OpenSSL在Linux系统中的指令。 OpenSSL官方网站&#xff1a;OpenSSL 中文手册 | OpenSSL 中文网 1. 查看…

【变量提升】关于JavaScript变量提升的理解,它导致了什么问题?

&#x1f601; 作者简介&#xff1a;一名大四的学生&#xff0c;致力学习前端开发技术 ⭐️个人主页&#xff1a;夜宵饽饽的主页 ❔ 系列专栏&#xff1a;JavaScript小贴士 &#x1f450;学习格言&#xff1a;成功不是终点&#xff0c;失败也并非末日&#xff0c;最重要的是继续…

带你摸透C语言相关内存函数

c语言中的小小白-CSDN博客c语言中的小小白关注算法,c,c语言,贪心算法,链表,mysql,动态规划,后端,线性回归,数据结构,排序算法领域.https://blog.csdn.net/bhbcdxb123?spm1001.2014.3001.5343 给大家分享一句我很喜欢我话&#xff1a; 知不足而奋进&#xff0c;望远山而前行&am…

vue2中如何实现添加一个空标签的效果,</>

前言&#xff1a; 众所周知&#xff0c;vue3突破了每一个vue文件中只能有一个根标签的限制&#xff0c;但是我们还有很多项目都是vue2的项目&#xff0c;如果让vue2中实现一个类似</>的效果呢&#xff0c;像react的16.2.0的版本之后&#xff0c;可以直接用<></&…

电脑音频显示红叉怎么办?这里提供四种方法

前言 如果你在系统托盘中看到音量图标上的红色X,则表示你无法使用音频设备。即使音频设备未被禁用,当你运行音频设备疑难解答时,仍然会看到此错误。 你的电脑将显示已安装高清音频设备,但当你将鼠标悬停在图标上时,它将显示未安装音频输出设备。这是一个非常奇怪的问题,…

C语言 指针(2)

文章目录 前言 一、数组名的理解 二、指针访问数组 三、一维数组传参的本质 四、冒泡排序 五、二级指针 六、指针数组 七、指针数组模拟二维数组 总结 前言 我们今天继续来了解指针的内容&#xff0c;让大家更加细致的理解到数组的含义 一、数组名的理解 之前我们在学习指针时…

王道OnlineJudge 14

题目 二叉树层次建树就是一层一层的建树&#xff0c;从左到右。随着纵向层次的深入&#xff0c;结点的数量变化规律为&#xff1a;1→2→4→8→16→32。 先画图&#xff0c;然后看图可闭眼写代码 右边为辅助队列&#xff0c;有多少个二叉树结点&#xff0c;就有多少个辅助队…

一个简单的Web UI自动化测试框架Java实现

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

口才教育:如何提升沟通技巧与表达能力

口才教育&#xff1a;如何提升沟通技巧与表达能力 口才教育在现代社会中扮演着越来越重要的角色。拥有良好的沟通技巧和表达能力对于个人的职业发展、人际交往乃至生活质量都至关重要。因此&#xff0c;如何有效地提升口才能力成为了许多人关注的焦点。本文将探讨口才教育的重…

java-可变参数

可变参数是什么&#xff1f; 可变参数就是指传入的参数个数是可变的&#xff0c;不是固定的 为什么要可变参数&#xff1f; 当我们要传入大量的形参时&#xff0c;我们就可以用到可变参数了 定义格式 数据类型...变量名; 例如int ...a; 可变参数的细节&#xff1a; &…

Vue2(五):收集表单数据、过滤器、内置指令和自定义指令

一、回顾 总结Vue监视数据 1.Vue监视数据的原理&#xff1a; 1.vue会监视data中所有层次的数据。 2.如何监测对象中的数据?通过setter实现监视&#xff0c;且要在new Vue时就传入要监测的数据。(1&#xff09;.对象中后追加的属性&#xff0c;Vue默认不做响应式处理(2&#…

java拷贝数组

package com.mohuanan.exercise;public class Exercise {public static void main(String[] args) {int[] arr {1, 2, 3, 4, 5, 6, 7, 8, 8}; //格式化快捷键 CTRL 加 Alt 加 L键// F1截图 F3贴图//调用 copyOfRangeint[] ints copyOfRange(arr, 3, 7);for (int i 0; i &l…

学习网络编程No.13【网络层IP协议理解】

引言&#xff1a; 北京时间&#xff1a;2024/3/5/8:38&#xff0c;早六加早八又是生不如死的一天&#xff0c;不过好在喝两口热水提口气手指还能跳动。当然起关键性作用的还是思维跟上了课程脑袋较为清晰&#xff0c;假如是听学校老师在哪里磨过来磨过去&#xff0c;那我倒头就…

三、HarmonyOS 应用开发入门之运行Hello World

目录 1、课程对象 1.1、有移动端开发经验 1.2、无移动端开发经验 1.3、对 HarmonyOS 感兴趣 2、DevEco Studio 的使用 2.1、DevEco Studio 的关键特性 智能代码编辑 低代码开发 多段双向实时预览 多端模拟仿真 2.2、安装配置 DevEco Studio 2.2.1、官网开发工具下载地…

Vue-Vben-Admin:中大型项目后台解决方案及如何实现页面反向传值

Vue-Vben-Admin&#xff1a;中大型项目后台解决方案及如何实现页面反向传值 摘要&#xff1a; Vue-Vben-Admin是一个基于Vue3.0、Vite、Ant-Design-Vue和TypeScript的开源项目&#xff0c;旨在为开发中大型项目提供一站式的解决方案。它涵盖了组件封装、实用工具、钩子函数、动…

Arduino ESP8266 SSD1306 硬件I2C+LittleFS存储GBK字库实现中文显示

Arduino ESP8266 SSD1306 硬件I2C+LittleFS存储GBK字库实现中文显示 📍相关篇《Arduino esp8266 软件I2C SSD1306 +LittleFS存储GBK字库实现中文显示》 🌼显示效果: ✨将部分函数重构,和上面相关篇的软件I2C通讯相关接口函数移植过来,除了汉字显示采用自己写的API函数外…