扩散模型基础

news2025/1/20 3:57:06

扩散模型发展至今日,早已成为各大机器学习顶会的香饽饽。本文简记扩散模型入门相关代码,主要参阅李忻玮、苏步升等人所编著的《扩散模型从原理到实战》

文章目录

    • 1. 简单去噪模型
      • 1.1 简单噪声可视化
      • 1.2 去噪模型
      • 1.3 小结
    • 2 扩散模型
      • 2.1 采样过程
      • 2.2 上科技
        • 2.2.1 升级模型表征模块
        • 2.2.2 升级加噪过程
        • 2.2.3 改变预测目标
      • 2.3 小结

1. 简单去噪模型

这一小节中,我们将尝试设计一个去噪模型。首先,我们将展示如何给图片加噪。然后,我们将训练一个模型,对加噪图片进行去噪。

1.1 简单噪声可视化

第一步,导入环境

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

第二步,获取训练集样本。此处使用 MNIST 数据集

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # download=True for the first time
x, y = next(iter(train_dataloader))
print(f'Input shape: {x.shape}')
print(f'Label: {y}')

第三步,设计简单噪声函数

def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount

该函数以原始图像为输入,生成与该输入同样维度的随机噪声,并根据参数 amount 将噪声和原始图像混合。
最后,我们在 (0, 1) 之间采样 8 个 amount,看看不同加噪程度下的图片。
集合上面所有代码,如下:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print(f'Input shape: {x.shape}')
print(f'Label: {y}')

fig, axs = plt.subplots(2, 1, figsize=(9, 4))
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys')

plt.savefig('visualize_corrupt.png', dpi=400)

结果如下:
请添加图片描述

1.2 去噪模型

所谓去噪模型,就是给模型加噪的图片,让模型直接预测真实图片。
我们搭建一个最简单的CV模型。

class BasicUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
        ])
        self.act = nn.ReLU(inplace=True)
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            if i < 2:
                h.append(x)
                x = self.downscale(x)

        for i, l in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x += h.pop()
            x = self.act(l(x))
        return x

初始化该模型,并查看参数:

net = BasicUNet()
print(sum([p.numel() for p in net.parameters()]))

输出为30w,可见该模型很小。
下面我们训练该模型:

batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(noisy_x)
        loss = loss_fn(pred, x)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()

模型训练完以后,我们跟上面加噪过程保持一致,分别设计 8 个不同程度损坏的照片,并让模型预测真实照片

x, y = next(iter(train_dataloader))
x = x[:8]


fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')

amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)


with torch.no_grad():
    preds = net(noised_x.to(device)).detach().cpu()

axs[2].set_title('Network prediction')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys')


axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')

本节完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount


class BasicUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
        ])
        self.act = nn.ReLU(inplace=True)
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            if i < 2:
                h.append(x)
                x = self.downscale(x)

        for i, l in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x += h.pop()
            x = self.act(l(x))
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())

batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(noisy_x)
        loss = loss_fn(pred, x)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()


# simple check noise
x, y = next(iter(train_dataloader))
x = x[:8]


fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')

amount = torch.linspace(0, 1, x.shape[0])
noised_x = corrupt(x, amount)


with torch.no_grad():
    preds = net(noised_x.to(device)).detach().cpu()

axs[2].set_title('Network prediction')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys')


axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')


plt.savefig('test_v0.png', dpi=400)

结果如下:
请添加图片描述

1.3 小结

本节中,我们可视化了简单加噪过程,并搭建了简单去噪模型。在经过简单训练后,我们的模型可以成功识别加噪程度较低的图片,令人欣慰。但如何从虚无(完全随机照片)中,生成一张可辨别的图片呢?或许我们可以将预测真实图片的过程设计为迭代过程,一步步去噪(上图从右往左),这就是扩散模型的核心。

2 扩散模型

2.1 采样过程

在得到 1.2 节去噪结果后,一个很自然的想法是,我们可以将去噪过程分成多步,每次预测结果和输入结果进行叠合,如此多步迭代后,期望能变成最左边清晰的图像。该多步去噪的过程叫做采样过程。

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount

class BasicUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
        ])
        self.act = nn.ReLU(inplace=True)
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            if i < 2:
                h.append(x)
                x = self.downscale(x)

        for i, l in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x += h.pop()
            x = self.act(l(x))
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())



batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(noisy_x)
        loss = loss_fn(pred, x)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()


n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device)
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
    with torch.no_grad():
        pred = net(x)
    pred_output_history.append(pred.detach().cpu())

    min_factor = 1/(n_steps - i)
    x = x*(1-min_factor)+pred*min_factor
    step_history.append(x.detach().cpu())
print(len(step_history))

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)

axs[0, 0].set_title('x (model input)')
axs[0, 1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i], cmap='Greys')[0].clip(0, 1), cmap='Greys')

plt.savefig('test_real_v1.png', dpi=400)

结果如下:
请添加图片描述
左侧为每次迭代时,模型的输入。右侧为基于加噪数据预测真实数据的结果。
可以看到,虽然从虚无中,模型一步步迭代,形成了相对清晰的结果,但这些结果不像是数字。这说明我们的模型还有改进的空间。一个最简单的方法就是增加迭代的步数,期望能有所改善。
我们将采样步数增加到40(并微调噪声代码),采样过程更改为:

n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
    noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))
    with torch.no_grad():
        pred = net(x)
    min_factor = 1/(n_steps - i)
    x = x*(1-min_factor) + pred*min_factor
fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')

下图为迭代40次结果,在 64 个样例中,已经可以依稀看到相对清晰的结果。
请添加图片描述

2.2 上科技

在健美运动中,只靠纯饮食是很难长出超乎常人的大块肌肉的。对此,健美圈大佬会摄入大量睾酮等激素,刺激身体发育。这种走捷径的方法常被叫做“上科技”。此处不是想教各位健身知识,而是借喻 AI 模型设计中走捷径的调包大法。

例如,上述结果很不理想,怎么快速提高模型表现呢?
我们可以 import 进来现成的模型嘛,大家是怎么训练的,咱也照做,保证短期内快速提高。照这个思路,我们需要升级以下几处:

2.2.1 升级模型表征模块

前面说了,我们的 BasicUNet 只有 30w 参数,只是一个玩具模型。我们可以直接 import 一个成熟的业内常用的模型,例如,UNet2DModel

from diffusers import UNet2DModel

net = UNet2DModel(
    sample_size=28,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        'DownBlock2D',
        "AttnDownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D"
    )
)

总代码:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())

a=1

batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3
net = UNet2DModel(
    sample_size=28,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        'DownBlock2D',
        "AttnDownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D"
    )
)

print(sum([p.numel() for p in net.parameters()]))
print(net)



net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(sample=noisy_x, timestep=0).sample
        loss = loss_fn(pred, x)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()


n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
    noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))
    with torch.no_grad():
        pred = net(x)
    min_factor = 1/(n_steps - i)
    x = x*(1-min_factor) + pred*min_factor
fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')

plt.savefig('test_v2.png', dpi=400)

效果如下:

请添加图片描述

2.2.2 升级加噪过程

其实有很多现成的加噪函数可以直接调。可以实现,前期加快点,后期加慢点的操作。

我们可以将加噪系数进行可视化:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r'${\sqrt{\bar{\alpha}_t}}$')
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu())**0.5, label=r'${1-\sqrt{\bar{\alpha}_t}}$')
plt.legend(fontsize="x-large")
plt.savefig('scheduler.png', dpi=400)

请添加图片描述

可视化加噪后的图片:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
x, y = next(iter(train_dataloader))
x = x.to(device)[:8]
x = x*2. - 1.
print(f'X shape {x.shape}')
fig, axs = plt.subplots(3, 1, figsize=(16, 10))

axs[0].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0], cmap='Greys')
axs[0].set_title('clean X')

timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(x)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
print(f'Noisy X shape {noisy_x.shape}')

axs[1].imshow(torchvision.utils.make_grid(noisy_x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1))')

axs[2].imshow(torchvision.utils.make_grid(noisy_x.detach().cpu(), nrow=8)[0], cmap='Greys')
axs[2].set_title('Noisy X')

plt.savefig('visualize_noise.png', dpi=400)

请添加图片描述

需要注意的是,很多时候的随机噪声是以 0 为期望,1 为方差的高斯分布。所以需要将原先的灰度范围(0~1之间)缩放到误差所在区间,即我们的 clean x,需要盖上一层灰(上图最下层最左侧)

2.2.3 改变预测目标

之前,所有的模型都是 去噪 模型。这些模型接收一个加噪的图片,输出真实图片。但实践表明,接收一个加噪图片,输出噪声,这样的预测目标有助于提升模型表现。
之前的推理过程:
加噪图片->模型->真实图片
现在是:
加噪图片->模型->噪声,
加噪图片-噪声->真实图片
相当于多了一步。至于为什么要多着一步,问就是这样精度高,大家都这样做。
下面将上述三条改进措施进行集成:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt


def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    return x*(1-amount) + noise*amount


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
dataset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())

a=1

batch_size = 100
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs = 3
net = UNet2DModel(
    sample_size=28,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        'DownBlock2D',
        "AttnDownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D"
    )
)

print(sum([p.numel() for p in net.parameters()]))
print(net)



net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []


noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        x = x * 2. - 1.

        timesteps = torch.linspace(0, 999, batch_size).long().to(device)
        noise = torch.randn_like(x).to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # noise_amount = torch.rand(x.shape[0]).to(device)
        # noisy_x = corrupt(x, noise_amount)
        pred = net(sample=noisy_x, timestep=timesteps).sample
        loss = loss_fn(pred, noise)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss: 05f}')
plt.plot(losses)
plt.show()
plt.close()

torch.save(net, f='model.pt')



# net = torch.load('model.pt')
#
# x, y = next(iter(train_dataloader))
# x = x[:100].to(device)
# x = x * 2. - 1.
#
# timesteps = torch.linspace(0, 999, 100).long().to(device)
# noise = torch.randn_like(x).to(device)
# noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
# with torch.no_grad():
#     pred_noise = net(noisy_x, timesteps).sample
#     real_pred = noisy_x - pred_noise
#     real_pred = (real_pred + 1)/2
#
# fig, ax = plt.subplots(1, 1, figsize=(12,12))
# ax.imshow(torchvision.utils.make_grid(real_pred.detach().cpu(), nrow=10)[0].clip(0, 1), cmap='Greys')
#
# plt.savefig('test_final_real.png', dpi=400)


net = torch.load('model.pt')
x = torch.rand(100, 1, 28, 28).to(device)
x = x * 2. - 1.

timesteps = torch.linspace(0, 999, 100).long().to(device)
noise = torch.randn_like(x).to(device)
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
with torch.no_grad():
    pred_noise = net(noisy_x, timesteps).sample
    real_pred = noisy_x - pred_noise
    real_pred = (real_pred + 1)/2

fig, ax = plt.subplots(1, 1, figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(real_pred.detach().cpu(), nrow=10)[0].clip(0, 1), cmap='Greys')

plt.savefig('test_final_random.png', dpi=400)

此处采样过程设计了两种。
第一种对应真实的情况:

  • sample 100 张图片
  • 模拟真实的训练环境,先对输入进行伸缩,映射到期望为 0 ,方差为 1 的区域
  • 对这些图片加噪,噪声比例随序列号增加
  • 使用模型预测噪声,并将加噪后的图片减去该噪声
  • 将照片映射回原来的 (0, 1) 区间

结果如下:
请添加图片描述

可以看到,在前排,加噪比例较低的情况下,模型能一步作出很好的预测。但后面图片加噪比例变高以后,模型预测质量变差,这凸显了迭代的意义。

此外,模型中的 timesteps 决定了模型在不同阶段对噪声的预测置信度。我们在设计模型时希望模型知晓自己处于迭代的什么阶段。刚开始时,模型更多的,可能只是复原一些背景噪声,随后直接预测具有语义信息的数字,最后对数字进行描边,逐渐精细化。

为验证这一点,我们可以模型同样随机的噪声,并告知模型所处的 timesteps,初始随机噪声减去模型预测噪声,即可看到模型真正想还原的语义:
请添加图片描述

可以看到:

  • timesteps 接近去噪初始阶段的右下角只是一些背景信息
  • 中间位置含有部分语义信息
  • timesteps 接近去噪最后阶段的左上角,模型预测显然集中于中心,进行精修,而不是右下角的普遍噪声。

2.3 小结

本节中,我们介绍了常用扩散模型相比简单去噪模型的主要改进之处。其中,迭代采样是最核心的思想,通过迭代采样,我们能够从虚无的噪声背景中逐渐还原出具有语义信息的图片。随后,我们通过调包,对去噪模型进行了全方位升级(又叫上科技)。最后,我们对上述改进之处进行了综合,并通过两个案例向大家介绍扩散模型在不同加噪阶段的还原能力及其蕴含的语义信息。

这篇博客写得囫囵吞枣,只是希望大家能快速上手,感受扩散模型的魅力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1348062.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线性代数基础知识

计算机视觉一些算法中常会用到线性代数的一些知识&#xff0c;为了便于理解和快速回忆&#xff0c;博主这边对常用的一些知识点做下整理&#xff0c;主要来源于如下这本书籍。 1. 矩阵不仅仅是数字排列而已&#xff0c;不然也不会有那么大精力研究它。其可以表示一种映射 关于…

《PCI Express体系结构导读》随记 —— 第I篇 第2章 PCI总线的桥与配置(1)

前言中曾提到&#xff1a;本章重点介绍PCI桥。 在PCI体系结构中含有两类桥&#xff1a;一类是HOST主桥&#xff1b;另一类是PCI桥。在每一个PCI设备中&#xff08;包括PCI桥&#xff09;&#xff0c;都含有一个配置空间。这个配置空间由HOST主桥管理&#xff0c;而PCI桥可以转…

CycleGAN 是如何工作的?

一、说明 CycleGAN即循环对抗网络&#xff0c;是图像翻译成图像的模型&#xff1b;是Pix2Pix模型的扩展&#xff0c;区别在于&#xff0c;Pix2Pix模型需要输入图像和目标图像成对给出训练&#xff0c;CycleGAN则不需要&#xff0c;例如&#xff1a;从 SAR 生成 RGB 图像、从 RG…

使用ASP.NET MiniAPI 调试未匹配请求路径

本文将介绍如何在使用ASP.NET MiniAPI时调试未匹配到的请求路径。我们将详细讨论使用MapFallback方法、中间件等工具来解决此类问题。 1. 引言 ASP.NET MiniAPI是一个轻量级的Web API框架&#xff0c;它可以让我们快速地构建和部署RESTful服务。然而&#xff0c;在开发过程中如…

S7-1200 PLC回原方式详细解读(SCL代码)

S7-1200PLC脉冲轴位置控制功能块的介绍,可以查看下面链接文章: https://rxxw-control.blog.csdn.net/article/details/135299302https://rxxw-control.blog.csdn.net/article/details/135299302脉冲轴工艺对象组态设置介绍 https://rxxw-control.blog.csdn.net/article/det…

算法(3)——二分查找

一、什么是二分查找 二分查找也称折半查找&#xff0c;是在一组有序(升序/降序)的数据中查找一个元素&#xff0c;它是一种效率较高的查找方法。 二、二分查找的原理 1、查找的目标数据元素必须是有序的。没有顺序的数据&#xff0c;二分法就失去意义。 2、数据元素通常是数值…

推荐系统中 排序策略 CTR 动态加权平均法

CTR&#xff08;Click-Through Rate&#xff09;动态加权平均法是一种用于计算广告点击率的方法&#xff0c;其中每个点击率被赋予一个权重&#xff0c;这个权重可以随着时间、事件或其他因素而动态调整。这种方法旨在更灵活地反映广告点击率的变化&#xff0c;使得最近的数据更…

HTML与CSS

目录 1、HTML简介 2、CSS简介 2.1选择器 2.1.1标签选择器 2.1.2类选择器 2.1.3层级选择器(后代选择器) 2.1.4id选择器 2.1.5组选择器 2.1.6伪类选择器 2.2样式属性 2.2.1布局常用样式属性 2.2.2文本常用样式属性 1、HTML简介 超文本标记语言HTML是一种标记语言&…

【GoLang】Go语言几种标准库介绍(三)

文章目录 前言几种库debug 库 (各种调试文件格式访问及调试功能)相关的包和工具&#xff1a;示例 encoding (常见算法如 JSON、XML、Base64 等)常用的子包和其主要功能&#xff1a;示例 flag(命令行解析)关键概念&#xff1a;示例示例执行 总结专栏集锦写在最后 前言 上一篇&a…

Leetcode 剑指 Offer II 059. 数据流中的第 K 大元素

题目难度: 简单 原题链接 今天继续更新 Leetcode 的剑指 Offer&#xff08;专项突击版&#xff09;系列, 大家在公众号 算法精选 里回复 剑指offer2 就能看到该系列当前连载的所有文章了, 记得关注哦~ 题目描述 设计一个找到数据流中第 k 大元素的类&#xff08;class&#xf…

使用keepalived时虚拟IP漂移注意事项

什么是Keepalived服务 keepalived是一个开源的软件项目&#xff0c;用于实现高可用性&#xff08;HA&#xff09;的网络服务器负载均衡和故障转移。它允许将多台服务器组合在一起&#xff0c;形成一个虚拟服务器集群&#xff0c;实现负载均衡和故障转移。 keepalived的核心功…

【力扣100】46.全排列

添加链接描述 class Solution:def permute(self, nums: List[int]) -> List[List[int]]:# 思路是使用回溯if not nums:return []def dfs(path,depth,visited,res):# 出递归的条件是当当前的深度已经和nums的长度一样了&#xff0c;把path加入数组&#xff0c;然后出递归if …

echarts 二分图布局_力向导图_关系图

Echarts 常用各类图表模板配置 注意&#xff1a; 这里主要就是基于各类图表&#xff0c;更多的使用 Echarts 的各类配置项&#xff1b; 以下代码都可以复制到 Echarts 官网&#xff0c;直接预览&#xff1b; 图标模板目录 Echarts 常用各类图表模板配置一、力向导图(二分图布局…

摩尔线程S80对于软件的支持

摩尔线程对软件的支持 时间&#xff1a;2024年1月1日 显卡型号&#xff1a;MTT S80 主板型号&#xff1a;七彩虹 igame z590 火神 V20 CPU&#xff1a; intel core i5 10400f 内存&#xff1a; 海盗船3600 16*2 存储&#xff1a; 致态1Tb nvme 显卡的驱动是最新的。 游戏 S…

从 MySQL 的事务 到 锁机制 再到 MVCC

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、事务 1.1 含义 1.2 ACID 二、锁机制 2.1 锁分类 2.2 隔离级别 三、MVCC 3.1 介绍 3.2 隔离级别 3.3 原理 四、总结 前…

关于Python里xlwings库对Excel表格的操作(二十五)

这篇小笔记主要记录如何【如何使用xlwings库的“Chart”类创建一个新图表】。 前面的小笔记已整理成目录&#xff0c;可点链接去目录寻找所需更方便。 【目录部分内容如下】【点击此处可进入目录】 &#xff08;1&#xff09;如何安装导入xlwings库&#xff1b; &#xff08;2…

LeetCode刷题--- 不同路径 III

个人主页&#xff1a;元清加油_【C】,【C语言】,【数据结构与算法】-CSDN博客 个人专栏 力扣递归算法题 http://t.csdnimg.cn/yUl2I 【C】 ​​​​​​http://t.csdnimg.cn/6AbpV 数据结构与算法 ​​​http://t.csdnimg.cn/hKh2l 前言&#xff1a;这个专栏主要讲述递…

二叉树详解(深度优先遍历、前序,中序,后序、广度优先遍历、二叉树所有节点的个数、叶节点的个数)

目录 一、树概念及结构(了解) 1.1树的概念 1.2树的表示 二、二叉树概念及结构 2.1概念 2.2现实中的二叉树&#xff1a; 2.3数据结构中的二叉树&#xff1a; 2.4特殊的二叉树&#xff1a; 2.5 二叉树的存储结构 2.51 顺序存储&#xff1a; 2.5.2 链式存储&…

Apache Flink连载(二十三):Flink HA - Flink基于Yarn HA

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录 1. Yarn HA配置 ​​​​…

婴幼儿家庭护理百科知识,宝宝健康成长育儿实用课程

一、教程描述 本套教程由具有丰富育儿经验的多名专家精心打造而成&#xff0c;也是专门提供给准爸妈们学习的实用课程&#xff0c;可以解决宝宝的日常护理、日常喂养、饮食调理、疾病防治、意外护理等多方面问题。课程不仅可以丰富你的育儿知识&#xff0c;而且能够让你把这些…