大模型之二十八-语音识别Whisper进阶

news2024/9/20 1:12:43

在上一篇博客大模型之二十七-语音识别Whisper实例浅析中遗留了几个问题,这里来看一下前两个问题。
1.如果不是Huggingface上可以下载的数据该怎么办?
2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?

进阶内容

在Whisper语音识别fine-tune的例子中,我们使用的是Huggingface封装好的数据加载以及Transformer工具,这将很多底层细节对开发人员屏蔽了,但是对于技术人员而言,这还远远不够,本篇通过一个要解决两个问题:
1.数据集是私有的,并不是Huggingface开源的数据集
2.不使用Huggingface封装好的Training pipeline,在Whisper开源的源代码基础之上fine-tune模型,并验证准确性。

整个框架代码使用pytorch-lightning来实现,目前很多优秀的比较大的开源都是实用pytorch-lightning来实现的。

安装一些python库

首先下载Whisper源代码,并且

! pip install git+https://github.com/openai/whisper.git
! pip install jiwer 
! pip install pytorch-lightning==2.4.0
! pip install -qqq evaluate==0.2.2

导入必要的python包

import os
import glob
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from tqdm.notebook import tqdm
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

遗留的第一个问题–数据集

这里的数据集基于清华大学开源的30小时中文照着文本读而录的音频,原下载地址,
为了减小资源的开销,在有限的资源下,多迭代epoch,这里对数据集做了处理:

  • 将数据集缩到了10个小时/30小时,
  • 去掉了txt里音素的标注,只留文本,因为在这个数据集开源的时候,那时语音识别系统还是基于音素的。

可以关注私信我,联系索取处理之后的语料。

数据集处理

import glob

DATASET_DIR = "/kaggle/input/th30-all"
SAMPLE_RATE = 16000
BATCH_SIZE = 4
TRAIN_RATE = 0.85

#whipser的输入是30s,16kHz采样率,最长480000 sample
AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120

DEVICE = "gpu" if torch.cuda.is_available() else "cpu"

###################### 读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))
13388

读取数据信息并分离出train和val

dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))

def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

def get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):
    audio_transcript_pair_list = []
    for transcripts_path in tqdm(transcripts_path_list):
        # audio文件目录确认
        audio_dir = os.path.dirname(transcripts_path)

        # 从翻译文本获取音频和文本
        with open(transcripts_path, "r") as f:
            text_list = f.readlines()
        for text in text_list:
            audio_id, text = text.replace("\n", "").split(":")
            #print(audio_id, text)

            audio_path = os.path.join(audio_dir, f"{audio_id}.wav")
            if os.path.exists(audio_path):
                # 检查数据
                audio = load_wave(audio_path, sample_rate=sample_rate)[0]
                if len(text) > text_max_length or len(audio) > audio_max_sample_length:
                    print(len(text), len(audio))
                    continue
                audio_transcript_pair_list.append((audio_id, str(audio_path), text))
    return audio_transcript_pair_list

train_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))
13388
  0%|          | 0/11379 [00:00<?, ?it/s]  
  0%|          | 0/2009 [00:00<?, ?it/s]
TRAIN AUDIO DATASET NUM:  11379
EVAL AUDIO DATASET NUM:  2009

Data loader

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
wmodel = whisper.load_model(name="small",download_root="./whisper-small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=woptions.task)

class Th30Dataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:
        super().__init__()

        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.audio_info_list)

    def __getitem__(self, index):
        audio_id, audio_path, text = self.audio_info_list[index]

        #aduio mono
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten(), AUDIO_MAX_LENGTH)
        mel = whisper.log_mel_spectrogram(audio)

        #text
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text[1:] + [self.tokenizer.eot]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text
        }
class WhisperDataCollatorWhithPadding:
    def __call__(self, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])


        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths + dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

这是典型的Pytorch而不是前篇中Huggingface的数据加载方法,需要实现datasetDataLoader,详细参考Pytorch Lightning官方文档。至此,遗留的第一个问题解决。

验证数据集加载

DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
for b in loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec)
        print(text)
    break
torch.Size([2, 50])
torch.Size([2, 80, 3000])
torch.Size([2, 50])
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着

验证解码器

with torch.no_grad():
    audio_features = wmodel.encoder(b["input_ids"].cuda())
    input_ids = b["input_ids"]
    labels = b["labels"].long()
    dec_input_ids = b["dec_input_ids"].long()

        
    audio_features = wmodel.encoder(input_ids.cuda())
    print(dec_input_ids)
    print(input_ids.shape, dec_input_ids.shape, audio_features.shape)
    print(audio_features.shape)
    print()

# 计算解码器的输出
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)

print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)
tensor([[50258, 50260, 50359, 50363, 45161, 11386, 47446,  5708,  5266,   104,
          5823, 35825, 20708, 17682,  3023,   222,  5975,  1787,   106, 44365,
          3919,    95, 31906,  1593,   247, 11160, 31382,  1881,   237, 31382,
         39861,  5155, 39152,  5648,   247, 10928,  3330,   123, 18937,  2289,
          5881,   249,  5419,   122, 50257, 50257, 50257, 50257, 50257, 50257],
        [50258, 50260, 50359, 50363, 12744, 25281, 22694,  6734, 42503, 12088,
          3308,   111, 36257,  4479,   223,  4035, 32045, 41111,   236,  3308,
           116,  7391,   102,  4479,   245, 11808,   246,  3416,   105, 33597,
         45506, 37960,  8501,   244, 45506, 37960,  3316,   238, 45506, 17819,
            99, 15789,  7732,   100, 44059, 10928, 48839, 16337,   234, 20708]])
torch.Size([2, 80, 3000]) torch.Size([2, 50]) torch.Size([2, 1500, 768])
torch.Size([2, 1500, 768])

torch.Size([2, 50, 51865])
torch.Size([100, 51865])
torch.Size([100])

token转文本输出

tokens = torch.argmax(out, dim=2)
for token in tokens:
    token[token == -100] = wtokenizer.eot
    text = wtokenizer.decode(token)
    print(text)
<|zh|><|translate|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙的避开了矛盾<|endoftext|><|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去,定河兩旁人身顎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>

构造trainer

class Config:
    learning_rate = 0.0001
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 2
    batch_size = 16
    num_worker = 2
    num_train_epochs = 1000
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE


class WhisperModelModule(LightningModule):
    def __init__(self, cfg: Config, model_name="small", lang="zh", train_dataset=[], eval_dataset=[]) -> None:
        super().__init__()
        self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)
        self.model = whisper.load_model(model_name)
        self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=self.options.task)

        # only decoder training
        for p in self.model.encoder.parameters():
            p.requires_grad = False

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics_wer = evaluate.load("wer")
        self.metrics_cer = evaluate.load("cer")

        self.cfg = cfg
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset

    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()

        with torch.no_grad():
            audio_features = self.model.encoder(input_ids)

        out = self.model.decoder(dec_input_ids, audio_features)
        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        self.log("train/loss", loss, on_step=False, on_epoch=True,  prog_bar=True, logger=True)
        return loss
    
    def on_train_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get("train/loss")
        
        # 获取当前的 epoch 数量
        epoch = self.current_epoch
        
        print(f"Epoch: {epoch}, Training - Loss: {avg_loss:.4f}")

    def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o))
            l_list.append(self.tokenizer.decode(l))
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        # 打印到终端
        #print(f"Validation - Loss: {loss:.4f}, WER: {wer:.4f}")

        return {
            "wer": wer,
            "loss": loss
        }
    
    def on_validation_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get("val/loss")
        avg_wer = self.trainer.callback_metrics.get("val/wer")
        
        # 获取当前的 epoch 数量
        epoch = self.current_epoch

        print(f"Epoch: {epoch}, Validation - Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}")
    

    def configure_optimizers(self):
        """创建优化程序和调度器 """
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters()
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters()
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.cfg.learning_rate,
                          eps=self.cfg.adam_epsilon)
        self.optimizer = optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.cfg.warmup_steps,
            num_training_steps=self.t_total
        )
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def setup(self, stage=None):
        """初始设置(读取数据集)"""

        if stage == 'fit' or stage is None:
            self.t_total = (
                    (len(self.__train_dataset) // (self.cfg.batch_size))
                    // self.cfg.gradient_accumulation_steps
                    * float(self.cfg.num_train_epochs)
            )

    def train_dataloader(self):
        """ 创建训练数据加载程序 """
        dataset = Th30Dataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=self.cfg.batch_size,
                                           drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
                                           collate_fn=WhisperDataCollatorWhithPadding()
                                           )

    def val_dataloader(self):
        """ 创建验证数据加载程序 """
        dataset = Th30Dataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=self.cfg.batch_size,
                                           num_workers=self.cfg.num_worker,
                                           collate_fn=WhisperDataCollatorWhithPadding()
                                           )  

主要是对LightningModule类相关方法的重载,定义了train、validate以及optimizer的行为,以及在训练过程中日志和相关信息、checkpoint的保存。

启动训练

log_output_dir = "./logs"
check_output_dir = "./artifacts"

train_name = "whisper"
train_id = "00001"

model_name = "small"
lang = "zh"


cfg = Config()

# os.mkdir(log_output_dir)
# os.mkdir(check_output_dir)

tflogger = TensorBoardLogger(
    save_dir=log_output_dir,
    name=train_name,
    version=train_id
)

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{check_output_dir}/checkpoint",
    filename="checkpoint-{epoch:04d}",
    save_top_k=2, # all model save
    save_on_train_epoch_end=False,
    monitor='val/wer',  # 需要监控的验证损失
    mode='min',  # 最小化 val_loss
    verbose=True  # 打印更多的信息到控制台
)
callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)

trainer = Trainer(
    precision=16,
    accelerator="gpu",
    max_epochs=cfg.num_train_epochs,
    check_val_every_n_epoch=2,
    accumulate_grad_batches=cfg.gradient_accumulation_steps,
    logger=tflogger,
    callbacks=callback_list
)

trainer.fit(model)
```shell
对于10小时数据集,你可能看到如下输出:

Epoch: 0, Training - Loss: 1.0872
Epoch: 1, Validation - Loss: 0.4207, WER: 0.9847
Epoch: 1, Training - Loss: 0.2955
Epoch: 2, Training - Loss: 0.5555
Epoch: 3, Validation - Loss: 0.2505, WER: 0.9006
Epoch: 3, Training - Loss: 0.0979
Epoch: 4, Training - Loss: 0.0602
Epoch: 5, Validation - Loss: 0.2889, WER: 0.8764
Epoch: 5, Training - Loss: 0.0721
Epoch: 6, Training - Loss: 0.0947
Epoch: 7, Validation - Loss: 0.3839, WER: 0.9809
Epoch: 7, Training - Loss: 0.1379

对于30小时数据集,即使用完整的th30数据,其中85%用于Training,而15%用于validation你可能看到如下输出:
```shell
3051.3s	121	Epoch: 1, Validation - Loss: 0.0588, WER: 0.3499
3061.1s	122	Epoch: 1, Training - Loss: 0.0340
4279.7s	123	Epoch: 2, Training - Loss: 0.0268
5691.4s	124	Epoch: 3, Validation - Loss: 0.0676, WER: 0.8318
5701.1s	125	Epoch: 3, Training - Loss: 0.0201
6919.2s	126	Epoch: 4, Training - Loss: 0.0257
8329.5s	127	Epoch: 5, Validation - Loss: 0.0484, WER: 0.8472
8329.5s	128	Epoch: 5, Training - Loss: 0.0144
9547.5s	129	Epoch: 6, Training - Loss: 0.1127
10959.3s	130	Epoch: 7, Validation - Loss: 0.0422, WER: 0.3982
10969.4s	131	Epoch: 7, Training - Loss: 0.0053
12188.2s	132	Epoch: 8, Training - Loss: 0.0076
13600.0s	133	Epoch: 9, Validation - Loss: 0.0482, WER: 0.8158
13600.0s	134	Epoch: 9, Training - Loss: 0.0126
14819.0s	135	Epoch: 10, Training - Loss: 0.0152
16230.8s	136	Epoch: 11, Validation - Loss: 0.0544, WER: 0.6829
16230.8s	137	Epoch: 11, Training - Loss: 0.0114
17450.1s	138	Epoch: 12, Training - Loss: 0.0174
18862.4s	139	Epoch: 13, Validation - Loss: 0.0523, WER: 0.3225
18872.1s	140	Epoch: 13, Training - Loss: 0.0117
20091.5s	141	Epoch: 14, Training - Loss: 0.0075
21503.2s	142	Epoch: 15, Validation - Loss: 0.0567, WER: 0.5187
21503.2s	143	Epoch: 15, Training - Loss: 0.0137
22722.5s	144	Epoch: 16, Training - Loss: 0.0150
24134.2s	145	Epoch: 17, Validation - Loss: 0.0631, WER: 0.4559
24134.2s	146	Epoch: 17, Training - Loss: 0.0122
25352.9s	147	Epoch: 18, Training - Loss: 0.0120
26765.0s	148	Epoch: 19, Validation - Loss: 0.0523, WER: 0.7387
26765.0s	149	Epoch: 19, Training - Loss: 0.0060
27983.9s	150	Epoch: 20, Training - Loss: 0.0154
29395.1s	151	Epoch: 21, Validation - Loss: 0.0520, WER: 0.4749
29395.1s	152	Epoch: 21, Training - Loss: 0.0073
30612.5s	153	Epoch: 22, Training - Loss: 0.6361
32022.4s	154	Epoch: 23, Validation - Loss: 0.0396, WER: 0.2912
32033.0s	155	Epoch: 23, Training - Loss: 0.0029
33250.8s	156	Epoch: 24, Training - Loss: 0.0036
34662.0s	157	Epoch: 25, Validation - Loss: 0.0461, WER: 0.6043
34662.0s	158	Epoch: 25, Training - Loss: 0.0094
35880.1s	159	Epoch: 26, Training - Loss: 0.0082
37291.0s	160	Epoch: 27, Validation - Loss: 0.0428, WER: 0.7481
37291.0s	161	Epoch: 27, Training - Loss: 0.0051
38509.7s	162	Epoch: 28, Training - Loss: 0.0075
39920.4s	163	Epoch: 29, Validation - Loss: 0.0447, WER: 0.8736
39920.4s	164	Epoch: 29, Training - Loss: 0.0091
41138.9s	165	Epoch: 30, Training - Loss: 0.0088
42549.9s	166	Epoch: 31, Validation - Loss: 0.0530, WER: 0.4500
42549.9s	167	Epoch: 31, Training - Loss: 0.0072

遗留的第二个问题

首先是数据集的问题,因为可以看到随着时长的增加,看到模型训练过程在符合预期方向走,

  1. 最低数据量:起步来说,至少需要几个小时的音频数据来进行有效的fine-tuning。例如,从10小时开始,这是一个相对较小的数据集,可以用来调试模型和流程。

  2. 中等数据量:为了获得更佳的效果,推荐使用20至50小时的音频数据。这可以帮助模型更好地学习到特定语言的特性。

  3. 理想数据量:如果资源允许,使用超过100小时的音频数据将更有助于模型性能的提升。更多的数据可以显著提高模型的泛化能力和准确性。

当然对于大模型,数据质量越高越好,数据多样性越多越好。
进一步通过tensorboard图可以看到:
请添加图片描述
在运行12个小时之后可以看到WER比一开始的确实下降了不少,但是还没有达到20%左右,最低的WER在0.2912,但是这里可以观察到一个非常有趣的现象:

在观察到训练损失(Training Loss)持续下降而验证损失(Validation Loss)和字错误率 (WER, Word Error Rate) 没有持续改善或波动较大的情况时,这通常是过拟合的一个迹象。在这种情况下,模型在训练数据上表现得越来越好,但在未见过的验证数据上的表现却没有相对应的提升,甚至出现恶化。

由于callback回调中会保持前两个在验证集上WER最小的两个checkpoint,接下来有几个思路:

1.分析模型在验证集上的错误,看是否存在特定模式或类型的错误,这可能帮助诊断问题并指导进一步模型调整,因为我们是在whipser开源的基础上fine-tune的,所以不可能简化模型结构的本身,如减少层数或神经元数目以改善过拟合。

2.可以考虑正则化技术(L2正则化、Dropout)等以有助于缓解过拟现象,增强模型的泛化能力

3.调整训练策略,调整学习率或者使用不同的优化器,以评估模型在验证集上的表现;

4.增加更多数据,帮助模型学习到更多特征,从而提高模型泛化能力

观察验证集识别效果

由于输出缩略或视觉上的相似性,一些小的差异(如标点、空白或特殊字符)可能不容易觉察。这些微小的差异在计算WER时会被考虑进去,但在人眼检查时可能会被忽略。
请添加图片描述
可以看到基本上是一致的,但是个别词是有出入的,这是因为th30是人工读的,准确性比较高,并不意味着通话、会议、游戏场景的识别率也能如此。
这里再留几个尾巴给读者自己实现:

CER

1.中文是基于字符的语言,通常我们会使用CER(Character Error Rate,字符错误率)来进行更精确的评估。然而,如果你使用的是WER来评估中文语音识别的质量,这里有几点可能需要注意:

  • 在处理中文时,如果WER是基于词的,就必须先进行准确的分词。中文没有明显的词与词之间的分隔,因此分词的准确性对于WER的计算非常关键。错误的分词可能导致高WER,即使识别的字符完全正确。
  • 中文中的一些微小差异,如同音字错误、词序变化或者是语气词的使用,都可以在视觉上看起来非常相似,但在WER的计算中会被视为错误。
  • 你查看的样本可能并不代表整体数据集的平均表现。此外,中文语音识别可能特别擅长处理某些特定的语句或者在某些领域表现更好。

数据质量

除了数据量之外,数据的质量也非常重要:

  • 多样性:数据应该涵盖多种口音、语速和语调,以及不同的背景噪声环境,这将帮助模型在各种输入条件下都能保持稳定的表现。
  • 标注准确性:确保你的数据标注尽可能准确,错误的标注会直接影响模型学习的结果。

预处理和增强

  • 预处理:对音频进行预处理,如采样率转换(确保和模型训练时使用的采样率一致),音量标准化等。
  • 数据增强:可以考虑使用音频数据增强技术,如添加背景噪声、改变语速和音高等,以增加模型的鲁棒性。

资源和迭代

  • 计算资源:fine-tuning一个语音识别模型可能需要大量的计算资源,特别是当使用大量数据时。确保你有足够的GPU资源进行训练。
  • 迭代和评估:在fine-tuning过程中需要多次迭代和评估,以找到最优的模型参数和设置。

总结来说,训练的过程如炼丹,有些训练的经验是不能从小模型直接用到大模型上的。比如small和对于large-v3两种。
在模型相对较小的时候,learning rate的设置可以比较激进,但是对非常大模型的时候,较大的lr可能导致模型一开始loss就无法收敛,是发散的,但如果设置的lr比较小,那可能使得训练的时长成倍增加,怎么办呢?针对很大的模型,warm-up策略是很多时候会使用的。

load weight and inference

checkpoint_path = "whisper-checkpoint/checkpoint-epoch0023.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
/tmp/ipykernel_36/4099222220.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(checkpoint_path)
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

加载模型参数

cfg = Config()
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 87.6MiB/s]
Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]
Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]
<All keys matched successfully>

前向推理

woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

refs = []
res = []
for b in tqdm(loader):
    input_ids = b["input_ids"].half().cuda()
    labels = b["labels"].long().cuda()
    with torch.no_grad():
        #audio_features = whisper_model.model.encoder(input_ids)
        #out = whisper_model.model.decoder(enc_input_ids, audio_features)
        results = whisper_model.model.decode(input_ids, woptions)
        for r in results:
            res.append(r.text)
        
        for l in labels:
            l[l == -100] = wtokenizer.eot
            ref = wtokenizer.decode(l)
            refs.append(ref)
     ```
     打印推理结果
     ```
     for k, v in zip(refs, res):
    print("-"*10)
    print(k)
    print(v)

部分输出结果

  ----------
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙地避开了矛盾
----------
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
----------
<|zh|><|transcribe|><|notimestamps|>旅与游的时间比往往旅长游短与游客的愿望相悖<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
旅与游的时间比往往旅长游短与游客的愿望相悖
----------
<|zh|><|transcribe|><|notimestamps|>中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职<|endoftext|>
中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职
----------
<|zh|><|transcribe|><|notimestamps|>该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等<|endoftext|>
该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等
----------
<|zh|><|transcribe|><|notimestamps|>何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品<|endoftext|><|endoftext|>
何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品
----------
<|zh|><|transcribe|><|notimestamps|>印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明<|endoftext|>
印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明
----------
<|zh|><|transcribe|><|notimestamps|>今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下
----------
<|zh|><|transcribe|><|notimestamps|>亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军<|endoftext|>
亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军
----------
<|zh|><|transcribe|><|notimestamps|>小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤<|endoftext|><|endoftext|><|endoftext|>
小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤
----------
<|zh|><|transcribe|><|notimestamps|>这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血
----------
<|zh|><|transcribe|><|notimestamps|>驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声<|endoftext|>
驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声
----------
<|zh|><|transcribe|><|notimestamps|>其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮运而享誉海内外<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮存而享誉海内外
----------
<|zh|><|transcribe|><|notimestamps|>一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人<|endoftext|>
一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人
----------
<|zh|><|transcribe|><|notimestamps|>当船往下漂时白唇鹿扬起四蹄在岸边追随好像是送行一直跑了十几里太亲切了<|endoftext|>
当船往下漂时白唇鹿扬起四蹄在岸边追赶好像是送行一直跑了十几里太亲切了
----------
<|zh|><|transcribe|><|notimestamps|>有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额
----------
<|zh|><|transcribe|><|notimestamps|>女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍<|endoftext|>
女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍
----------
<|zh|><|transcribe|><|notimestamps|>日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌<|endoftext|><|endoftext|><|endoftext|>
日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌
----------
<|zh|><|transcribe|><|notimestamps|>如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉<|endoftext|><|endoftext|><|endoftext|>
如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉
  ```
接下来还有三个问题对于应用更需要细致考虑:
1.Whisper除了识别,还有直接翻译功能,在以前要先识别成中文,再汉译英等,这个好处是显而易见的,首先只要一个模型,节约部分人力、机器以及服务端GPU,业务场景上可以是会议的实时翻译、看英文视频实时翻译成中文,这会减少latency,用户体验也更好;
2.如何在实时的流式场景中使用?
3.kv-caching是个什么技术?12倍是如何做到的?这在工程部署商用价值非常大。

欢迎点赞、收藏、关注,以便及时收到下一篇推送。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2085836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【netty系列-08】深入Netty组件底层原理和基本实现

Netty系列整体栏目 内容链接地址【一】深入理解网络通信基本原理和tcp/ip协议https://zhenghuisheng.blog.csdn.net/article/details/136359640【二】深入理解Socket本质和BIOhttps://zhenghuisheng.blog.csdn.net/article/details/136549478【三】深入理解NIO的基本原理和底层…

第一个go程序

package main import "fmt" func main(){fmt.Println("Hello World") }hello.go所在目录 运行go程序

美团代付多模版三合一源码 附教程

简介 美团代付多模版三合一源码 附教程 界面

这一轮医疗数字化,沃可趣以医疗专业人员交流成长为中心

沃可趣看见医疗行业人员需求痛点&#xff0c;量身打造数字服务平台&#xff0c;促进知识分享&#xff0c;赋能活动组织。 从电子病历的普及到远程医疗的兴起&#xff0c;从人工智能辅助诊断到大数据在医疗管理中的应用&#xff0c;科技进步正在大力推动医疗领域的发展。然而&a…

Ubuntu系统本地搭建WordPress网站并一键发布内网站点至公网实战

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

嵌入式技术文件、学习资料、在线工具、学习网站、技术论坛,非常全面的分享~~~

一、数据手册查询及下载网站 1、【ALLDATASHEET 自称是最大的在线电子元件数据的搜索引擎】ALLDATASHEETCN.COM - 电子元件和半导体及其他半导体的数据表搜索网站。电子元件和半导体, 集成电路, 二极管, 三端双向可控硅 和其他半导体的https://www.alldatasheetcn.com/ 2、【…

defineProps、defineEmits、 defineExpose的TS写法

小满视频 1 defineProps&#xff1a;父向子传递数据 作用&#xff1a;父组件向子组件传递数据 1.1 传递纯类型参数的方式来声明 父组件中的代码&#xff1a; 父组件App.vue <template><div><span>传递给子组件的响应式数据&#xff1a;</span>&l…

【循环顺序队的实现】

1.队列的逻辑结构 与 抽象数据类型定义 先进先出的线性表 在顺序队列中&#xff0c;我们使用头指针front指向队首元素&#xff1b;用尾指针rear指向队尾元素的下一个位置&#xff08;当然这里的指针是用下标模拟出来的&#xff09; 同时顺序队列中的元素当然是用数组来存储的 …

【系统架构设计】嵌入式系统设计(1)

【系统架构设计】嵌入式系统设计&#xff08;1&#xff09; 嵌入式系统概论嵌入式系统的组成硬件嵌入式处理器总线存储器I/O 设备与接口 软件 嵌入式开发平台与调试环境交叉平台开发环境交叉编译环境调试 嵌入式网络系统嵌入式数据库管理系统实时系统与嵌入式操作系统嵌入式系统…

【Qt笔记】QToolButton控件详解

目录 一、前言 二、QToolButton的基本特性 2.1 图标和文本 2.2 自动提升 2.3 下拉菜单 2.4 工具提示 2.5 弹出模式 三、高级功能 3.1 自定义大小与形状 3.2 检查框与单选按钮 3.3 动画效果 四、常用方法与信号槽 常用方法 信号槽 五、实际应用示例 说明 六、总…

ESP32 CYD 使用 LVGL 在屏幕上显示图像 | Random Nerd Tutorials

在本指南中&#xff0c;你将学习如何使用LVGL在ESP32 Cheap Yellow Display (CYD) 板上处理和加载图像。ESP32将使用Arduino IDE进行编程。 对ESP32 Cheap Yellow Display不熟悉&#xff1f; 从这里开始&#xff1a;开始使用ESP32 Cheap Yellow Display Board – CYD (ESP32-2…

线性代数 第三讲 线性相关无关 线性表示

线性代数 第三讲 线性相关无关 线性表示 文章目录 线性代数 第三讲 线性相关无关 线性表示1.向量运算1.线性相关与线性无关1.1 线性相关与线性无关基本概念 2.线性表示&#xff08;线性组合&#xff09;3.线性相关无关与线性表示的定理大总结3.1 向量β可由向量组线性表出的同义…

心觉:潜意识显化很简单,只是很多人想复杂了

很多人知道潜意识的力量很大&#xff0c;是意识力量的30000倍以上 也知道该怎么显化自己的潜意识 但是就是做不到 这就像很多肥胖的人知道运动可以减肥 知道减肥之后就可以穿漂亮的衣服 知道减肥之后自己有多帅多美 但是就是迈不开腿 根本原因是你的潜意识和意识上的认知不…

RenderMan v26.2更新内容!云渲染平台支持新版本

皮克斯的最新RenderMan v26.2版本带来了一系列激动人心的新特性和改进&#xff0c;进一步巩固了其在高端渲染领域的领导地位&#xff0c;为艺术家们提供了更丰富的创意工具和更流畅的工作流程。作为老牌的云渲染农场&#xff0c;瑞云依然支持新版本的使用。 RenderMan v26.2更新…

移动端视频编辑SDK解决方案,关键帧曲线塑造生动效果

美摄科技&#xff0c;作为移动视频编辑技术的领航者&#xff0c;携其革命性的移动端视频编辑SDK解决方案&#xff0c;正以前所未有的创新力&#xff0c;为视频创作者们开启了一扇通往无限创意的大门。 重塑视频编辑体验&#xff0c;让创意触手可及 美摄科技的移动端视频编辑S…

公网信息泄露监测(网盘、暗网、搜索引擎、文档平台)思路分享

一、背景 众测项目中白帽可能会提交一些信息泄露漏洞&#xff0c;同时甲方可会收到一些白帽提交的公网信息泄露文件漏洞&#xff0c;例如百度网盘被员工分享某些文件或者某些包含敏感信息的文件可以通过如谷歌、百度等搜索引擎通过特定语法搜索到。为了可以及时发现泄露的文件…

九泰智库 | 医械周刊- Vol.54

⚖️ 法规动态 国家药监局综合司发布医疗器械管理法草案征求意见稿 国家药监局综合司发布了《中华人民共和国医疗器械管理法&#xff08;草案征求意见稿&#xff09;》&#xff0c;公开征求意见&#xff0c;以加强医疗器械的管理并推动产业高质量发展。该草案共十一章190条&a…

深入解析财务报表:掌握重要财务指标的技巧

一、概述 财务报表中有大量信息&#xff0c;如果我们在分析时缺乏明确的方向或忽视了重点&#xff0c;就很容易在繁杂的数据中迷失方向。本文将深入探讨财务报表中的几个重要指标&#xff0c;帮助大家更有针对性地理解这些内容&#xff0c;包括如何分析资产负债率、解读净资产…

基于python的web框架 Flask 入门基础知识【1】

Flask是一个轻量级的可定制框架&#xff0c;使用Python语言编写&#xff0c;较其他同类型框架更为灵活、轻便、安全且容易上手。 目录 一、项目环境搭配以及安装运行 1.下载安装 2.最小的应用 3.运行应用 4.运行结果 4.1 外部可见的服务器 二、路由 三、http请求 3.1静…

无人机的核心技术!!!

无人机的核心技术涵盖了多个关键领域&#xff0c;这些技术共同支撑了无人机的稳定飞行、精准控制、高效数据传输以及多样化的应用功能。 1. 飞行控制技术 核心地位&#xff1a;飞行控制技术是无人机的核心关键技术之一&#xff0c;它确保了无人机在复杂飞行环境下的稳定性和安…