今天我们继续Diffusion扩散模型的后半部分学习
条件U-Net
网络构建过程如下:
-
首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置
-
接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
-
在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
-
接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
-
最后,应用ResNet/ConvNeXT块,然后应用卷积层
最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。
代码如下
class Unet(nn.Cell):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().__init__()
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ConvNextBlock, mult=convnext_mult)
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.CellList([])
self.ups = nn.CellList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.CellList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.SequentialCell(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
len_h = len(h) - 1
for block1, block2, attn, upsample in self.ups:
x = ops.concat((x, h[len_h]), 1)
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
正向扩散
-
我们将正向过程方差设置为常数,从𝛽1=10−41=10−4线性增加到𝛽𝑇=0.02=0.02。
-
但是,它在(Nichol et al., 2021)中表明,当使用余弦调度时,可以获得更好的结果。
代码
from mindspore.dataset import ImageFolderDataset
image_size = 128
transforms = [
Resize(image_size, Inter.BILINEAR),
CenterCrop(image_size),
ToTensor(),
lambda t: (t * 2) - 1
]
path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
extensions=['.jpg', '.jpeg', '.png', '.tiff'],
num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)
文末附上打卡时间