HW2: LibriSpeech phoneme classification

news2024/12/27 13:25:39

任务描述

在这里插入图片描述

音位分类预测(Phoneme classification),通过语音数据,预测音位。音位(phoneme),是人类某一种语言中能够区别意义的最小语音单位,是音位学分析的基础概念。每种语言都有一套自己的音位系统。

一帧(frame)设定为长25ms的音段,每次滑动10ms截得一个frame。每个frame经过MFCC处理,变成长度为39的向量。对于每个frame向量,数据集都提供了标签。标签有41类, 每个类代表一个phoneme

在这里插入图片描述

通常一个音位会跨越好几帧,所以训练时会结合前n帧和后n帧来对当前这一帧进行判断。
在这里插入图片描述

助教给出的代码:https://colab.research.google.com/drive/1wzeiVy2g7HpSjlidUr0Gi50NnHBWTkvN#scrollTo=KVUGfWTo7_Oj
数据(kaggle):https://www.kaggle.com/competitions/ml2023spring-hw2/data

数据说明

  • train_split.txt: 其中每一行对应一个训练数据,其所对应的文件在feat/train/中

在这里插入图片描述

  • train_labels.txt: 由训练数据和labels组成,格式为: filename labels。其中,label 为 frame 对应的 phoneme

在这里插入图片描述

  • test_split.txt: 其中每一行对应一个要求预测的数据,其所对应的文件在feat/test/中
  • feat/train/{id}.ptfeat/test/{id}.pt: 音频对应的 MFCC 文件,维度为39,这些文件可以通过torch.load()直接导入,导入后的shape为(T, 39)。

代码详解

导包

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import random
import gc

定义数据集

class LibriDataset(Dataset):

    def __init__(self, X, y=None):
        self.data = X
        if y is not None:
            self.label = torch.LongTensor(y)
        else:
            self.label = None

    def __getitem__(self, idx):
        if self.label is not None:
            return self.data[idx], self.label[idx]
        else:
            return self.data[idx]

    def __len__(self):
        return len(self.data)

torch.LongTensor将y进行转换成Long类型的

定义模型

class BasicBlock(nn.Module):

    def __init__(self, input_dim, output_dim):
        super().__init__()

        # TODO: 应用 batch normalization 和 dropout
        self.block = nn.Sequential(nn.Linear(input_dim, output_dim), nn.ReLU(),
                                   nn.BatchNorm1d(output_dim), nn.Dropout(0.3))

    def forward(self, x):
        x = self.block(x)
        return x


class Classifier(nn.Module):

    def __init__(self,
                 input_dim,
                 output_dim=41,
                 hidden_layers=1,
                 hidden_dim=256):
        super().__init__()

        # *用于解包列表
        self.fc = nn.Sequential(
            BasicBlock(input_dim, hidden_dim), *[
                BasicBlock(hidden_dim, hidden_dim)
                for _ in range(hidden_layers)
            ], nn.Linear(hidden_dim, output_dim))

    def forward(self, x):
        x = self.fc(x)
        return x

其中,BasicBlock为自定义的一个基本单元,便于Classsifier应用。
Classifier的Sequential中*[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)]利用列表生成器快速生成多个隐层,但Sequential的参数不是列表,所以用*进行解包。

一些工具函数

def load_feat(path):
    feat = torch.load(path)
    return feat


def shift(x, n):
    if n < 0:
        left = x[0].repeat(-n, 1)
        right = x[:n]
    elif n > 0:
        right = x[-1].repeat(n, 1)
        left = x[n:]
    else:
        return x

    return torch.cat((left, right), dim=0)


def concat_feat(x, concat_n):
    '''
    concat_n: 连接帧数
    '''
    assert concat_n % 2 == 1  # n必须为奇数
    if concat_n < 2:
        return x
    seq_len, feature_dim = x.size(0), x.size(1)
    x = x.repeat(1, concat_n)
    x = x.view(seq_len, concat_n,
               feature_dim).permute(1, 0, 2)  # concat_n, seq_len, feature_dim
    mid = (concat_n // 2)
    for r_idx in range(1, mid + 1):
        x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)
        x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)

    return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)


def preprocess_data(split,
                    feat_dir,
                    phone_path,
                    concat_nframes,
                    train_ratio=0.8,
                    random_seed=1213):
    '''
    split:用于区分训练集、验证集、预测集
    concat_nframes: 连接帧数
    '''
    class_num = 41  # NOTE: pre-computed, should not need change

    if split == 'train' or split == 'val':
        mode = 'train'
    elif split == 'test':
        mode = 'test'
    else:
        raise ValueError(
            'Invalid \'split\' argument for dataset: PhoneDataset!')

    label_dict = {}
    if mode == 'train':
        for line in open(os.path.join(phone_path,
                                      f'train_labels.txt')).readlines():
            line = line.strip('\n').split(' ')
            label_dict[line[0]] = [int(p) for p in line[1:]]

        # 划分训练集和验证集
        usage_list = open(os.path.join(phone_path,
                                       'train_split.txt')).readlines()
        random.seed(random_seed)  # 设置种子
        random.shuffle(usage_list)  # 打乱
        train_len = int(len(usage_list) * train_ratio)  #训练集大小
        # 如果为训练集就分割前面的,反之则为验证集,分割后面的
        usage_list = usage_list[:train_len] if split == 'train' else usage_list[
            train_len:]

    elif mode == 'test':
        usage_list = open(os.path.join(phone_path,
                                       'test_split.txt')).readlines()

    usage_list = [line.strip('\n') for line in usage_list]
    print('[Dataset] - # phone classes: ' + str(class_num) +
          ', number of utterances for ' + split + ': ' + str(len(usage_list)))

    max_len = 3000000
    X = torch.empty(max_len, 39 * concat_nframes)
    if mode == 'train':
        y = torch.empty(max_len, dtype=torch.long)

    idx = 0
    for i, fname in tqdm(enumerate(usage_list)):
        feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))
        cur_len = len(feat)
        feat = concat_feat(feat, concat_nframes)
        if mode == 'train':
            label = torch.LongTensor(label_dict[fname])

        X[idx:idx + cur_len, :] = feat
        if mode == 'train':
            y[idx:idx + cur_len] = label

        idx += cur_len

    X = X[:idx, :]
    if mode == 'train':
        y = y[:idx]

    print(f'[INFO] {split} set')
    print(X.shape)
    if mode == 'train':
        print(y.shape)
        return X, y
    else:
        return X

超参数

# 数据参数
concat_nframes = 17  # 连接帧数n必须为奇数(总共2k+1=n帧)
train_ratio = 0.9  # 用于训练的数据比例,其余将用于验证

# 训练参数
seed = 1213  # 随机数种子
batch_size = 512  # 分组大小
num_epoch = 50  # 训练轮数
learning_rate = 1e-3  # 学习率
model_path = './model.ckpt'  # 模型保存的路径

# 模型参数
input_dim = 39 * concat_nframes  # 模型输入维度
hidden_layers = 15  # 隐层层数
hidden_dim = 2048  # 隐层维度

读取数据

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE: {device}')

# 预处理数据
train_X, train_y = preprocess_data(split='train',
                                   feat_dir='./feat',
                                   phone_path='./',
                                   concat_nframes=concat_nframes,
                                   train_ratio=train_ratio,
                                   random_seed=seed)
val_X, val_y = preprocess_data(split='val',
                               feat_dir='./feat',
                               phone_path='./',
                               concat_nframes=concat_nframes,
                               train_ratio=train_ratio,
                               random_seed=seed)

# 获取数据集
train_set = LibriDataset(train_X, train_y)
val_set = LibriDataset(val_X, val_y)

# 移除原始数据以节省内存
del train_X, train_y, val_X, val_y
gc.collect()

gc.collect()用于垃圾回收:对已经销毁的对象(这就是前一行del 的原因),Python不会自动释放其占据的内存空间。为了能够充分地利用分配的内存,避免程序跑到一半停止,要时不时地进行内存回收。
垃圾回收开始的时候当前所有线程都将被挂起,开始收集托管堆上的垃圾,收集完了还要压缩内存,然后等待垃圾回收结束以后再恢复这些线程,从这个角度来说,还是少调用垃圾回收,但是不是不能调,要视情况而定。

创建模型

# 获取数据加载器
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

# 创建模型,定义损失函数和优化器
model = Classifier(input_dim=input_dim,
                   hidden_layers=hidden_layers,
                   hidden_dim=hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

训练

best_acc = 0.0  # 最高准确率
for epoch in range(num_epoch):
    train_acc = 0.0
    train_loss = 0.0
    val_acc = 0.0
    val_loss = 0.0

    # 训练
    model.train()  # 将模型设置为训练模式
    for i, batch in enumerate(tqdm(train_loader)):
        features, labels = batch
        features = features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()  # 梯度清零
        outputs = model(features)  # 获取模型输出

        loss = criterion(outputs, labels)  # 计算偏差
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        _, train_pred = torch.max(outputs, 1)  # 获取具有最高概率的类别索引
        train_acc += (train_pred.detach() == labels.detach()).sum().item()
        train_loss += loss.item()

    # 验证
    model.eval()  # 将模型设置为评估模式
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_loader)):
            features, labels = batch
            features = features.to(device)
            labels = labels.to(device)
            outputs = model(features)

            loss = criterion(outputs, labels)

            _, val_pred = torch.max(outputs, 1)
            val_acc += (
                val_pred.cpu() == labels.cpu()).sum().item()  # 获取具有最高概率的类别索引
            val_loss += loss.item()

    print(
        f'[{epoch+1:03d}/{num_epoch:03d}] Train Acc: {train_acc/len(train_set):3.5f} Loss: {train_loss/len(train_loader):3.5f} | Val Acc: {val_acc/len(val_set):3.5f} loss: {val_loss/len(val_loader):3.5f}'
    )

    # 如果模型有改进,在此时保存一个检查点
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), model_path)
        print(f'saving model with acc {best_acc/len(val_set):.5f}')

del train_set, val_set
del train_loader, val_loader
gc.collect()

预测

# 加载预测数据
test_X = preprocess_data(split='test',
                         feat_dir='./feat',
                         phone_path='./',
                         concat_nframes=concat_nframes)
test_set = LibriDataset(test_X, None)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# 加载模型
model = Classifier(input_dim=input_dim,
                   hidden_layers=hidden_layers,
                   hidden_dim=hidden_dim).to(device)
model.load_state_dict(torch.load(model_path))

# 存储预测结果
pred = np.array([], dtype=np.int32)

# 预测
model.eval()
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        features = batch
        features = features.to(device)

        outputs = model(features)

        _, test_pred = torch.max(outputs, 1)  # 获取具有最高概率的类别索引
        pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)

# 将预测结果写入CSV文件
with open('prediction.csv', 'w') as f:
    f.write('Id,Class\n')
    for i, y in enumerate(pred):
        f.write('{},{}\n'.format(i, y))

运行结果

在这里插入图片描述

参数量上来之后自己的电脑很难跑得动了,基本就是colab+Kaggle跑

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/820675.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【CesiumJS材质】(2)图片横向移动

效果示例 要素说明&#xff1a; 代码 /** Date: 2023-07-19 11:15:22* LastEditors: ReBeX 420659880qq.com* LastEditTime: 2023-07-28 12:08:58* FilePath: \cesium-tyro-blog\src\utils\Material\FlowPictureMaterialProperty.js* Description: 流动纹理/图片材质*/ imp…

PySpark 数据操作(综合案例)

搜索引擎日志分析 要求&#xff1a; 读取文件转换成RDD&#xff0c;并完成&#xff1a; 打印输出&#xff1a;热门搜索时间段&#xff08;小时精度&#xff09;Top3打印输出&#xff1a;热门搜索词Top3打印输出&#xff1a;统计黑马程序员关键字在哪个时段被搜索最多将数据转…

02|Oracle学习(数据类型、DDL)

1. 数据类型&#xff1a; 通常为&#xff1a;字符型、数值型、日期型以及大字段型大字段型&#xff1a;存放大数据及文件。 存储大数据时&#xff0c;基本上blob就能满足。 2. DDL&#xff08;数据库定义语言&#xff09; 主要包括对数据库对象的创建、删除及修改的操作。…

16. Spring Boot 统一功能处理

目录 1. 用户登录权限校验 1.1 最初用户登录验证 1.2 Spring AOP 用户统一登陆验证 1.3 Spring 拦截器 1.3.1 创建自定义拦截器 1.3.2 将自定义拦截器加入系统配置 1.4 练习&#xff1a;登录拦截器 1.5 拦截器实现原理 1.6 统一访问前缀添加 2. 统一异常处理 3. 统…

Redis篇

文章目录 Redis-使用场景1、缓存穿透2、缓存击穿3、缓存雪崩4、双写一致5、Redis持久化6、数据过期策略7、数据淘汰策略 Redis-分布式锁1、redis分布式锁&#xff0c;是如何实现的&#xff1f;2、redisson实现的分布式锁执行流程3、redisson实现的分布式锁-可重入4、redisson实…

AMEYA详解松下Panasonic HF SSOP 1 Form A AQY PhotoMOS继电器

Panasonic HF SSOP 1 Form A AQY PhotoMOS继电器采用微型SSOP封装&#xff0c;具有600V的负载电压和1500Vrms 的I/O隔离电压 这些继电器具有8Ω的低导通电阻和高速运行的特点&#xff0c;SSOP封装旨在实现高密度安装。Panasonic HF SSOP AQY PhotoMOS继电器适用于从测试和测量设…

【python】冒泡法--详细讲解(python实现)

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

简单工厂模式(Simple Factory)

简单工厂模式&#xff0c;又称为静态工厂方法(Static Factory Method)模式。在简单工厂模式中&#xff0c;可以根据参数的不同返回不同类的实例。简单工厂模式专门定义一个类来负责创建其他类的实例&#xff0c;被创建的实例通常都具有共同的父类。简单工厂模式不属于GoF的23个…

iPhone 7透明屏的显示效果怎么样?

iPhone 7是苹果公司于2016年推出的一款智能手机&#xff0c;它采用了4.7英寸的Retina HD显示屏&#xff0c;分辨率为1334x750像素。 虽然iPhone 7的屏幕并不是透明的&#xff0c;但是苹果公司在设计上采用了一些技术&#xff0c;使得用户在使用iPhone 7时可以有一种透明的感觉…

虚拟个家用服务器集群(3):更换 PVE 软件源

风无痕 July 31,2023 前言 很多人想建个人博客类的网站&#xff0c;这就需要网站服务器&#xff1b;需要管理手机、电脑中积累的照片&#xff0c;每张照片可都是人生一个片段的记录&#xff0c;需要管理微信中收发的各种文档等等&#xff0c;这就需要一台 NAS 即 Network Att…

教师工作量管理系统Springmvc+Spring+Mybatis课程工作量教室java源代码mysql

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 教师工作量管理系统SpringmvcSpringMybatis 系统有1权…

800V电驱动产品和技术汇总

文章来源&#xff1a; 赵老师——国汽战略院 汽车电动化研究中心 副主任研究员 需要样件请联&#xff1a;shbinzer 拆车邦 德国采埃孚 采埃孚于2022年量产800V电驱系统&#xff0c;采埃孚电驱传动技术事业部亚太区研发副总裁王岳在《采埃孚新一代超紧凑电驱动系统》报告中展…

【入门SpringCloud(一)】什么是SpringCloud?

一、概述 集群&#xff08;Cluster&#xff09;&#xff1a;同一种软件服务的多个服务节点共同为系统提供服务过程&#xff0c;称之为该软件服务集群。 分布式&#xff08;Distribute&#xff09;&#xff1a;分布式是一种系统架构&#xff0c;是将系统中的不同组件分布在不同…

计算机网络期末复习简答题、综合题、实验题答案整理汇总详细(持续更新中)

文章目录 简答题一、第一章&#xff1a;计算机网络概述1. TCP/IP 与 OSI 相结合的五层体系结构将计算机网络划分成哪几个层次&#xff1f;各层的主要功能是什么 二、第二章&#xff1a;物理层1. 交换机、路由器、网卡、网桥、集线器、中继器分别工作在哪一层2. 简述交换机、集线…

10.类型声明文件

类型声明文件的作用是 为已存在的JS库提供类型信息 目录 1 axios中的类型声明文件 2 类型声明文件与普通ts文件的区别 3 vscode中内置的类型声明文件 4 第三方库内置的类型声明文件 5 DefinitelyTyped 提供类型声明文件 6 自定义类型声明文件 6.1 创建给ts用的类…

同为科技(TOWE)带热插拔功能机柜PDU插座的应用

所谓热插拔&#xff08;hot-plugging或Hot Swap&#xff09;&#xff0c;即带电插拔&#xff0c;指的是在不关闭系统电源的情况下&#xff0c;将模块、板卡插入或拔出系统而不影响系统的正常工作&#xff0c;从而提高了系统的可靠性、快速维修性、冗余性和对灾难的及时恢复能力…

JMeter 的使用

文章目录 1. JMeter下载2. JMeter的使用2.1 JMeter中文设置2.2 JMeter的使用2.2.1 创建线程组2.2.2 HTTP请求2.2.3 监听器 1. JMeter下载 官网地址 https://jmeter.apache.org/download_jmeter.cgi https://dlcdn.apache.org//jmeter/binaries/apache-jmeter-5.6.2.zip 下载解…

Vue2 第十二节 Vue组件化编程 (二)

1. VueComponent 2. 单文件组件 一. VueComponent 组件本质上是一个名为VueComponent的构造函数&#xff0c;不是程序员定义的&#xff0c;是Vue.extend生成的只需要写<school/>或者<school><school/>&#xff0c;Vue解析时&#xff0c;会帮我们创建schoo…

ThinkPHP 6 添加跳转提示扩展 liliuwei/thinkphp-jump

liliuwei/thinkphp-jump 是 TP5 中经典跳转提示&#xff0c;在 TP6 中已经取消&#xff0c;通过 composer 下载该扩展可以在 TP6 中使用 TP5 的跳转提示操作。 安装扩展 在应用根目录执行: composer require liliuwei/thinkphp-jump引入扩展 在全局配置目录生成 jump.php 文件…

Activity的生存期

以下内容摘自郭霖《第一行代码》第三版 Activity的生存期 Activity类中定义了7个回调方法&#xff0c;覆盖了Activity生命周期的每一个环节&#xff1a; onCreate()。这个方法你已经看到过很多次了&#xff0c;我们在每个Activity中都重写了这个方法&#xff0c;它会在Activit…