加噪声的分数匹配
def anneal_dsm_score_estimation(scorenet, samples, labels, sigmas, anneal_power=2.):
# 取出每个样本对应噪声级别下的噪声分布的标准差,即公式中的sigma_i,
# 这里的 labels 是用于标识每个样本的噪声级别的,就是 i,实际是一种索引标识
# (bs,)->(bs,1,1,1) 扩展至与图像一致的维度数
used_sigmas = sigmas[labels].view(samples.shape[0], *([1] * len(samples.shape[1:])))
# 加噪:x' = x + sigma * z (z ~ N(0,1))
perturbed_samples = samples + torch.randn_like(samples) * used_sigmas
# 目标score,本质是对数条件概率密度 log(p(x'|x)) 对噪声数据 x' 的梯度
# 由于这里建模为高斯分布,因此可计算出结果最终如下,见前文公式(vii)
target = - 1 / (used_sigmas ** 2) * (perturbed_samples - samples)
# 模型预测的 score
scores = scorenet(perturbed_samples, labels)
target = target.view(target.shape[0], -1)
scores = scores.view(scores.shape[0], -1)
# 先计算每个样本在所有维度下分数估计的误差总和,再对所有样本求平均
# 见前文公式(vii)
loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power
return loss.mean(dim=0)
采样生成:
def anneal_Langevin_dynamics(self, x_mod, scorenet, sigmas, n_steps_each=100, step_lr=0.00002):
images = []
with torch.no_grad():
# 依次在每个噪声级别下进行朗之万动力学采样生成,噪声强度递减
for c, sigma in tqdm.tqdm(enumerate(sigmas), total=len(sigmas), desc='annealed Langevin dynamics sampling'):
# 噪声级别
labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
labels = labels.long()
# 这个步长并非 Algorithm 1 中的 alpha,而是其中第6步的 alpha/2
# 对应朗之万动力学采样公式(见公式(vi))的 epsilon/2
step_size = step_lr * (sigma / sigmas[-1]) ** 2
# 每个噪声级别下进行一定步数的朗之万动力学采样生成
for s in range(n_steps_each):
images.append(torch.clamp(x_mod, 0.0, 1.0).to('cpu'))
# 对应公式(vi)最后一项
noise = torch.randn_like(x_mod) * np.sqrt(step_size * 2)
# 网络估计的分数
grad = scorenet(x_mod, labels)
# 朗之万动力方程
x_mod = x_mod + step_size * grad + noise
return images
详细的解释(强推):https://zhuanlan.zhihu.com/p/597490389