https://github.com/dfdazac/wassdistance/tree/master
前置知识
Computational optimal transport学习
具体看到熵对偶的坐标上升那就行
L
C
ε
(
a
,
b
)
=
def.
min
P
∈
U
(
a
,
b
)
⟨
P
,
C
⟩
−
ε
H
(
P
)
\mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=} \min _{\mathbf{P} \in \mathbf{U}(\mathbf{a}, \mathbf{b})}\langle\mathbf{P}, \mathbf{C}\rangle-\varepsilon \mathbf{H}(\mathbf{P})
LCε(a,b)= def. P∈U(a,b)min⟨P,C⟩−εH(P)
U
(
a
,
b
)
=
def.
{
P
∈
R
+
n
×
m
:
P
1
m
=
a
and
P
T
1
n
=
b
}
\mathbf{U}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=}\left\{\mathbf{P} \in \mathbb{R}_{+}^{n \times m}: \mathbf{P} \mathbf{1}_m=\mathbf{a} \quad \text { and } \quad \mathbf{P}^{\mathrm{T}} \mathbf{1}_n=\mathbf{b}\right\}
U(a,b)= def. {P∈R+n×m:P1m=a and PT1n=b}
对偶
L
C
ε
(
a
,
b
)
=
max
f
∈
R
n
,
g
∈
R
m
⟨
f
,
a
⟩
+
⟨
g
,
b
⟩
−
ε
⟨
e
f
/
ε
,
K
e
g
/
ε
⟩
\mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b})=\max _{\mathbf{f} \in \mathbb{R}^n, \mathbf{g} \in \mathbb{R}^m}\langle\mathbf{f}, \mathbf{a}\rangle+\langle\mathbf{g}, \mathbf{b}\rangle-\varepsilon\left\langle e^{\mathbf{f} / \varepsilon}, \mathbf{K} e^{\mathbf{g} / \varepsilon}\right\rangle
LCε(a,b)=f∈Rn,g∈Rmmax⟨f,a⟩+⟨g,b⟩−ε⟨ef/ε,Keg/ε⟩
(
u
,
v
)
=
(
e
f
/
ε
,
e
g
/
ε
)
(\mathbf{u}, \mathbf{v})=\left(e^{\mathbf{f} / \varepsilon}, e^{\mathbf{g} / \varepsilon}\right)
(u,v)=(ef/ε,eg/ε)
P = d i a g ( u ) K d i a g ( v ) , K = e x p ( − C ϵ ) \mathbf{P}=\rm{diag}\left(\mathbf{u}\right)\mathbf{K}\rm{diag}\left(\mathbf{v}\right),\quad \mathbf{K}=exp\left(-\frac{C}{\epsilon}\right) P=diag(u)Kdiag(v),K=exp(−ϵC)
坐标上升
f
(
ℓ
+
1
)
=
ε
log
a
−
ε
log
(
K
e
g
(
ℓ
)
/
ε
)
,
g
(
ℓ
+
1
)
=
ε
log
b
−
ε
log
(
K
T
e
f
(
ℓ
+
1
)
/
ε
)
.
\begin{aligned} \mathbf{f}^{(\ell+1)} & =\varepsilon \log \mathbf{a}-\varepsilon \log \left(\mathbf{K} e^{\mathbf{g}^{(\ell)} / \varepsilon}\right), \\ \mathbf{g}^{(\ell+1)} & =\varepsilon \log \mathbf{b}-\varepsilon \log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f}^{(\ell+1)} / \varepsilon}\right) . \end{aligned}
f(ℓ+1)g(ℓ+1)=εloga−εlog(Keg(ℓ)/ε),=εlogb−εlog(KTef(ℓ+1)/ε).
代码中有一些变化
考虑 C ∈ R n × m , f ∈ R n , g ∈ R m \mathbf{C}\in\mathbb{R}^{n\times m}, \mathbf{f}\in\mathbb{R}^n, \mathbf{g}\in\mathbb{R}^m C∈Rn×m,f∈Rn,g∈Rm
log
(
K
e
g
/
ε
)
=
log
(
[
∑
j
e
−
C
i
,
j
−
g
j
ε
]
i
)
=
log
(
[
∑
j
e
−
C
i
,
j
−
g
j
ε
e
f
i
ε
e
−
f
i
ε
]
i
)
=
log
(
[
∑
j
e
−
C
i
,
j
−
f
i
−
g
j
ε
]
i
⊙
e
−
f
ε
)
=
log
(
[
∑
j
e
−
C
i
,
j
−
f
i
−
g
j
ε
]
i
)
−
f
ε
=
logsumexp
(
−
C
−
f
T
−
g
ε
,
d
i
m
=
−
1
)
−
f
ε
\begin{aligned} &\log \left(\mathbf{K} e^{\mathbf{g} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}e^{\frac{f_i}{\varepsilon}}e^{-\frac{f_i}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\odot e^{-\frac{\mathbf{f}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\right)-\frac{\mathbf{f}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-1\right)-\frac{\mathbf{f}}{\varepsilon}\\ \end{aligned}
=====log(Keg/ε)log
[j∑e−εCi,j−gj]i
log
[j∑e−εCi,j−gjeεfie−εfi]i
log
[j∑e−εCi,j−fi−gj]i⊙e−εf
log
[j∑e−εCi,j−fi−gj]i
−εflogsumexp(−εC−fT−g,dim=−1)−εf
其中最后一步,向量和矩阵相加涉及广播机制
log
(
K
T
e
f
/
ε
)
=
log
(
[
∑
i
e
−
C
i
,
j
−
f
i
ε
]
j
)
=
log
(
[
∑
i
e
−
C
i
,
j
−
f
i
ε
e
g
j
ε
e
−
g
j
ε
]
j
)
=
log
(
[
∑
i
e
−
C
i
,
j
−
f
i
−
g
j
ε
]
j
⊙
e
−
g
ε
)
=
log
(
[
∑
i
e
−
C
i
,
j
−
f
i
−
g
j
ε
]
j
)
−
g
ε
=
logsumexp
(
−
C
−
f
T
−
g
ε
,
d
i
m
=
−
2
)
−
g
ε
=
logsumexp
(
−
(
C
−
f
T
−
g
)
T
ε
,
d
i
m
=
−
1
)
−
g
ε
\begin{aligned} &\log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}e^{\frac{g_j}{\varepsilon}}e^{-\frac{g_j}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\odot e^{-\frac{\mathbf{g}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-2\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\left(\mathbf{C}-\mathbf{f}^T-\mathbf{g}\right)^T}{\varepsilon},dim=-1\right)-\frac{\mathbf{g}}{\varepsilon}\\ \end{aligned}
======log(KTef/ε)log
[i∑e−εCi,j−fi]j
log
[i∑e−εCi,j−fieεgje−εgj]j
log
[i∑e−εCi,j−fi−gj]j⊙e−εg
log
[i∑e−εCi,j−fi−gj]j
−εglogsumexp(−εC−fT−g,dim=−2)−εglogsumexp(−ε(C−fT−g)T,dim=−1)−εg