注:加粗+下划线名词详解见文章末
了解VQGAN之前,还学习了VQVAE(Vector QuantisedVariational AutoEncoder))这篇论文Neural Discrete Representation Learning,看了几个不错的学习视频 进行了深入了解
VQVAE的思想来源 vector quantisation (VQ) 向量矢量化
VQVAE的核心思想是:在无监督预训练中学到有用的表征,即学习一个离散的潜在表征,用于规避VAE framework中出现的posterior collapse 后验坍塌问题
VQVAE和VAE 的不同点在于
1编码器的输出是离散编码的而不是连续编码
2 先验是学习的(一个离散的类别分布)而不是静态的(vae 的先验(p(z)是一个标准的正态分布)
VQ-VAE(Vector Quantized Variational Autoencoder)模型的主要贡献,具体可以分解为以下几点:
- 引入VQ-VAE模型:
- 简单性:VQ-VAE模型设计得相对简单。
- 使用离散隐变量:与许多其他生成模型(如传统的VAE)使用连续隐变量不同,VQ-VAE采用离散隐变量。这种离散性有助于模型学习到更加稳定和可解释的表示。
- 避免“后验坍塌”:VQ-VAE由于其特殊的离散性和训练机制,有效避免了在训练过程中常见的“后验坍塌”问题,即编码器输出的后验分布过早地坍塌到先验分布,导致隐变量无法有效编码输入数据的信息。
- 无方差问题:由于使用离散隐变量,VQ-VAE在训练过程中不会遇到与连续隐变量相关的方差问题。
- 性能表现:
- 与连续模型相当的对数似然度:研究表明,尽管VQ-VAE使用离散隐变量,但其在对数似然度这一指标上的表现与连续隐变量的模型(如传统的VAE)相当,甚至可能更优。
- 高质量的样本生成:
- 搭配强大的先验分布:当VQ-VAE模型与强大的先验分布(如自回归模型或流模型)结合使用时,能够在多种应用(如语音和视频生成)中生成连贯且高质量的样本。
- 无监督学习的应用:
- 通过原始语音学习语言:研究表明,VQ-VAE模型能够在没有任何监督的情况下,通过原始语音数据学习到语言结构,展示了其强大的无监督学习能力。
- 无监督说话人转换:此外,VQ-VAE还被应用于无监督说话人转换任务中,展示了其在实际应用中的潜力和价值。
VQ-VAE 方法
3.1Discrete Latent variables
离散潜变量(Discrete Latent Variables)在生成模型中,特别是在变分自编码器(VAE)的变种如Vector Quantized Variational Autoencoder(VQ-VAE)中,扮演着重要角色。为了更好地理解这一概念,我们可以从以下几个方面深入:
1. 潜在嵌入空间 e
-
定义:首先,我们定义了一个潜在嵌入空间 e∈RK×D,其中 K 是嵌入向量的数量(即离散潜变量的类别数),D 是每个嵌入向量 ei 的维度。这个空间包含了 K 个可能的嵌入向量,每个向量都是 D 维的。
-
物理意义:这些嵌入向量可以被视为“代码本”中的条目,模型通过学习这些条目来有效地表示输入数据 x 的关键特征。
2. 编码器与离散潜变量的计算
-
编码器输出:给定输入 x,编码器 ze(x) 产生一个连续的表示(通常是一个向量),这个表示在训练过程中会尝试接近某个嵌入向量。
-
最近邻查找:然后,通过比较编码器输出与嵌入空间 e 中的所有向量,找到最近的嵌入向量 ek。这个过程通常通过计算欧几里得距离或余弦相似度来实现。
-
离散潜变量 z:找到的最近嵌入向量的索引 k 被用作离散潜变量 z 的值。因此,z 是一个整数,表示选择了哪个嵌入向量。
3. 解码器与重构
-
解码器输入:解码器的输入是找到的嵌入向量 ek,而不是编码器的原始输出。这意味着解码器直接基于离散的、量化的表示来重构输入 x。
-
重构过程:解码器通过学习如何根据嵌入向量 ek 生成接近原始输入 x 的输出来训练。
4. 后验分类分布 q(z|x)
-
定义:在VQ-VAE中,后验分布 q(z∣x) 通常被简化为一个one-hot分布,其中只有一个元素为1(对应于最近的嵌入向量的索引),其余元素为0。
-
解释:这种简化的后验分布反映了模型对输入 x 潜在表示的确定性选择。在训练过程中,模型试图最大化这种选择的准确性,即确保编码器输出尽可能接近其最近的嵌入向量。
5. 模型的完整性与训练
-
参数集:模型的完整参数集包括编码器的参数、解码器的参数以及嵌入空间 e 中的嵌入向量本身。
-
训练目标:训练目标通常包括两部分:重构损失(确保解码器能够准确重构输入)和嵌入损失(鼓励编码器输出接近某个嵌入向量)。此外,还可能包括正则化项来防止过拟合。
3.2 学习
总训练目标为:
我们详细探讨了VQ-VAE(Vector Quantized Variational Autoencoder)模型的学习过程和训练目标。以下是对这段描述的逐点理解:
1. 梯度估计
- 直通梯度估计:由于VQ-VAE中的量化操作(将连续编码器输出映射到最近的嵌入向量)是不可微的,因此不能直接计算梯度。为了绕过这个问题,作者使用了类似于直通估计器(Straight-Through Estimator, STE)的技巧。在前向传播时,使用量化后的嵌入向量 zq(x) 作为解码器的输入;在后向传播时,将梯度 ∇zL(关于解码器输入的梯度)直接传递给编码器的输出 ze(x),就好像量化操作是可微的一样。这种方法允许梯度通过量化层进行反向传播,尽管它实际上并不准确反映真实的梯度。
2. 损失函数
- 总损失函数:VQ-VAE的训练目标由三个部分组成,每个部分针对不同的模型组件进行优化。
- 重构损失(第一项):优化解码器和编码器(通过直通梯度估计)。这个损失项鼓励解码器根据量化后的嵌入向量 zq(x) 重构输入 x,同时编码器通过调整其输出 ze(x) 来最小化这个重构误差。
- 嵌入损失(第二项):仅用于更新嵌入向量 ei。这个损失项使用 l2 误差将嵌入向量拉向编码器的输出 ze(x),从而优化嵌入空间。由于这个损失项只影响嵌入向量,因此在计算时使用了
stopgradient
运算符来阻止梯度流向编码器。 - 承诺损失(第三项):确保编码器的输出不会过度偏离嵌入空间。这个损失项防止编码器输出的方差过大,从而确保编码器能够“承诺”到嵌入空间上。
3. 训练细节
- β 值:承诺损失的权重 β 在不同实验中表现出很强的稳健性,作者在所有实验中使用 β = 0.25。然而,这个值通常取决于重构损失的规模,可能需要针对特定任务进行调整。
- KL 项的忽略:由于假设 z 有一个统一的先验,并且量化操作使得后验 q(z∣x) 实际上是一个 one-hot 分布,因此 KL 散度项在训练过程中是常数,可以忽略不计。
4. 总结
VQ-VAE通过结合直通梯度估计和特定的损失函数,有效地解决了离散潜变量在训练中的梯度问题,并实现了对编码器、解码器和嵌入空间的联合优化。这种方法不仅避免了后验坍塌的问题,还能够在多种应用中生成高质量、连贯的样本。
3.3 先验
离散潜在变量 p(z) 上的先验分布是一种类别分布,并且可以通过依赖于特征图中的其他 z 来实现自回归。在训练 VQ-VAE 时,先验保持恒定且均匀。训练后,我们在 z 上拟合自回归分布 p(z),以便我们可以通过ancestral sampling生成 x。我们对图像的离散潜在特征使用 PixelCNN,对原始音频使用 WaveNet。联合训练先验和 VQ-VAE,这可以加强我们的结果,留待未来的研究。
后验坍塌
原论文Using the VQ method allows the model to circumvent issues of “posterior collapse” -— where the latents are ignored when they are paired with a powerful autoregressive decoder -— typically observed in the VAE framework.
后验坍塌指的是在训练过程中,VAE的编码器(Encoder)产生的后验分布q(z|x)(用于近似真实的后验分布p(z|x))逐渐退化为一个与先验分布p(z)非常接近的分布,甚至在某些极端情况下,两者几乎完全相同。这导致隐变量z无法有效地捕捉到输入数据x的信息,进而使得解码器(Decoder)在重构输入时几乎不依赖于隐变量z,而是直接基于输入数据x进行重构。
后验坍塌的原因主要可以归结为以下几点:
-
ELBO目标函数的优化:VAE通过优化证据下界(ELBO)来训练模型。ELBO由两部分组成:重构项(Reconstruction term)和KL散度项(KL term)。重构项鼓励解码器生成与输入相似的输出,而KL散度项则鼓励编码器的后验分布与先验分布相似。然而,在实际训练中,KL散度项容易变得非常小甚至为0,导致后验分布坍塌到先验分布。
-
强大的解码器:当解码器非常强大时,它可能能够仅通过输入数据x的重构误差来优化模型,而不需要依赖隐变量z。这导致KL散度项在优化过程中逐渐被忽略,进而引发后验坍塌。
-
模型容量不平衡:编码器和解码器之间的容量不平衡也可能导致后验坍塌。如果解码器过于强大而编码器相对较弱,解码器可能会忽略编码器的输出,直接基于输入数据进行重构。
学习视频
1 简要了解[论文简析]VQ-VAE:Neural discrete representation learning[1711.00937]_哔哩哔哩_bilibili
2 详细了解
68、VQVAE预训练模型的论文原理及PyTorch代码逐行讲解_哔哩哔哩_bilibili