diffusion model(三)—— classifier guided diffusion model

news2024/9/26 19:49:29

classifier guided diffusion model

背景

对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,如何生成特定类别的图片呢?这就是classifier guide需要解决的问题。

方法大意

为了实现带类别标签 y y y的DM的推导,进行了以下定义
q ^ ( x 0 ) : = q ( x 0 ) q ^ ( y ∣ x 0 ) : = Know labels per sample q ^ ( x t + 1 ∣ x t , y ) : = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) : = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) (1) \begin{aligned} \hat{q}(x_0) &:= q(x_0) \\ \hat{q}(y|x_0) &:= \text{Know labels per sample} \\ \hat{q}(x_{t+1}|x_{t}, y) &:= q(x_{t+1}|x_t) \\ \hat{q}(x_{1:T}|x_0, y)&:= \prod \limits_{t=1}^T\hat{q}(x_t|x_{t-1}, y) \\ \end{aligned} \tag{1} q^(x0)q^(yx0)q^(xt+1xt,y)q^(x1:Tx0,y):=q(x0):=Know labels per sample:=q(xt+1xt):=t=1Tq^(xtxt1,y)(1)
虽然上式定义了以 y y y为条件的噪声过程 q ^ \hat{q} q^,但我们还可以证明当 q ^ \hat{q} q^不以 y y y为条件时的行为与 q q q完全相同,即
q ^ ( x t + 1 ∣ x t ) = ∫ y q ^ ( x t + 1 , y ∣ x t ) d y = ∫ y q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) d y = ∫ y q ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) ∫ y q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t , y ) (2) \begin{aligned} \hat{q}(x_{t+1}|x_t) &= \int_y \hat{q}(x_{t+1}, y| x_t)dy \\ &= \int_y \hat{q}(x_{t+1}|x_t, y)\hat{q}(y|x_t)dy \\ &= \int_y q(x_{t+1}|x_t)\hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \int_y \hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \\ &= \hat{q}(x_{t+1}|x_t, y) \\ \end{aligned}\tag{2} q^(xt+1xt)=yq^(xt+1,yxt)dy=yq^(xt+1xt,y)q^(yxt)dy=yq(xt+1xt)q^(yxt)dy=q(xt+1xt)yq^(yxt)dy=q(xt+1xt)=q^(xt+1xt,y)(2)
同样的思路:
q ^ ( x 1 : T ∣ x 0 ) = ∫ y q ^ ( x 1 : T , y ∣ x 0 ) d y = ∫ y q ^ ( x 1 : T ∣ y , x 0 ) q ( y ∣ x 0 ) d y = ∫ y ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) ⏟ q ( x t ∣ x t − 1 ) q ( y ∣ x 0 ) d y = ∏ t = 1 T q ( x t ∣ x t − 1 ) ⏟ q ( x 1 : T ∣ x 0 ) ∫ y q ( y ∣ x 0 ) d y ⏟ = 1 = q ( x 1 : T ∣ x 0 ) (3) \begin{aligned} \hat{q}(x_{1:T}|x_0) &= \int_y \hat{q}(x_{1:T}, y|x_0) d_y \\ &= \int_y \hat{q}(x_{1:T}|y, x_0)q(y| x_0) d_y \\ &= \int_y \prod \limits_{t=1}^T \underbrace{ \hat{q}(x_t|x_{t-1}, y)}_{q(x_t|x_t-1)} q(y| x_0) d_y \\ &= \underbrace{\prod \limits_{t=1}^Tq(x_t|x_{t-1})}_{q(x_{1:T}|x_0)} \underbrace{\int_y q(y| x_0)d_y}_{=1} \\ &= q(x_{1:T}|x_0) \end{aligned}\tag{3} q^(x1:Tx0)=yq^(x1:T,yx0)dy=yq^(x1:Ty,x0)q(yx0)dy=yt=1Tq(xtxt1) q^(xtxt1,y)q(yx0)dy=q(x1:Tx0) t=1Tq(xtxt1)=1 yq(yx0)dy=q(x1:Tx0)(3)
根据上式同样可以推导出
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 , ⋯   , x t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) ⏟ q ( x 0 ) q ^ ( x 1 , ⋯   , x t ∣ x 0 ) ⏟ q ( x 1 : T ∣ x 0 ) d x 0 : t − 1 = q ( x t ) (4) \begin{aligned} \hat{q}(x_t) &= \int_{x_{0:t - 1}} \hat{q}(x_0, \cdots, x_t)dx_{0:t-1} \\ &= \int_{x_{0:t - 1}} \underbrace{\hat{q}(x_0)}_{q(x_0)} \underbrace{\hat{q}(x_1, \cdots, x_t|x_0)}_{q(x_{1:T}|x_0)}dx_{0:t-1} \\ &= q(x_t) \end{aligned} \tag{4} q^(xt)=x0:t1q^(x0,,xt)dx0:t1=x0:t1q(x0) q^(x0)q(x1:Tx0) q^(x1,,xtx0)dx0:t1=q(xt)(4)
由上述推导可见带条件的DM的前向过程与DDPM完全相同。并且根据贝叶斯公式,不带逆向过程也满足
p ^ ( x t ∣ x t + 1 ) = p ( x t ∣ x t + 1 ) (5) \hat{p}(x_t|x_{t+1}) = p(x_t|x_{t+1}) \tag{5} p^(xtxt+1)=p(xtxt+1)(5)
与此同时我们可以证明分类分布 q ^ ( y ∣ x t ) \hat{q}(y|x_t) q^(yxt)只和当前时刻的输入 x t x_t xt有关,与 x t + 1 x_{t+1} xt+1无关
q ^ ( y ∣ x t , x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) ⏞ q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) (6) \begin{aligned} \hat{q}(y|x_t, x_{t+1}) & = \frac{ \overbrace{ \hat{q}(x_{t+1}|x_t, y)}^{\hat{q}(x_{t+1}|x_t)} \hat{q}(y|x_t) } {\hat{q}(x_{t+1}|x_t )} \\ & = \hat{q}(y|x_t) \end{aligned} \tag{6} q^(yxt,xt+1)=q^(xt+1xt)q^(xt+1xt,y) q^(xt+1xt)q^(yxt)=q^(yxt)(6)

基于条件的去噪过程

将带类别信息的去噪过程定义为 p ^ ( x t ∣ x t + 1 , y ) \hat{p}(x_t|x_{t+1}, y) p^(xtxt+1,y)

p ^ ( x t ∣ x t + 1 , y ) = p ^ ( x t , x t + 1 , y ) p ^ ( y ∣ x t + 1 ) p ^ ( x t + 1 ) = p ^ ( x t , y ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t , x t + 1 ) ⏞ p ^ ( y ∣ x t ) p ^ ( x t ∣ x t + 1 ) ⏞ p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) (7) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) & = \frac{\hat{p} (x_t, x_{t+1}, y) }{\hat{p} (y|x_{t+1}) \hat{p} (x_{t+1}) } \\ & = \frac{\hat{p} (x_t, y | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\overbrace{\hat{p} (y|x_t, x_{t+1})}^{\hat{p}(y|x_t)} \overbrace{\hat{p}(x_t | x_{t+1})}^{p(x_t|x_{t+1})} }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\hat{p} (y|x_t) p(x_t | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \end{aligned} \tag{7} p^(xtxt+1,y)=p^(yxt+1)p^(xt+1)p^(xt,xt+1,y)=p^(yxt+1)p^(xt,yxt+1)=p^(yxt+1)p^(yxt,xt+1) p^(yxt)p^(xtxt+1) p(xtxt+1)=p^(yxt+1)p^(yxt)p(xtxt+1)(7)
由于 x t + 1 x_{t+1} xt+1是已知的, p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)这个概率分布与 x t x_t xt无关,可以将 p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)视为常数 Z Z Z。此时上式可以表述为
p ^ ( x t ∣ x t + 1 , y ) = Z p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) (8) \hat{p} (x_t| x_{t+1}, y) = Z \hat{p} (y|x_t) p(x_t | x_{t+1}) \tag{8} p^(xtxt+1,y)=Zp^(yxt)p(xtxt+1)(8)
上式的右边第二项 p ^ ( y ∣ x t ) \hat{p} (y|x_t) p^(yxt)很容易得到,我们可以根据 x t , y x_t, y xt,y的pair对训练一个分类模型 p ^ ϕ ( y ∣ x t ) \hat{p}_\phi(y|x_t) p^ϕ(yxt)

上式的右边第三项 p ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) p(xtxt+1)在DDPM中也能够通过一个neural network进行估计 p ( x t ∣ x t + 1 ) ≈ p θ ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) \approx p_\theta(x_t|x_{t+1}) p(xtxt+1)pθ(xtxt+1)

故采样分布
p ^ ( x t ∣ x t + 1 , y ) ≈ p ^ ϕ , θ ( x t ∣ x t + 1 , y ) = Z p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) (9) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) &\approx \hat{p}_{\phi, \theta} (x_t| x_{t+1}, y) \\ &= Z \hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1}) \end{aligned} \tag{9} p^(xtxt+1,y)p^ϕ,θ(xtxt+1,y)=Zp^ϕ(yxt)pθ(xtxt+1)(9)
下面来看有了上面这个式子如何进行采样

直接对上面的式子进行采样是很难解决的。论文参考文献1将上式近似为perturbed Gaussian distribution。

根据前文DM的推导可知 p θ ( x t ∣ x t + 1 ) = N ( μ , Σ ) = 1 2 π Σ exp ⁡ ( − ( x − μ ) 2 2 Σ ) p_{\theta}(x_t | x_{t+1}) = \mathcal{N}(\mu, \Sigma)=\frac{1}{\sqrt{2\pi} \sqrt{\Sigma} } \exp \left ({- \frac{(x - \mu)^2}{2\Sigma}} \right) pθ(xtxt+1)=N(μ,Σ)=2π Σ 1exp((xμ)2) ,对其取对数
log ⁡ p θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C (10) \log p_{\theta}(x_t|x_{t+1}) = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + C \tag{10} logpθ(xtxt+1)=21(xtμ)TΣ1(xtμ)+C(10)
对于 log ⁡ p ^ ϕ ( y ∣ x t ) \log \hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) 作者假设其curvature比 Σ − 1 \Sigma^{-1} Σ1低。这个假设是合理的,对于当diffusion steps足够大时, ∥ Σ ∥ → 0 \parallel \Sigma \parallel \rightarrow 0 Σ∥→0。在该情况下,对 log ⁡ p ^ ϕ ( y ∣ x t ) \log\hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) x t = μ x_t = \mu xt=μ处进行泰勒展开
log ⁡ p ^ ϕ ( y ∣ x t ) ≈ log ⁡ p ^ ϕ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ = ( x t − μ ) g + C 1 where:  g = ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ , C 1  is a contant. (11) \begin{aligned} \log \hat{p}_{\phi} (y|x_t) & \approx \log \hat{p}_{\phi} (y|x_t) | _{x_t = \mu} + (x_t - \mu) \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu} \\ &= (x_t - \mu) g + C_1 \\ \text{where: } g &= \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu}, C_1\text{ is a contant.} \end{aligned} \tag{11} logp^ϕ(yxt)where: glogp^ϕ(yxt)xt=μ+(xtμ)xtlogpϕ(yxt)xt=μ=(xtμ)g+C1=xtlogpϕ(yxt)xt=μ,C1 is a contant.(11)

log ⁡ ( p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C 3 = log ⁡ p ( z ) + C 4 , z ∼ N ( μ + Σ g , Σ ) (12) \begin{aligned} \log (\hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1})) & = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + C_3 \\ & = \log p(z) + C_4, z \sim \mathcal{N}(\mu + \Sigma g, \Sigma) \end{aligned} \tag{12} log(p^ϕ(yxt)pθ(xtxt+1))=21(xtμ)TΣ1(xtμ)+(xtμ)g+C2=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C2=21(xtμΣg)TΣ1(xtμΣg)+C3=logp(z)+C4,zN(μ+Σg,Σ)(12)

(附录给出了验证性证明)

通过上述推导,我们得到了带类别条件的采样过程也可以用高斯分布来近似,只是均值需要加上 Σ g \Sigma g Σg。具体的算法如下
在这里插入图片描述

代码实现

p_mean_var_ddpm是DDPM对高斯分布均值、方差的计算函数

p_mean_var_ddpm_with_classifier是引入类别控制后的对高斯分布均值、方差的计算函数

有了均值方差就可以进行采样了

def p_mean_var_ddpm(self, noise_model, x, t):
    """
    Math:
    \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t -
        \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}f_\theta(x_t, t) \tag{30}
    """
    betas_t = extract(self.betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        self.sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
    model_mean_t = sqrt_recip_alphas_t * (
        x - betas_t * noise_model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = extract(self.posterior_variance, t, x.shape)
    return model_mean_t, posterior_variance_t

  
def p_mean_var_ddpm_with_classifier(classifier, noise_model, x, t, y=None, cfs=1):
    def cond_fn(x: torch.Tensor, t: torch.Tensor, y: torch.Tensor): 
        assert y is not None
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return torch.autograd.grad(selected.sum(), x_in)[0].float()   # gradient descend
    grad = cond_fn(x_temp, t, y=y) * cfs 
    model_mean_t, posterior_variance_t = p_mean_var_ddpm(noise_model, x, t)
    new_mean = model_mean_t + posterior_variance_t * grad
    return new_mean, posterior_variance_t

DDIM 中基于条件的去噪过程

上述条件抽样推导仅对随机扩散采样过程有效,不能应用于DDIM2等确定性采样方法(因为DDIM中设定了方差为0,故无法推导出式19)。为此,作者在研究中采用score-based的思路,参考了Song等人[^ 3]的方法,并利用了扩散模型和score matching之间的联系3

首先根据贝叶斯公式
p ( x t ∣ y ) = p ( y ∣ x t ) p ( x t ) p ( y ) ⇒ log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) ⏟ = 0 ⇒ ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) (13) \begin{aligned} p (x_t| y) & = \frac{p (y|x_t) p(x_t) }{p (y) } \\ \Rightarrow \log{p (x_t| y) } &= \log{p (y|x_t)} + \log{p(x_t)} - \log{p (y) } \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (x_t|y)} &= \nabla_{x_t}\log{p (y|x_t)} + \nabla_{x_t}\log{p(x_t)} - \underbrace{\nabla_{x_t}\log{p(y) }}_{=0} \\ \Rightarrow \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} + \nabla_{x_t}\log{p(x_t)} \\ \end{aligned} \tag{13} p(xty)logp(xty)xt求导xtlogp(xty)xtlogp(xty)=p(y)p(yxt)p(xt)=logp(yxt)+logp(xt)logp(y)=xtlogp(yxt)+xtlogp(xt)=0 xtlogp(y)=xtlogp(yxt)+xtlogp(xt)(13)
具体来说,如果我们有一个模型 ϵ θ ( x t ) \epsilon_\theta(x_t) ϵθ(xt)来预测添加到样本中的噪声,那么可以利用它来推导出一个score function:
∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) (14) \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \tag{14} xtlogpθ(xt)=1αt 1ϵθ(xt)(14)
代入式(20)得
∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ‾ t ϵ θ ( x t ) ⇒ 1 − α ‾ t ∇ x t log ⁡ p ( x t ∣ y ) = 1 − α ‾ t ∇ x t log ⁡ p ( y ∣ x t ) − ϵ θ ( x t ) (15) \begin{aligned} \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \\ \Rightarrow \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(x_t| y)} &= \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(y|x_t)} - \epsilon_\theta(x_t) \end{aligned} \tag{15} xtlogp(xty)1αt xtlogp(xty)=xtlogp(yxt)1αt 1ϵθ(xt)=1αt xtlogp(yxt)ϵθ(xt)(15)
定义在条件 y y y下的估计噪声 ϵ ^ ( x t ∣ y ) \hat{\epsilon}(x_t|y) ϵ^(xty)为:
ϵ ^ ( x t ∣ y ) : = ϵ θ ( x t ) − 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (16) \hat{\epsilon}(x_t|y) := \epsilon_\theta(x_t) - \sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{16} ϵ^(xty):=ϵθ(xt)1αt xtlogpϕ(yxt)(16)
只需将DDIM中的$ \epsilon_\theta(x_t) 替换为 替换为 替换为\hat{\epsilon}(x_t|y)$就得到了基于条件的去噪过程。

在这里插入图片描述

代码上也很直观

def p_sample_ddim(self, model, x, t):
    """
    x_{t-1} &=  \sqrt{\overline{\alpha}_{t-1}} \frac{x_t - \sqrt{1 - \overline{\alpha}_{t}}\boldsymbol{\epsilon}_\theta(x_t, t)}
        {\sqrt{\overline{\alpha}_{t}}} +  \sqrt{1 - \overline{\alpha}_{t-1} } \boldsymbol{\epsilon}_\theta(x_t, t)
    """
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction
  
  
def p_sample_with_classifier(self, model, x, t, t_index, y=None, **kwargs):
    if y is None:
        return self.p_sample_ddim(model, x, t, t_index=t_index)
    cfs = kwargs.get("cfs", 1) 
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    score = self.cond_fn(x, t, y=y) * cfs
    pred_noise = pred_noise - sqrt_one_minus_alphas_cumprod_t * score  # update noise 
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction

一些细节

classifier的训练

classifier的训练与扩散模型的训练可以是独立的。在训练classifier的时候可以噪声预测模型(Unet)的encode部分作为主干,在后面接了一个分类层。并且需要与相应的扩散模型相同的噪声分布对classifier进行训练。训练数据集如 [ ( x 1 t , t , y 1 ) , ( x 2 t , t , y 2 ) , . . . , ( x N t , t , y N ) ] [(x_1^t,t, y_1), (x_2^t,t, y_2), ..., (x_N^t,t, y_N)] [(x1t,t,y1),(x2t,t,y2),...,(xNt,t,yN)] t t t是对时间步的采样, x t x^t xt x x x在时间步 t t t的输出。训练完成后,采用上面的算法集成到采样过程中。

gradient score的作用

在上面的采样算法我们看到有一个gradient scale s s s来对梯度进行拉伸。

实验视角

一般来说当 s = 1 s=1 s=1时,大约能保证生成的图片50%是想要的类别4,随着 s s s的增大,这个比例也能够增加。如下图,当 s s s增加到10,此时生成的图片都是期望的类别。因此 s s s也称之为guidance scale。
在这里插入图片描述

其实理解这个scale还有另一个视角

s ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) ) = ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) s ) s\nabla_{x_t} \log (p_\phi(y|x_t)) = \nabla_{x_t} \log (p_\phi(y|x_t)^s) sxtlog(pϕ(yxt))=xtlog(pϕ(yxt)s),当 s > 1 s>1 s>1他相当于对分布 p ϕ ( y ∣ x t ) p_\phi(y|x_t) pϕ(yxt)进行了一个指数拉升,从而带来更大的梯度更新收益。

根据DM的采样过程,当没有classifier guided时,在时刻 t t t,的采样过程应当是
x t − 1 = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) = 1 α t ( x t − 1 − α t 1 − α ‾ t ϵ θ ( x t , t ) ) ⏟ μ θ ( x t , t ) + σ ( t ) ϵ (17) \begin{aligned} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \underbrace{\frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}_t}}\epsilon_\theta(x_t, t))}_{\mu_\theta(x_t, t)} + \sigma(t) \epsilon \end{aligned} \tag{17} xt1=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)=μθ(xt,t) αt 1(xt1αt 1αtϵθ(xt,t))+σ(t)ϵ(17)
当加了classifier guided相当于将 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t)向预测类别为 y y y的方向更新了一小步。 s s s是控制更新的幅值。
x t − 1 = μ θ ( x t , t ) + s ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) \begin{align} x_{t-1} &=& \mu_{\theta}(x_t, t) + s\nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu_{\theta}(x_t, t)} + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \tag{18} \end{align} xt1=μθ(xt,t)+sxtlogpϕ(yxt)xt=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)(18)

参考文献

附录

式12推导验证
− 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 − μ T Σ − 1 − g T Σ T Σ − 1 ⏟ g T ) ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ − Σ g ) − μ T Σ − 1 ( x t − μ − Σ g ) − g T ( x t − μ − Σ g ) ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ ) − μ T Σ − 1 ( x t − μ ) ) ⏟ ( x t − μ ) T Σ − 1 ( x t − μ ) − 1 2 ( − g T ( x t − μ − Σ g ) + ( − x t T Σ − 1 Σ g ) ⏟ − x t T g + μ T Σ − 1 Σ g ⏟ μ T g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 \begin{align*} &- \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} - \mu^T \Sigma^{-1} - \underbrace{g^T \Sigma^T \Sigma^{-1}}_{g^T} )(x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} (x_t - \mu - \Sigma g) - \mu^T \Sigma^{-1} (x_t - \mu - \Sigma g) - g^T (x_t - \mu - \Sigma g)) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} \underbrace{(x_t^T \Sigma^{-1} (x_t - \mu ) - \mu^T \Sigma^{-1} (x_t - \mu))}_{(x_t - \mu)^T \Sigma^{-1} (x_t - \mu)} - \frac{1}{2} ( - g^T (x_t - \mu - \Sigma g) + \underbrace{(- x_t^T \Sigma^{-1}\Sigma g)}_{-x_t^Tg} + \underbrace{\mu^T \Sigma^{-1}\Sigma g}_{\mu^Tg}) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ \end{align*} ======21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTΣ1μTΣ1gT gTΣTΣ1)(xtμΣg)+21gTΣg+C221(xtTΣ1(xtμΣg)μTΣ1(xtμΣg)gT(xtμΣg))+21gTΣg+C221(xtμ)TΣ1(xtμ) (xtTΣ1(xtμ)μTΣ1(xtμ))21(gT(xtμΣg)+xtTg (xtTΣ1Σg)+μTg μTΣ1Σg)+21gTΣg+C221(xtμ)TΣ1(xtμ)+(xtμ)g+C2


  1. Deep unsupervised learning using nonequilibrium thermodynamics ↩︎

  2. [Denoising Diffusion Implicit Models (DDIM) Sampling](https://arxiv.org/abs/2010.02502) ↩︎

  3. Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. arXiv:arXiv:1907.05600, 2020. ↩︎

  4. Diffusion Models Beat GANs on Image Synthesis ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/688623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前沿重器[35] | 提示工程和提示构造技巧

前沿重器 栏目主要给大家分享各种大厂、顶会的论文和分享,从中抽取关键精华的部分和大家分享,和大家一起把握前沿技术。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。(算起来,专项启动已经…

MySQL数据库主从复制与读写分离(图文详解!)

目录 前言 一:MySQL数据库主从复制与读写分离 1、什么是读写分离? 2、为什么要读写分离呢? 3、什么时候要读写分离? 4、主从复制与读写分离 5、mysql支持的复制类型 (1)STATEMENT (2&…

SLAM面试笔记(5) — C++面试题

目录 第1章 C基础 1 C中static静态变量有什么作用,在什么情况下会用? 2 类中的this指针指向哪里? 3 说一下const的作用。 4 std::string类型为啥不能memset? 5 emplace_back( )和push_back( )有什么区别? 6 tra…

【状态估计】基于无味卡尔曼滤波模拟倾斜传感器研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

通过Redisson的管道批量操作来提高Redis Io效率

一、背景 当在对redis进行刷数操作时,大部分的redis框架对于单次执行的效率差不多,但我们有时需要一次性写入大量的redis key时,一次一次的操作速度就很慢。尤其是处于跨region的环境,一次的redis io就高达数十毫秒&#xff0…

Android aar包的生成与使用

前言 最近用Android Studio开发Android时,会经常接触到aar包(Java Archive),aar包含所有资源,class以及res资源文件全部包含。 优势 Android通过aar方式把代码和资源打成一个包,提供给第三方使用或者是开…

什么是AOP?

目录 一、AOP简介 1、AOP简介和作用 2、AOP的概念 二、AOP的基本实现 三、AOP工作流程 1 、AOP工作流程 2、AOP核心概念 四、AOP切入点表达式 1、语法格式 2、通配符 五、AOP通知类型 1、AOP通知分类 2、AOP通知详解 (1)前置通知 &#xf…

Java Web JDBC(1)23.6.25

JDBC 1,JDBC概述 在开发中我们使用的是java语言,那么势必要通过java语言操作数据库中的数据。这就是接下来要学习的JDBC。 1.1 JDBC概念 JDBC 就是使用Java语言操作关系型数据库的一套API 全称:( Java DataBase Connectivity ) Java 数据库…

vue3-实战-13-管理后台-数据大屏解决方案-顶部组件搭建-实时游客统计

目录 1-数据大屏解决方案vw和vh 2-数据大屏解决方案scale 3-数据大屏原型需求图 4-数据大屏顶部搭建 4.1-顶部原型需求 4.2-顶部模块父组件的结构和逻辑 4.3-顶部模块子组件结构和逻辑 5-数据大屏游客统计 5.1-原型需求图分析 5.2-结构样式逻辑开发 1-数据大屏解决方…

视觉与多模态大模型前沿进展 | 2023智源大会精彩回顾

导读 6 月 9 日下午,智源大会「视觉与多模态大模型」专题论坛如期举行。随着 stable diffusion、midjourney、SAM 等爆火应用相继问世,AIGC 和计算机视觉与大模型的结合成为了新的「风口」。本次研讨会由智源研究院访问首席科学家颜水成和马尔奖获得者曹…

在UE5编辑器环境中使用Python

UE有很多Python方案,本文所讲述的Python为UE5官方内嵌版本方案,并且只能在编辑器环境下使用,使用该功能可以编写编辑器下的辅助工具,提升开发效率。 1.调用Python的几种方式 讲一讲UE5中调用Python的几种方式,首先是…

rust abc(5): 常量

文章目录 1. 目的2. 基本用法2.1 说明2.2 运行结果 3. 不推荐或不正确用法3.1 不推荐用小写字母作为常量名字3.2 常量名称中含有小写字母就会报warning3.3 定义常量时,不指定数据类型会编译报错 4. const 和 immutable 的区别4.1 const 可以在函数外声明&#xff0c…

三、决策树 四、随机森林

三、决策树1.决策树模型的原理1)什么是决策树2)决策树模型原理3.构建决策树的目的4)决策树的优缺点 2.决策树的典型生成算法1)常用的特征选择有信息增益、信息增益率、基尼系数2)基于信息增益的ID3算法3)基…

JAVAWEB 30-

JAVAWEB 30- 快速入门DriverManagerConnectionresultsetPreparedStatement增删改查查询所有添加 修改 MAVEN坐标MyBatis代理开发mybatis查询条件查询添加删除参数传递 快速入门 public static void main(String[] args) throws Exception { /1.注册驱动 Class.forName("co…

【TA100】Bloom算法

一、什么是Bloom算法 1、首先看一下Bloom效果长什么样 2、什么是Bloom ● Bloom,也称辉光,是一种常见的屏幕效果 ● 模拟摄像机的一种图像效果,让画面中较亮的区域“扩散”到周围的区域中,造成一种朦胧的效果 ● 可以让物体具有…

[JVM]再聊 CMS 收集器

题目之所以是再聊,是因为以前聊过: [JVM]聊聊 CMS 收集器 最近又看了下这块的知识,打算把 CMS/标记-清除/GC Roots/引用 这些知识串起来 我依旧可能写的不是很好,降低下期待 GC 算法 CMS 是基于 标记-清除 算法来做的,那我们就先从 GC 算法开始聊 GC 算法有: 标记-清除 标…

一篇博客教会你使用Docker部署Redis哨兵

文章目录 主数据库配置文件启动实例容器虚拟IP 从数据库配置文件启动实例 主从数据库查看主数据库查看从数据库 哨兵配置文件启动哨兵查看哨兵 哨兵机制哨兵选举选举日志重启主数据库 今天我们学习使用 Docker 部署 Redis 的主从复制,并部署 Redis 哨兵,…

Linux学习之grub配置文件介绍

grub配置文件 /etc/default/grub这个文件里边有一些简单的grub配置。 可以看到/etc/default/grub文件里有GRUB_CMDLINE_LINUX"crashkernelauto rhgb quiet idlehalt biosdevname0 net.ifnames0 consoletty0 consolettyS0,115200n8 noibrs nvme_core.io_timeout429496729…

全网独家--【图像色彩增强】方法梳理和问题分析

文章目录 图像增强图像色彩增强问题可视化比较 难点色彩空间大,难以准确表征?不同场景差异大,难以自适应?计算量大,但应用场景往往实时性要求高? 方法传统方法深度学习逐像素预测3D LUT模仿ISP 个人思考批判…

2.数据的类型、数据的输入输出

2.数据的类型、数据的输入输出 2.1 数据类型-常量-变量(整型-浮点-字符)2.1.1 数据类型2.1.2 常量2.1.3 变量2.1.4 整型类型2.1.5 浮点型数据2.1.6 字符型数据字符型常量字符型变量 2.1.7 字符串型常量 2.2 混合运算-printf讲解 2.1 数据类型-常量-变量(整型-浮点-字符) 2.1.1…