TTS | 轻量级VITS2的项目实现以及API设置

news2025/2/25 21:34:37

 本文主要是实现了MB-iSTFT-VITS2语音合成模型的训练,相比于VITS模型,MB-iSTFT-VITS模型相对来说会小一点,最重要的是在合成结果来看,MB-iSTFT-VITS模型推理更快,更加自然(个人经验).项目地址如下:

FENRlR/MB-iSTFT-VITS2: Application of MB-iSTFT-VITS components to vits2_pytorch (github.com)

 目前项目还未来得及发表论文,且项目还在完善中(截止到2023.10.18)。

目录

 0.环境设置

1.项目设置及数据处理

2.训练

3.设置API

过程中遇到的问题及解决(PS) 

[PS1]训练时出现cuda out of m

[PS2]训练后出现线程相关错误

[PS3]API设置时出现,RuntimeError: Invalid device, must be cuda device

[PS4]RuntimeError: Error(s) in loading state_dict for SynthesizerTrn:        size mismatch for enc_q.pre.weight: copying a param with shape torch.Size([192, 80, 1]) from checkpoint, the shape in current model is torch.Size([192, 513, 1]).

[PS5]TypeError: load_checkpoint() got an unexpected keyword argument 'skip_optimizer'


 0.环境设置

docker镜像容器(Linux20.04+Pytorch1.13.1+torchvision0.14.1+cuda11.7+python3.8),

1.项目设置及数据处理

# 克隆项目到本地
git clone https://github.com/FENRlR/MB-iSTFT-VITS2

cd MB-iSTFT-VITS2

#安装所需要的库
pip install -r requirements.txt
apt-get install espeak

# 文本预处理

## 选择1 : 单人数据集
python preprocess.py --text_index 1 --filelists PATH_TO_train.txt --text_cleaners CLEANER_NAME
python preprocess.py --text_index 1 --filelists PATH_TO_val.txt --text_cleaners CLEANER_NAME


## 选择2 : 多人数据集 
python preprocess.py --text_index 2 --filelists PATH_TO_train.txt --text_cleaners CLEANER_NAME
python preprocess.py --text_index 2 --filelists PATH_TO_val.txt --text_cleaners CLEANER_NAME

# 设置MAS
cd monotonic_align
mkdir monotonic_align
python setup.py build_ext --inplace

前期设置与vits/vits2基本相同

编辑配置文件

2.训练


# 选择1 : 单人数据集训练
python train.py -c configs/mb_istft_vits2_base.json -m models/test



# 选择2 : 多人数据集训练 
python train_ms.py -c configs/mb_istft_vits2_base.json -m models/test

训练后生成

 训练过程

3.设置API

webui.py


import sys, os
import logging
import re

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

logging.basicConfig(
    level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)

logger = logging.getLogger(__name__)

import torch
import argparse
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
#from text import cleaned_text_to_sequence, get_bert
from text import text_to_sequence
#from text.cleaner import clean_text
import gradio as gr
import webbrowser
import numpy as np


net_g = None

if sys.platform == "darwin" and torch.backends.mps.is_available():
    device = "mps"
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
    device = "cuda"

   

def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def langdetector(text):  # from PolyLangVITS
    try:
        lang = langdetect.detect(text).lower()
        if lang == 'ko':
            return f'[KO]{text}[KO]'
        elif lang == 'ja':
            return f'[JA]{text}[JA]'
        elif lang == 'en':
            return f'[EN]{text}[EN]'
        elif lang == 'zh-cn':
            return f'[ZH]{text}[ZH]'
        else:
            return text
    except Exception as e:
        return text

def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
    global net_g
    fltstr = re.sub(r"[\[\]\(\)\{\}]", "", text)
    stn_tst = get_text(fltstr, hps)

    speed = 1
    output_dir = 'output'
    sid = 0
    with torch.no_grad():
        x_tst = stn_tst.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1 / speed)[0][
                0, 0].data.cpu().float().numpy()

    return audio

def tts_fn(
    text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale
):
    slices = text.split("|")
    audio_list = []
    with torch.no_grad():
        for slice in slices:
            audio = infer(
                slice,
                sdp_ratio=sdp_ratio,
                noise_scale=noise_scale,
                noise_scale_w=noise_scale_w,
                length_scale=length_scale,
                sid=speaker,

            )
            audio_list.append(audio)
            silence = np.zeros(hps.data.sampling_rate)  # 生成1秒的静音
            audio_list.append(silence)  # 将静音添加到列表中
    audio_concat = np.concatenate(audio_list)
    return "Success", (hps.data.sampling_rate, audio_concat)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m", "--model", default="/workspace/tts/MB-iSTFT-VITS2/logs/models/G_94000.pth", help="path of your model"
    )
    parser.add_argument(
        "-c",
        "--config",
        default="/workspace/tts/MB-iSTFT-VITS2/logs/models/config.json",
        help="path of your config file",
    )
    parser.add_argument(
        "--share", default=False, help="make link public", action="store_true"
    )
    parser.add_argument(
        "-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
    )

    args = parser.parse_args()
    if args.debug:
        logger.info("Enable DEBUG-LEVEL log")
        logging.basicConfig(level=logging.DEBUG)
    hps = utils.get_hparams_from_file(args.config)
    
    if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
        print("Using mel posterior encoder for VITS2")
        posterior_channels = 80  # vits2
        hps.data.use_mel_posterior_encoder = True
    else:
        print("Using lin posterior encoder for VITS1")
        posterior_channels = hps.data.filter_length // 2 + 1
        hps.data.use_mel_posterior_encoder = False
    device = (
        "cuda:0"
        if torch.cuda.is_available()
        else (
            "mps"
            if sys.platform == "darwin" and torch.backends.mps.is_available()
            else "cpu"
        )
    )
    net_g = SynthesizerTrn(
        len(symbols),
        posterior_channels,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers, #- >0 for multi speaker
        **hps.model
    ).to(device)
    _ = net_g.eval()

    #_ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)
    _ = utils.load_checkpoint(path_to_model, net_g, None)

    #speaker_ids = hps.data.spk2id
    #speakers = list(speaker_ids.keys())
    speakers = hps.data.n_speakers
    languages = ["KO"]
    with gr.Blocks() as app:
        with gr.Row():
            with gr.Column():
                text = gr.TextArea(
                    label="Text",
                    placeholder="Input Text Here",
                    value="测试文本.",
                )
                '''speaker = gr.Dropdown(
                    choices=speakers, value=speakers[0], label="Speaker"
                )'''
                speaker = gr.Slider(
                    minimum=0, maximum=speakers-1, value=0, step=1, label="Speaker"
                )
                sdp_ratio = gr.Slider(
                    minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
                )
                noise_scale = gr.Slider(
                    minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise Scale"
                )
                noise_scale_w = gr.Slider(
                    minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise Scale W"
                )
                length_scale = gr.Slider(
                    minimum=0.1, maximum=2, value=1, step=0.1, label="Length Scale"
                )
                language = gr.Dropdown(
                    choices=languages, value=languages[0], label="Language"
                )
                btn = gr.Button("Generate!", variant="primary")
            with gr.Column():
                text_output = gr.Textbox(label="Message")
                audio_output = gr.Audio(label="Output Audio")

        btn.click(
            tts_fn,
            inputs=[
                text,
                speaker,
                sdp_ratio,
                noise_scale,
                noise_scale_w,
                length_scale,
            ],
            outputs=[text_output, audio_output],
        )

    webbrowser.open("http://127.0.0.1:7860")
    app.launch(share=True)

运行后实现

 

过程中遇到的问题及解决(PS) 

[PS1]训练时出现cuda out of m

解决办法:

  • 修改batch size
  • 修改num_workers=0

[PS2]训练后出现线程相关错误

Traceback (most recent call last):
  File "/opt/miniconda3/envs/vits/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/tts/MB-iSTFT-VITS2/train.py", line 240, in run
    train_and_evaluate(rank, epoch, hps, [net_g, net_d, net_dur_disc], [optim_g, optim_d, optim_dur_disc],
  File "/workspace/tts/MB-iSTFT-VITS2/train.py", line 358, in train_and_evaluate
    scaler.scale(loss_gen_all).backward()
  File "/opt/miniconda3/envs/vits/lib/python3.8/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/opt/miniconda3/envs/vits/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(OpType=ALLREDUCE, TensorShape=[5248002], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(OpType=ALLREDUCE).

原因分析:

在jupyter lab时调用gpu后出现的错误。

解决方案:

重新再次训练后就解决了,再次训练时,会加载上次训练的权重文件。

[PS3]API设置时出现,RuntimeError: Invalid device, must be cuda device

Traceback (most recent call last):
  File "aoi.py", line 55, in <module>
    model.build_wav(0, "안녕하세요", "./test.wav")
  File "aoi.py", line 50, in build_wav
    audio = self.net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
  File "/workspace/tts/MB-iSTFT-VITS-multilingual/models.py", line 718, in infer
    o, o_mb = self.dec((z * y_mask)[:,:,:max_len], g=g)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/workspace/tts/MB-iSTFT-VITS-multilingual/models.py", line 344, in forward
    pqmf = PQMF(x.device)
  File "/workspace/tts/MB-iSTFT-VITS-multilingual/pqmf.py", line 78, in __init__
    analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).cuda(device)
RuntimeError: Invalid device, must be cuda device

原因分析

1.在不支持cuda(GPU)的机器上,把模型或者数据放到GPU中。

2.因为在训练别的程序,大概率是把卡所有的显存都用上了,所以导致显存不足。

解决办法,停掉正在训练的程序,改小batch size,减少显存占用量。

[PS4]RuntimeError: Error(s) in loading state_dict for SynthesizerTrn:
        size mismatch for enc_q.pre.weight: copying a param with shape torch.Size([192, 80, 1]) from checkpoint, the shape in current model is torch.Size([192, 513, 1]).

解决办法

net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model,
    ).to(device)

改为

    net_g = SynthesizerTrn(
    len(symbols),
    posterior_channels,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=hps.data.n_speakers, #- >0 for multi speaker
    **hps.model).to(device)

[PS5]TypeError: load_checkpoint() got an unexpected keyword argument 'skip_optimizer'

解决办法

将原本的

_ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)

改为

    _ = utils.load_checkpoint(args.model, net_g, None)

[PS6]api设置后运行是空白,编写gradio程序时候,发现任何代码运行起来都是一直显示Loading也不报错。

解决方案:

pip install gradio==3.12.0

pip install gradio==3.23.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1126776.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

冒泡排序:了解原理与实现

目录 原理 实现 性能分析 结论 冒泡排序&#xff08;Bubble Sort&#xff09;是一种简单但效率较低的排序算法。它重复地比较相邻的元素并交换位置&#xff0c;直到整个序列有序为止。虽然冒泡排序的时间复杂度较高&#xff0c;但在小规模数据集上仍然具有一定的实际应用价…

Unity 通过jar包形式接入讯飞星火SDK

最近工作上遇到了要接入gpt相关内容的需求&#xff0c;简单实现了一个安卓端接入讯飞星火的UnitySDK。 或者也可以接入WebSocket接口的。本文只讲安卓实现 我使用的Unity版本为2021.3.27f1c2 Android版本为4.2.2 1.下载SDK 登陆讯飞开放平台下载如图所示SDK 2.新建安卓工程…

【Tensorflow 2.12 简单智能商城商品推荐系统搭建】

Tensorflow 2.12 简单智能商城商品推荐系统搭建 前言架构数据召回排序部署调用结尾 前言 基于 Tensorflow 2.12 搭建一个简单的智能商城商品推荐系统demo~ 主要包含6个部分&#xff0c;首先是简单介绍系统架构&#xff0c;接着是训练数据收集、处理&#xff0c;然后是召回模型、…

一个小的图文编辑软件 -- 采用winform开发

本人用winform开发了一款图文编辑软件&#xff0c;实现了图片、文字、图形混合排版; 可以对图元调整大小、设置角度、添加剪切区间等操作。本人以前也写过一款类似的软件《WinForm版图像编辑小程序》&#xff1b; 最近几年&#xff0c;本人一直从事图形处理方面的开发&#xff…

雷达开发的基本概念fft,cfar,以及Clutter, CFAR,AoA

CFAR Constant False-Alarm Rate的缩写。在雷达信号检测中&#xff0c;当外界干扰强度变化时&#xff0c;雷达能自动调整其灵敏度&#xff0c;使雷达的虚警概率保持不变。具有这种特性的接收机称为恒虚警接收机。雷达信号的检测总是在干扰背景下进行的&#xff0c;这些干扰包括…

SAP PO/PI 设置字段或静态参数到URL

文章目录 需求一、字段内容设置到URL中二、使用静态值三、测试总结 需求 通过PO/PI访问第三方接口并把字段或静态参数设置在URL中 一、字段内容设置到URL中 首先我们在MassageMapping中需要把字段内容发送到DynamicConfiguration中去&#xff0c;利用UDF UDF代码 这里面需要…

编译工具链 之一 基本概念、组成部分、编译过程、命名规则

编译工具链将程序源代码翻译成可以在计算机上运行的可执行程序。编译过程是由一系列的步骤组成的&#xff0c;每一个步骤都有一个对应的工具。这些工具紧密地工作在一起&#xff0c;前一个工具的输出是后一个工具的输入&#xff0c;像一根链条一样&#xff0c;我们称这一系列工…

【汇编】第一个汇编程序(学习笔记)

一、程序从编写到执行的过程 1、编写 Notepad / UltraEdit 汇编语言 2、编译、连接 MASM.EXE&#xff1a;编译产生目标文件 LINK.EXE&#xff1a;连接&#xff0c;产生可执行文件 连接作用&#xff1a;源程序分为多个子程序编译后&#xff0c;连接在一起。或程序调用其他…

【JavaEE】网络编程---UDP数据报套接字编程

一、UDP数据报套接字编程 1.1 DatagramSocket API DatagramSocket 是UDP Socket&#xff0c;用于发送和接收UDP数据报。 DatagramSocket 构造方法&#xff1a; DatagramSocket 方法&#xff1a; 1.2 DatagramPacket API DatagramPacket是UDP Socket发送和接收的数据报。…

SQL查询优化---单表使用索引及常见索引失效优化

如何避免索引失效 1、全值匹配 系统中经常出现的sql语句如下&#xff1a; EXPLAIN SELECT SQL_NO_CACHE * FROM emp WHERE emp.age30 EXPLAIN SELECT SQL_NO_CACHE * FROM emp WHERE emp.age30 and deptid4EXPLAIN SELECT SQL_NO_CACHE * FROM emp WHERE emp.age30 and dept…

美团真题解析

文章目录 &#x1f31f; 美团真题解析&#x1f34a; 美团面试真题-美团招聘简介&#x1f34a; 美团面试真题-介绍一下MyBatis的缓存机制&#x1f389; 一级缓存&#x1f389; 二级缓存 &#x1f34a; 美团面试真题-谈谈jvm的内存模型&#x1f34a; 美团面试真题-谈谈你知道的垃…

手写 Promise(1)核心功能的实现

一&#xff1a;什么是 Promise Promise 是异步编程的一种解决方案&#xff0c;其实是一个构造函数&#xff0c;自己身上有all、reject、resolve这几个方法&#xff0c;原型上有then、catch等方法。 Promise对象有以下两个特点。 &#xff08;1&#xff09;对象的状态不受…

[SQL开发笔记]WHERE子句 : 用于提取满足指定条件的记录

SELECT DISTINCT语句用户返回列表的唯一值&#xff1a;这是一个很特定的条件&#xff0c;假设我需要考虑很多中限制条件进行查询呢&#xff1f;这时我们就可以使用WHERE子句进行条件的限定 一、功能描述&#xff1a; WHERE子句用于提取满足指定条件的记录&#xff1b; 二、WH…

nginx快速部署一个网站服务 + 多域名 + 多端口

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01; &#x1f40b; 希望大家多多支…

C#,数值计算——分类与推理Phylo_slc的计算方法与源程序

1 文本格式 using System; using System.Collections.Generic; namespace Legalsoft.Truffer { public class Phylo_slc : Phylagglom { public override void premin(double[,] d, int[] nextp) { } public override double dminfn(double[…

OS 处理机调度

目录 处理机调度的层次 高级调度 作业 作业控制块 JCB 作业调度的主要任务 低级调度 中级调度 进程调度 进程调度时机 进程调度任务 进程调度机制 排队器 分派器 上下文切换器 进程调度方式 非抢占调度方式 抢占调度方式 调度算法 处理机调度算法的目标 处理…

UE5 虚幻引擎中UI、HUD和UMG的区别与联系

目录 0 引言1 UI 用户界面2 HUD 用户界面3 UMG4 总结 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xff1a;UE虚幻引擎专栏&#x1f4a5; 标题&#xff1a;UE5 虚幻引擎中UI、HUD和UMG的区别与联系❣️ 寄语&#xff1a;加油&#xff0c;一次专注一件事…

Java编写图片转base64

图片转成base64 url &#xff0c; 在我们的工作中也会经常用到&#xff0c;比如说导出 word,pdf 等功能&#xff0c;今天我们尝试写一下。 File file new File("");byte[] data null;InputStream in null;ByteArrayOutputStream out null;try{URL url new URL(&…

C++初阶 入门(2)

目录 一、缺省函数 1.1什么是缺省函数 1.2为什么要有缺省函数 1.3使用缺省函数 1.4测试代码 二、函数重载 2.1什么是函数重载 2.2为什么要有函数重载 2.3什么情况构成函数重载 2.4函数重载例子及代码 三、引用 3.1什么是引用 3.2如何引用 ​3.3常引用(可略过) 3…

【宝塔面板建站】本地连接云服务器的数据库 以阿里云服务器为例子(保姆级图文)

目录 实现效果实现过程1. 获取云服务的数据库root密码 2.尝试本地连接2.1 端口放行2.2 云服务器授权本地访问MySQL权限 实现代码总结 『宝塔面板建站』分享宝塔面板从安装到实战的宝塔面板本机免云服务器免域名搭建网站等内容。 欢迎关注 『宝塔面板建站』 系列&#xff0c;持续…