《Classifier-Free Diffusion Guidance》的核心观点与方法

news2025/3/18 12:32:20

介绍《Classifier-Free Diffusion Guidance》的核心观点与方法

在扩散模型(Diffusion Models)的研究中,如何在生成样本的质量与多样性之间找到平衡一直是核心挑战之一。传统的生成模型(如GANs或Glow)通过截断(truncation)或低温采样(low temperature sampling)来实现这一目标,但扩散模型在这方面的尝试却往往效果不佳。Dhariwal 和 Nichol 在 2021 年提出了“分类器引导”(Classifier Guidance),通过引入额外的分类器来提升样本质量,但这增加了训练复杂性,并引发了是否必须依赖分类器的问题。Jonathan Ho 和 Tim Salimans 在论文《Classifier-Free Diffusion Guidance》中提出了一种新颖的替代方法——“无分类器引导”(Classifier-Free Guidance),旨在以纯生成模型的方式实现类似的效果。本文将为熟悉扩散模型的深度学习研究者介绍其核心观点、方法及关键数学公式,并加以解释。


核心观点

论文的核心贡献在于证明了扩散模型无需依赖外部分类器即可实现样本质量与多样性的权衡。传统的分类器引导通过结合扩散模型的分数估计(score estimate)和分类器的梯度来调整采样方向,而无分类器引导则通过联合训练一个条件扩散模型和一个无条件扩散模型,并在采样时混合两者的分数估计来达到类似目的。这种方法不仅简化了训练流程,还避免了分类器梯度可能带来的对抗性解释(如对分类器基于指标的优化)。

主要观点包括:

  1. 纯生成模型的能力:无分类器引导表明,扩散模型本身足以生成高质量样本,无需借助分类器。
  2. 训练与采样的简单性:通过在训练时随机丢弃条件信息,以及在采样时线性组合条件与无条件分数,方法实现起来非常直观。
  3. 效果验证:实验表明,无分类器引导能在 FID(Fréchet Inception Distance)和 IS(Inception Score)之间实现与分类器引导相似的权衡曲线。

方法详解
1. 背景:扩散模型的训练与采样

扩散模型通过正向过程逐步向数据添加噪声,并在逆向过程中从噪声中恢复数据。给定数据 ( x ∼ p ( x ) \mathbf{x} \sim p(\mathbf{x}) xp(x)),正向过程定义为:
q ( z λ ∣ x ) = N ( α λ x , σ λ 2 I ) , q(\mathbf{z}_\lambda \mid \mathbf{x}) = \mathcal{N}(\alpha_\lambda \mathbf{x}, \sigma_\lambda^2 \mathbf{I}), q(zλx)=N(αλx,σλ2I),
其中 ( α λ = 1 / ( 1 + e − λ ) \alpha_\lambda = \sqrt{1 / (1 + e^{-\lambda})} αλ=1/(1+eλ) ),( σ λ 2 = 1 − α λ 2 \sigma_\lambda^2 = 1 - \alpha_\lambda^2 σλ2=1αλ2),( λ \lambda λ) 是信噪比的对数(log signal-to-noise ratio)。逆向过程则通过学习一个参数化的模型 ( p θ ( z λ ′ ∣ z λ ) p_\theta(\mathbf{z}_{\lambda'} \mid \mathbf{z}_\lambda) pθ(zλzλ)) 来近似数据的分布,通常使用去噪分数匹配目标:
E ϵ , λ [ ∥ ϵ θ ( z λ ) − ϵ ∥ 2 2 ] , \mathbb{E}_{\epsilon, \lambda} \left[ \left\| \epsilon_\theta(\mathbf{z}_\lambda) - \epsilon \right\|_2^2 \right], Eϵ,λ[ϵθ(zλ)ϵ22],
其中 ( z λ = α λ x + σ λ ϵ \mathbf{z}_\lambda = \alpha_\lambda \mathbf{x} + \sigma_\lambda \epsilon zλ=αλx+σλϵ),( ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵN(0,I)),( ϵ θ ( z λ ) \epsilon_\theta(\mathbf{z}_\lambda) ϵθ(zλ)) 是模型预测的噪声。

对于条件生成(如类条件图像生成),只需将条件 ( c \mathbf{c} c) 输入模型,变为 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c))。

2. 分类器引导的局限

分类器引导通过调整分数估计来提升样本质量:
ϵ ~ θ ( z λ , c ) = ϵ θ ( z λ , c ) − w σ λ ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) , \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) - w \sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda), ϵ~θ(zλ,c)=ϵθ(zλ,c)wσλzλlogpϕ(czλ),
其中 ( w w w) 是引导强度,( ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) zλlogpϕ(czλ)) 是分类器对 ( z λ \mathbf{z}_\lambda zλ) 的梯度。这相当于采样近似分布(推导见下文):
p ~ θ ( z λ ∣ c ) ∝ p θ ( z λ ∣ c ) p ϕ ( c ∣ z λ ) w 。 \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) \propto p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w。 p~θ(zλc)pθ(zλc)pϕ(czλ)w
然而,这需要额外训练一个分类器 ( p ϕ p_\phi pϕ),且分类器必须在噪声数据上训练,无法直接使用预训练模型。此外,这种方法可能被视为对分类器的对抗性优化,引发对结果真实性的质疑。

3. 无分类器引导的核心方法

无分类器引导提出了一种替代方案,通过联合训练条件模型 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 和无条件模型 ( ϵ θ ( z λ ) \epsilon_\theta(\mathbf{z}_\lambda) ϵθ(zλ)) 来实现引导。具体步骤如下:

  • 联合训练:使用单一神经网络同时建模条件和无条件分布。在训练时,以概率 ( p uncond p_{\text{uncond}} puncond) 随机将条件 ( c \mathbf{c} c) 替换为无条件标识符(如 ( ∅ \varnothing )),从而同时优化:

    • 条件分数 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)),
    • 无条件分数 ( ϵ θ ( z λ ) = ϵ θ ( z λ , c = ∅ ) \epsilon_\theta(\mathbf{z}_\lambda) = \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c} = \varnothing) ϵθ(zλ)=ϵθ(zλ,c=))。

    训练算法如下(伪代码摘自论文):
    repeat
    \quad ( x , c )   p ( x , c ) (x, c) ~ p(x, c) (x,c) p(x,c) // 从数据集中采样带条件的数据
    \quad c ← ∅ with probability p_uncond // 随机丢弃条件
    \quad λ ∼ p ( λ ) λ \sim p(λ) λp(λ) // 采样信噪比
    \quad ε ∼ N ( 0 , I ) ε \sim N(0, I) εN(0,I) // 采样噪声
    \quad z λ = α λ x + σ λ ε z_λ = α_λ x + σ_λ ε zλ=αλx+σλε // 添加噪声
    \quad 优化 ∇ θ ∣ ∣ ε θ ( z λ , c ) − ε ∣ ∣ 2 ∇_θ ||ε_θ(z_λ, c) - ε||² θ∣∣εθ(zλ,c)ε2 // 更新模型参数
    until converged

  • 采样时的分数混合:在采样时,通过线性组合条件和无条件分数来调整生成方向:
    ϵ ~ θ ( z λ , c ) = ( 1 + w ) ϵ θ ( z λ , c ) − w ϵ θ ( z λ ) , \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = (1 + w) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon_\theta(\mathbf{z}_\lambda), ϵ~θ(zλ,c)=(1+w)ϵθ(zλ,c)wϵθ(zλ),
    其中 ( w w w) 是引导强度。这等价于在条件分布上施加额外的引导信号,而不依赖分类器梯度。

    采样算法如下:

( w w w ):引导强度。
( c c c ):条件采样的条件信息
( λ 1 , … , λ T \lambda_1, \dots, \lambda_T λ1,,λT ):对数信噪比(SNR)的递增序列,其中 ( λ 1 = λ min \lambda_1 = \lambda_{\text{min}} λ1=λmin ),( λ T = λ max \lambda_T = \lambda_{\text{max}} λT=λmax )

  1. ( z 1 ∼ N ( 0 , I ) z_1 \sim \mathcal{N}(0, I) z1N(0,I) )
  2. 对于 ( t = 1 , … , T t = 1, \dots, T t=1,,T ) 执行
    • 在对数信噪比 ( λ t \lambda_t λt ) 处形成无分类器引导的得分
    • ( ϵ ^ t = ( 1 + w ) ϵ θ ( z t , c ) − w ϵ θ ( z t ) \hat{\epsilon}_t = (1 + w) \epsilon_\theta (z_t, c) - w \epsilon_\theta (z_t) ϵ^t=(1+w)ϵθ(zt,c)wϵθ(zt) )
    • 采样步骤(可以被其他采样器替换,例如 DDIM)
    • ( x ^ t = ( z t − σ t ϵ ^ t ) / α t \hat{x}_t = (z_t - \sigma_t \hat{\epsilon}_t) / \alpha_t x^t=(ztσtϵ^t)/αt )
    • ( z t + 1 ∼ N ( μ λ t + 1 ∣ λ t ( z t , x ^ t ) , ( σ λ t + 1 ∣ λ t 2 ) I ) 1 − v ( σ λ t ∣ λ t + 1 2 I ) v ) z_{t+1} \sim \mathcal{N}(\mu_{\lambda_{t+1}|\lambda_t}(z_t, \hat{x}_t), (\sigma^2_{\lambda_{t+1}|\lambda_t})I)^{1-v} (\sigma^2_{\lambda_t|\lambda_{t+1}}I)^v) zt+1N(μλt+1λt(zt,x^t),(σλt+1λt2)I)1v(σλtλt+12I)v) ) 如果 ( t < T t < T t<T ) 否则 ( z t + 1 = x ^ t z_{t+1} = \hat{x}_t zt+1=x^t )
  3. 结束循环
  4. 返回 ( z T + 1 z_{T+1} zT+1 )

4. 数学解释

无分类器引导的灵感来源于隐式分类器 ( p i ( c ∣ z λ ) ∝ p ( z λ ∣ c ) / p ( z λ ) p^i(\mathbf{c} \mid \mathbf{z}_\lambda) \propto p(\mathbf{z}_\lambda \mid \mathbf{c}) / p(\mathbf{z}_\lambda) pi(czλ)p(zλc)/p(zλ))。解释见下文。若有精确分数:
∇ z λ log ⁡ p i ( c ∣ z λ ) = − 1 σ λ [ ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) ] , \nabla_{\mathbf{z}_\lambda} \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = -\frac{1}{\sigma_\lambda} [\epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda)], zλlogpi(czλ)=σλ1[ϵ(zλ,c)ϵ(zλ)],
将其代入分类器引导公式可得:
ϵ ~ ∗ ( z λ , c ) = ( 1 + w ) ϵ ∗ ( z λ , c ) − w ϵ ∗ ( z λ ) 。 \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = (1 + w) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon^*(\mathbf{z}_\lambda)。 ϵ~(zλ,c)=(1+w)ϵ(zλ,c)wϵ(zλ)
这与无分类器引导的形式一致。然而,由于 ( ϵ θ \epsilon_\theta ϵθ) 是神经网络的输出,不一定对应某个标量势函数的梯度,因此 ( ϵ ~ θ \tilde{\epsilon}_\theta ϵ~θ) 并非严格的分类器引导,而是通过分数差间接模拟了条件分布的增强。

这种方法的直观解释是:条件分数 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 推动样本朝特定条件方向移动,而无条件分数 ( ϵ θ ( z λ ) \epsilon_\theta(\mathbf{z}_\lambda) ϵθ(zλ)) 提供全局分布的约束,二者混合后增强了条件方向的“确定性”,从而提升样本质量。


实验验证

论文在 ImageNet 数据集上验证了方法的有效性:

  • 64x64 分辨率:在 ( w = 0.1 w=0.1 w=0.1) 时获得最佳 FID(1.55),在 ( w = 4.0 w=4.0 w=4.0) 时获得最佳 IS(260.2)。
  • 128x128 分辨率:在 ( w = 0.3 w=0.3 w=0.3) 时 FID 达 2.43,优于分类器引导的 ADM-G;在 ( w = 4.0 w=4.0 w=4.0) 时 IS 达 422.29,超越 BigGAN-deep。
  • 超参数影响:( p uncond = 0.1 或 0.2 p_{\text{uncond}} = 0.1 或 0.2 puncond=0.10.2) 时效果最佳,表明只需少量无条件训练即可实现有效引导。

样本图像显示,随着 ( w w w) 增加,样本多样性降低,但个体质量(如颜色饱和度)显著提升。


讨论与意义

无分类器引导的优势在于其简单性和纯生成性,避免了分类器训练的复杂性,同时证明了扩散模型自身的潜力。相比之下,其采样速度可能较慢(需要两次前向传播),但这可以通过网络结构优化(如延迟条件注入)来缓解。

对于深度学习研究者,这一方法提供了一个新思路:通过分数估计的组合,扩散模型可以在不引入外部监督的情况下实现灵活的生成控制。未来可探索其在多模态数据或高维条件上的应用,以及如何在提升质量的同时保持多样性。


结语

《Classifier-Free Diffusion Guidance》展示了一种优雅而高效的扩散模型优化策略。通过联合训练和分数混合,它不仅简化了流程,还深化了我们对生成模型能力的理解。对于研究者而言,这篇文章是值得深入挖掘的宝藏,或许能启发更多创新的生成方法。

“相当于采样近似分布”的数学推导过程

详细解释分类器引导中“相当于采样近似分布”的数学推导过程,特别是公式 ( p ~ θ ( z λ ∣ c ) ∝ p θ ( z λ ∣ c ) p ϕ ( c ∣ z λ ) w \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) \propto p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w p~θ(zλc)pθ(zλc)pϕ(czλ)w) 的由来。这部分涉及扩散模型的采样过程和分数估计的含义,适合熟悉扩散模型的读者深入理解。


背景:扩散模型与分数估计

扩散模型的核心是通过逆向过程从噪声分布逐步生成数据。逆向过程依赖于分数估计(score estimate),即数据分布的对数密度的梯度 ( ∇ z λ log ⁡ p ( z λ ) \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda) zλlogp(zλ))。在条件扩散模型中,分数估计变为 ( ∇ z λ log ⁡ p ( z λ ∣ c ) \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) zλlogp(zλc)),由模型 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 近似,表示为:
ϵ θ ( z λ , c ) ≈ − σ λ ∇ z λ log ⁡ p ( z λ ∣ c ) , \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) \approx -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}), ϵθ(zλ,c)σλzλlogp(zλc),
其中 ( σ λ \sigma_\lambda σλ) 是噪声尺度,( z λ \mathbf{z}_\lambda zλ) 是给定信噪比 ( λ \lambda λ) 时的噪声数据。

在采样时,扩散模型通过 Langevin 动力学或类似方法,利用分数估计逐步更新 ( z λ \mathbf{z}_\lambda zλ),以逼近目标分布 ( p ( z λ ∣ c ) p(\mathbf{z}_\lambda \mid \mathbf{c}) p(zλc))。


分类器引导的分数调整

分类器引导引入了一个额外的分类器 ( p ϕ ( c ∣ z λ ) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) pϕ(czλ)),并调整原始分数估计:
ϵ ~ θ ( z λ , c ) = ϵ θ ( z λ , c ) − w σ λ ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) , \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) - w \sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda), ϵ~θ(zλ,c)=ϵθ(zλ,c)wσλzλlogpϕ(czλ),
其中 ( w w w) 是引导强度参数。目标是理解这一调整如何影响采样的分布。

将 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 的定义代入:
ϵ ~ θ ( z λ , c ) = − σ λ ∇ z λ log ⁡ p ( z λ ∣ c ) − w σ λ ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) 。 \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) - w \sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)。 ϵ~θ(zλ,c)=σλzλlogp(zλc)wσλzλlogpϕ(czλ)
提取公共因子 ( σ λ \sigma_\lambda σλ):
ϵ ~ θ ( z λ , c ) = − σ λ [ ∇ z λ log ⁡ p ( z λ ∣ c ) + w ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) ] 。 \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = -\sigma_\lambda \left[ \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) \right]。 ϵ~θ(zλ,c)=σλ[zλlogp(zλc)+wzλlogpϕ(czλ)]
根据梯度的线性性质:
∇ z λ log ⁡ p ( z λ ∣ c ) + w ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) = ∇ z λ [ log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) ] 。 \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) = \nabla_{\mathbf{z}_\lambda} \left[ \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) \right]。 zλlogp(zλc)+wzλlogpϕ(czλ)=zλ[logp(zλc)+wlogpϕ(czλ)]
因此:
ϵ ~ θ ( z λ , c ) = − σ λ ∇ z λ [ log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) ] 。 \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \left[ \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) \right]。 ϵ~θ(zλ,c)=σλzλ[logp(zλc)+wlogpϕ(czλ)]
这表明调整后的分数 ( ϵ ~ θ ( z λ , c ) \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵ~θ(zλ,c)) 是某个新分布对数密度的梯度乘以 ( − σ λ -\sigma_\lambda σλ)。


新分布的定义

在扩散模型中,分数 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 定义了采样分布的对数梯度。假设 ( ϵ ~ θ ( z λ , c ) \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵ~θ(zλ,c)) 是近似正确的分数(即训练充分且误差较小),我们可以将其视为某个目标分布 ( p ~ θ ( z λ ∣ c ) \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p~θ(zλc)) 的分数:
ϵ ~ θ ( z λ , c ) ≈ − σ λ ∇ z λ log ⁡ p ~ θ ( z λ ∣ c ) 。 \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) \approx -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c})。 ϵ~θ(zλ,c)σλzλlogp~θ(zλc)
结合上式:
− σ λ ∇ z λ log ⁡ p ~ θ ( z λ ∣ c ) = − σ λ ∇ z λ [ log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) ] 。 -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) = -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \left[ \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) \right]。 σλzλlogp~θ(zλc)=σλzλ[logp(zλc)+wlogpϕ(czλ)]
两边除以 ( − σ λ -\sigma_\lambda σλ)(假设 ( σ λ ≠ 0 \sigma_\lambda \neq 0 σλ=0)):
∇ z λ log ⁡ p ~ θ ( z λ ∣ c ) = ∇ z λ [ log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) ] 。 \nabla_{\mathbf{z}_\lambda} \log \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) = \nabla_{\mathbf{z}_\lambda} \left[ \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) \right]。 zλlogp~θ(zλc)=zλ[logp(zλc)+wlogpϕ(czλ)]
由于梯度相等,( log ⁡ p ~ θ ( z λ ∣ c ) \log \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) logp~θ(zλc)) 和 ( log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) logp(zλc)+wlogpϕ(czλ)) 在数学上应相差一个与 ( z λ \mathbf{z}_\lambda zλ) 无关的常数(归一化常数)。因此:
log ⁡ p ~ θ ( z λ ∣ c ) = log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) + C , \log \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) = \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) + C, logp~θ(zλc)=logp(zλc)+wlogpϕ(czλ)+C,
其中 ( C C C) 是归一化常数。对两边取指数:
p ~ θ ( z λ ∣ c ) = e log ⁡ p ( z λ ∣ c ) + w log ⁡ p ϕ ( c ∣ z λ ) + C = e C ⋅ p ( z λ ∣ c ) ⋅ p ϕ ( c ∣ z λ ) w 。 \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) = e^{\log p(\mathbf{z}_\lambda \mid \mathbf{c}) + w \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) + C} = e^C \cdot p(\mathbf{z}_\lambda \mid \mathbf{c}) \cdot p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w。 p~θ(zλc)=elogp(zλc)+wlogpϕ(czλ)+C=eCp(zλc)pϕ(czλ)w
由于 ( p ~ θ ( z λ ∣ c ) \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p~θ(zλc)) 是概率密度,需满足归一化条件 ( ∫ p ~ θ ( z λ ∣ c ) d z λ = 1 \int \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) d\mathbf{z}_\lambda = 1 p~θ(zλc)dzλ=1)。令 ( Z = e C Z = e^C Z=eC) 为归一化因子:
p ~ θ ( z λ ∣ c ) = p ( z λ ∣ c ) p ϕ ( c ∣ z λ ) w ∫ p ( z λ ′ ∣ c ) p ϕ ( c ∣ z λ ′ ) w d z λ ′ 。 \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) = \frac{p(\mathbf{z}_\lambda \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w}{\int p(\mathbf{z}_\lambda' \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda')^w d\mathbf{z}_\lambda'}。 p~θ(zλc)=p(zλc)pϕ(czλ)wdzλp(zλc)pϕ(czλ)w
在实践中,扩散模型采样时通常不显式计算 ( Z Z Z),而是通过分数直接逼近分布。因此,论文中简化为比例形式:
p ~ θ ( z λ ∣ c ) ∝ p θ ( z λ ∣ c ) p ϕ ( c ∣ z λ ) w 。 \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) \propto p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w。 p~θ(zλc)pθ(zλc)pϕ(czλ)w
这里的 ( p θ ( z λ ∣ c ) p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) pθ(zλc)) 是模型近似的条件分布,替换了理论上的 ( p ( z λ ∣ c ) p(\mathbf{z}_\lambda \mid \mathbf{c}) p(zλc))。


直观解释

  • 原始分布 ( p θ ( z λ ∣ c ) p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) pθ(zλc)):扩散模型试图采样的条件分布。
  • 分类器项 ( p ϕ ( c ∣ z λ ) w p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w pϕ(czλ)w):分类器对 ( z λ \mathbf{z}_\lambda zλ) 属于条件 ( c \mathbf{c} c) 的置信度,( w w w) 控制其影响强度。乘以 ( p ϕ ( c ∣ z λ ) w p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w pϕ(czλ)w) 相当于对更符合条件 ( c \mathbf{c} c) 的样本赋予更高权重。
  • 效果:新分布 ( p ~ θ ( z λ ∣ c ) \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p~θ(zλc)) 倾向于生成分类器高置信度的样本,提升了样本质量(如 IS),但可能减少多样性。

为什么是“近似”分布?

  1. 模型误差:( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 和 ( p ϕ ( c ∣ z λ ) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) pϕ(czλ)) 都是近似估计,而非精确分布的分数。
  2. 采样过程:扩散模型通过有限步数的迭代逼近目标分布,实际采样的分布可能偏离理论上的 ( p ~ θ ( z λ ∣ c ) \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p~θ(zλc))。
  3. 非保守场:神经网络输出的 ( ϵ ~ θ \tilde{\epsilon}_\theta ϵ~θ) 不一定对应某个标量势函数的梯度,因此严格来说 ( p ~ θ \tilde{p}_\theta p~θ) 可能不是一个可精确定义的分布。

总结

“相当于采样近似分布 ( p ~ θ ( z λ ∣ c ) ∝ p θ ( z λ ∣ c ) p ϕ ( c ∣ z λ ) w \tilde{p}_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) \propto p_\theta(\mathbf{z}_\lambda \mid \mathbf{c}) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda)^w p~θ(zλc)pθ(zλc)pϕ(czλ)w)”的推导源于分数调整的数学性质:通过将分类器梯度融入分数估计,采样过程被引导向一个新的分布,这个分布增强了条件一致性。这种形式揭示了分类器引导的本质——通过外部监督调整生成方向,而无分类器引导则试图用纯生成模型模拟这一效果。

详细解释“隐式分类器”(implicit classifier)

详细解释“隐式分类器”(implicit classifier)是什么,以及它在《Classifier-Free Diffusion Guidance》论文中的数学推导和意义。这部分内容面向熟悉扩散模型的深度学习研究者,帮助理解无分类器引导的灵感来源及其与分类器引导的联系。


什么是隐式分类器?

在机器学习中,“隐式分类器”并不是一个直接训练得到的分类模型,而是通过生成模型的概率分布,利用贝叶斯规则间接推导出的分类器。具体来说,给定生成模型的条件分布 ( p ( z λ ∣ c ) p(\mathbf{z}_\lambda \mid \mathbf{c}) p(zλc)) 和无条件分布 ( p ( z λ ) p(\mathbf{z}_\lambda) p(zλ)),隐式分类器 ( p i ( c ∣ z λ ) p^i(\mathbf{c} \mid \mathbf{z}_\lambda) pi(czλ)) 定义为:
p i ( c ∣ z λ ) = p ( z λ ∣ c ) p ( c ) p ( z λ ) , p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = \frac{p(\mathbf{z}_\lambda \mid \mathbf{c}) p(\mathbf{c})}{p(\mathbf{z}_\lambda)}, pi(czλ)=p(zλ)p(zλc)p(c),
其中:

  • ( p ( z λ ∣ c ) p(\mathbf{z}_\lambda \mid \mathbf{c}) p(zλc)) 是条件生成模型的概率密度,表示在条件 ( c \mathbf{c} c) 下生成 ( z λ \mathbf{z}_\lambda zλ) 的似然;
  • ( p ( c ) p(\mathbf{c}) p(c)) 是条件 ( c \mathbf{c} c) 的先验概率;
  • ( p ( z λ ) = ∑ c ′ p ( z λ ∣ c ′ ) p ( c ′ ) p(\mathbf{z}_\lambda) = \sum_{\mathbf{c}'} p(\mathbf{z}_\lambda \mid \mathbf{c}') p(\mathbf{c}') p(zλ)=cp(zλc)p(c)) 是 ( z λ \mathbf{z}_\lambda zλ) 的边际分布。

由于 ( p ( z λ ) p(\mathbf{z}_\lambda) p(zλ)) 在给定 ( z λ \mathbf{z}_\lambda zλ) 时是一个常数(不依赖于 ( c \mathbf{c} c)),论文中将其简化为比例形式:
p i ( c ∣ z λ ) ∝ p ( z λ ∣ c ) p ( z λ ) , p^i(\mathbf{c} \mid \mathbf{z}_\lambda) \propto \frac{p(\mathbf{z}_\lambda \mid \mathbf{c})}{p(\mathbf{z}_\lambda)}, pi(czλ)p(zλ)p(zλc),
这里的比例符号 ( ∝ \propto ) 表示忽略了归一化因子 ( p ( c ) / p ( z λ ) p(\mathbf{c}) / p(\mathbf{z}_\lambda) p(c)/p(zλ)),因为在计算梯度时,常数因子不会影响结果。

直观含义

隐式分类器是通过生成模型“反推”得到的分类器。它没有显式训练一个独立的分类模型,而是利用生成模型已经学习到的分布信息,通过贝叶斯规则间接判断 ( z λ \mathbf{z}_\lambda zλ) 属于某个条件 ( c \mathbf{c} c) 的概率。这种方法在生成模型研究中常见,因为生成模型天然提供了 ( p ( z λ ∣ c ) p(\mathbf{z}_\lambda \mid \mathbf{c}) p(zλc)),而边际分布 ( p ( z λ ) p(\mathbf{z}_\lambda) p(zλ)) 可以通过条件分布整合得到。


隐式分类器在论文中的作用

在无分类器引导的数学解释中,作者提出其方法灵感来源于隐式分类器。具体来说,他们考虑如果用这个隐式分类器的梯度来引导扩散模型,会得到与无分类器引导形式相似的分数调整。让我们逐步推导。

1. 计算隐式分类器的梯度

假设我们有精确的分数(score),即:

  • ( ϵ ∗ ( z λ , c ) = − σ λ ∇ z λ log ⁡ p ( z λ ∣ c ) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) = -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) ϵ(zλ,c)=σλzλlogp(zλc)),是条件分布的分数;
  • ( ϵ ∗ ( z λ ) = − σ λ ∇ z λ log ⁡ p ( z λ ) \epsilon^*(\mathbf{z}_\lambda) = -\sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda) ϵ(zλ)=σλzλlogp(zλ)),是无条件分布的分数。

隐式分类器的对数为:
log ⁡ p i ( c ∣ z λ ) = log ⁡ ( p ( z λ ∣ c ) p ( c ) p ( z λ ) ) = log ⁡ p ( z λ ∣ c ) + log ⁡ p ( c ) − log ⁡ p ( z λ ) 。 \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = \log \left( \frac{p(\mathbf{z}_\lambda \mid \mathbf{c}) p(\mathbf{c})}{p(\mathbf{z}_\lambda)} \right) = \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + \log p(\mathbf{c}) - \log p(\mathbf{z}_\lambda)。 logpi(czλ)=log(p(zλ)p(zλc)p(c))=logp(zλc)+logp(c)logp(zλ)
对其求梯度:
∇ z λ log ⁡ p i ( c ∣ z λ ) = ∇ z λ log ⁡ p ( z λ ∣ c ) + ∇ z λ log ⁡ p ( c ) − ∇ z λ log ⁡ p ( z λ ) 。 \nabla_{\mathbf{z}_\lambda} \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) + \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{c}) - \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda)。 zλlogpi(czλ)=zλlogp(zλc)+zλlogp(c)zλlogp(zλ)
由于 ( log ⁡ p ( c ) \log p(\mathbf{c}) logp(c)) 是 ( c \mathbf{c} c) 的先验,与 ( z λ \mathbf{z}_\lambda zλ) 无关,其梯度为零,因此:
∇ z λ log ⁡ p i ( c ∣ z λ ) = ∇ z λ log ⁡ p ( z λ ∣ c ) − ∇ z λ log ⁡ p ( z λ ) 。 \nabla_{\mathbf{z}_\lambda} \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) - \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda)。 zλlogpi(czλ)=zλlogp(zλc)zλlogp(zλ)
将分数定义代入:
∇ z λ log ⁡ p ( z λ ∣ c ) = − 1 σ λ ϵ ∗ ( z λ , c ) , ∇ z λ log ⁡ p ( z λ ) = − 1 σ λ ϵ ∗ ( z λ ) , \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda \mid \mathbf{c}) = -\frac{1}{\sigma_\lambda} \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}), \quad \nabla_{\mathbf{z}_\lambda} \log p(\mathbf{z}_\lambda) = -\frac{1}{\sigma_\lambda} \epsilon^*(\mathbf{z}_\lambda), zλlogp(zλc)=σλ1ϵ(zλ,c),zλlogp(zλ)=σλ1ϵ(zλ),
于是:
∇ z λ log ⁡ p i ( c ∣ z λ ) = − 1 σ λ ϵ ∗ ( z λ , c ) − ( − 1 σ λ ϵ ∗ ( z λ ) ) = − 1 σ λ [ ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) ] 。 \nabla_{\mathbf{z}_\lambda} \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = -\frac{1}{\sigma_\lambda} \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \left( -\frac{1}{\sigma_\lambda} \epsilon^*(\mathbf{z}_\lambda) \right) = -\frac{1}{\sigma_\lambda} \left[ \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda) \right]。 zλlogpi(czλ)=σλ1ϵ(zλ,c)(σλ1ϵ(zλ))=σλ1[ϵ(zλ,c)ϵ(zλ)]
这就是论文中给出的公式:
∇ z λ log ⁡ p i ( c ∣ z λ ) = − 1 σ λ [ ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) ] 。 \nabla_{\mathbf{z}_\lambda} \log p^i(\mathbf{c} \mid \mathbf{z}_\lambda) = -\frac{1}{\sigma_\lambda} [\epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda)]。 zλlogpi(czλ)=σλ1[ϵ(zλ,c)ϵ(zλ)]

2. 将隐式分类器梯度用于引导

在分类器引导中,分数调整为:
ϵ ~ ∗ ( z λ , c ) = ϵ ∗ ( z λ , c ) − w σ λ ∇ z λ log ⁡ p ϕ ( c ∣ z λ ) , \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \sigma_\lambda \nabla_{\mathbf{z}_\lambda} \log p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda), ϵ~(zλ,c)=ϵ(zλ,c)wσλzλlogpϕ(czλ),
其中 ( p ϕ ( c ∣ z λ ) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) pϕ(czλ)) 是显式训练的分类器。现在,如果我们用隐式分类器 ( p i ( c ∣ z λ ) p^i(\mathbf{c} \mid \mathbf{z}_\lambda) pi(czλ)) 替换 ( p ϕ ( c ∣ z λ ) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) pϕ(czλ)),代入其梯度:
ϵ ~ ∗ ( z λ , c ) = ϵ ∗ ( z λ , c ) − w σ λ ( − 1 σ λ [ ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) ] ) 。 \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \sigma_\lambda \left( -\frac{1}{\sigma_\lambda} [\epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda)] \right)。 ϵ~(zλ,c)=ϵ(zλ,c)wσλ(σλ1[ϵ(zλ,c)ϵ(zλ)])
化简:
ϵ ~ ∗ ( z λ , c ) = ϵ ∗ ( z λ , c ) + w [ ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) ] 。 \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) + w [\epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda)]。 ϵ~(zλ,c)=ϵ(zλ,c)+w[ϵ(zλ,c)ϵ(zλ)]
整理:
ϵ ~ ∗ ( z λ , c ) = ϵ ∗ ( z λ , c ) + w ϵ ∗ ( z λ , c ) − w ϵ ∗ ( z λ ) = ( 1 + w ) ϵ ∗ ( z λ , c ) − w ϵ ∗ ( z λ ) 。 \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) + w \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon^*(\mathbf{z}_\lambda) = (1 + w) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon^*(\mathbf{z}_\lambda)。 ϵ~(zλ,c)=ϵ(zλ,c)+wϵ(zλ,c)wϵ(zλ)=(1+w)ϵ(zλ,c)wϵ(zλ)
这正是论文中给出的形式:
ϵ ~ ∗ ( z λ , c ) = ( 1 + w ) ϵ ∗ ( z λ , c ) − w ϵ ∗ ( z λ ) 。 \tilde{\epsilon}^*(\mathbf{z}_\lambda, \mathbf{c}) = (1 + w) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon^*(\mathbf{z}_\lambda)。 ϵ~(zλ,c)=(1+w)ϵ(zλ,c)wϵ(zλ)

3. 与无分类器引导的联系

无分类器引导直接定义:
ϵ ~ θ ( z λ , c ) = ( 1 + w ) ϵ θ ( z λ , c ) − w ϵ θ ( z λ ) , \tilde{\epsilon}_\theta(\mathbf{z}_\lambda, \mathbf{c}) = (1 + w) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) - w \epsilon_\theta(\mathbf{z}_\lambda), ϵ~θ(zλ,c)=(1+w)ϵθ(zλ,c)wϵθ(zλ),
形式上与隐式分类器引导的结果完全一致。这表明,无分类器引导可以看作是用生成模型自身的条件和无条件分数,模拟了隐式分类器的引导效果。


隐式分类器的意义与局限

意义
  • 灵感来源:隐式分类器提供了一个理论依据,解释了为什么条件分数和无条件分数的线性组合能起到引导作用。它本质上是利用生成模型的分布差异(( ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda) ϵ(zλ,c)ϵ(zλ)))来增强条件一致性。
  • 纯生成性:不像分类器引导需要额外训练 ( p ϕ ( c ∣ z λ ) p_\phi(\mathbf{c} \mid \mathbf{z}_\lambda) pϕ(czλ)),隐式分类器完全依赖生成模型已有信息,与无分类器引导的“无外部监督”理念契合。
局限
  • 理论与实践的差异:论文指出,( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 和 ( ϵ θ ( z λ ) \epsilon_\theta(\mathbf{z}_\lambda) ϵθ(zλ)) 是神经网络输出,不一定精确对应 ( log ⁡ p ( z λ ∣ c ) \log p(\mathbf{z}_\lambda \mid \mathbf{c}) logp(zλc)) 和 ( log ⁡ p ( z λ ) \log p(\mathbf{z}_\lambda) logp(zλ)) 的梯度,因此 ( ϵ ~ θ \tilde{\epsilon}_\theta ϵ~θ) 并非严格的隐式分类器引导,而是近似模拟。
  • 性能不确定性:文献(如 Grandvalet & Bengio, 2004)表明,隐式分类器在模型misspecification的情况下可能不一致,这也解释了为什么无分类器引导的效果需要实验验证,而非理论保证。

总结

隐式分类器 ( p i ( c ∣ z λ ) ∝ p ( z λ ∣ c ) / p ( z λ ) p^i(\mathbf{c} \mid \mathbf{z}_\lambda) \propto p(\mathbf{z}_\lambda \mid \mathbf{c}) / p(\mathbf{z}_\lambda) pi(czλ)p(zλc)/p(zλ)) 是通过生成模型分布反推出的分类器,其梯度 ( ϵ ∗ ( z λ , c ) − ϵ ∗ ( z λ ) \epsilon^*(\mathbf{z}_\lambda, \mathbf{c}) - \epsilon^*(\mathbf{z}_\lambda) ϵ(zλ,c)ϵ(zλ)) 捕捉了条件分布相对于全局分布的“方向性”。无分类器引导借鉴这一思想,用神经网络预测的分数差来实现类似效果,避免了显式分类器的训练。这种方法既优雅又实用,为扩散模型的优化提供了一个新视角。

代码实现

下面将提供一个简化的、可运行的 PyTorch 代码实现,用于演示《Classifier-Free Diffusion Guidance》的核心思想。这个实现基于一个简单的二维高斯混合分布,便于理解和可视化,同时保留了论文中的关键步骤(联合训练条件与无条件模型、分数混合采样)。会详细解释每个部分的逻辑和实现细节,面向熟悉扩散模型的深度学习研究者。


代码目标

  • 数据:生成一个二维高斯混合分布(例如两个类别的点云)。
  • 模型:训练一个扩散模型,同时支持条件生成(给定类别)和无条件生成(丢弃类别信息)。
  • 采样:实现无分类器引导,调整引导强度 ( w w w),生成并可视化结果。

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 数据生成:二维高斯混合分布
def generate_data(n_samples=1000):
    # 两个类别的高斯分布
    mean1, cov1 = torch.tensor([2.0, 2.0]), torch.eye(2) * 0.5
    mean2, cov2 = torch.tensor([-2.0, -2.0]), torch.eye(2) * 0.5
    data1 = torch.distributions.MultivariateNormal(mean1, cov1).sample((n_samples // 2,))
    data2 = torch.distributions.MultivariateNormal(mean2, cov2).sample((n_samples // 2,))
    x = torch.cat([data1, data2], dim=0)
    c = torch.cat([torch.zeros(n_samples // 2), torch.ones(n_samples // 2)]).long()
    return x.to(device), c.to(device)

# 2. 噪声调度
def get_alpha_sigma(t):
    # t 是 [0, 1] 之间的归一化时间步
    lambda_t = -10 + 20 * t  # λ 从 -10 到 10
    alpha_t = torch.sqrt(1 / (1 + torch.exp(-lambda_t)))
    sigma_t = torch.sqrt(1 - alpha_t**2)
    return alpha_t, sigma_t

# 3. 模型定义
class SimpleDiffusionModel(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=128, n_classes=2):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.class_embed = nn.Embedding(n_classes + 1, hidden_dim)  # +1 用于无条件(类别 -1)
        self.net = nn.Sequential(
            nn.Linear(input_dim + hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x, t, c):
        # 时间嵌入
        t = t.view(-1, 1)
        t_embed = self.time_embed(t)
        # 类别嵌入,无条件时 c = -1
        c_embed = self.class_embed(c + 1)  # 将类别从 [0, 1] 映射到 [1, 2],-1 映射到 0
        # 输入拼接
        combined = torch.cat([x, t_embed, c_embed], dim=-1)
        return self.net(combined)

# 4. 训练函数
def train_model(model, x, c, n_steps=1000, n_epochs=200, p_uncond=0.2):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for _ in range(n_steps):
            # 随机时间步
            t = torch.rand(x.shape[0], device=device)
            alpha_t, sigma_t = get_alpha_sigma(t)
            # 添加噪声
            epsilon = torch.randn_like(x)
            z_t = alpha_t[:, None] * x + sigma_t[:, None] * epsilon
            # 随机丢弃条件
            mask = (torch.rand(x.shape[0], device=device) < p_uncond).long()
            c_masked = c * (1 - mask) + (-1) * mask  # 无条件时 c = -1
            # 预测噪声
            epsilon_pred = model(z_t, t, c_masked)
            loss = torch.mean((epsilon_pred - epsilon) ** 2)
            # 优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch + 1}, Loss: {total_loss / n_steps:.4f}")
    return model

# 5. 无分类器引导采样
def sample_with_guidance(model, n_samples=500, n_steps=100, w=1.0, c_target=0):
    model.eval()
    with torch.no_grad():
        # 从纯噪声开始
        z = torch.randn(n_samples, 2, device=device)
        for i in range(n_steps - 1, -1, -1):
            t = torch.full((n_samples,), i / n_steps, device=device)
            alpha_t, sigma_t = get_alpha_sigma(t)
            # 计算条件和无条件分数
            c_cond = torch.full((n_samples,), c_target, dtype=torch.long, device=device)
            c_uncond = torch.full((n_samples,), -1, dtype=torch.long, device=device)
            epsilon_cond = model(z, t, c_cond)
            epsilon_uncond = model(z, t, c_uncond)
            # 无分类器引导分数
            epsilon_guided = (1 + w) * epsilon_cond - w * epsilon_uncond
            # 更新 z
            x_tilde = (z - sigma_t[:, None] * epsilon_guided) / alpha_t[:, None]
            if i > 0:
                t_next = torch.full((n_samples,), (i - 1) / n_steps, device=device)
                alpha_next, sigma_next = get_alpha_sigma(t_next)
                mu = (alpha_next / alpha_t) * z + (sigma_next**2 / sigma_t) * (x_tilde - z / alpha_t)
                z = mu + torch.randn_like(z) * torch.sqrt(sigma_next**2 - sigma_t**2 * (1 - sigma_next**2 / sigma_t**2))
            else:
                z = x_tilde
    return z

# 6. 主程序
if __name__ == "__main__":
    # 生成数据
    x, c = generate_data(n_samples=1000)
    # 初始化并训练模型
    model = SimpleDiffusionModel().to(device)
    model = train_model(model, x, c, n_steps=100, n_epochs=200, p_uncond=0.2)
    
    # 采样并可视化
    plt.figure(figsize=(12, 4))
    for i, w in enumerate([0.0, 1.0, 3.0]):
        samples = sample_with_guidance(model, n_samples=500, n_steps=100, w=w, c_target=0)
        samples = samples.cpu().numpy()
        plt.subplot(1, 3, i + 1)
        plt.scatter(samples[:, 0], samples[:, 1], s=5, alpha=0.5)
        plt.title(f"Guidance Strength w={w}")
        plt.xlim(-5, 5)
        plt.ylim(-5, 5)
    plt.tight_layout()
    plt.show()

代码详细解释

1. 数据生成 (generate_data)
  • 目的:生成一个简单的二维高斯混合分布,模拟条件生成任务。
  • 实现:两个类别(( c = 0 c=0 c=0) 和 ( c = 1 c=1 c=1)),分别以均值 ( [ 2 , 2 ] [2, 2] [2,2]) 和 ( [ − 2 , − 2 ] [-2, -2] [2,2]) 为中心,方差为 ( 0.5 0.5 0.5) 的高斯分布。
  • 输出:数据 ( x \mathbf{x} x)(形状 ( [ 1000 , 2 ] [1000, 2] [1000,2]))和类别标签 ( c \mathbf{c} c)(形状 ([1000]))。
2. 噪声调度 (get_alpha_sigma)
  • 目的:定义扩散过程中的 ( α λ \alpha_\lambda αλ) 和 ( σ λ \sigma_\lambda σλ),控制信号和噪声的比例。
  • 实现:基于论文中的 ( α λ = 1 / ( 1 + e − λ ) \alpha_\lambda = \sqrt{1 / (1 + e^{-\lambda})} αλ=1/(1+eλ) ),( σ λ = 1 − α λ 2 \sigma_\lambda = \sqrt{1 - \alpha_\lambda^2} σλ=1αλ2 )。这里用归一化时间步 ( t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1]) 映射到 ( λ ∈ [ − 10 , 10 ] \lambda \in [-10, 10] λ[10,10])。
  • 解释:( λ \lambda λ) 模拟信噪比的变化,( t = 0 t=0 t=0) 时接近纯噪声,( t = 1 t=1 t=1) 时接近原始数据。
3. 模型定义 (SimpleDiffusionModel)
  • 结构
    • 时间嵌入:将时间步 ( t t t) 映射到隐藏维度(128)。
    • 类别嵌入:支持 (n_classes + 1) 个类别(包括无条件类别 ( − 1 -1 1))。
    • 主网络:输入为 ( z λ \mathbf{z}_\lambda zλ)、时间嵌入和类别嵌入的拼接,输出预测噪声 ( ϵ θ \epsilon_\theta ϵθ)。
  • 输入
    • ( x \mathbf{x} x):噪声数据 ( z λ \mathbf{z}_\lambda zλ)。
    • ( t t t):当前时间步。
    • ( c \mathbf{c} c):类别(( − 1 -1 1) 表示无条件)。
  • 解释:模型同时学习条件分数 ( ϵ θ ( z λ , c ) \epsilon_\theta(\mathbf{z}_\lambda, \mathbf{c}) ϵθ(zλ,c)) 和无条件分数 ( ϵ θ ( z λ , − 1 ) \epsilon_\theta(\mathbf{z}_\lambda, -1) ϵθ(zλ,1))。
4. 训练函数 (train_model)
  • 算法:实现论文 Algorithm 1(联合训练)。
  • 步骤
    1. 随机采样时间步 ( t t t)。
    2. 计算 ( α t \alpha_t αt) 和 ( σ t \sigma_t σt),生成噪声数据 ( z t = α t x + σ t ϵ \mathbf{z}_t = \alpha_t \mathbf{x} + \sigma_t \epsilon zt=αtx+σtϵ)。
    3. 以概率 ( p uncond = 0.2 p_{\text{uncond}}=0.2 puncond=0.2) 丢弃条件(( c = − 1 \mathbf{c} = -1 c=1))。
    4. 预测噪声 ( ϵ θ ( z t , c ) \epsilon_\theta(\mathbf{z}_t, \mathbf{c}) ϵθ(zt,c)),计算均方误差损失。
    5. 优化模型参数。
  • 解释:通过随机丢弃条件,模型学会同时拟合条件和无条件分布。
5. 无分类器引导采样 (sample_with_guidance)
  • 算法:实现论文 Algorithm 2。
  • 步骤
    1. 从纯噪声 ( z ∼ N ( 0 , I ) \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) zN(0,I)) 开始。
    2. 迭代 (n_steps) 次:
      • 计算条件分数 ( ϵ θ ( z , t , c ) \epsilon_\theta(\mathbf{z}, t, \mathbf{c}) ϵθ(z,t,c)) 和无条件分数 ( ϵ θ ( z , t , − 1 ) \epsilon_\theta(\mathbf{z}, t, -1) ϵθ(z,t,1));
      • 混合分数:( ϵ ~ = ( 1 + w ) ϵ cond − w ϵ uncond \tilde{\epsilon} = (1 + w) \epsilon_{\text{cond}} - w \epsilon_{\text{uncond}} ϵ~=(1+w)ϵcondwϵuncond);
      • 更新 ( x ~ t = ( z − σ t ϵ ~ ) / α t \tilde{\mathbf{x}}_t = (\mathbf{z} - \sigma_t \tilde{\epsilon}) / \alpha_t x~t=(zσtϵ~)/αt);
      • 计算下一时间步的均值 ( μ \mu μ) 和噪声,更新 ( z \mathbf{z} z)。
  • 解释:( w w w) 控制引导强度,( w = 0 w=0 w=0) 时退化为普通条件采样,( w > 0 w>0 w>0) 时增强条件一致性。
6. 主程序
  • 训练:训练 200 轮,每轮 100 次更新。
  • 采样:以 ( w = 0 , 1 , 3 w=0, 1, 3 w=0,1,3) 采样类别 0 的样本,绘制散点图。
  • 可视化:观察引导强度对样本分布的影响。

运行结果

  • ( w = 0 w=0 w=0):样本分布接近原始条件分布(类别 0 的高斯),但可能有少量混杂。
  • ( w = 1 w=1 w=1):样本更集中于类别 0 的中心,质量提升。
  • ( w = 3 w=3 w=3):样本高度集中,失去部分多样性,但更符合条件。

注意事项

  1. 简化:这是一个玩具示例,未使用复杂的 U-Net 或 ImageNet 数据,仅展示核心概念。
  2. 超参数:( n s t e p s = 100 n_steps=100 nsteps=100)、( p uncond = 0.2 p_{\text{uncond}}=0.2 puncond=0.2)、学习率等可调整以优化效果。
  3. 扩展:实际应用中需加入 DDIM 等高级采样器,或处理更高维数据(如图像)。

总结

这个实现展示了无分类器引导的关键步骤:联合训练条件与无条件模型、在采样时混合分数。通过调整 ( w w w),你可以看到样本质量与多样性的权衡,与论文中的理论和实验一致。希望这个代码和解释能帮助你深入理解并实验这一方法!

后记

2025年3月17日21点04分于上海,在Grok 3大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317221.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是数学建模?数学建模是将实际问题转化为数学问题

数学建模是将实际问题转化为数学问题&#xff0c;并通过数学工具进行分析、求解和验证的过程。 一、数学建模的基本流程 问题分析 • 明确目标&#xff1a;确定需要解决的核心问题。 • 简化现实&#xff1a;识别关键变量、忽略次要因素。 • 定义输入和输出&#xff1a;明确模…

唤起“队列”的回忆

又来博客记录自己的学习心得了&#xff0c;嘿嘿嘿(^&#xff5e;^) 目录 队列的概念和结构&#xff1a; 队列的创建和初始化&#xff1a; 队列入栈&#xff1a; 队列出栈&#xff1a; 队列的销毁&#xff1a; 取队头和队尾数据&#xff1a; 结语&#xff1a; 队列的概念…

Linux(8.4)NFS

文章目录 一、概念二、详解NFS1&#xff09;软件名2&#xff09;服务名3&#xff09;配置文件4&#xff09;端口号5&#xff09;相关命令 三、部署NFS一、NFS服务端1&#xff09;**配置源&#xff08;本地或者网络源&#xff09;**2&#xff09;2、安装NFS**3&#xff09;启动服…

【位运算】速算密钥:位运算探秘

文章目录 前言例题一、判定字符是否唯一二、丢失的数字三、两整数之和四、只出现⼀次的数字 II五、消失的两个数字 结语 前言 什么是位运算算法呢&#xff1f; 位运算算法是以位运算为核心操作&#xff0c;设计用来高效解决特定问题的一系列计算步骤集合。它巧妙利用位运算直接…

STM32G070CBT6读写FLASH中的数据

向FLASH中写入数据函数 /*函数说明&#xff1a;向FLASH中写数据形参&#xff1a;addr-要写入数据的起始地址 data-准备写入数据 len-数据大小返回值&#xff1a;1-成功&#xff0c;0-失败 */ uint8_t FlashWriteData(uint64_t addr,uint8_t data[],size_t len) {uint32_t Fir…

AD绘图基本操作

一、基本操作 注意&#xff1a;快捷键都要在英文模式下才能生效 1、移动 按住鼠标右键移动 2、切换桌面栅格距离 G 3、英寸和毫米 尺寸切换 Q 4、元件在3D模式下的移动 3D视角鼠标左键只起到选择元器件并移动之的功能&#xff0c; 单纯鼠标右键只能平移桌面 shift鼠…

dfs(十二)21. 合并两个有序链表 递归解决

21. 合并两个有序链表 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4]示例 2&#xff1a; 输入&#xff1a;l1 [], l2 [] …

51单片机指令系统入门

目录 基本概念讲解 一、机器指令​ 二、汇编指令​ &#xff08;一&#xff09;汇编指令的一般格式 &#xff08;二&#xff09;按字节数分类的指令 三、高级指令 总结​ 基本概念讲解 指令是计算机&#xff08;或单片机&#xff09;中 CPU 能够识别并执行的基本操作命令…

安全无事故连续天数计算,python 时间工具的高效利用

安全天数计算&#xff0c;数据系统时间直取&#xff0c;安全标准高效便捷好用。 笔记模板由python脚本于2025-03-17 23:50:52创建&#xff0c;本篇笔记适合对python时间工具有研究欲的coder翻阅。 【学习的细节是欢悦的历程】 博客的核心价值&#xff1a;在于输出思考与经验&am…

如何玩DeepSeek!15分钟快速创建GIS动态数据可视化仪表盘

DeepSeek最近火遍全球&#xff0c;大家用的都用的不亦乐乎。国外呢&#xff1f;当然也是&#xff0c;最近一上YouTube、X等都是deepseek的推送。 今天介绍一下&#xff0c;我在YouTube上看到的GIS行业与DeepSeek结合的一个案例&#xff1a; 快速轻松构建交互式地图仪表盘&…

课上测试:MIRACL共享库使用测试

MIRACL(MultiprecisionIntegerandRationalArithmeticC/cLibrary)是著名的密码算法库&#xff0c;设法去官网下载安装MIRACL&#xff0c;提交安装过程截图或过程文本&#xff08;3分&#xff09;. 去github官网下载.zip文件 使用如下命令进行解压 unzip -j -aa -L MIRACL-mast…

网络编程知识预备阶段

1. OSI七层模型 OSI&#xff08;Open System Interconnect&#xff09;七层模型是一种将计算机网络通信协议划分为七个不同层次的标准化框架。每一层都负责不同的功能&#xff0c;从物理连接到应用程序的处理。这种模型有助于不同的系统之间进行通信时&#xff0c;更好地理解和…

STM32微控制器_03_GPIO原理与应用

核心内容 STM32 GPIO基本原理&#xff08;熟悉&#xff09;GPIO输出功能HAL库编程实现的应用&#xff08;重点&#xff09;GPIO输入功能HAL库编程实现的应用&#xff08;重点&#xff09; 一.STM32 GPIO基本原理 1.GPIO简介 STM32的GPIO相当于STM32的四肢&#xff0c;一个S…

零拷贝分析

kafka 零拷贝 请求 - 网口 - socket - 用户态 - 内核缓存区 - 内核态&#xff08;磁盘信息&#xff09; 磁盘 - 内核缓存区 - 用户缓存区 - 网络缓存区 零拷贝&#xff08;Zero-Copy&#xff09; 是一种高效的数据传输技术&#xff0c;旨在减少数据在内存中的拷贝次数&#x…

从Instagram到画廊:社交平台如何改变艺术家的展示方式

从Instagram到画廊&#xff1a;社交平台如何改变艺术家的展示方式 在数字时代&#xff0c;艺术家的展示方式正在经历一场革命。社交平台&#xff0c;尤其是Instagram&#xff0c;已经成为艺术家展示作品、与观众互动和建立品牌的重要渠道。本文将探讨社交平台如何改变艺术家的…

✎ 一次有趣的经历

&#x1f4c6;2025年3月17日 | 周一 | ☀️晴 &#x1f4cd;今天路过学院楼7&#xff0c;见到了满园盛开的花&#x1f33a;&#xff0c;心情瞬间明朗&#xff01; &#x1f4cc;希望接下来的日子也能像这些花一样&#xff0c;充满活力&#x1f525;&#xff01; &#x1…

快!快!快!NDPP时延测试数据公布!

在全方位认识NDPP第3期《NDPP在金融场景的应用》中&#xff0c;我们重点介绍了NDPP的典型应用场景行情解码硬件加速和策略计算加速&#xff0c;并帮助某百亿私募用户基于NDPP实现期货业务加速的案例。 近期&#xff0c;中科驭数凭借低时延产品荣获信创“大比武”行业融合赛道三…

激光雷达“开卷”2.0,头部Tier1入局

高阶智驾的普及&#xff0c;正在催生激光雷达市场的巨大潜在增长空间。 本周&#xff0c;汽车激光雷达主力供应商之一的禾赛科技发布财报&#xff0c;去年第四季度激光雷达总交付量为222,054台&#xff0c;同比增长153.1%&#xff0c;超过2023年全年。2024全年激光雷达总交付量…

力扣No.376.摆动序列

题目&#xff1a; 链接&#xff1a; https://leetcode.cn/problems/wiggle-subsequence/description/ 代码&#xff1a; class Solution {public int wiggleMaxLength(int[] nums) {int nnums.length;//状态表示:int[] fnew int[n];int[] gnew int[n];//初始化:for(int i0;i…

C语言中qsort函数的详解,以及模拟

引言 C语言中qsort函数的详解和模拟实现qsort函数&#xff0c;这里为了使用冒泡排序来模拟qsort函数 一、详解qsort函数 在 C 语言中&#xff0c;qsort 函数是一个标准库函数&#xff0c;用于对数组进行快速排序&#xff08;Quick Sort&#xff09;。它位于 <stdlib.h>…