6. Diffusion-based Image Translation using Disentangled Style and Content Representation
本文介绍了一种基于扩散模型的图像转换方法,图像转换就是根据文本引导或者图像的引导,将源图像转换到目标域中,如下图所示。
在图像转换中待解决的一个关键问题就是如何在将语义特征转换到目标域中时保留源图像的结构特征,而不是随机的生成一张目标域中的图像。而扩散模型本身具备很强的随机性,因此必须引入一定的约束,控制其生成的过程。作者受到流形约束梯度(Manifold Constrained Gradient,MCG)的启发,将损失梯度引入到生成过程中,实现方法如下
x
t
−
1
′
=
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
+
σ
t
ϵ
\boldsymbol{x}_{t-1}^{\prime}=\frac{1}{\sqrt{\alpha_{t}}}\left(\boldsymbol{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, t\right)\right)+\sigma_{t} \boldsymbol{\epsilon}
xt−1′=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtϵ
x
t
−
1
=
x
t
−
1
′
−
∇
x
t
ℓ
total
(
x
^
0
(
x
t
)
)
\boldsymbol{x}_{t-1}=\boldsymbol{x}_{t-1}^{\prime}-\nabla_{\boldsymbol{x}_{t}} \ell_{\text {total }}\left(\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t}\right)\right)
xt−1=xt−1′−∇xtℓtotal (x^0(xt))其中
x
t
−
1
′
\boldsymbol{x}_{t-1}^{\prime}
xt−1′沿用了DDPM中的计算方法,在此基础上又根据损失梯度
∇
x
t
ℓ
total
(
x
^
0
(
x
t
)
)
\nabla_{\boldsymbol{x}_{t}} \ell_{\text {total }}\left(\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t}\right)\right)
∇xtℓtotal (x^0(xt))进一步更新得到
x
t
−
1
\boldsymbol{x}_{t-1}
xt−1,
x
^
0
(
x
t
)
\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t}\right)
x^0(xt)是根据
x
t
\boldsymbol{x}_{t}
xt估计得到的生成结果,后续简写为
x
\boldsymbol{x}
x,其计算过程如下
x
^
0
(
x
t
)
:
=
x
t
α
ˉ
t
−
1
−
α
ˉ
t
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t}\right):=\frac{\boldsymbol{x}_{t}}{\sqrt{\bar{\alpha}_{t}}}-\frac{\sqrt{1-\bar{\alpha}_{t}}}{\sqrt{\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\boldsymbol{x}_{t}, t\right)
x^0(xt):=αˉtxt−αˉt1−αˉtϵθ(xt,t)则接下来的任务就是如何定义损失函数
ℓ
total
\ell_{\text {total }}
ℓtotal 了。
1. 结构损失
正如上文所述,我们希望输出的结果能够保持源图像的结构信息,而Slicing Vision Transformer这篇论文指出ViT中多头自注意力层的key值 k l k^l kl包含了结构信息,而最后一层的类别Token保留了语义信息。基于此,提出了一种损失通过匹配key值之间的自相似性矩阵 S l S^l Sl来保持输入和输出之间的结构一致性,计算过程如下 ℓ s s i m ( x s r c , x ) = ∥ S l ( x s r c ) − S l ( x ) ∥ F \ell_{ssim}(x_{src},x)=\left \| S^l(x_{src})-S^l(x)\right\|_F ℓssim(xsrc,x)= Sl(xsrc)−Sl(x) F其中 [ S l ( x ) ] i , j = cos ( k i l ( x ) , k j l ( x ) ) [S^l(x)]_{i,j}=\cos(k^l_i(x),k^l_j(x)) [Sl(x)]i,j=cos(kil(x),kjl(x)), k i l ( x ) k^l_i(x) kil(x)表示在ViT的第 l l l个多头自注意力层中的第 i i i个key值。虽然自相似损失可以保留输入和输出之间的内容信息,但对于DDPM扩散任务只能提供非常有限的约束。这是因为第 i i i个key值对应着图像中第 i i i个图块的位置,其与其他位置处的key值可能区别很大。为此,作者引入一种对比学习损失,让相同位置的图块之间更相似,而增大不同位置之间的距离,计算过程如下 ℓ cont ( x s r c , x ) = − ∑ i log ( exp ( sim ( k i l ( x ) , k i l ( x s r c ) ) / τ ) exp ( sim ( k i l ( x ) , k i l ( x s r c ) ) / τ + ∑ j ≠ i exp ( sim ( k i l ( x ) , k j l ( x s r c ) ) / τ ) ) \ell_{\text {cont }}\left(\boldsymbol{x}_{s r c}, \boldsymbol{x}\right)=-\sum_{i} \log \left(\frac{\exp \left(\operatorname{sim}\left(k_{i}^{l}(\boldsymbol{x}), k_{i}^{l}\left(\boldsymbol{x}_{s r c}\right)\right) / \tau\right)}{\exp \left(\operatorname{sim}\left(k_{i}^{l}(\boldsymbol{x}), k_{i}^{l}\left(\boldsymbol{x}_{s r c}\right)\right) / \tau+\sum_{j \neq i} \exp \left(\operatorname{sim}\left(k_{i}^{l}(\boldsymbol{x}), k_{j}^{l}\left(\boldsymbol{x}_{s r c}\right)\right) / \tau\right)\right.}\right) ℓcont (xsrc,x)=−i∑log exp(sim(kil(x),kil(xsrc))/τ+∑j=iexp(sim(kil(x),kjl(xsrc))/τ)exp(sim(kil(x),kil(xsrc))/τ) 其中 sim ( ⋅ , ⋅ ) \operatorname{sim}(\cdot,\cdot) sim(⋅,⋅)表示规范化的余弦相似性函数。
2. 风格损失
风格损失的目的是增加生成结果和引导之间的语义相似性。对于文本引导,作者提出了 ℓ C L I P \ell_{CLIP} ℓCLIP损失,计算过程如下 ℓ C L I P ( x ; d t r g , x s r c , d s r c ) : = − sim ( v t r g , v s r c ) \ell_{C L I P}\left(\boldsymbol{x} ; \boldsymbol{d}_{t r g}, \boldsymbol{x}_{s r c}, \boldsymbol{d}_{s r c}\right):=-\operatorname{sim}\left(\boldsymbol{v}_{t r g}, \boldsymbol{v}_{s r c}\right) ℓCLIP(x;dtrg,xsrc,dsrc):=−sim(vtrg,vsrc)其中 v t r g : = E T ( d t r g ) + λ i E I ( x s r c ) − λ s E T ( d s r c ) , v s r c : = E I ( aug ( x ) ) \boldsymbol{v}_{t r g}:=E_{T}\left(\boldsymbol{d}_{t r g}\right)+\lambda_{i} E_{I}\left(\boldsymbol{x}_{s r c}\right)-\lambda_{s} E_{T}\left(\boldsymbol{d}_{s r c}\right), \quad \boldsymbol{v}_{s r c}:=E_{I}(\operatorname{aug}(\boldsymbol{x})) vtrg:=ET(dtrg)+λiEI(xsrc)−λsET(dsrc),vsrc:=EI(aug(x)) d s r c , d t r g \boldsymbol{d}_{src},\boldsymbol{d}_{t r g} dsrc,dtrg和 x s r c \boldsymbol{x}_{src} xsrc分别表示源图像的文本描述,目标文本描述(文本引导)和源图像, E T , E I E_T,E_I ET,EI分别表示CLIP模型中的文本和图像编码器, aug \operatorname{aug} aug表示CLIP中为了防止出现对抗伪影的增强操作。通过调节 λ s \lambda_s λs和 λ i \lambda_i λi可以去除源域的文本信息和增加源域的图像信息。而对于图像引导,作者提出一种语义风格损失 ℓ s t y \ell_{sty} ℓsty,正如前文所说ViT的最后一层的类别Token保留了语义信息,因此可以借此来保持源图像和生成图像之间的语义一致性。此外,作者发现仅用语义一致性作为约束,会导致生成的图像之间存在颜色差异,因此作者又引入了图像之间的MSE损失来减少颜色上的差距,义风格损失 ℓ s t y \ell_{sty} ℓsty如下 ℓ s t y ( x t r g , x ) = ∥ e [ C L S ] L ( x t r g ) − e [ C L S ] L ( x ) ∥ 2 + λ m s e ∥ x t r g − x ∥ 2 \ell_{s t y}\left(\boldsymbol{x}_{t r g}, \boldsymbol{x}\right)=\left\|\boldsymbol{e}_{[C L S]}^{L}\left(\boldsymbol{x}_{t r g}\right)-\boldsymbol{e}_{[C L S]}^{L}(\boldsymbol{x})\right\|_{2}+\lambda_{m s e}\left\|\boldsymbol{x}_{t r g}-\boldsymbol{x}\right\|_{2} ℓsty(xtrg,x)= e[CLS]L(xtrg)−e[CLS]L(x) 2+λmse∥xtrg−x∥2其中 e [ C L S ] L \boldsymbol{e}_{[C L S]}^{L} e[CLS]L表示最后一层的类别Token。
3. 语义差异损失
为了加快生成的过程,减少反向去噪的步骤,作者设计了语义差异损失 ℓ s e m \ell_{sem} ℓsem, ℓ s e m ( x t ; x t + 1 ) = − ∥ e [ C L S ] L ( x ^ 0 ( x t ) ) − e [ C L S ] L ( x ^ 0 ( x t + 1 ) ) ∥ 2 \ell_{s e m}\left(\boldsymbol{x}_{t} ; \boldsymbol{x}_{t+1}\right)=-\left\|\boldsymbol{e}_{[C L S]}^{L}\left(\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t}\right)\right)-\boldsymbol{e}_{[C L S]}^{L}\left(\hat{\boldsymbol{x}}_{0}\left(\boldsymbol{x}_{t+1}\right)\right)\right\|_{2} ℓsem(xt;xt+1)=− e[CLS]L(x^0(xt))−e[CLS]L(x^0(xt+1)) 2通过增加相邻两次反向去噪步骤估计的生成结果之间的语义差异,促使生成结果能够更快的变换到目标域中。最终的损失函数 ℓ t o t a l \ell_{total} ℓtotal如下 ℓ total = λ 1 ℓ cont + λ 2 ℓ ssim + λ 3 ℓ C L I P + λ 4 ℓ sem + λ 5 ℓ r n g \ell_{\text {total }}=\lambda_{1} \ell_{\text {cont }}+\lambda_{2} \ell_{\text {ssim }}+\lambda_{3} \ell_{C L I P}+\lambda_{4} \ell_{\text {sem }}+\lambda_{5} \ell_{r n g} ℓtotal =λ1ℓcont +λ2ℓssim +λ3ℓCLIP+λ4ℓsem +λ5ℓrng其中 ℓ r n g \ell_{rng} ℓrng是一个正则化损失项,为了防止反向去噪过程中出现不合理的步骤。如果采用图像引导,上式中的 ℓ C L I P \ell_{C L I P} ℓCLIP应改为 ℓ s t y \ell_{sty} ℓsty。
为了进一步加快生成过程,作者还提出一种重采样策略。作者发现一个好的生成起始点 X T X_T XT能够有效的提升生成效果,缩短反向去噪步骤。因此作者重复 N N N次对 X T − 1 X_{T-1} XT−1的采样步骤,并根据前向扩散公式计算得到 X T X_T XT, x T = 1 − β T − 1 x T − 1 + β T − 1 ϵ \boldsymbol{x}_{T}=\sqrt{1-\beta_{T-1}} \boldsymbol{x}_{T-1}+\beta_{T-1} \boldsymbol{\epsilon} xT=1−βT−1xT−1+βT−1ϵ然后,从中选取梯度最容易受到损失函数影响的 X T X_T XT作为起始点。整个算法的计算流程如下所示
实验表明,在文本引导和图像引导的图像转换任务中,DiffuseIT都取得了不错的效果。