今天,以色列人工智能初创公司 aiOla 宣布推出一种新的开源语音识别模型,其速度比 OpenAI 著名的 Whisper 快 50%。
该模型被正式命名为 Whisper-Medusa,它建立在 Whisper 的基础上,但使用了一种新颖的 "多头关注 "架构,一次预测的标记数量远远超过 OpenAI 的产品。该模型的代码和权重已根据 MIT 许可在 Hugging Face 上发布,允许研究和商业使用。
"aiOla的研究副总裁吉尔-赫兹(Gill Hetz)告诉VentureBeat说:"通过开源发布我们的解决方案,我们鼓励社区内的进一步创新与合作,随着开发人员和研究人员对我们的工作做出贡献并在此基础上进行改进,我们的速度会得到更大的提高和完善。
这项工作可以为复合人工智能系统铺平道路,该系统几乎可以实时理解和回答用户提出的任何问题。
aiOla Whisper-Medusa 的独特之处是什么?
即使在基础模型能够产生多样化内容的时代,高级语音识别仍然具有很强的现实意义。这项技术不仅正在推动医疗保健和金融科技等领域的关键功能–帮助完成转录等任务,而且还在为功能强大的多模态人工智能系统提供动力。去年,该领域的领导者 OpenAI 通过开发自己的 Whisper 模型开始了这一征程。它将用户音频转换成文本,让 LLM 处理查询并提供答案,再将答案转换回语音。
由于 Whisper 能够几乎实时地处理不同语言和口音的复杂语音,它已成为语音识别领域的黄金标准,每月下载量超过 500 万次,并为数以万计的应用程序提供支持。
但是,如果一个模型能比 Whisper 更快地识别和转录语音呢?这就是aiOla新推出的Whisper-Medusa产品所要实现的目标–为语音到文本的无缝转换铺平道路。
为了开发 Whisper-Medusa,该公司修改了 Whisper 的架构,增加了多头注意力机制–众所周知,该机制允许模型通过并行使用多个 “注意力头”,共同关注来自不同位置的不同表示子空间的信息。 结构的改变使模型能够每次预测十个词组,而不是标准的一次预测一个词组,最终使语音预测速度和生成运行时间提高了 50%。
更重要的是,由于 Whisper-Medusa 的主干系统建立在 Whisper 的基础上,因此速度的提高并不会以性能的降低为代价。这款新产品转录文本的准确度与原来的 Whisper 不相上下。Hetz 指出,他们是业内首家成功将该方法应用于 ASR 模型并向公众开放以进一步研究和开发的公司。
"提高 LLM 的速度和延迟比自动语音识别系统要容易得多。由于处理连续音频信号和处理噪音或口音的复杂性,编码器和解码器架构面临着独特的挑战。他说:"我们通过采用新颖的多头注意力方法来应对这些挑战,从而使模型的预测速度提高了近一倍,同时保持了 Whisper 的高准确度。
如何训练语音识别模型?
在训练 Whisper-Medusa 时,aiOla 采用了一种称为弱监督的机器学习方法。作为其中的一部分,它冻结了 Whisper 的主要组件,并使用模型生成的音频转录作为标签来训练额外的标记预测模块。
赫兹告诉 VentureBeat,他们最初使用的是 10 头模型,但很快就会扩展到更大的 20 头版本,能够一次预测 20 个标记,从而在不降低准确性的情况下加快识别和转录速度。
"我们选择对模型进行训练,以便每次预测 10 个词组,从而在保持准确性的同时大幅提高了速度,但同样的方法也可用于在每一步中预测任意数量的词组。研究副总裁解释说:"由于 Whisper 模型的解码器是一次性处理整个语音音频,而不是逐段处理,因此我们的方法减少了多次处理数据的需要,有效地加快了速度。
当被问及是否有公司可以提前使用 Whisper-Medusa 时,Hetz 没有多说。不过,他也指出,他们已经在真实的企业数据使用案例中测试了这一新颖的模型,以确保其在真实场景中的准确表现。最终,他相信识别和转录速度的提高将加快语音应用的周转时间,并为提供实时响应铺平道路。想象一下,Alexa 能在几秒钟内识别您的命令并返回预期的答案。
"任何涉及实时语音到文本功能的解决方案,如对话语音应用中的解决方案,都将使业界受益匪浅。个人和公司可以提高生产率,降低运营成本,并更及时地传送内容,"Hetz 补充说。
代码
Github: https://github.com/aiola-lab/whisper-medusa
import torch
import torchaudio
from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor
model_name = "aiola/whisper-medusa-v1"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
path_to_audio = "path/to/audio.wav"
SAMPLING_RATE = 16000
language = "en"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_speech, sr = torchaudio.load(path_to_audio)
if sr != SAMPLING_RATE:
input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
input_features = input_features.to(device)
model = model.to(device)
model_output = model.generate(
input_features,
language=language,
)
predict_ids = model_output[0]
pred = processor.decode(predict_ids, skip_special_tokens=True)
print(pred)
感谢大家花时间阅读我的文章,你们的支持是我不断前进的动力。期望未来能为大家带来更多有价值的内容,请多多关注我的动态!