gumbel-softmax如何实现离散分布可微+torch代码+原理+证明_gumbel softmax-CSDN博客
https://zhuanlan.zhihu.com/p/678930684
gumbelsoftmax 是为了防止丢失其他类别的梯度。
相当于不止选择概率大的那个类别被更新,其他类别的梯度也被更新了。
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape)
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature=0.5):
y = torch.log(lo