Whisper简介
Whisper是OpenAI于2022年9月开源的一个多语种识别模型,目前支持99种语言,是目前性能最好的开源多语种识别ASR大模型,第一版版使用了68万小时标注好的语料预训练模型,而large-v3的标注数据超过了500万小时,其paper中并没透露使用语料的详细来源,估计是爬了一些版权数据,在Huggingface上提到模型有很强的泛化能力,能够在未经特定训练的情况下处理新的语言或任务,同时可以使用fine-tune的方式提高特定语言的识别性能。
开源的Whisper情况如下:
Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed | Layers | Width | Heads |
---|---|---|---|---|---|---|---|---|
tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x | 4 | 384 | 6 |
base | 74 M | base.en | base | ~1 GB | ~16x | 6 | 512 | 8 |
small | 244 M | small.en | small | ~2 GB | ~6x | 12 | 768 | 12 |
medium | 769 M | medium.en | medium | ~5 GB | ~2x | 24 | 1024 | 16 |
large | 1550 M | N/A | large (2022.09) | ~10 GB | 1x | 32 | 1280 | 20 |
large-v2 | 1550 M | N/A | large-v2(2022.12) | ~10 GB | 1x | 32 | 1280 | 20 |
large-v3 | 1550 M | N/A | Large-v3(2023.11) | ~10 GB | 1x | 32 | 1280 | 20 |
只有large-v3是23年底开源的模型,在Encoder和Decoder上同large和large-v2是一样的,以下两点有差异:
-
梅尔谱从80个频点增加到了128个频点
-
为粤语增加了token方法
large-v3在100万小时的弱标注,以及基于large-v2模型的400万标注的结果,总共500万小时音频数据训练了两个2epoch。和large-v2模型相比,large-v3在多种语言上显示性能有提升,大概从10%~20%不等的WER的提升。
但这并不意味着在任何时候v3的模型一定是比v2好的,比如:
I am currently working on a project where my objective is to transcribe audio calls from various languages into English. Until now, our application has been utilizing the large-v2 model, and we are considering migrating to the large-v3 model. However, upon testing both the large-v2 and large-v3 models on a set of 20 audio files, I observed that the large-v2 model generally produces better output compared to the large-v3 model, except in two instances where the large-v3 model performed better. Large-v2 transcripts are better by around 20 - 30%.
I am trying to understand if there’s something I might be overlooking. The large-v3 model is purported to be an improvement, yet in my experience, it seems to be the opposite.
For reference, I am using the code provided for the large-v3 model, which can be found here: huggingface[.]co/openai/whisper-large-v3.
这意味想要针对业务场景使用,可能需要对各个场景和模型进行评估,并且在必要的时候进行fine-tune以及增加音频的前处理。本篇就在介绍Whisper的基础上,介绍基于Huggingface Transformer的fine-tune过程。
Whisper模型结构
Whisper是基于Transformer的编码器-解码器模型,也被称为序列到序列模型。它将音频频谱特征映射为文本token,然后再转为文本。
- 首先,原始音频输入通过特征提取器的作用转换为log-Mel频谱图,v2及以下提取80个梅尔频点,v3提取128个梅尔频点;
- 然后,Transformer编码器对频谱进行编码,形成一系列编码器隐藏状态(hidden states)。
- 最后,解码器通过自回归方式预测文本标记,条件是前面的标记和编码器隐藏状态。
在序列到序列模型中,编码器将音频输入转换为一组隐藏状态表示,提取音频重要特征。解码器扮演语言模型的角色,处理隐藏状态表示并生成相应的文本转录。在系统架构内部引入语言模型称为深度融合。这和早一代的CTC + 𝑛n-gram有巨大的性能提升、人力节省、系统复杂度降低。
通过深度融合,整个系统可以端对端地使用相同的训练数据和损失函数进行训练,从而提供更大的灵活性和通常更出色的性能。
图一:* Whisper 模型. 模型结构是标准的Transformer架构的encoder-decoder 模型. encoder输入是log-Mel 谱. decoder 的输入是文本token,其通过 cross-attention 和encoder输出的hidden states相关联。
自回归decoder预测文本 tokens文本序列的格式如下:
之所以称为自回归,是因为Decoder的输入又给到Decoder自身的输入端了。
Huggingface 工具fine-tune
这里将使用几个流行的Python包来对Whisper模型进行微调。使用Huggingface提供的datasets[audio]下载和准备训练数据,同时使用transformers和加速器来加载和fine-tune Whisper-small模型。此外还需要soundfile包来预处理音频文件,评估和jiwer以评估fine-tune后模型的性能,以及tensorboard来记录我们的度量标准。最后,使用gradio来构建我们微调模型的演示。
首先是GPU的选择,对于V100/16GB GPU
首先是GPU的选择,对于V100/16GB GPU
Model | Train Batch Size | Gradient Acc Steps | Eval Batch size |
---|---|---|---|
small | 16 | 2 | 8 |
medium | 2 | 16 | 1 |
使用NVIDIA V100 GPU进行模型训练时,建议使用“small”模型:由于硬件限制和性能考虑,建议在V100 GPU上运行较小的模型配置,即“small”模型。因为colab/kaggle的免费GPU时长受限,而不是使用“medium”模型,免费GPU训练时间没有那么长且效率较低。这种做法可以更高效地利用硬件资源,同时达到合理的训练效果。
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers==4.38.2 accelerate evaluate jiwer tensorboard gradio
Requirement already satisfied: pip in /opt/conda/lib/python3.10/site-packages (23.3.2)
Collecting pip
Downloading pip-24.2-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-24.2-py3-none-any.whl (1.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 28.0 MB/s eta 0:00:00a 0:00:01
Installing collected packages: pip
Attempting uninstall: pip
Found existing installation: pip 23.3.2
Uninstalling pip-23.3.2:
Successfully uninstalled pip-23.3.2
Successfully installed pip-24.2
Collecting transformers==4.38.2
Downloading transformers-4.38.2-py3-none-any.whl.metadata (130 kB)
Requirement already satisfied: accelerate in /opt/conda/lib/python3.10/site-packages (0.32.1)
Collecting accelerate
Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Collecting evaluate
Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Collecting jiwer
Downloading jiwer-3.0.4-py3-none-any.whl.metadata (2.6 kB)
Requirement already satisfied: tensorboard in /opt/conda/lib/python3.10/site-packages (2.15.1)
Collecting tensorboard
Downloading tensorboard-2.17.1-py3-none-any.whl.metadata (1.6 kB)
Collecting gradio
Downloading gradio-4.41.0-py3-none-any.whl.metadata (15 kB)
...
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.4.1 which is incompatible.
tensorflow 2.15.0 requires tensorboard<2.16,>=2.15, but you have tensorboard 2.17.1 which is incompatible.
ydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.
Successfully installed accelerate-0.33.0 datasets-2.21.0 evaluate-0.4.2 ffmpy-0.4.0 gradio-4.41.0 gradio-client-1.3.0 jiwer-3.0.4 python-multipart-0.0.9 rapidfuzz-3.9.6 ruff-0.6.1 semantic-version-2.10.0 soxr-0.4.0 tensorboard-2.17.1 tokenizers-0.15.2 tomlkit-0.12.0 transformers-4.38.2 urllib3-2.1.0
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
使用的small需要的一些模型文件和配置信息在Huggingface上。
登录Huggingface账号
from huggingface_hub import notebook_login
notebook_login()
数据集
采用https://commonvoice.mozilla.org/zh-CN/datasets
因为这个数据集比较小,训练时间会比较快一些
https://huggingface.co/datasets/mozilla-foundation/common_voice_7_0/viewer
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_7_0", "zh-CN", split="train+validation",trust_remote_code=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_7_0", "zh-CN", split="test", trust_remote_code=True)
print(common_voice)
Downloading builder script: 0%| | 0.00/11.5k [00:00<?, ?B/s]
Downloading readme: 0%| | 0.00/11.3k [00:00<?, ?B/s]
Downloading extra modules: 0%| | 0.00/3.29k [00:00<?, ?B/s]
Downloading extra modules: 0%| | 0.00/49.7k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/2.37G [00:00<?, ?B/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating test split: 0 examples [00:00, ? examples/s]
Generating validation split: 0 examples [00:00, ? examples/s]
Generating other split: 0 examples [00:00, ? examples/s]
Generating invalidated split: 0 examples [00:00, ? examples/s]
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 30617
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 9338
})
})
去除无用的字段
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
print(common_voice)
DatasetDict({
train: Dataset({
features: ['audio', 'sentence'],
num_rows: 30617
})
test: Dataset({
features: ['audio', 'sentence'],
num_rows: 9338
})
})
特征提取和分词
Whisper的ASR可以分为三个主要部分:
-
一个特征提取器WhisperFeatureExtractor,用于预处理原始音频输入
-
执行序列到序列映射的模型
-
一个分词器WhisperTokenizer ,用于将模型输出后处理为文本格式
Whisper特征提取器接收16kHz采样率的音频,对于语音而言16kHz已经够用,而音乐类模型往往要到44.1kHz且双声道,Whisper 特征提取分两个步骤,首先对音频样本进行填充/截断,使所有样本的输入长度为30秒。少于30秒的样本将填充到30秒,在序列末尾附加零(音频信号中的零对应于无信号或静音)来实现,长于30秒的样本将被截断为30秒。
由于批处理中的所有元素都被填充/截断到输入空间中的最大长度,因此在将音频输入转发到 Whisper 模型时,不需要注意力掩码(attention mask )。这在 Whisper 上是独特的 - 对于大多数音频模型,会使用attention mask标记填充的部分,因此在自注意机制中(self-attention mechanism)应该忽略哪些位置。Whisper 被训练为在没有attention mask的情况下运行,并直接从语音信号中推断哪些输入需要被忽略。
Whisper 特征提取器的第二步是将填充的音频数组转换为对数梅尔频谱图。这些频谱图是信号频率的视觉表示,类似于傅里叶变换。图2显示了一个示例频谱图。沿 y 轴是梅尔通道,这些通道对应特定频率区段。沿 x 轴是时间。每个像素的颜色对应于给定时间该频率区段的对数强度。对数梅尔频谱图是 Whisper 模型预期输入的形式。
特征提取和分词
Whisper的ASR可以分为三个主要部分:
-
一个特征提取器WhisperFeatureExtractor,用于预处理原始音频输入
-
执行序列到序列映射的模型
-
一个分词器WhisperTokenizer ,用于将模型输出后处理为文本格式
Whisper特征提取器接收16kHz采样率的音频,对于语音而言16kHz已经够用,而音乐类模型往往要到44.1kHz且双声道,Whisper 特征提取分两个步骤,首先对音频样本进行填充/截断,使所有样本的输入长度为30秒。少于30秒的样本将填充到30秒,在序列末尾附加零(音频信号中的零对应于无信号或静音)来实现,长于30秒的样本将被截断为30秒。
由于批处理中的所有元素都被填充/截断到输入空间中的最大长度,因此在将音频输入转发到 Whisper 模型时,不需要注意力掩码(attention mask )。这在 Whisper 上是独特的 - 对于大多数音频模型,会使用attention mask标记填充的部分,因此在自注意机制中(self-attention mechanism)应该忽略哪些位置。Whisper 被训练为在没有attention mask的情况下运行,并直接从语音信号中推断哪些输入需要被忽略。
Whisper 特征提取器的第二步是将填充的音频数组转换为对数梅尔频谱图。这些频谱图是信号频率的视觉表示,类似于傅里叶变换。图2显示了一个示例频谱图。沿 y 轴是梅尔通道,这些通道对应特定频率区段。沿 x 轴是时间。每个像素的颜色对应于给定时间该频率区段的对数强度。对数梅尔频谱图是 Whisper 模型预期输入的形式。
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
preprocessor_config.json: 0%| | 0.00/185k [00:00<?, ?B/s]
WhisperTokenizer
Whisper 模型输出是token(其实就是整数),这些token对应于词汇表中的索引。tokenizer将一系列token映射到实际文本字符串(例如 [1169, 3797, 3332] -> “the cat sat”).
Whisper tokenizer是在96种(v3支持99种)语言上预训练的,tokenizer的方法对Unicode编码格式的文本采用的是byte-pair 进行token化,对于中文简体,可以不加改动的使用。
使用的small需要的一些模型文件和配置信息在Huggingface上。
WhisperTokenizer
Whisper 模型输出是token(其实就是整数),这些token对应于词汇表中的索引。tokenizer将一系列token映射到实际文本字符串(例如 [1169, 3797, 3332] -> “the cat sat”).
Whisper tokenizer是在96种语言上预训练的,tokenizer的方法对Unicode编码格式的文本采用的是byte-pair 进行token化,对于中文简体,可以不加改动的使用。
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="mandarin", task="transcribe")
tokenizer_config.json: 0%| | 0.00/283k [00:00<?, ?B/s]
vocab.json: 0%| | 0.00/836k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/2.48M [00:00<?, ?B/s]
merges.txt: 0%| | 0.00/494k [00:00<?, ?B/s]
normalizer.json: 0%| | 0.00/52.7k [00:00<?, ?B/s]
added_tokens.json: 0%| | 0.00/34.6k [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/2.19k [00:00<?, ?B/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
可以验证一下tokenizer是否是正确的:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Input: {input_str}")
print(f"Decoded w/ special: {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")
2024-08-20 01:05:15.915835: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-20 01:05:15.916000: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-20 01:05:16.054376: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Input: 汉元鼎六年,武帝平定南越国,南越之地重新划郡,番禺仍为南海郡治。
Decoded w/ special: <|startoftranscript|><|zh|><|transcribe|><|notimestamps|>汉元鼎六年,武帝平定南越国,南越之地重新划郡,番禺仍为南海郡治。<|endoftext|>
Decoded w/out special: 汉元鼎六年,武帝平定南越国,南越之地重新划郡,番禺仍为南海郡治。
Are equal: True
WhisperProcessor
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="mandarin", task="transcribe")
首先可以查看一下数据集信息:
print(common_voice["train"][0])
{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/b7cbd58d9ce8dac7731e276330d736a60c07273e298acfd44d96df4a02156a16/cv-corpus-7.0-2021-07-21/zh-CN/clips/common_voice_zh-CN_18531536.mp3', 'array': array([ 0.00000000e+00, 9.86623760e-13, 1.28757139e-15, ...,
2.35939888e-06, -9.32492367e-06, -6.35876040e-06]), 'sampling_rate': 48000}, 'sentence': '汉元鼎六年,武帝平定南越国,南越之地重新划郡,番禺仍为南海郡治。'}
可以看到数据的关系,采样率是48kHz的,因为Whisper需要16kHz的输入,所以这里先进行重采样.
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
我们可以使用数据集的 .map 方法将数据准备函数应用于所有训练示例:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
Map (num_proc=4): 0%| | 0/30617 [00:00<?, ? examples/s]
Map (num_proc=4): 0%| | 0/9338 [00:00<?, ? examples/s]
训练和评估
将用Trainer实现训练的pipeline,具体步骤如下:
- 加载一个预训练的checkpoint,
- 定义一个data collator,data collator接收前面预处理的数据并转成Pytorch后面训练模型时的张量格式
- 评估标准,使用 word error rate (WER)
- 定义训练的参数
加载一个预训练checkpoint
检查挂载路径是否正确
import os
checkpoint_path = '/kaggle/working/whisper-small-zh/checkpoint-1000'
if os.path.exists(checkpoint_path):
print(f"Directory exists: {checkpoint_path}")
for file in os.listdir(checkpoint_path):
print(file)
else:
print(f"Directory does not exist: {checkpoint_path}")
Directory exists: /kaggle/working/whisper-small-zh/checkpoint-1000
scheduler.pt
generation_config.json
training_args.bin
config.json
preprocessor_config.json
trainer_state.json
model.safetensors
optimizer.pt
rng_state.pth
如果没有checkpoint,则可以从Huggingface上下载
from transformers import WhisperForConditionalGeneration
#从Huggingface上加载模型
#model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
#从checkpoint加载模型
model = WhisperForConditionalGeneration.from_pretrained("/kaggle/working/whisper-small-zh/checkpoint-1000")
在推理时间,Whisper 模型会自动检测来源音频的语言,并在该语言的token id。在已知源音频语言的情况下,例如多语言微调,明确设置语言是有益的。这样做可以避免当预测的语言不正确时,导致生成过程中预测文本偏离真实语言。
model.generation_config.language = "mandarin"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
data collator
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
评估标准
定义评估集中的评估标准,使用WER标准。
import evaluate
metric = evaluate.load("wer")
Downloading builder script: 0%| | 0.00/4.49k [00:00<?, ?B/s]
然后定义compute_metrics()函数接收模型的预测输出并返回WER,首先使用-100替换pad_token_id(在data collator时为了计算loss时忽略padded tokens的反向处理)。然后将id解码成字符串。
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
定义参数
- output_dir:模型权重存储的位置
- generation_max_length:在评估时,自回归生成token的最大值
- save_steps:在训练时,中间的checkpoint将在save_step时被保存
- eval_steps:在训练时,每eval_steps到达时会评估模型性能
- report_to:训练的log保持的位置
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-zh", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=2000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=100,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
)
将模型、数据集以及评估封装成一个trainer:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
训练
trainer.train()
## training results
Training Loss Epoch Step Validation Loss Wer
0.5179 2.02 1000 0.3333 72.9831
0.1273 4.04 2000 0.3562 73.9621
0.0163 6.06 3000 0.3790 73.9708
0.004 8.07 4000 0.3946 72.3626
0.025 11.0 5000 0.4019 72.6772
save model
from transformers import pipeline, AutoModel, AutoTokenizer
pipe = pipeline(task='summarization', # replace with whatever task you have
model=model,
tokenizer=tokenizer)
pipe.save_pretrained("my_local_path")
离线识别
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="my_local_path") # change to
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small zh",
description="Realtime demo for zh speech recognition using a fine-tuned Whisper small model.",
)
iface.launch()
一般来说对于商用的ASR模型,Validation Loss Wer要在25以下,好的可以到10+%,那这就衍生出几个问题来:
1.如果不是Huggingface上可以下载的数据该怎么办?
2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?
3.针对于影视、短视频字幕可以使用Whisper,那么对于实时的视频会议场景,Whisper怎么流式处理?
4.可以在4090消费级显卡上,将推理速度提升12倍的kv-caching技术是什么,Huggingface上WhisperSpeech/WhisperSpeech对推理的改进是哪些?
我们将在下一篇文章揭露1和2,欢迎关注、点赞以便及时收到推送。