利用扩散模型DDPM生成高分辨率图像(生成高保真图像项目实践)
Mindspore框架利用扩散模型DDPM生成高分辨率图像|(一)关于denoising diffusion probabilistic model (DDPM)模型
Mindspore框架利用扩散模型DDPM生成高分辨率图像|(二)数据集准备与处理
Mindspore框架利用扩散模型DDPM生成高分辨率图像|(三)模型训练与推理实践
一、扩散模型DDPM模型训练
扩散模型生成新图像是通过反转扩散过程来实现。理想情况下,我们最终会得到一个看起来像是来自真实数据分布的图像。
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = randn_like(x)
return model_mean + ops.sqrt(posterior_variance_t) * noise
def p_sample_loop(model, shape):
b = shape[0]
# 从纯噪声开始
img = randn(shape, dtype=None)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
imgs.append(img.asnumpy())
return imgs
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
1.1 模型初始化
# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
# 定义 Unet模型
unet_model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
item.name = name_list[i]
i += 1
# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
# 定义前向过程
def forward_fn(data, t, noise=None):
loss = p_losses(unet_model, data, t, noise)
return loss
# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
# 梯度更新
def train_step(data, t, noise):
loss, grads = grad_fn(data, t, noise)
optimizer(grads)
return loss
1.2 模型训练
import time
epochs = 10
iterator = dataset.create_tuple_iterator(num_epochs=epochs)
for epoch in range(epochs):
begin_time = time.time()
for step, batch in enumerate(iterator):
unet_model.set_train()
batch_size = batch[0].shape[0]
t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
noise = randn_like(batch[0])
loss = train_step(batch[0], t, noise)
if step % 500 == 0:
print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
end_time = time.time()
times = end_time - begin_time
print("training time:", times, "s")
# 展示随机采样效果
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")
二、模型推理预测
2.1 推理过程
从模型中采样
# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
# 展示一个随机效果
# random_index = 5
# plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
创建去噪过程
import matplotlib.animation as animation
random_index = 53
fig = plt.figure()
ims = []
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
ims.append([im])
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()
三、参考文献
- The Annotated Diffusion Model
- 由浅入深了解Diffusion Model
- Diffusion扩散模型