whisper 模型源码解读

news2024/10/5 19:19:54

在这里插入图片描述

whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function

@dataclass
class ModelDimensions:
    """
    该类用于存储模型的各项参数
    """
    n_mels: int  # Mel谱图的频带数量
    n_audio_ctx: int  # 音频上下文窗口大小
    n_audio_state: int  # 音频状态维度
    n_audio_head: int  # 音频注意力头数量
    n_audio_layer: int  # 音频层数量
    n_vocab: int  # 词汇表大小
    n_text_ctx: int  # 文本上下文窗口大小
    n_text_state: int  # 文本状态维度
    n_text_head: int  # 文本注意力头数量
    n_text_layer: int  # 文本层数量

class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保输入张量的类型在归一化前后保持一致
        """
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保权重和偏置与输入张量的类型一致
        """
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        """
        重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致
        """
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )

def sinusoids(length, channels, max_timescale=10000):
    """
    生成用于位置嵌入的正弦曲线
    """
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        """
        初始化多头注意力层
        """
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        多头注意力的前向传播
        """
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # 如果没有缓存键和值,则正常计算
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # 如果有缓存,则使用缓存的键和值
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        """
        计算 QKV 注意力
        """
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        """
        初始化残差注意力块
        """
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        残差注意力块的前向传播
        """
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化音频编码器
        """
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        """
        前向传播,处理音频输入

        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            音频的Mel谱图
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化文本解码器
        """
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        前向传播,处理文本输入并结合音频特征

        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            文本的标

记序列
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            编码后的音频特征
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        """
        初始化 Whisper 模型
        """
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # 默认情况下,使用解码器层的后一半进行时间对齐;
        # 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        """
        设置对齐的注意力头
        """
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        """
        编码音频特征
        """
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        """
        获取预测的logits
        """
        return self.decoder(tokens, audio_features)

    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        前向传播
        """
        return self.decoder(tokens, self.encoder(mel))

    @property
    def device(self):
        """
        获取模型所在的设备
        """
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        """
        判断模型是否支持多语言
        """
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        """
        获取模型支持的语言数量
        """
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        为键和值的投影模块安装缓存钩子

        返回
        -------
        cache : Dict[nn.Module, torch.Tensor]
            映射键/值投影模块到其缓存的字典对象
        hooks : List[RemovableHandle]
            用于停止调用钩子的 PyTorch RemovableHandle 对象列表
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # 第一次标记或交叉注意时保存原始值
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function  # 语言检测函数
    transcribe = transcribe_function  # 转录函数
    decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

import torch

# 初始化模型参数
dims = ModelDimensions(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=512,
    n_audio_head=8,
    n_audio_layer=6,
    n_vocab=51865,
    n_text_ctx=448,
    n_text_state=512,
    n_text_head=8,
    n_text_layer=6,
)

# 创建模型实例
model = Whisper(dims)

# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)

# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)

# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)

# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

# 最终生成的文本标记序列
final_tokens = initial_tokens

# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

for _ in range(10):
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1829669.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ffmpeg解封装rtsp并录制视频-(2)使用VLC模拟一个rtsp服务器并用ffmpeg解封装该rtsp流

VCL模拟服务器并打开播放该视频文件&#xff1a; - 准备好一个mp4文件&#xff0c;打开vlc软件 - 选择“媒体”》“流” - 添加一个mp4文件 - 点击下方按钮选择“串流” - 下一步目标选择rtsp 点击“添加” - 端口默认8554 - 路径设置 /test - 用…

XML Encoding = ‘GBK‘ after STRANS,中文乱码

最近帮同事处理了一个中信银行银企直连接口的一个问题,同事反馈,使用STRANS转换XML后,encoding始终是’utf-16’,就算指定了GBK也不行。尝试了很多办法始终不行,发到银行的数据中,中文始终是乱码。 Debug使用HTML视图看报文时也可以看到中文是乱码。 解决方案: 使用cl…

飞腾银河麒麟V10安装Todesk

下载安装包 下载地址 https://www.todesk.com/linux.html 安装 yum makecache yum install libappindicator-gtk3-devel.aarch64 rpm -ivh 下载的安装包文件后台启动 service todeskd start修改配置 编辑 /opt/todesk/config/config.ini 移除自动更新临时密码 passupda…

STM32CubeMX配置-RTC周期唤醒

一、简介 MCU为STM32G070&#xff0c;采用内部时钟32KHZ&#xff0c;配置为周期6s唤醒&#xff0c;调用回调函数&#xff0c;进行喂狗操作。 二、配置 初始时间、日期、周期唤醒时间配置。 开启周期唤醒中断 三、生成代码 调用回调函数&#xff0c;进行喂狗操作。 //RTC唤醒回…

STM32理论 —— μCOS-Ⅲ(2/2):时间管理、消息队列、信号量、任务内嵌信号量/队列、事件标志、软件定时器、内存管理

文章目录 9. 时间管理9.1 OSTimeDly()9.2 OSTimeDlyHMSM()9.3 OSTimeDlyResume()9.4 延时函数实验 10. 消息队列10.1 创建消息队列函数OSQCreate()10.2 发送消息到消息队列函数(写入队列)OSQPost()10.3 获取消息队列中的消息函数(读出队列)OSQPend()10.4 消息队列操作实验 11. …

绿色版DirectoryOpus功能强大且高度可定制的Windows文件管理器

Directory Opus&#xff08;通常简称为DOpus&#xff09;是一款功能强大且高度可定制的Windows文件管理器。它提供了许多超越Windows默认文件资源管理器&#xff08;Explorer&#xff09;的功能&#xff0c;使得文件和文件夹的管理变得更加高效和直观。以下是对Directory Opus的…

Hex-Rays IDA Pro V7安装教程 (交互式反汇编工具)

前言 DA Pro就像提供了一张二进制的地图&#xff0c;标注了系统函数以及分析人员注解的函数调用&#xff0c;同时展现出各级函数和代码块之间的调用关系。此外&#xff0c;IDA Pro的扩展性能很好&#xff0c;可以利用IDA Pro提供的API接口和IDC脚本来扩展应用&#xff0c;而且…

R语言 | 绘制带P值的差异柱状图

原文链接&#xff1a;R语言 | 绘制带P值的差异柱状图 本期教程 小杜的生信笔记&#xff0c;自2021年11月开始做的知识分享&#xff0c;主要内容是R语言绘图教程、转录组上游分析、转录组下游分析等内容。凡是在社群同学&#xff0c;可免费获得自2021年11月份至今全部教程&…

ElementPlus非表单组件ElUpload值更新后校验不消失问题

项目场景&#xff1a; el-form表单中有一个上传组件&#xff0c;有必填校验。 问题描述 先触发表单的必填校验(点击提交按钮)&#xff0c;然后再上传文件&#xff0c;必填校验的提示一直存在&#xff0c;如果再次点击提交&#xff0c;手动触发表单校验&#xff0c;必填校验消…

2. 机器学习概述

机器学习是对能通过经验自动改进的计算机算法的研究。 ---汤姆. 米切尔 1997 通俗来讲&#xff0c;机器学习就是让计算机从数据中进行自动学习&#xff0c;得到某种知识&#xff08;或规律&#xff09;。在早期的工程领域&#xff0c;机器学习也经常被称为模式识别&#xff08;…

Termius安装docker

安装Termius 直接上官网 新建主机 更新一下yum 更新完成 安装docker的包 设置阿里的yum 直接用命令安装 设置一下开机启动&#xff0c;可以查看docker的版本 拉去portainer的镜像发现报错 出现Get “https://registry-1.docker.io/v2/...“:net/http: TLS hand shake time o…

如何从索尼存储卡恢复数据?

Sony 存储卡广泛用于在数码相机、数码摄像机等中存储照片和视频。如果您从 Sony 存储卡中删除重要数据而未备份&#xff0c;您仍然可以找回丢失的数据。实际上&#xff0c;已删除的视频/照片或文档不会永远丢失&#xff0c;它们仍存储在 Sony 存储卡上&#xff0c;可以通过数据…

功能强大的API函数FindFirstFile使用介绍(附源码)

在处理文件的相关代码中,会频繁使用到Windows系统API函数FindFirstFile,这个函数功能很强大,很多功能都不开它。本文就根据我们在项目中使用该函数的情况,来大概地梳理一下使用FindFirstFile都可以实现哪些常用的功能。 1、FindFirstFile函数声明与WIN32_FIND_DATA结构体 我…

英伟达发布Nemotron-4 340B通用模型:专为生成合成数据设计的突破性AI

引言 2023年6月14日&#xff0c;英伟达发布了Nemotron-4 340B通用模型&#xff0c;专为生成训练大语言模型的合成数据而设计。这一模型可能彻底改变训练大模型时合成数据的生成方式&#xff0c;标志着AI行业的一个重要里程碑。本文将详细介绍Nemotron-4 340B的各个方面&#x…

java高级——Arrays工具类(包含核心的归并和二分排序以及多个底层知识点)

java高级——Arrays工具类 前情提要文章介绍提前了解的知识点1 二分查找思想 Arrays常用方法介绍&#xff08;8大类&#xff09;1. 创建数组1.1 copyOf&#xff08;&#xff09;1.2 copyOfRange&#xff08;&#xff09;1.3 fill&#xff08;&#xff09; 2. 数组转集合&#x…

Win11安装WSA 安卓系统,然后再电脑安装APK文件

参考文章&#xff1a; https://blog.csdn.net/m0_56076343/article/details/122334759 https://blog.csdn.net/u012514495/article/details/120885242 在微软的网站下载 打开&#xff1a;https://store.rg-adguard.net/ &#xff0c;如下图&#xff1a; 在 1 的那个地方&am…

Arthas线上环境问题排查定位工具

一、Arthas简介 Arthas是alibaba推出的一款JVM性能诊断调优的工具&#xff0c;也可以称之为是线上监控诊断产品&#xff0c;通过全局的视角可以实时的查看应用load、内存、GC、线程的状态信息&#xff0c;并且还可以在不修改应用代码的前提下&#xff0c;对业务问题进行诊断&a…

yolo实现大人 小孩 老年人的识别

通过构建人脸检测数据集&#xff0c;实现检测人脸模型的训练 通过构建小孩人脸 大人人脸 老年人人脸的分类数据集&#xff0c;训练分类模型 通过级联人脸检测模型与分类模型&#xff0c;实现图片 视频 摄像头中的人脸检测➕年龄属性判断 python开发语言 pytorch框架 yolo算…

电子设计教程基础篇(电容)

文章目录 前言一、电容原理1.原理2.公式 二、电容种类1.结构1、固定电容2、可变电容3、微调电容 2.介质材料1、气体介质电容1、空气电容2、真空电容3、充气式电容 2、固体介质电容1、无机1、云母电容2、陶瓷电容1、瓷片电容2、独石电容 3、玻璃釉电容&#xff08;CI&#xff09…

go的netpoll学习

go的运行时调度框架简介 Go的运行时&#xff08;runtime&#xff09;中&#xff0c;由调度器管理&#xff1a;goroutine&#xff08;G&#xff09;、操作系统线程&#xff08;M&#xff09;和逻辑处理器&#xff08;P&#xff09;之间的关系 以实现高效的并发执行 当一个gorout…