文章目录
- 1、UNet网络结构
- 1.1 residual网络和attention网络的细节
- 1.2 t 的作用
- 1.3 DDPM 中的 Positional Embedding 的使用
- 1.4 DDPM 中的 Positional Embedding 代码
- 1.5 residual block
- 1.6 attention block
- 1.7 UNet结构
- 2、命令行参数解析
- 3、数据的获取与预处理
- 4、模型的训练框架
- 参考:
1、UNet网络结构
UNet网络的总体框架如下,右边是UNet网络的整体框架,左边是residual网络和attention网络,
下面是UNet网络的详解结构图,左边进行有规律地残差、下采样、attention,右边也是有规律地残差、上采样、attention,相关的代码在图中给出,
1.1 residual网络和attention网络的细节
熟悉CNN的同学应该能看懂下图中的大部分过程。其中的 t 是时间从0到1000的随机值,假如是888,经过Positional Embedding输出长度是128的向量,下面再经过全连接层和silu层等,下面会详细讲解Positional Embedding、residual网络和attention网络,
1.2 t 的作用
1、和原图像一起,计算出 t 时刻的图像 x t = 1 − α t ‾ ϵ + α t ‾ x 0 x_t=\sqrt{1-\overline{\alpha_t}}\epsilon+\sqrt{\overline{\alpha_t}}x_0 xt=1−αtϵ+αtx0
2、将 t 进行编码,编码后,加到模型中,使模型学习到当前在哪个时刻
1.3 DDPM 中的 Positional Embedding 的使用
左图是Transformer的Positional Embedding,行索引代表第几个单词,列索引代表每个单词的特征向量,右图是DDPM的Positional Embedding,DDPM的Positional Embedding和Transformer的Positional Embedding的区别是DDPM的Positional Embedding并不是给每个词位置编码的,只需要在1000行中随机取出一行就可以了;另一个区别是DDPM的Positional Embedding并没有按照奇数位和偶数位进行拼接,而是按照前后的sin和cos进行拼接的,虽然拼接方式不同,但是最终的效果是一样的。如下图所示,
位置编码只要能保证每一行的唯一性,以及每一行和其他行的关系性就可以了。
1.4 DDPM 中的 Positional Embedding 代码
代码:
class PositionalEmbedding(nn.Module):
__doc__ = r"""..."""
def init (self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim # 特征向量
self.scale = scale # 正弦函数和余弦函数的周期不做调整
def forward(self, x): # x:表示t,从0-1000中随机出来的一个数值,因为设置batch-size=2,所以假设x:tensor([645,958])
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * - emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
代码解释:
下图中的 e m b 2 × 64 emb_{2\times64} emb2×64中的2表示batch-size等于2,
使用位置:
1.5 residual block
原代码:
class ResidualBlock(nn.Module):
__doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.
Input:
x: tensor of shape (N, in_channels, H, W)
time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
y: classes tensor of shape (N) or None if the block doesn't use class conditioning
Output:
tensor of shape (N, out_channels, H, W)
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
use_attention (bool): if True applies AttentionBlock to the output. Default: False
"""
def __init__(
self,
in_channels,
out_channels,
dropout,
time_emb_dim=None,
num_classes=None,
activation=F.relu,
norm="gn",
num_groups=32,
use_attention=False,
):
super().__init__()
self.activation = activation
self.norm_1 = get_norm(norm, in_channels, num_groups)
self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm_2 = get_norm(norm, out_channels, num_groups)
self.conv_2 = nn.Sequential(
nn.Dropout(p=dropout),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
)
self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
def forward(self, x, time_emb=None, y=None):
out = self.activation(self.norm_1(x))
out = self.conv_1(out)
if self.time_bias is not None:
if time_emb is None:
raise ValueError("time conditioning was specified but time_emb is not passed")
out += self.time_bias(self.activation(time_emb))[:, :, None, None]
if self.class_bias is not None:
if y is None:
raise ValueError("class conditioning was specified but y is not passed")
out += self.class_bias(y)[:, :, None, None]
out = self.activation(self.norm_2(out))
out = self.conv_2(out) + self.residual_connection(x)
out = self.attention(out)
return out
代码解释:
1.6 attention block
UNet网络中一共有5个attention block,每个attention block的输入尺寸都是256x16x16,输入尺寸和输出尺寸相同,
原代码:
class AttentionBlock(nn.Module):
__doc__ = r"""Applies QKV self-attention with a residual connection.
Input:
x: tensor of shape (N, in_channels, H, W)
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
Output:
tensor of shape (N, in_channels, H, W)
Args:
in_channels (int): number of input channels
"""
def __init__(self, in_channels, norm="gn", num_groups=32):
super().__init__()
self.in_channels = in_channels
self.norm = get_norm(norm, in_channels, num_groups)
self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
self.to_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
b, c, h, w = x.shape
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
q = q.permute(0, 2, 3, 1).view(b, h * w, c)
k = k.view(b, c, h * w)
v = v.permute(0, 2, 3, 1).view(b, h * w, c)
dot_products = torch.bmm(q, k) * (c ** (-0.5))
assert dot_products.shape == (b, h * w, h * w)
attention = torch.softmax(dot_products, dim=-1)
out = torch.bmm(attention, v)
assert out.shape == (b, h * w, c)
out = out.view(b, h, w, c).permute(0, 3, 1, 2)
return self.to_out(out) + x
代码解释:
1.7 UNet结构
UNet的输入有2个部分,一个输入是之前介绍的time embedding,它是需要在每个 residual block 添加进来,另外一个输入是加噪后的数据 x t x_t xt,加噪后尺寸不变,
UNet原代码:
class UNet(nn.Module):
__doc__ = """UNet model used to estimate noise.
Input:
x: tensor of shape (N, in_channels, H, W)
time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
y: classes tensor of shape (N) or None if the block doesn't use class conditioning
Output:
tensor of shape (N, out_channels, H, W)
Args:
img_channels (int): number of image channels
base_channels (int): number of base channels (after first convolution)
channel_mults (tuple): tuple of channel multiplers. Default: (1, 2, 4, 8)
time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
time_emb_scale (float): linear scale to be applied to timesteps. Default: 1.0
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
dropout (float): dropout rate at the end of each residual block
attention_resolutions (tuple): list of relative resolutions at which to apply attention. Default: ()
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
initial_pad (int): initial padding applied to image. Should be used if height or width is not a power of 2. Default: 0
"""
def __init__(
self,
img_channels,
base_channels,
channel_mults=(1, 2, 4, 8),
num_res_blocks=2,
time_emb_dim=None,
time_emb_scale=1.0,
num_classes=None,
activation=F.relu,
dropout=0.1,
attention_resolutions=(),
norm="gn",
num_groups=32,
initial_pad=0,
):
super().__init__()
self.activation = activation
self.initial_pad = initial_pad
self.num_classes = num_classes
self.time_mlp = nn.Sequential(
PositionalEmbedding(base_channels, time_emb_scale),
nn.Linear(base_channels, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
) if time_emb_dim is not None else None
self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_mults):
out_channels = base_channels * mult
for _ in range(num_res_blocks):
self.downs.append(ResidualBlock(
now_channels,
out_channels,
dropout,
time_emb_dim=time_emb_dim,
num_classes=num_classes,
activation=activation,
norm=norm,
num_groups=num_groups,
use_attention=i in attention_resolutions,
))
now_channels = out_channels
channels.append(now_channels)
if i != len(channel_mults) - 1:
self.downs.append(Downsample(now_channels))
channels.append(now_channels)
self.mid = nn.ModuleList([
ResidualBlock(
now_channels,
now_channels,
dropout,
time_emb_dim=time_emb_dim,
num_classes=num_classes,
activation=activation,
norm=norm,
num_groups=num_groups,
use_attention=True,
),
ResidualBlock(
now_channels,
now_channels,
dropout,
time_emb_dim=time_emb_dim,
num_classes=num_classes,
activation=activation,
norm=norm,
num_groups=num_groups,
use_attention=False,
),
])
for i, mult in reversed(list(enumerate(channel_mults))):
out_channels = base_channels * mult
for _ in range(num_res_blocks + 1):
self.ups.append(ResidualBlock(
channels.pop() + now_channels,
out_channels,
dropout,
time_emb_dim=time_emb_dim,
num_classes=num_classes,
activation=activation,
norm=norm,
num_groups=num_groups,
use_attention=i in attention_resolutions,
))
now_channels = out_channels
if i != 0:
self.ups.append(Upsample(now_channels))
assert len(channels) == 0
self.out_norm = get_norm(norm, base_channels, num_groups)
self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
def forward(self, x, time=None, y=None):
ip = self.initial_pad
if ip != 0:
x = F.pad(x, (ip,) * 4)
if self.time_mlp is not None:
if time is None:
raise ValueError("time conditioning was specified but tim is not passed")
time_emb = self.time_mlp(time)
else:
time_emb = None
if self.num_classes is not None and y is None:
raise ValueError("class conditioning was specified but y is not passed")
x = self.init_conv(x)
skips = [x]
for layer in self.downs:
x = layer(x, time_emb, y)
skips.append(x)
for layer in self.mid:
x = layer(x, time_emb, y)
for layer in self.ups:
if isinstance(layer, ResidualBlock):
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, time_emb, y)
x = self.activation(self.out_norm(x))
x = self.out_conv(x)
if self.initial_pad != 0:
return x[:, :, ip:-ip, ip:-ip]
else:
return x
代码解释:
整体解释结构如下:
—分割线—
这是时间编码的解释,self.time_mlp的输入是 t ,是0-1000中的随机数值,
—分割线—
这是下采样模块的解释,
—分割线—
这是middle部分的解释:
—分割线—
这是up部分的解释:
2、命令行参数解析
原代码:
'''
code from https://github.com/abarankab/DDPM/tree/main
'''
import argparse
import datetime
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils
def main():
args = create_argparser().parse_args()
device = args.device
try:
diffusion = script_utils.get_diffusion_from_args(args).to(device)
optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
if args.model_checkpoint is not None:
diffusion.load_state_dict(torch.load(args.model_checkpoint))
if args.optim_checkpoint is not None:
optimizer.load_state_dict(torch.load(args.optim_checkpoint))
if args.log_to_wandb:
if args.project_name is None:
raise ValueError("args.log_to_wandb set to True but args.project_name is None")
run = wandb.init(
project=args.project_name,
# entity='treaptofun', # 用于指定实验所属的团队或组织
config=vars(args),
name=args.run_name,
)
wandb.watch(diffusion)
batch_size = args.batch_size
train_dataset = datasets.CIFAR10(
root='./cifar_train',
train=True,
download=True,
transform=script_utils.get_transform(),
)
test_dataset = datasets.CIFAR10(
root='./cifar_test',
train=False,
download=True,
transform=script_utils.get_transform(),
)
train_loader = script_utils.cycle(DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=0,
))
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=0)
acc_train_loss = 0
for iteration in range(1, args.iterations + 1):
diffusion.train()
x, y = next(train_loader)
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
acc_train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
diffusion.update_ema()
if iteration % args.log_rate == 0:
test_loss = 0
with torch.no_grad():
diffusion.eval()
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
test_loss += loss.item()
if args.use_labels:
samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
else:
samples = diffusion.sample(10, device)
samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()
test_loss /= len(test_loader)
acc_train_loss /= args.log_rate
wandb.log({
"test_loss": test_loss,
"train_loss": acc_train_loss,
"samples": [wandb.Image(sample) for sample in samples],
})
acc_train_loss = 0
if iteration % args.checkpoint_rate == 0:
model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"
torch.save(diffusion.state_dict(), model_filename)
torch.save(optimizer.state_dict(), optim_filename)
if args.log_to_wandb:
run.finish()
except KeyboardInterrupt:
if args.log_to_wandb:
run.finish()
print("Keyboard interrupt, run finished early")
def create_argparser():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
defaults = dict(
learning_rate=2e-4,
batch_size=2,
iterations=800000,
log_to_wandb=True,
log_rate=1000,
checkpoint_rate=1000,
log_dir="~/ddpm_logs",
project_name="Enzo_ddpm",
run_name=run_name,
model_checkpoint=None,
optim_checkpoint=None,
schedule_low=1e-4,
schedule_high=0.02,
device=device,
)
defaults.update(script_utils.diffusion_defaults())
parser = argparse.ArgumentParser()
script_utils.add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
3、数据的获取与预处理
为什么要将图像 x 0 x_0 x0 像素值映射到 [-1,1]之间?
- 因为图像后面加入的噪声是服从均值为0,方差为1的分布,原图像的像素值要和噪音的值做一个加权和(也就是加噪过程),所以也需要把原图像处理为均值为0的分布,
4、模型的训练框架
下图是betas的生成代码,以及代码的整体框架,
参考:
1、哔哩哔哩视频
2、https://github.com/Enzo-MiMan/cv_related_collections/tree/main/diffusion