目录
- Gumbel-Softmax分布
- Gumbel-Softmax Estimator
- Straight-Through (ST) Gumbel-Softmax Estimator
- Straight-Through Estimator (STE)
- Straight-Through (ST) Gumbel-Softmax Estimator
- 参考
Gumbel-Softmax分布
Gumbel-Softmax分布是一个定义在单纯形(simplex)上的连续分布。
Gumbel-Softmax分布可以近似categorical分布。
用
z
z
z表示为服从
π
=
(
π
1
,
…
,
π
k
)
\pi = (\pi_1,\ldots,\pi_k)
π=(π1,…,πk)的categorical随机变量。categorical分布的样本表示为
k
k
k维的one-hot向量,在
k
−
1
k-1
k−1维的单纯形空间
△
k
−
1
\bigtriangleup^{k-1}
△k−1中。
Gumbel-Max trick是reparametrization tricks的一个特例,其提供了一个简单有效的从categorical分布采样的方法:
z
=
one-hot
(
argmax
i
[
g
i
+
log
π
i
]
)
(1)
z = \text{one-hot}\left(\operatorname{argmax}_i[g_i + \log \pi_i]\right) \tag{1}
z=one-hot(argmaxi[gi+logπi])(1)其中
g
i
∼
G
u
m
b
e
l
(
0
,
1
)
g_i \sim Gumbel(0,1)
gi∼Gumbel(0,1)。Gumbel分布用于对各种分布的多个样本的最大值(或最小值)的分布进行建模。Gumbel分布的概率密度是:
G
u
m
b
e
l
(
μ
,
β
)
=
1
β
exp
(
−
x
−
μ
β
+
exp
(
−
x
−
μ
β
)
)
Gumbel(\mu, \beta) = \frac{1}{\beta}\exp(-\frac{x - \mu}{\beta} + \exp(-\frac{x - \mu}{\beta}))
Gumbel(μ,β)=β1exp(−βx−μ+exp(−βx−μ))
softmax函数是可导的,使用softmax函数去近似公式(1)中的argmax,可以得到样本
y
∈
△
k
−
1
y\in\bigtriangleup^{k-1}
y∈△k−1:
y
i
=
softmax
[
g
i
+
log
π
i
]
=
exp
(
log
π
i
+
g
i
τ
)
∑
j
=
1
k
exp
(
log
π
j
+
g
j
τ
)
y_i = \operatorname{softmax}[g_i + \log \pi_i] = \frac{\exp(\frac{\log \pi_i + g_i}{\tau})}{\sum_{j = 1}^k \exp(\frac{\log \pi_j + g_j}{\tau})}
yi=softmax[gi+logπi]=∑j=1kexp(τlogπj+gj)exp(τlogπi+gi)Gumbel-Softmax分布的概率密度函数是:
随着
τ
\tau
τ趋近于0,Gumbel-Softmax分布的样本逐渐变成one-hot的,Gumbel-Softmax分布也逐渐变成了categorical分布。如下图所示:
Gumbel-Softmax Estimator
Gumbel-Softmax分布的
∂
y
∂
π
\frac{\partial y}{\partial \pi}
∂π∂y是有定义的。
通过用Gumbel-Softmax样本替换categorical样本,我们可以使用反向传播来计算梯度。
把在训练阶段,用可导的Gumbel-Softmax样本替代不可导的categorical样本的过程称为Gumbel-Softmax Estimator。
在温度
τ
\tau
τ小时,样本接近单热但梯度方差大,在温度
τ
\tau
τ大时,样本平滑但梯度方差小。
实际中,我们从高温
τ
\tau
τ开始,然后退火到一个很小但非零的温度。
Straight-Through (ST) Gumbel-Softmax Estimator
Straight-Through Estimator (STE)
首先介绍下Straight-Through Estimator (STE)。
STE是量化(quantization)中常见的求导方式。
比如有sign函数:
w
b
=
sign
(
w
)
=
{
+
1
,
if
w
≥
0
−
1
,
otherwise
w_{b}=\operatorname{sign}(w)=\left\{ \begin{array}{ll}{+1,}{\text { if } w \geq 0} \\ {-1,}{\text { otherwise }}\end{array}\right.
wb=sign(w)={+1, if w≥0−1, otherwise 这个sign函数在定义域范围内导数都是0。STE就是用来解决sign函数梯度无法反传的问题的。
二值网络训练过程可以是这样:模型中每个参数其实都是一个浮点型的数,每次迭代其实都是在更新这个浮点型数。但是,在前向传播的过程中,先用sign函数对浮点型参数二值化处理然后再参与到运算,而此时并没有把这个浮点型数值抛弃掉,而是暂时在内存中保存起来。前向传播完之后,网络得到一个输出,就可以接着通过反向传播算出二值参数的梯度,再直接用这个梯度来更新对应的浮点型参数。这样,前向反向就跑通了。等训练的差不多了,就最后对模型的这些浮点型参数做一次二值化处理形成最终的二值网络,此时浮点型的参数就完成了任务,可以被抛弃掉了。
Straight-Through (ST) Gumbel-Softmax Estimator
在前向的时候使用argmax离散化 y y y,但在梯度反传的时候,使用连续近似 ∇ θ z ≈ ∇ θ y \nabla_\theta z \approx \nabla_\theta y ∇θz≈∇θy。
参考
ICLR 2017 Categorical Reparameterization with Gumbel-Softmax
Emma Benjaminson blog
二值网络,围绕STE的那些事儿
gumbel-max-trick的数学证明