Keras深度学习实战(41)——语音识别

news2024/9/25 23:25:10

Keras深度学习实战(41)——语音识别

    • 0.前言
    • 1. 模型与数据集分析
      • 1.1 数据集分析
      • 1.2 模型分析
    • 2. 语音识别模型
      • 2.1 数据加载与预处理
      • 2.2 模型构建与训练
    • 小结
    • 系列链接

0.前言

语音识别(Automatic Speech Recognition, ASR,或称语音转录文本)使声音变得"可读",让计算机能够"听懂"人类的语言并做出相应的操作,是人工智能实现人机交互的关键技术之一。在《图像字幕生成》一节中,我们已经学习了如何将手写文本图像转录为文本,在本节中,我们将利用类似的端到端模型实现将语音转录文本模型,将语音文件转录为文字。

1. 模型与数据集分析

1.1 数据集分析

为了构建语音转录文本模型,我们所使用的数据集中包含了大约 29000 条语音文件及其对应的文本,相关数据集可以在 openslr 链接中下载,下载完成后解压缩,可以看到文件夹 train-clean-100 中的有若干子目录,每个目录下都有数条音频文件和对应的文本数据。

数据示例

1.2 模型分析

在我们继续实现语音转文字之前,首先简单介绍模型采用的转录语音策略流程:

  • 下载包含音频文件及其相对应的转录文本(真实标签)的数据集
  • 在读取音频文件时指定采样率:
    • 如果采样率为 16000,则每秒可以提取 16000 个数据样本点
  • 提取音频序列的快速傅立叶变换 (Fast Fourier Transformation, FFT):
    • 使用 FFT 可以确保我们仅提取信号最重要的特征
    • 默认情况下,FFT 获取 n / 2 个数据样本点,其中 n 是整个音频记录中的数据样本点数
  • 采样音频的 FFT 特征,一次提取 320 个数据样本点;也就是说,我们一次提取 20 毫秒 (320/16000 = 1/50秒) 的音频数据
  • 此外,我们将每隔 10 毫秒的时间间隔采样 20 毫秒的音频数据
  • 本节中,为了降低模型的复杂度,作为演示目的我们仅使用音频持续时间小于 20 秒的音频记录
  • 将每次采样的 20 毫秒音频数据存储到一个数组中:
    • 每隔 10 毫秒采样 20 毫秒的数据
    • 因此,对于一秒钟的音频剪辑,有 100 x 320 个数据样本点,对于 10 秒钟的音频剪辑,有 1000 x 320 = 320000 个数据样本点
  • 初始化一个包含 160000 个数据样本点的空数组,并用 FFT 值填充这些值——我们已经知道 FFT 值是原始数据样本点数的一半
  • 对于每个 1000 x 320 数据样本点的数组,存储相应的转录文本
  • 为每个字符分配一个索引,然后将输出转换为索引列表
  • 此外,还需要存储输入长度作为预定义的时间戳数以及转录文本长度作为输出中出现的实际字符数
  • 基于实际输出、预测输出、时间戳数(输入长度)和转录文本长度(输出中的字符数)定义 CTC 损失函数
  • 定义模型,该模型综合使用 conv1DGRU,同时在模型中使用批归一化对数据进行归一化,以避免出现梯度消失问题
  • 每次使用 mini batch 训练该模型,首次随机采样一个 batch 数据,将其输入到构建的模型中,以最大程度地减少 CTC 损失
  • 最后,使用 ctc_decode 方法对测试数据样本点上的模型预测进行解码

2. 语音识别模型

接下来,我们实现在上一小节中讨论的语音识别模型。

2.1 数据加载与预处理

(1) 首先,导入相关的软件包,并遍历数据集中所有音频文件及其对应的转录文本,然后将它们存储到列表中:

import librosa
import numpy as np
import os
import re
import random
from matplotlib import pyplot as plt

org_path = 'train-clean-100/LibriSpeech/train-clean-100/'
count = 0
inp = []
k = 0
audio_name = []
audio_trans = []
for dir1 in os.listdir(org_path):
    dir2_path = org_path + dir1 + '/'
    for dir2 in os.listdir(dir2_path):
        dir3_path = dir2_path + dir2 + '/'
        for audio in os.listdir(dir3_path):
            if audio.endswith('.txt'):
                k += 1
                file_path = dir3_path + audio
                with open(file_path) as f:
                    lines = f.readlines()
                    for line in lines:
                        audio_name.append(dir3_path + line.split()[0] + '.flac')
                        words2 = line.split()[1:]
                        words3 = ' '.join(words2)
                        audio_trans.append(words3)

(2) 将转录文本长度存储到列表中,以便我们获取最大转录文本长度:

len_audio_name = []
for i in range(len(audio_name)):
    tt = re.sub(' ','-',audio_trans[i])
    len_audio_name.append(len(tt))

(3) 为了能够在单个 GPU 上训练模型,我们将仅使用转录文本长度小于 100 个字符的音频文件进行训练(如果想要获取性能更加优异的模型,在 GPU 内存允许的情况下可以使用长度更高的音频文件,以提高训练数据集大小):

final_audio_name = []
final_audio_trans = []
for i in range(len(audio_name)):
    if(len_audio_name[i]<100):
        final_audio_name.append(audio_name[i])
        final_audio_trans.append(audio_trans[i])

在以上的代码中,我们仅存储转录文本长度少于 100 个字符的音频记录的音频名称和相应的音频转录文本。

(4) 将输入存储为 2D 数组,并仅存储持续时间少于 10 秒的音频文件的相应输出:

inp = []
inp2 = []
op = []
op2 = []

for j in range(len(final_audio_name)):
    t = librosa.core.load(final_audio_name[j],sr=16000, mono= True) 
    if(t[0].shape[0]<160000):
        t = np.array(t[0])
        t2 = np.zeros(160000)
        t2[:len(t)] = t
        inp = []
        for i in range(t2.shape[0]//160-1):
            k = t2[(i*160):((i*160)+320)]
            fft = np.fft.rfft(k)
            inp.append(np.abs(fft))
        inp2.append(inp)
        op2.append(final_audio_trans[j])

(5) 为数据中的每个不重复字符创建一个索引:

import itertools
list2d = op2
charList = list(set(list(itertools.chain(*list2d))))

(6) 创建用于存储输入和转录文本长度的 Numpy 数组,我们创建的输入长度为 243,因此之后创建模型的输出也将具有 243 个时间戳:

num_audio = len(op2)
y2 = []
input_lengths = np.ones((num_audio,1))*243
label_lengths = np.zeros((num_audio,1))
for i in range(num_audio):
    val = list(map(lambda x: charList.index(x), op2[i]))
    while len(val)<243:
        val.append(len(charList)+1)
    y2.append(val)
    label_lengths[i] = len(op2[i])
    input_lengths[i] = 243

2.2 模型构建与训练

(1) 定义CTC损失函数:

import keras.backend as K
def ctc_loss(args):
    y_pred, labels, input_length, label_length = args
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

(2) 定义语音识别模型:

from keras.layers import Input, BatchNormalization, Conv1D, GRU, concatenate
from keras.layers import TimeDistributed, Dense, Activation, Lambda
from keras.models import Model

input_data = Input(name='the_input', shape = (999,161), dtype='float32')
inp = BatchNormalization(name="inp")(input_data)
conv= Conv1D(filters=220, kernel_size = 11,strides = 2, padding='valid',activation='relu')(inp)
conv = BatchNormalization(name="Normal0")(conv)
conv1= Conv1D(filters=220, kernel_size = 11,strides = 2, padding='valid',activation='relu')(conv)
conv1 = BatchNormalization(name="Normal1")(conv1)
gru_3 = GRU(512, return_sequences = True, name = 'gru_3')(conv1)
gru_4 = GRU(512, return_sequences = True, go_backwards = True, name = 'gru_4')(conv1)
merged = concatenate([gru_3, gru_4])
normalized = BatchNormalization(name="Normal")(merged)
dense = TimeDistributed(Dense(30))(normalized)
y_pred = TimeDistributed(Activation('softmax', name='softmax'))(dense)
Model(inputs = input_data, outputs = y_pred).summary()

(3) 定义优化器以及 CTC 损失函数的输入和输出参数:

from keras.optimizers import Adam
optimizer = Adam(lr = 0.001)
labels = Input(name = 'the_labels', shape=[243], dtype='float32')
input_length = Input(name='input_length', shape=[1],dtype='int64')
label_length = Input(name='label_length',shape=[1],dtype='int64')
output = Lambda(ctc_loss, output_shape=(1,),name='ctc')([y_pred, labels, input_length, label_length])

(4) 构建并编译模型:

model = Model(inputs = [input_data, labels, input_length, label_length], outputs= output)
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = optimizer, metrics = ['acc'])

该模型的简要架构信息输出如下:

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
the_input (InputLayer)          [(None, 999, 161)]   0                                            
__________________________________________________________________________________________________
inp (BatchNormalization)        (None, 999, 161)     644         the_input[0][0]                  
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, 495, 220)     389840      inp[0][0]                        
__________________________________________________________________________________________________
Normal0 (BatchNormalization)    (None, 495, 220)     880         conv1d[0][0]                     
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 243, 220)     532620      Normal0[0][0]                    
__________________________________________________________________________________________________
Normal1 (BatchNormalization)    (None, 243, 220)     880         conv1d_1[0][0]                   
__________________________________________________________________________________________________
gru_3 (GRU)                     (None, 243, 512)     1127424     Normal1[0][0]                    
__________________________________________________________________________________________________
gru_4 (GRU)                     (None, 243, 512)     1127424     Normal1[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 243, 1024)    0           gru_3[0][0]                      
                                                                 gru_4[0][0]                      
__________________________________________________________________________________________________
Normal (BatchNormalization)     (None, 243, 1024)    4096        concatenate[0][0]                
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, 243, 30)      30750       Normal[0][0]                     
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (None, 243, 30)      0           time_distributed[0][0]           
==================================================================================================
Total params: 3,214,558
Trainable params: 3,211,308
Non-trainable params: 3,250
__________________________________________________________________________________________________

(5) 每次从输入数据中采样一个 mini batch 的数据进行训练,按照以上步骤循环训练,提取了 20000mini batch 的数据,对输入数据进行归一化,并拟合模型:

x = np.asarray(inp2)
y2 = np.asarray(y2)
l_train = []
for i in range(20000):
    samp=random.sample(range(len(inp2)-25),32)
    batch_input=[inp2[i] for i in samp]
    batch_input = np.array(batch_input)
    batch_input = batch_input / np.max(inp2)
    batch_output = [y2[i] for i in samp]
    batch_output = np.array(batch_output)
    input_lengths2 = [input_lengths[i] for i in samp]
    label_lengths2 = [label_lengths[i] for i in samp]
    input_lengths2 = np.array(input_lengths2)
    label_lengths2 = np.array(label_lengths2)
    inputs = {'the_input': batch_input,
            'the_labels': batch_output,
            'input_length': input_lengths2,
            'label_length': label_lengths2}
    outputs = {'ctc': np.zeros([32])} 
    history = model.fit(inputs, outputs, batch_size = 32, epochs=2, verbose =1)
    if i % 100:
        l_train.append(history.history['loss'][0])

此外,由于该数据集和模型组合的 CTC 损失降低较为缓慢,因此需要大量的时间进行训练。

(6) 根据训练完成的模型,预测测试音频。指定模型 model2,输入测试数组并在243个时间戳中的每个时间步中提取模型预测:

model2 = Model(inputs = input_data, outputs = y_pred)

k=-12
pred= model2.predict(np.array(inp2[k]).reshape(1,999,161)/np.max(inp2))

在以上代码中,我们使用输入数组的倒数第 12 个数据样本,并利用训练后的模型预测该数据样本。我们将输入数据传递给训练后的模型,并以与模型训练过程相同的方式对输入数据进行预处理。

(7) 定义函数用于解码模型对测试数据样本点的预测结果,我们使用 ctc_decode 方法对预测进行解码。最后,通过调用定义的函数来解码预测,打印预测结果:

def decoder(pred):
    pred_ints = (K.eval(K.ctc_decode(pred,[243])[0][0])).flatten().tolist()
    out = ""
    for i in range(len(pred_ints)):
        if pred_ints[i]<28:
            out = out+charList[pred_ints[i]]
    print(out)

decoder(pred)

预测的输出如下:

AI YOUN MARN KON MAN SOME FHRTATI NER AUHER

尽管前面的输出看起来较为混乱,但在声音上与实际音频确有类似之处。我们可以使用以下方法进一步提高语音转录的准确率:

  • 使用更多的数据样本进行训练
  • 合并自然语言处理模型以对输出执行模糊匹配,以便校正预测的输出

小结

语音识别 (Automatic Speech Recognition, ASR) 是人工智能领域里一个重要的研究方向,是人机交互的重要方式。对于如何实现语音识别,将语音序列转化为文本序列一直以来都是研究人员关注的重点领域,近年来神经网络技术在语音识别领域的应用快速发展,已经成为语音识别领域中主流的声学建模技术。在本节中,我们利用 Keras 实现了端到端的深度神经网络模型,达到将语音文件转录为文字的目的。

系列链接

Keras深度学习实战(1)——神经网络基础与模型训练过程详解
Keras深度学习实战(2)——使用Keras构建神经网络
Keras深度学习实战(3)——神经网络性能优化技术
Keras深度学习实战(4)——深度学习中常用激活函数和损失函数详解
Keras深度学习实战(5)——批归一化详解
Keras深度学习实战(6)——深度学习过拟合问题及解决方法
Keras深度学习实战(7)——卷积神经网络详解与实现
Keras深度学习实战(8)——使用数据增强提高神经网络性能
Keras深度学习实战(9)——卷积神经网络的局限性
Keras深度学习实战(10)——迁移学习详解
Keras深度学习实战(11)——可视化神经网络中间层输出
Keras深度学习实战(12)——面部特征点检测
Keras深度学习实战(13)——目标检测基础详解
Keras深度学习实战(14)——从零开始实现R-CNN目标检测
Keras深度学习实战(15)——从零开始实现YOLO目标检测
Keras深度学习实战(16)——自编码器详解
Keras深度学习实战(17)——使用U-Net架构进行图像分割
Keras深度学习实战(18)——语义分割详解
Keras深度学习实战(19)——使用对抗攻击生成可欺骗神经网络的图像
Keras深度学习实战(20)——DeepDream模型详解
Keras深度学习实战(21)——神经风格迁移详解
Keras深度学习实战(22)——生成对抗网络详解与实现
Keras深度学习实战(23)——DCGAN详解与实现
Keras深度学习实战(24)——从零开始构建单词向量
Keras深度学习实战(25)——使用skip-gram和CBOW模型构建单词向量
Keras深度学习实战(26)——文档向量详解
Keras深度学习实战(27)——循环神经详解与实现
Keras深度学习实战(28)——利用单词向量构建情感分析模型
Keras深度学习实战(29)——长短时记忆网络详解与实现
Keras深度学习实战(30)——使用文本生成模型进行文学创作
Keras深度学习实战(31)——构建电影推荐系统
Keras深度学习实战(32)——基于LSTM预测股价
Keras深度学习实战(33)——基于LSTM的序列预测模型
Keras深度学习实战(34)——构建聊天机器人
Keras深度学习实战(35)——构建机器翻译模型
Keras深度学习实战(36)——基于编码器-解码器的机器翻译模型
Keras深度学习实战(37)——手写文字识别
Keras深度学习实战(38)——图像字幕生成
Keras深度学习实战(39)——音乐音频分类
Keras深度学习实战(40)——音频生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/97589.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openssl加密base64编码

openssl OpenSSL 是一个安全套接字层密码库&#xff0c;囊括主要的密码算法、常用的密钥和证书封装管理功能及SSL协议&#xff0c;并提供丰富的应用程序供测试或其它目的使用。 首先&#xff0c;要安装 openssl: centos命令: sudo yum install openssl-devel ubuntu命令&#x…

WebService基于Baidu OCR和Map API的导航服务

哈尔滨工业大学国家示范性软件学院 《面向服务的软件系统》大作业 项目题目&#xff1a; 基于OCR和地图API的路牌定位与导航服务 项目组成员&#xff1a; 姓名 学号 李启明 120L021920 完成日期&#xff1a; 2022年 12 月 15 日 1.选题 1.1 作业…

NUS CS5477 assignment1

课程链接三维视觉 作业任务任务 课程任务就一个&#xff0c;实现一个Linear Sweep Algorithm&#xff0c;这个算法是用来检测两张图片之间的对应点。 因为SIFT检测如果把检测点的数量增大&#xff0c;可能会存在一些错误错误检测点&#xff0c;所有通常把SIFT检测的点的数量…

内网穿透:在家远程ssh访问学校内部网服务器

注册一个cpolar账号 cpolar官网注册即可&#xff08;邮箱即可&#xff09; cpolar支持http/https/tcp协议&#xff0c;不限制流量&#xff08;花生壳免费只能使用1G流量&#xff09;&#xff0c;也不需要公网ip&#xff0c;只要在服务器上安装客户端即可配置&#xff0c;免费&…

攻防世界-file_include

题目 访问路径获得源码 <?php highlight_file(__FILE__);include("./check.php");if(isset($_GET[filename])){$filename $_GET[filename];include($filename);} ?> 通过阅读php代码&#xff0c;我们明显的可以发现&#xff0c;这个一个文件包含的类型题…

Java项目:ssm校内超市管理系统

作者主页&#xff1a;源码空间站2022 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本系统分为管理员与普通用户两种角色。采用后端SSM框架&#xff0c;前端BootStrap&#xff08;前后端不分离&#xff09;的系统架构模式&#x…

python中调用命令行执行外部程序

&#x1f31e;欢迎来到python的世界 &#x1f308;博客主页&#xff1a;卿云阁 &#x1f48c;欢迎关注&#x1f389;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; &#x1f31f;本文由卿云阁原创&#xff01; &#x1f320;本阶段属于练气阶段&#xff0c;希望各位仙友顺利完成…

STM32的三种更新固件的方式

说明&#xff1a; stm32有三种更新固件的方式&#xff0c;分别为&#xff08;1&#xff09;DFU模式&#xff08; Development Firmware Upgrade 即“开发固件升级”&#xff09;&#xff1b;&#xff08;2&#xff09;SWD/JLINK 下载 &#xff08;3&#xff09;第三方bootload…

NoSQL数据库原理与应用综合项目——HBase篇

NoSQL数据库原理与应用综合项目——HBase篇 文章目录NoSQL数据库原理与应用综合项目——HBase篇0、 写在前面1、本地数据或HDFS数据导入到HBase2、Hbase数据库表操作2.1 Java API 连接HBase2.2 查询数据2.3 插入数据2.4 修改数据2.5 删除数据3、Windows远程连接HBase4、数据及源…

springboot常用组件集成

今天与大家分享spring-mybatis、reids集成&#xff0c;druid数据库连接池。如果有问题&#xff0c;望指教。 1. 创建项目 File -> New -> project ...Spring Initializr选择项目需要的第三方组件注&#xff1a;可以参考第二次课演示的操作步骤&#xff0c;有详细的拷图…

java药店网站药店系统药店源码刷脸支付源码

简介 首页&#xff0c;搜索商品&#xff0c;详情页&#xff0c;根据不同规格显示不同的商品价格&#xff0c;加入购物车&#xff0c;立即购买&#xff0c;评价列表展示&#xff0c;商品详情展示&#xff0c;商品评分&#xff0c;分类商品&#xff0c;标签查询&#xff0c;更多…

MapReduce 概述原理说明

文章目录MapReduce概述一、MapReduce定义二、MapReduce 优缺点1、MapReduce 优点(1)、MapReduce 易于编程(2)、良好的扩展性(3)、高容错性(4)、适合PB级以上的海量数据的离线处理2、MapReduce 缺点(1)、不擅长实时计算(2)、不擅长流式计算(3)、不擅长DAG(有向图)计算三、MapRed…

二叉树进阶

博主的博客主页&#xff1a;CSND博客 Gitee主页&#xff1a;博主的Gitee 博主的稀土掘金&#xff1a;稀土掘金主页 博主的b站账号&#xff1a;程序员乐 公众号——《小白技术圈》&#xff0c;回复关键字&#xff1a;学习资料。小白学习的电子书籍都在这。 目录根据二叉树创建字…

基于java+springmvc+mybatis+vue+mysql的协同过滤算法的电影推荐系统

项目介绍 基于协同过滤算法的电影推荐系统利用网络沟通、计算机信息存储管理&#xff0c;有着与传统的方式所无法替代的优点。比如计算检索速度特别快、可靠性特别高、存储容量特别大、保密性特别好、可保存时间特别长、成本特别低等。在工作效率上&#xff0c;能够得到极大地…

Hive自定义UDF函数

以下基于hive 3.1.2版本 Hive中自定义UDF函数&#xff0c;有两种实现方式&#xff0c;一是通过继承org.apache.hadoop.hive.ql.exec.UDF类实现&#xff0c;二是通过继承org.apache.hadoop.hive.ql.udf.generic.GenericUDF类实现。 无论是哪种方式&#xff0c;实现步骤都是&…

网上超市系统

开发工具(eclipse/idea/vscode等)&#xff1a; 数据库(sqlite/mysql/sqlserver等)&#xff1a; 功能模块(请用文字描述&#xff0c;至少200字)&#xff1a; 研究内容&#xff1a;设计开发简单购网上超市系统&#xff0c;采用Java语言&#xff0c;使用ySQL数据库&#xff0c; 实…

毕业设计 单片机家用燃气可视化实时监控报警仪 - 物联网 嵌入式 stm32

文章目录0 前言1 简介2 主要器件3 实现效果4 设计原理4.1 硬件部分4.2 软件部分5 部分核心代码6 最后0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两…

SAP ABAP 开发管理 代码内存标记 位置使用清单(Mark of memory id)

SAP ABAP 开发管理 代码内存标记 位置使用清单&#xff08;Mark of memory id&#xff09; 引言&#xff1a; 代码内存标记&#xff08;Mark of memory id&#xff09;是我开发中对 ABAP MEMORY ID 使用管理的一种方法&#xff0c;他能有效保障使用了 ABAP MEMORY ID 程序的可…

25岁从运维转向软件开发是选择Python还是Java

25岁的年龄不大&#xff0c;若是有扎实的基础&#xff0c;后期转转向软件开发是个不错的选择&#xff0c;Python是目前最火的编程语言&#xff0c;python作为人工智能的主要编程语言也有着不错的发展前景。 关于编程语言的选择&#xff0c;如果从就业的角度出发应该重点考虑一…

[附源码]Nodejs计算机毕业设计基于框架的校园爱心公益平台的设计与实现Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置&#xff1a; Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术&#xff1a; Express框架 Node.js Vue 等等组成&#xff0c;B/S模式 Vscode管理前后端分…