文章目录
- 引言
- 正文
- 生成器损失函数
- 最小二乘损失函数
- 梅尔频谱图损失函数
- 特征匹配损失函数
- 生成器最终损失函数loss
- 生成器loss对应代码
- 鉴定器损失函数
- 鉴定器损失函数代码
- 总结
- 引用
引言
- 这里翻译了HiFi-GAN这篇论文的具体内容,具体链接。
- 这篇文章还是学到了很多东西,从整体上说,学到了生成对抗网络的构建思路,包括生成器和鉴定器。细化到具体实现的细节,如何 实现对于特定周期的数据处理?在细化,膨胀卷积是如何实现的?这些通过文章,仅仅是了解大概的实现原理,但是对于代码的实现细节并不是很了解。如果要加深印象,还是要结合代码来具体看一下实现的细节。
- 本文主要围绕具体的代码实现细节展开,对于相关原理,只会简单引用和讲解。因为官方代码使用的是pytorch,所以是通过pytorch展开的。
- 关于模型其他部分的介绍,链接如下
- 论文代码学习(1)—HiFi-GAN——生成器generator代码
- 论文代码学习—HiFi-GAN(2)——鉴别器discriminator代码
正文
- 关于模型的损失函数,这里总共有两部分损失函数,分别是生成器损失函数和鉴定器损失函数。其中生成器的损失函数,有分为三部分,分别是常规的对抗生成损失、针对特征匹配的损失函数和针对梅尔频谱图的损失函数,后两者是作者自己的加上去的。
生成器损失函数
- 对于生成器损失函数,作者分成了三个部分,分别是基本损失函数、针对特征匹配的损失函数以及梅尔损失函数。
最小二乘损失函数
-
不同于一般的GAN网络使用交叉熵损失函数,这里使用的是最小二乘损失函数,借此来避免梯度丢失的现象。
-
最小二乘损失函数
- 用于衡量模型预测值和真实值的差异,具体特点如下
- 平方项:通过平方差异,扩大误差,模型更加关注于难以拟合的样本
- 连续可微:连续可微,可以有效找到最小值
- 非负:损失函数的值始终非负
- 用于衡量模型预测值和真实值的差异,具体特点如下
- 生成器的损失函数的目的是为了使得生成的数据,经过鉴定器判定,和真的差不多。
- 具体的公式如下
- s s s是梅尔频谱图,输入的条件变量
- x x x是真实数据
- D ( x ) D(x) D(x)是鉴定器对于输入结果的评分,越逼真越接近1
- G ( s ) G(s) G(s)是生成器根据梅尔频谱图生成的结果
- 在上式子中,损失函数越小越好,生成器的效果越好,鉴定器,会将其分辨为1,做差,越靠近零,效果越好。
梅尔频谱图损失函数
- 除了考虑基本的损失函数,这里还增加梅尔频谱图损失函数,用来提高训练效果和生成音频的分辨率,主要是抓住了梅尔频谱图对于感知能力的重视。
- 定义
- 计算合成的波形图和实际波形图的对应采样点的L1距离
- 参数说明
- ∅ \varnothing ∅表示将波形图转为mel频谱图
- 效果:
- 帮助生成器生成和输入相关的实际波形
- 是的对抗训练阶段能够快速稳定下来
特征匹配损失函数
-
特征匹配损失函数是用来衡量真实样本和生成样本在鉴定器上提取出来的特征的差异程度。不同于上一个mel频谱图的特征衡量,这里是直接衡量鉴定器生成的中间特征的差异程度。
-
定义
- 计算真实样本和生成样本分别在鉴定器上生成的中间特征的L1距离
-
参数说明
- T T T表示为鉴定器的层数
- D i D^i Di和 N i N_i Ni分别表示第i层的特征值和特征的数量。
-
效果
- 从鉴定器特征角度使得生成器的样本更加逼真
- 从鉴定器特征角度使得生成器的样本更加逼真
-
注意
- 这里并不是单单一个层的特征,是鉴定器上每一层的输出特征的L1距离累加和的平均值。
生成器最终损失函数loss
- 生成器最终的损失函数,是上述三个损失函数之和,并且特征匹配损失函数和mel频谱图损失函数,加上对应的权重,具体如下
-
λ
f
m
=
2
\lambda_{fm} = 2
λfm=2和
λ
m
e
l
=
45
\lambda_{mel} = 45
λmel=45
-
λ
f
m
=
2
\lambda_{fm} = 2
λfm=2和
λ
m
e
l
=
45
\lambda_{mel} = 45
λmel=45
生成器loss对应代码
def feature_loss(fmap_r, fmap_g):
# 特征损失函数
# fmap_r是真实音频信号的特征图,fmap_g是生成音频信号的特征图
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# 遍历每一层特征图,计算特征损失,做差,求绝对值,求均值
loss += torch.mean(torch.abs(rl - gl))
# 根据经验,特征损失函数的权重为10
return loss*2
def generator_loss(disc_outputs):
# 生成器的损失函数
# disc_outputs是鉴定器的输出
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
# loss是生成器的总损失,用于反向传播来更新生成器的参数
# gen_losses是生成器的损失列表,用于记录鉴定器中每一个元素对应的损失,可以用于调试设备
return loss, gen_losses
- 结合代码来看,并没有将mel频谱图损失记录在内,这里仅仅包含了两个损失函数,generator_loss实现了最小二乘损失函数,feature_loss计算了鉴定器每一层的匹配的损失函数。
- 她是把mel频谱图损失定义在训练过程中了.
鉴定器损失函数
- 我们鉴定器的训练目标:
- 能够将真实数据鉴定为真,标记为1
- 能够将生成器生成的数据鉴定为假,标记为0
- 所以,鉴定器的损失函数应该从两方面进行考虑,分别是鉴定生成数据和鉴定真实数据。
- 具体的公式如下
- s s s是梅尔频谱图,输入的条件变量
- x x x是真实数据
- D ( x ) D(x) D(x)是鉴定器对于输入结果的评分,越逼真越接近1
- G ( s ) G(s) G(s)是生成器根据梅尔频谱图生成的结果
鉴定器损失函数代码
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
# 鉴定器的损失函数
# disc_real_outputs是真实音频信号的鉴定器的输出
# disc_generated_outputs是生成音频信号的鉴定器的输出
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# 计算真实音频信号的损失
r_loss = torch.mean((1-dr)**2)
# 计算生成音频信号的损失
g_loss = torch.mean(dg**2)
# 将两个损失相加
loss += (r_loss + g_loss)
# 记录各个鉴定器的损失
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
- 这个损失函数实现起来还是比较容易的,只需要分别计算两种数据的损失,然后累加求和即可
总结
- 总的来说,这是第一次接触对抗生成学习,知道了对于鉴定器和生成器要分别定义,损失函数也是分别定义的。除此之外,他们的损失函数也是相互调用的。值得学习。
- 下部分将讲述关于train文件具体内容,这个是模型的具体训练文件,定义了模型的前向传播和反向传播的过程。
引用
- chatGPT-plus
- HiFi-GAN demo
- HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis