FastSpeech2中文语音合成就步解析:TTS数据训练实战篇

news2024/10/6 2:44:43
  1. 参考github网址:

GitHub - roedoejet/FastSpeech2: An implementation of Microsoft’s “FastSpeech 2: Fast and High-Quality End-to-End Text to Speech”

  1. 数据训练所用python 命令:

python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml

  1. 数据训练代码解析

3.1 代码架构overview:

通过 if __name__ == "__main__"运行整个py文件:

调用 “train.txt"和dataset.py加载数据,

调用utils文件夹下的model.py加载模型,声码器,

调用model文件夹下的loss.py中的FastSpeech2Loss class 设置损失函数,

用前面加载的模型和损失函数开始训练模型,导出结果并记录日志。

3.2 按训练步骤分解代码:

Step 0 : 定义可控训练参数, 调动main函数

if __name__ == "__main__":

    #Define Args
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args() #args为可控训练参数

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    #Run _main_ function
    main(args, configs)

Step 1 : 启动main函数,加载可控训练参数

def main(args, configs): 
    print("Prepare training ...")

    #加载可控训练参数
    preprocess_config, model_config, train_config = configs

Step 2 : 从train.txt加载数据,并经由dataset.py和torch里的Dataloader处理

def main(args, configs):

    # Get dataset
    dataset = Dataset(
        "train.txt", preprocess_config, train_config, sort=True, drop_last=True
    ) #从 train.txt 中获取dataset
    batch_size = train_config["optimizer"]["batch_size"]
    group_size = 4  # Set this larger than 1 to enable sorting in Dataset,初始值为4

    assert batch_size * group_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
    )

Step 3 : 定义模型,声码器,损失函数

def main(args, configs):

    # Prepare model
    model, optimizer = get_model(args, configs, device, train=True) #设置优化器

    # 将模型并行训练并移入计算设备中
    model = nn.DataParallel(model) # Model Has Been Defined

    # 计算模型参数量
    num_param = get_param_num(model) # Number of TTS Parameters: num_param
    print("Number of FastSpeech2 Parameters:", num_param)

    # 设置损失函数
    Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)

    # 加载声码器
    vocoder = get_vocoder(model_config, device)

Step 4 : 加载日志,在"./output/log/AISHELL3"目录建立train, val两个文件夹来记录日志

def main(args, configs):

    # Init logger
    for p in train_config["path"].values():
        os.makedirs(p, exist_ok=True)
    train_log_path = os.path.join(train_config["path"]["log_path"], "train")
    val_log_path = os.path.join(train_config["path"]["log_path"], "val")
    os.makedirs(train_log_path, exist_ok=True)
    os.makedirs(val_log_path, exist_ok=True)
    train_logger = SummaryWriter(train_log_path)
    val_logger = SummaryWriter(val_log_path)

Step 5 : 准备训练,加载可控训练参数

def main(args, configs):

    # Training
    step = args.restore_step + 1
    epoch = 1
    grad_acc_step = train_config["optimizer"]["grad_acc_step"]
    grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
    total_step = train_config["step"]["total_step"]
    log_step = train_config["step"]["log_step"]
    save_step = train_config["step"]["save_step"]
    synth_step = train_config["step"]["synth_step"]
    val_step = train_config["step"]["val_step"]

    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = args.restore_step
    outer_bar.update()

Step 6 : 准备训练,加载进度条,调动utils文件夹下tools.py中的to_device function来提取数据

    while True:
        inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
        for batchs in loader:
            for batch in batchs:
                batch = to_device(batch, device)

Step 7 :开始训练,前向传播,计算损失,反向传播,梯度剪枝,更新模型权重参数

    #Load Data
            for batch in batchs:
                batch = to_device(batch, device)
                
                # Forward
                output = model(*(batch[2:]))

                # Cal Loss
                losses = Loss(batch, output)
                total_loss = losses[0]

                # Backward
                total_loss = total_loss / grad_acc_step
                total_loss.backward()
                if step % grad_acc_step == 0:
                    # Clipping gradients to avoid gradient explosion
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

                    # Update weights
                    optimizer.step_and_update_lr()
                    optimizer.zero_grad()

Step 8 : 当训练步数到达预先设定的log_step时,调动utils文件夹下tool.py里的log function,记录loss和step

                if step % log_step == 0:
                    losses = [l.item() for l in losses]
                    message1 = "Step {}/{}, ".format(step, total_step)
                    message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
                        *losses
                    )

                    with open(os.path.join(train_log_path, "log.txt"), "a") as f:
                        f.write(message1 + message2 + "\n")

                    outer_bar.write(message1 + message2)

                    log(train_logger, step, losses=losses)

Step 9 : 当训练步数到达预先设定的synth_step时,调动utils文件夹下tool.py里的log function 和 synth_one_sample function(具体用来干什么没看懂)

                if step % synth_step == 0:
                    fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
                        batch,
                        output,
                        vocoder,
                        model_config,
                        preprocess_config,
                    )
                    log(
                        train_logger,
                        fig=fig,
                        tag="Training/step_{}_{}".format(step, tag),
                    )
                    sampling_rate = preprocess_config["preprocessing"]["audio"][
                        "sampling_rate"
                    ]
                    log(
                        train_logger,
                        audio=wav_reconstruction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_reconstructed".format(step, tag),
                    )
                    log(
                        train_logger,
                        audio=wav_prediction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_synthesized".format(step, tag),
                    )

Step 10 : 当训练步数到达预先设定的val_step时,调动evaluate.py里的evaluate function来进行evaluation,并记录在log/AISHELL3/val/log.txt

                if step % val_step == 0:
                    model.eval()
                    message = evaluate(model, step, configs, val_logger, vocoder)
                    with open(os.path.join(val_log_path, "log.txt"), "a") as f:
                        f.write(message + "\n")
                    outer_bar.write(message)

                    model.train()

Step 11 : 当训练步数到达预先设定的save_step时,保存训练模型

                if step % save_step == 0:
                    torch.save(
                        {
                            "model": model.module.state_dict(),
                            "optimizer": optimizer._optimizer.state_dict(),
                        },
                        os.path.join(
                            train_config["path"]["ckpt_path"],
                            "{}.pth.tar".format(step),
                        ),
                    )

Step 12 : 当训练步数到达预先设定的total_step时,退出训练

                if step == total_step:
                    quit()
                step += 1
                outer_bar.update(1)

            inner_bar.update(1)
        epoch += 1

  1. 数据训练代码的输出

在train_log_path和val_log_path输出日志

在ckpt_path输出训练过程中按照save_step存储的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1903924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

增强安全防护,解读智慧校园系统的登录日志功能

在构建智慧校园系统时&#xff0c;登录日志功能扮演着不可或缺的角色&#xff0c;它不仅是系统安全的守护者&#xff0c;也是提升管理效率和确保合规性的有力工具。这一机制详细记录每次登录尝试的方方面面&#xff0c;涵盖了时间戳、用户身份、登录来源的IP地址乃至使用的设备…

python脚本“文档”撰写——“诱骗”ai撰写“火火的动态”python“自动”脚本文档

“火火的动态”python“自动”脚本文档&#xff0c;又从ai学习搭子那儿“套”来&#xff0c;可谓良心质量&#x1f44d;&#x1f44d;。 (笔记模板由python脚本于2024年07月07日 15:15:33创建&#xff0c;本篇笔记适合喜欢钻研python和页面源码的coder翻阅) 【学习的细节是欢悦…

Python遥感开发之批量对TIF数据合并

Python遥感开发之批量对TIF数据合并 1 原始数据展示2 按照年份合成春夏秋冬季节数据3 完整代码实现 前言&#xff1a;通常对遥感数据按照月和年或者季节进行分析&#xff0c;我们需要对我们下载的8天或者16天数据按照需求进行合并&#xff0c;对数据的合并一般可以采取均值法、…

移动校园(5):课程表数据获取及展示

首先写下静态页面&#xff0c;起初打算做成一周的课表&#xff0c;由于是以小程序的形式展现&#xff0c;所以显示一周的话会很拥挤&#xff0c;所以放弃下面的方案&#xff0c;改作一次显示一天 改后结果如下&#xff0c;后期还会进行外观优化 真正困难的部分是数据获取 大家大…

高德地图 key 和安全密钥使用

参考高德地图&#xff1a;JS API 安全密钥使用 高德地图 key 和安全密钥使用 一、通过明文方式设置参数查看如下成功后返回的信息 二、通过代理服务器转发实验&#xff1a;通过本地地址转发返回错的错误信息&#xff0c;如下通过正确的项目的的服务地址&#xff0c;返回正常参数…

Flink,spark对比

三&#xff1a;az 如何调度Spark、Flink&#xff0c;MR 任务 首先&#xff0c;使用java编写一个spark任务&#xff0c;定义一个类&#xff0c;它有main方法&#xff0c;里面写好逻辑&#xff0c;sparkConf 和JavaSparkContext 获取上下文&#xff0c;然后打成一个jar包&#xf…

深度学习之网络构建

目标 选择合适的神经网络 卷积神经网络&#xff08;CNN&#xff09;&#xff1a;我们处理图片、视频一般选择CNN 循环神经网络&#xff08;RNN&#xff09;&#xff1a;我们处理时序数据一般选择RNN 超参数的设置 为什么训练的模型的错误率居高不下 如何调测出最优的超参数 …

Node.js介绍 , 安装与使用

1.Node.js 1 什么是Node.js 官网&#xff1a;https://nodejs.org/zh-cn/ 中文学习网&#xff1a;http://nodejs.cn/learn1.Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行环境。Node.js 使用了一个事件驱动、非阻塞式 I/O 的模型,使其轻量又高效。 2.前端的底层 html…

C++、QT企业管理系统

目录 一、项目介绍 二、项目展示 三、源码获取 一、项目介绍 人事端&#xff1a; 1、【产品中心】产品案列、新闻动态的发布&#xff1b; 2、【员工管理】新增、修改、删除、搜索功能&#xff1b;合同以图片的方式上传 3、【考勤总览】根据日期显示所有员工上班、下班时间…

下载nvm 管理多个node版本并切换

nvm管理多个node版本并切换 安装nvm时不能安装任何node版本&#xff08;先删除node再安装nvm&#xff09;&#xff0c;保证系统无任何node.js残留 1. 卸载node 控制面板中卸载nodejs 删除以下文件夹&#xff1a; C:\Program Files (x86)\Nodejs C:\Program Files\Nodejs C…

聚类分析方法(一)

目录 一、聚类分析原理&#xff08;一&#xff09;聚类分析概述&#xff08;二&#xff09;聚类的数学定义&#xff08;三&#xff09;簇的常见类型&#xff08;四&#xff09;聚类框架及性能要求&#xff08;五&#xff09;簇的距离 二、划分聚类算法&#xff08;一&#xff0…

【matlab】状态空间模型与传递函数模型的建立与转换

目录 SISO系统 MIMO系统 状态空间模型 状态空间模型到传递函数模型的转换 传递函数模型到状态空间模型的转换 (1) 转换函数ss() (2) 规范形转换函数canon() (3) 常微分方程(传递函数)转换为状态空间模型函数dif2ss() 状态空间模型的变换 特征值、特征向量与广义特征向量的计算…

进程控制-wait和waitpid进程回收

wait 阻塞函数 函数作用&#xff1a; 1. 阻塞并等待子进程退出 2. 回收子进程残留资源 3. 获取子进程结束状态&#xff08;退出原因&#xff09; pid_t wait(int *wstatus); 返回值&#xff1a; ‐1 : 回收失败&#xff0c;已经没有子进程了 >0 : 回收子进程对应的…

线上问题---反思与回顾

线上问题一&#xff1a;麦哲伦2.0 人群配置不生效 发现背景&#xff1a;产品发现三层模型部分计划个别测试计划圈选人群特征与数仓统计数据的人群不一致&#xff0c;向值班人员反馈 根因定位&#xff1a; &#xff08;1&#xff09;用户配置三层模型计划时&#xff0c;配置单…

文心一言 VS 讯飞星火 VS chatgpt (297)-- 算法导论22.1 1题

一、给定有向图的邻接链表&#xff0c;需要多长时间才能计算出每个结点的出度(发出的边的条数)&#xff1f;多长时间才能计算出每个结点的入度(进入的边的条数)&#xff1f;如果要写代码&#xff0c;请用go语言。 文心一言&#xff1a; 计算出度 对于有向图的邻接链表表示&a…

rufus-4.5 制作 Clonezilla(再生龙)启动盘报syslinux-6.04下载错误

1、官网下载rufus 官网下载rufus-4.5&#xff0c;下载地址&#xff1a;https://rufus.ie/downloads/ 2、下载再生龙&#xff08;Clonezilla&#xff09; 下载最新版本&#xff1a; Clonezilla live 版本: 3.1.2-22&#xff1a;https://sourceforge.net/projects/clonezill…

在 Docker 容器中运行 Vite 开发环境,有这两个问题要注意

容器化开发给我们带来了很多便捷&#xff0c;但是在开发环境下也有一些问题要注意&#xff0c;如果不解决这些问题&#xff0c;你的开发体验不会很好。 容器启动正常&#xff0c;却无法访问 我们用 Docker 启动一个 Vite Vue3 项目的开发环境后&#xff0c;发现端口日志一切…

Springboot使用WebSocket发送消息

1. 创建springboot项目&#xff0c;引入spring-boot-starter-websocket依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency>完整项目依赖 <?xml ver…

【游戏客户端】大话版本slg玩法正式上线~~

【游戏客户端】制作率土之滨Like玩法 大家好&#xff0c;我是Lampard家杰~~ 好久好久没有更新博客了&#xff0c;有不少大佬都在后台私信我催更&#xff0c;但是很悲伤这段时间都忙的不行QAQ 那在忙什么呢&#xff1f;就是在制作一个SLG类的玩法【帮派纷争】啦 &#xff0c;布…

RNN、LSTM与GRU循环神经网络的深度探索与实战

循环神经网络RNN、LSTM、GRU 一、引言1.1 序列数据的迷宫探索者&#xff1a;循环神经网络&#xff08;RNN&#xff09;概览1.2 深度探索的阶梯&#xff1a;LSTM与GRU的崛起1.3 撰写本博客的目的与意义 二、循环神经网络&#xff08;RNN&#xff09;基础2.1 定义与原理2.1.1 RNN…