【GPT-SOVITS-03】SOVITS 模块-生成模型解析

news2024/11/20 20:22:12

说明:该系列文章从本人知乎账号迁入,主要原因是知乎图片附件过于模糊。

知乎专栏地址:
语音生成专栏

系列文章地址:
【GPT-SOVITS-01】源码梳理
【GPT-SOVITS-02】GPT模块解析
【GPT-SOVITS-03】SOVITS 模块-生成模型解析
【GPT-SOVITS-04】SOVITS 模块-鉴别模型解析
【GPT-SOVITS-05】SOVITS 模块-残差量化解析
【GPT-SOVITS-06】特征工程-HuBert原理

1.概述

SOVIT 模块的主要功能是生成最终的音频文件。

GPT-SOVITS的核心与SOVITS差别不大,仍然是分了两个部分:

  • 基于 VAE + FLOW 的生成器,源代码为 SynthesizerTrn
  • 基于多尺度分类器的鉴别器,源代码为 SynthesizerTrn

针对鉴别器相较于SOVITS5做了一些简化,主要的差异是在在生成模型处引入了残差量化层。

在训练时进入先验编码器的是经过残差量化层的 quatized 数据。

在推理时,用的是AR模块推理出的 code,然后用code直接生成 quatized 数据,再进入先验编码器。

训练所涉及特征包括:
在这里插入图片描述

2.训练流程

在这里插入图片描述

  • 如概述所注,在训练时SSL特征经过残差量化层中会产生量化编码 code 和数据 quatized。
  • 这个 code 也会作为 AR,即GPT模块训练的特征
  • 在推理时,这个code 就由 GPT 模块生成
  • 损失函数如下:
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
    loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
    loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

    loss_fm = feature_loss(fmap_r, fmap_g)
    loss_gen, losses_gen = generator_loss(y_d_hat_g)
    loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl

3.推理流程

在这里插入图片描述
推理时直接通过先验编码器,通过FLOW的逆,进入解码器后输出推理音频

4.调试代码参考

import os,sys
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.data import DataLoader

from vof.vits.data_utils import (
    TextAudioSpeakerLoader,
    TextAudioSpeakerCollate,
    DistributedBucketSampler,
)
from vof.vits.models import SynthesizerTrn
from vof.script.utils import HParams

now_dir   = os.getcwd()
root_dir  = os.path.dirname(now_dir)
prj_name  = 'project01'               # 项目名称
prj_dir   = root_dir + '/res/' + prj_name + '/'

with open(root_dir + '/res/configs/s2.json') as f:
    data = f.read()
    data = json.loads(data)

# 新增其他参数
s2_dir = prj_dir + 'logs'  # gpt 训练用目录
os.makedirs("%s/logs_s2" % (s2_dir), exist_ok=True)

data["train"]["batch_size"]             = 3
data["train"]["epochs"]                 = 15
data["train"]["text_low_lr_rate"]       = 0.4
data["train"]["pretrained_s2G"]         = root_dir + '/res/pretrained_models/s2G488k.pth'
data["train"]["pretrained_s2D"]         = root_dir + '/res/pretrained_models/s2D488k.pth'
data["train"]["if_save_latest"]         = True
data["train"]["if_save_every_weights"]  = True
data["train"]["save_every_epoch"]       = 5
data["train"]["gpu_numbers"]            = 0
data["data"]["exp_dir"]                 = data["s2_ckpt_dir"] = s2_dir
data["save_weight_dir"]                 = root_dir + '/res/weight/sovits'
data["name"]                            = prj_name
data['exp_dir']                         = s2_dir

hps = HParams(**data)
print(hps)
"""
self.path2 = "%s/2-name2text-0.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
"""
train_dataset = TextAudioSpeakerLoader(hps.data)
"""
ssl  hubert 特征 [1,768,195]
spec [1025,195]
wav  [1,124800]
text [14,]
"""
train_sampler = DistributedBucketSampler(
    train_dataset,
    hps.train.batch_size,
    [
        32,
        300,
        400,
        500,
        600,
        700,
        800,
        900,
        1000,
        1100,
        1200,
        1300,
        1400,
        1500,
        1600,
        1700,
        1800,
        1900,
    ],
    num_replicas=1,
    rank=0,
    shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
    batch_sampler=train_sampler
)

def _model_forward(ssl, y, y_lengths, text, text_lengths):

    net_g = SynthesizerTrn(
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model,
    )
    net_g.forward(ssl, y, y_lengths, text, text_lengths)

for data in train_loader:

    ssl_padded   = data[0]
    ssl_lengths  = data[1]
    spec_padded  = data[2]
    spec_lengths = data[3]
    wav_padded   = data[4]
    wav_lengths  = data[5]
    text_padded  = data[6]
    text_lengths = data[7]

    _model_forward(ssl_padded, spec_padded, spec_lengths, text_padded, text_lengths)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1525105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分布式搜索引擎(3)

1.数据聚合 **[聚合(](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations.html)[aggregations](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations.html)[)](https://www.ela…

旅游系统-软件与环境

一. 软件 1.Navicat、phpstudy、Idea、Vsode 参考 网盘链接 二.配置文件 1.NodeJS、JDK、Mysql 参考 网盘链接 注意点: 1.Mysql 切记需要环境变量配置 2.数据库密码要好记点的,别乱设 3.环境变量配置的路径要能找到 三.安装运行 1.下载网盘内的软件&am…

html系列:按钮被样式图片挡着了,无法点击怎么办

​ 背景 在开发中会遇到一些奇奇怪怪的需求,比如在按钮上要显示一个样式图片,同时还要能不影响按钮的点击使用;这时候,设置好了样式,按钮无法点击怎么办? 在查阅资料的时候找到了解决方案。 解决方案 …

kafka集群介绍

介绍 kafka是一个高性能、低延迟、分布式的消息传递系统,特点在于实时处理数据。集群由多个成员节点broker组成,每个节点都可以独立处理消息传递和存储任务。 路由策略 发布消息由key、value组成,真正的消息是value,key是标识路…

前端面试题01(css)

前端面试题01(css) 文章目录 前端面试题01(css)1、CSS选择器的优先级2、隐藏元素的方法有哪些3、px和rem的区别4、重绘和重排的区别5、水平垂直居中的方式6、CSS的那些属性可以继承7、预处理器 🎉写在最后 hello hello…

2023安洵杯 ezjava

2023安洵杯 ezjava 附件地址&#xff1a;https://github.com/D0g3-Lab/i-SOON_CTF_2023 先看依赖&#xff1a; <dependency><groupId>org.postgresql</groupId><artifactId>postgresql</artifactId><version>42.3.1</version><…

Google云计算原理与应用(三)

目录 五、分布式存储系统Megastore&#xff08;一&#xff09;设计目标及方案选择&#xff08;二&#xff09;Megastore数据模型&#xff08;三&#xff09;Megastore中的事务及并发控制&#xff08;四&#xff09;Megastore基本架构&#xff08;五&#xff09;核心技术——复制…

Transformer的前世今生 day02(神经网络语言模型

神经网络语言模型 使用神经网络的方法&#xff0c;去完成语言模型的两个问题&#xff0c;下图为两层感知机的神经网络语言模型&#xff1a; 以下为预备概念 感知机 线性模型可以用下图来表示&#xff1a;输入经过线性层得到输出 线性层 / 全连接层 / 稠密层&#xff1a;假…

【C++ leetcode 】双指针问题

1. 183. 移动零 题目 给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元素的相对顺序。 请注意 &#xff0c;必须在不复制数组的情况下原地对数组进行操作。 题目链接 . - 力扣&#xff08;LeetCode&#xff09; 画图 和 文字 分…

无尘室设计常用参数与选型

无尘车间(Clean Room)是指空气无尘度达到规定级别的受控空间。其功能是把空气中的微粒子、有害空气、细菌等污染物排除室外,并将室内的无尘度、温度、湿度、室内压力、气流速度与气流分布、噪音、振动、照明及静电控制在某一需求范围内。无尘车间最主要的作用在于控制产品所…

《前端系列》之前端学习路线

目录 1 前言2 前端学习路线2.1 入门阶段2.1.1 HTML2.1.2 CSS2.1.3 JavaScript2.1.4 网络基础 2.2 基础阶段2.2.1 前端框架2.2.2 深入JavaScript2.2.3 ES62.2.4 工程化知识 2.3 进阶阶段2.3.1 CSS2.3.2 Javascript2.3.3 单元测试2.3.4 性能优化 3 总结 1 前言 在技术更新迭代发…

如何在没有备份的情况下恢复 Android 上已删除的照片?

丢失 Android 设备上的珍贵照片可能是一场噩梦&#xff0c;尤其是在没有备份的情况下。无论是意外删除图像还是由于Android 崩溃而丢失图像&#xff0c;一想到它们可能会永远消失就令人沮丧。幸运的是&#xff0c;有多种方法可以在 Android 上恢复已删除的照片。 如何在没有备份…

C语言中内存函数的使用

memcpy函数的使用和模拟实现 memcpy的使用 函数使用说明&#xff1a; • 函数memcpy从source的位置开始向后复制num个字节的数据到destination指向的内存位置。 • 这个函数在遇到 \0 的时候并不会停下来。 • 如果source和destination有任何的重叠&#xff0c;复制的结…

CSS案例-2.简单版侧边栏练习

效果 知识点 标签显示模式 块级元素 block-level 常见元素:<h1>~<h6>、<p>、<div>、<ul>、<ol>、<li>等。 特点: 独占一行长度、宽度、边距都可以控制宽度默认是容器(父级宽度)的100%是一个容器及盒子,里面可以放行内或者…

matplotlib画堆叠、并列直方图

在用 matplotlib.pyplot.hist 画分布图时&#xff0c;若总分布由几个分量组成&#xff08;如高斯混合&#xff09;&#xff0c;想用不同颜色标识出来&#xff0c;方便看到各分量占比&#xff0c;参考 [1]。 效果&#xff1a; 分布由两个分量&#xff08;x、y&#xff09;组成…

Web入门

一Spring简单介绍&#xff1a; Spring Boot 是基于Spring的但是&#xff0c;Spring更为简单高效。 1.2Spring Boot快速入门&#xff1a; 二HTTP协议&#xff1a; 2.1HTTP协议概述 2.2请求协议 <!DOCTYPE html> <html lang"en"> <head><meta ch…

ArkTS 基础组件

目录 一、常用组件 二、文本显示&#xff08;Text/Span) 2.1 创建文本 2.2 属性 2.3 添加子组件(Span) 2.4 添加事件 三、按钮&#xff08;Button&#xff09; 3.1 创建按钮 3.2 设置按钮类型 3.3 悬浮按钮 四、文本输入&#xff08;TextInput/TextArea&#xff09;…

十四、GPT

在GPT-1之前&#xff0c;传统的 NLP 模型往往使用大量的数据对有监督的模型进行任务相关的模型训练&#xff0c;但是这种有监督学习的任务存在两个缺点&#xff1a;预训练语言模型之GPT 需要大量的标注数据&#xff0c;高质量的标注数据往往很难获得&#xff0c;因为在很多任务…

数据结构和算法:哈希表

哈希表 哈希表&#xff08;hash table&#xff09;&#xff0c;又称散列表&#xff0c;它通过建立键 key 与值 value 之间的映射&#xff0c;实现高效的元素查询。具体而言&#xff0c;向哈希表中输入一个键 key &#xff0c;则可以在 &#x1d442;(1) 时间内获取对应的值 va…

PyCharm实现一个简单的注册登录Django项目

之前已经实现了一个简单的Django项目&#xff0c;今天我们j基于之前的项目来实现注册、登录以及登录成功之后跳转到StuList页面。 1、连接数据库 1.1 配置数据库信息&#xff1a; 首先在myweb的settings.py 文件中设置MySQL数据库连接信息&#xff1a; DATABASES {default…