基于DTW算法的命令字识别

news2024/11/26 2:36:54

DTW算法介绍

DTW(Dynamic Time Warping):按距离最近原则,构建两个序列之间的对应的关系,评估两个序列的相似性。

要求:

  • 单向对应,不能回头;
  • 一一对应,不能有空;
  • 对应之后,距离最近。

 

DTW代码实现

import numpy as np

def dis_abs(x, y):
    return abs(x - y)[0]

def estimate_twf(A, B, dis_func=dis_abs):
    N_A = len(A)
    N_B = len(B)
    
    D = np.zeros([N_A, N_B])
    D[0, 0] = dis_func(A[0], B[0])
    # 左边一列
    for i in range(1, N_A):
        D[i, 0] = D[i - 1, 0] + dis_func(A[i], B[0])
    # 下边一行
    for j in range(1, N_B):
        D[0, j] = D[0, j-1] + dis_func(A[0], B[j])
    # 中间部分
    for i in range(1, N_A):
        for j in range(1, N_B):
            D[i, j] = dis_func(A[i], B[j]) + min(D[i-1, j], D[i, j-1], D[i-1][j-1])
    
    # 路径回溯
    i = N_A - 1
    j = N_B - 1
    cnt = 0
    d = np.zeros(max(N_A, N_B) * 3)
    path = []
    while True:
        if i > 0 and j > 0:
            path.append((i, j))
            m = min(D[i-1, j], D[i, j-1], D[i-1, j-1])
            if m == D[i-1, j-1]:
                d[cnt] = D[i,j] - D[i-1, j-1]
                i -= 1
                j -= 1
                cnt += 1
            elif m == D[i, j-1]:
                d[cnt] = D[i,j] - D[i, j-1]
                j -= 1
                cnt += 1
            elif m == D[i-1, j]:
                d[cnt] = D[i,j] - D[i-1, j]
                i -= 1
                cnt += 1
        elif i == 0 and j == 0:
            path.append((i, j))
            d[cnt] = D[i, j]
            cnt += 1
            break
        elif i == 0:
            path.append((i, j))
            d[cnt] = D[i, j] - D[i, j-1]
            j -= 1
            cnt += 1
        elif j == 0:
            path.append((i, j))
            d[cnt] = D[i, j] - D[i-1, j]
            i -= 1
            cnt += 1
    mean = np.sum(d) / cnt
    return mean, path[::-1], D
a = np.array([1,3,4,9,8,2,1,5,7,3])
b = np.array([1,6,2,3,0,9,4,1,6,3])
a = a[:, np.newaxis]
b = b[:, np.newaxis]
dis, path, D = estimate_twf(a, b, dis_func=dis_abs)
print(dis, path, D)

>>:
1.0833333333333333
 [(0, 0), (1, 1), (1, 2), (1, 3), (2, 4), (3, 5), (4, 5), (5, 6), (6, 7), (7, 8), (8, 8), (9, 9)] 
[[ 0.  5.  6.  8.  9. 17. 20. 20. 25. 27.]
 [ 2.  3.  4.  4.  7. 13. 14. 16. 19. 19.]
 [ 5.  4.  5.  5.  8. 12. 12. 15. 17. 18.]
 [13.  7. 11. 11. 14.  8. 13. 20. 18. 23.]
 [20.  9. 13. 16. 19.  9. 12. 19. 20. 23.]
 [21. 13.  9. 10. 12. 16. 11. 12. 16. 17.]
 [21. 18. 10. 11. 11. 19. 14. 11. 16. 18.]
 [25. 19. 13. 12. 16. 15. 15. 15. 12. 14.]
 [31. 20. 18. 16. 19. 17. 18. 21. 13. 16.]
 [33. 23. 19. 16. 19. 23. 18. 20. 16. 13.]]

基于DTW算法的命令字识别

utils.py:

# -*- coding:UTF-8 -*-
import streamlit as st
import pyaudio
import wave
import librosa
import soundfile as sf
import numpy as np
import os
import time


# 采用MFCC特征使用mcd距离
def euclideanDistance(a, b):
    diff = a - b
    mcd = 10.0 / np.log(10) * np.sqrt(2.0 * np.sum(diff ** 2))
    return mcd


# DTW算法匹配距离
class DTW:
    def __init__(self, disFunc=euclideanDistance):
        self.disFunc = disFunc

    def compute_distance(self, reference, test):
        DTW_matrix = np.empty([reference.shape[0], test.shape[0]])
        DTW_matrix[:] = np.inf
        DTW_matrix[0, 0] = 0

        for i in range(reference.shape[0]):
            for j in range(test.shape[0]):
                cost = self.disFunc(reference[i, :], test[j, :])
                r_index = i - 1
                c_index = j - 1
                if r_index < 0:
                    r_index = 0
                if c_index < 0:
                    c_index = 0
                DTW_matrix[i, j] = cost + min(DTW_matrix[r_index, j], DTW_matrix[i, c_index],
                                              DTW_matrix[r_index, c_index])
        return DTW_matrix[-1, -1] / (test.shape[0] + reference.shape[0])


# 语音录制
class wordRecorder:
    def __init__(self, samplingFrequency=8000, threshold=20):
        self.samplingFrequency = samplingFrequency
        self.threshold = threshold

    def record(self):
        p = pyaudio.PyAudio()
        stream = p.open(format=pyaudio.paInt16, channels=1, rate=self.samplingFrequency, input=True, output=False,
                        frames_per_buffer=1024)
        frames = []
        for i in range(int(self.samplingFrequency * 4 / 1024)):
            data = stream.read(1024)
            frames.append(data)
        stream.stop_stream()
        stream.close()
        p.terminate()
        return frames

    def record2File(self, path):
        frames = self.record()
        p = pyaudio.PyAudio()
        with wave.open(path, 'wb') as wf:
            wf.setnchannels(1)
            wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
            wf.setframerate(self.samplingFrequency)
            wf.writeframes(b''.join(frames))
            print('record finished!')


# 提取mfcc特征
def getmfcc(audio, isfile=True):
    if isfile:
        # 读取音频文件
        y, fs = librosa.load(audio, sr=8000)
    else:
        # 音频数据,需要去除静音
        y = np.array(audio)

    intervals = librosa.effects.split(y, top_db=20)
    y = librosa.effects.remix(y, intervals)
    # 预加重
    y = librosa.effects.preemphasis(y)

    fs = 8000
    N_fft = 256
    win_length = 256
    hop_length = 128
    n_mels = 23
    n_mfcc = 14
    # mfcc提取
    mfcc = librosa.feature.mfcc(y=y, sr=fs, n_mfcc=n_mfcc, n_mels=n_mels, n_fft=N_fft, win_length=win_length,
                                hop_length=hop_length)
    mfcc = mfcc[1:, :]
    # 添加差分量
    mfcc_deta = librosa.feature.delta(mfcc)
    mfcc_deta2 = librosa.feature.delta(mfcc, order=2)
    # 特征拼接
    mfcc_d1_d2 = np.concatenate([mfcc, mfcc_deta, mfcc_deta2], axis=0)
    return mfcc_d1_d2.T


# 指定文件夹下文件个数
def check_file(name):
    os.makedirs('data', exist_ok=True)
    save_dir = os.path.join('data', name)
    os.makedirs(save_dir, exist_ok=True)

    n_files = 0
    for roots, dirs, files in os.walk(save_dir):
        for file in files:
            if file.endswith('.wav'):
                n_files += 1
    return n_files


@st.cache_resource  # 防止重载
def model_load():
    model1 = ModelHotWord(os.path.join('data', '向上'))
    model2 = ModelHotWord(os.path.join('data', '向下'))
    model3 = ModelHotWord(os.path.join('data', '向左'))
    model4 = ModelHotWord(os.path.join('data', '向右'))
    models = [model1, model2, model3, model4]
    return models

class ModelHotWord(object):
    def __init__(self, path):
        self.mfccs = get_train_mfcc_list(path)

    def get_score(self, ref_mfcc):
        return get_score(ref_mfcc, self.mfccs)


def get_train_mfcc_list(data_path):
    mfccs = []
    for roots, dirs, files in os.walk(data_path):
        for file in files:
            if file.endswith('wav'):
                file_audio = os.path.join(data_path, file)
                mfcc = getmfcc(file_audio)
                mfccs.append(mfcc)
    return mfccs


def get_score(ref_mfcc, list_mfccs):
    m_dtw = DTW()
    N = len(list_mfccs)
    scores = 0
    for i in range(N):
        dis = m_dtw.compute_distance(ref_mfcc, list_mfccs[i])
        scores = scores + dis
    return scores / N

DTW.py:

# -*- coding:UTF-8 -*-
from utils import *


st.title('基于DTW算法的命令字识别')
tab1, tab2 = st.tabs(['音频录制', '识别演示'])

with tab1:
    list_labs = ['向上', '向下', '向左', '向右']
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        name = st.selectbox('模型选择', list_labs)
    with col2:
        st.write('命令字录制')
        flag_record = st.button(label='录音')
    with col3:
        st.write('命令字重录')
        flag_cancel = st.button(label='撤销')
    with col4:
        st.write('试听')
        flag_show_audios = st.button(label='试听')
        
    info_file_number = st.empty()
    info_file_number.write('命令字---%s--已有%d个样本'%(name, check_file(name)))
     
    
info_audios = st.empty()
info_success = st.empty()
if flag_record:
    info_audios.info('')
    info_success.success('')   
    n_files = check_file(name)
    info_audios.info('开始录制---第%d个命令字---%s--请在2s内完成录制.....'%(n_files + 1, name))
    save_dir = os.path.join('data', name)
    audio_name = os.path.join(save_dir, '%d.wav'%(n_files + 1))
    wRec = wordRecorder()
    wRec.record2File(audio_name)
    info_success.success('录制完成,保存为' + audio_name)

if flag_cancel:
    n_files = check_file(name)
    save_dir = os.path.join('data', name)
    file_del = os.path.join(save_dir, str(n_files)+'.wav')
    os.remove(file_del)
    info_file_number.write('命令字--%s--已有%d个样本'%(name, check_file(name)))
    
if flag_show_audios:
    n_files = check_file(name)
    save_dir = os.path.join('data', name)
    if n_files > 0:
        for i in range(n_files):
            audio_file = open(os.path.join(save_dir, '%d.wav'%(i+1)), 'rb')
            audio_bytes = audio_file.read()
            st.audio(audio_bytes, format='audio/')


with tab2:
    th = 125
    st.write('识别演示')
    if 'run' not in st.session_state:
        st.session_state['run'] = False
    def start_listening():
        st.session_state['run'] = True
    def stop_listening():
        st.session_state['run'] = False

    col1, col2 = st.columns(2)
    with col1:
        st.button('开始检测', on_click=start_listening)
    with col2:
        st.button('停止检测', on_click=stop_listening)

    det_word = st.empty()
    def init_up():
        det_word.write('向上')
    def init_down():
        det_word.write('向下')
    def init_left():
        det_word.write('向左')
    def init_right():
        det_word.write('向右')
    callbacks = [init_up, init_down, init_left, init_right]

    # 加载预测模型,提取好的一些mfcc特征
    models = model_load()
    dic_labs = {'0': '向上', '1': '向下', '2': '向左', '3': '向右', '-1': ''}

    while st.session_state['run']:  # 循环进行检测
        wRec = wordRecorder()
        wRec.record2File('data/test.wav')
        ref_mfcc = getmfcc('data/test.wav', True)
        # 在每个模型上进行打分,扎到最小分数作为检测结果
        scores = [model.get_score(ref_mfcc) for model in models]
        i_word = np.argmin(scores)
        score = np.min(scores)
        print(i_word, score)
        if score < th:
            i_det_word = i_word
            callback = callbacks[i_det_word]
            if callback is not None:
                callback()
            print('---------det word---------', dic_labs[str(i_det_word)])
        else:
            continue 

python命令行运行streamlit run DTW.py即会出现web网页ui,结果如下图所示:

参考DTW关键字检测-代码实现_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1048365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【图文】IRRA:跨模态隐式关系推理与对齐 | CVPR2023

详细内容指路zhihu&#x1f449;CVPR2023 | IRRA论文阅读 摘要 Text-to-image Person Retrieval的目的是根据给定的文本描述查询确定目标个体。主要的挑战是学习把视觉和文本模态映射到一个公共的潜在空间里。之前的工作尝试通过利用单模态分开预训练来提取图像和文本特征来解…

TOWE工业级多口大功率USB插座,助力多设备同时供电

同为科技&#xff08;TOWE&#xff09;工业级多口大功率USB桌面PDU插座 随着科技的不断进步&#xff0c;人们对电子设备的需求也越来越多样化。在如今的快节奏生活中&#xff0c;我们常常需要同时给多个设备充电&#xff0c;而传统的插座往往无法满足这一需求。为解决这一问题…

JavaSE(三)

3.1 异常 Java 异常类层次结构图概览&#xff1a; 1.Exception 和 Error 有什么区别&#xff1f; 在 Java 中&#xff0c;所有的异常都有一个共同的祖先 java.lang 包中的 Throwable 类。Throwable 类有两个重要的子类&#xff0c;分别是 Exception 和 Error&#xff1a; Ex…

嵌入式中如何用C语言操作sqlite3(07)

sqlite3编程接口非常多&#xff0c;对于初学者来说&#xff0c;我们暂时只需要掌握常用的几个函数&#xff0c;其他函数自然就知道如何使用了。 数据库 本篇假设数据库为my.db,有数据表student。 nonamescore4嵌入式开发爱好者89.0 创建表格语句如下&#xff1a; CREATE T…

数据结构算法--8基数排序

> 多关键字排序&#xff1a;现在有一个员工表&#xff0c;要求按照薪资排序&#xff0c;薪资相同的员工按照年龄排序 >> 先按照年龄排序&#xff0c;再按照薪资进行稳定的排序 > 例如&#xff1a;32&#xff0c;13&#xff0c;94&#xff0c;52&#xff0c;17&am…

C++中指针的概念和声明

C中指针的概念和声明 学习 C 的指针既简单又有趣。通过指针&#xff0c;可以简化一些 C 编程任务的执行&#xff0c;还有一些任务&#xff0c;如动态内存分配&#xff0c;没有指针是无法执行的。所以&#xff0c;想要成为一名优秀的 C 程序员&#xff0c;学习指针是很有必要的…

Unity实现设计模式——责任链模式

Unity实现设计模式——责任链模式 责任链模式定义&#xff1a;将请求的发送和接收解耦&#xff0c;让多个接收对象都有机会处理这个请求。将这些接收对象串成一条链&#xff0c;并沿着这条链传递这个请求&#xff0c;直到链上的某个接收对象能够处理它为止。 在职责链模式中&…

CIP或者EtherNET/IP中的PATH是什么含义?

目录 SegmentPATH举例 最近在学习EtherNET/IP&#xff0c;PATH不太明白&#xff0c;翻了翻规范&#xff0c;在这里记个笔记。下面的叙述可能是中英混合&#xff0c;有一些是规范中的原文我直接搬过来的。我翻译的不准确。 Segment PATH是CIP Segment中的一个分类。要了解PATH…

dataGrip导出导入的方式

导出&#xff1a;选中需要导出的表 导入&#xff1a;选中导出的sql文件

运动控制:为什么高精度的测量都是用大理石平台

一、大理石的应用场景 在一些应用直线电机的场景&#xff0c;以及一些量测性仪器仪表上面&#xff0c;我们都能看到大理石的身影&#xff0c;毫无疑问&#xff0c;只要是精度要求高的地方&#xff0c;就少不了大理石&#xff0c;这和大理石的自身特性是分不开的。 二、天然大理…

【JVM】并发可达性分析-三色标记算法

欢迎访问&#x1f44b;zjyun.cc 可达性分析 为了验证堆中的对象是否为可回收对象&#xff08;Garbage&#xff09;标记上的对象&#xff0c;即是存活的对象&#xff0c;不会被垃圾回收器回收&#xff0c;没有标记的对象会被垃圾回收器回收&#xff0c;在标记的过程中需要stop…

项目集成七牛云存储sdk

以PHP为例 第一步&#xff1a;下载sdk PHP SDK_SDK 下载_对象存储 - 七牛开发者中心 sdk下载成功之后&#xff0c;将sdk放入项目中&#xff0c;目录选择以自己项目实际情况而定。 注意&#xff1a;在examples目录中有各种上传文件的参考示例&#xff0c;这里我们主要参考的是…

Vue 实现表单的增删改查功能及表单的验证

前言&#xff1a; 上一篇我们已经将前端表单的数据和后端的数据交互了&#xff0c;今天我们就继续开发功能来实现表单的增删改查功能及表单的验证 一&#xff0c;表单的增删改查功能 新增 去官网找模版&#xff1a; 1.1添加新增按钮&#xff1a; 1.2添加新增弹窗点击事件&am…

HC32 IIC/I2C读写

IIC状态码 IIC 初始化 void iicInit(uint32_t speed) {stc_gpio_cfg_t stcGpioCfg;DDL_ZERO_STRUCT(stcGpioCfg);Sysctrl_SetPeripheralGate(SysctrlPeripheralGpio, TRUE); //开启GPIO时钟门控stcGpioCfg.enDir GpioDirOut; ///< 端口方向配置…

Kubernetes 上的数据已跨越鸿沟:在 GKE 上运行有状态应用程序的案例

Kubernetes 是当今云原生开发的事实上的标准。长期以来&#xff0c;Kubernetes 主要与无状态应用程序相关&#xff0c;例如 Web 和批处理应用程序。然而&#xff0c;与大多数事物一样&#xff0c;Kubernetes 也在不断发展。如今&#xff0c;我们看到 Kubernetes 上有状态应用程…

MySQL学习笔记19

MySQL日志文件&#xff1a;MySQL中我们需要了解哪些日志&#xff1f; 常见日志文件&#xff1a; 我们需要掌握错误日志、二进制日志、中继日志、慢查询日志。 错误日志&#xff1a; 作用&#xff1a;存放数据库的启动、停止和运行时的错误信息。 场景&#xff1a;用于数据库的…

ubuntu apt工具软件操作

apt工具 -----> 网关 国内网络(仓库源) 美国网络(仓库源)/etc/apt/sources.list https://mirrors.tuna.tsinghua.edu.cn/help/ubuntu/sudo apt-get update sudo apt install sl 安装包 sudo apt-cache show sl 查看包信…

Jquery 复选框全选和反选失灵的问题

页面上有这么一张表格&#xff0c;点击All时将列表中的复选框全部勾选&#xff0c;反之亦然。 表头&#xff1a; <th><input type"checkbox" id"chkAll" onclick"CheckAll(this)" />All </th> 表格数据源绑定&#xff1a; …

TouchGFX界面开发 | 项目代码结构分析

项目代码结构分析 本文介绍TouchGFX项目中TouchGFX Designer自动生成的代码&#xff0c;以及需要用户编写的扩展代码。 一、生成的代码和用户代码 TouchGFX Designer生成的代码将与用户编写的代码完全分离。 事实上&#xff0c;自动生成的代码位于generated/gui_generated文…

【DTEmpower案例操作教程】智能模型预警

DTEmpower是由天洑软件自主研发的一款通用的智能数据建模软件&#xff0c;致力于帮助工程师及工科专业学生&#xff0c;利用工业领域中的仿真、试验、测量等各类数据进行挖掘分析&#xff0c;建立高质量的数据模型&#xff0c;实现快速设计评估、实时仿真预测、系统参数预警、设…