大模型推理——MLA实现方案

news2025/2/11 2:01:50

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn


# rms归一化
class RMSNorm(nn.Module):
    """

    """
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


# 旋转位置编码
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size,
                 mode):
        super().__init__()
        self.dim = dim  # 隐藏层维度
        self.n_heads = n_heads  # 总头数
        self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度
        self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度
        self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度
        self.mode = mode
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵
        # 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944

        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵
        # nn.Linear(self.dim, self.kv_lora_rank)
        # nn.Linear(self.dim, self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵

        self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码
        # 没有矩阵融合
        if self.mode == 'naive':
            self.register_buffer('k_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),
                                 persistent=False)
            self.register_buffer('v_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),
                                 persistent=False)
        # 有矩阵融合
        else:
            self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),
                                 persistent=False)
            self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),
                                 persistent=False)

    def forward(self, x, mask=None):

        bs, seq_len, _ = x.shape

        q = self.wq_a(x)  # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]
        q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]
        q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
                                   dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]

        kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],
                               dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]

        k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个key
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)
        if self.mode == 'naive':

            q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]

            kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]
            kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]
            kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
            # k shape:[bs, seq_len, n_heads, qk_head_dim]
            self.k_cache[:bs, :seq_len, :, :] = k
            self.v_cache[:bs, :seq_len, :, :] = v
            # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            scores = torch.matmul(q.transpose(1, 2),
                                  self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(
                                      self.qk_nope_head_dim + self.qk_rope_head_dim))
            scores = scores.transpose(1, 2)

        else:
            k_pe = k_pe.squeeze(2)
            wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
            wkv_b = wkv_b.view(self.n_heads, -1,
                               self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope,
                                  wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]
            # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v
            # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/v
            kv = self.kv_norm(kv)
            self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]
            self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]
            scores_nope = torch.einsum("bshc,btc->bsht", q_nope,
                                       self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bsht
            scores_pe = torch.einsum("bshr,btr->bsht", q_pe,
                                     self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bsht
            scores = (scores_nope + scores_pe) / math.sqrt(
                self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]

        if mask is not None:
            # mask shape:[bs, seq_len, seq_len]
            scores += mask.unsqueeze(2)

        scores = scores.softmax(dim=-1)

        if self.mode == 'naive':
            x = torch.einsum("bsht,bthd->bshd", scores,
                             self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshd
        else:

            # scores * v = scores * c * wkv_b[:, -self.v_head_dim:]
            x = torch.einsum("bsht,btc->bshc", scores,
                             self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshd

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x) 

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size,
              mode=mode)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2296070.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据项目2:基于hadoop的电影推荐和分析系统设计和实现

前言 大数据项目源码资料说明: 大数据项目资料来自我多年工作中的开发积累与沉淀。 我分享的每个项目都有完整代码、数据、文档、效果图、部署文档及讲解视频。 可用于毕设、课设、学习、工作或者二次开发等,极大提升效率! 1、项目目标 本…

Windows逆向工程入门之汇编环境搭建

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 Visual Studio逆向工程配置 基础环境搭建 Visual Studio 官方下载地址安装配置选项(后期可随时通过VS调整) 使用C的桌面开发 拓展可选选项 MASM汇编框架 配置MASM汇编项目 创建新项目 选择空…

gc buffer busy acquire导致的重大数据库性能故障

📢📢📢📣📣📣 作者:IT邦德 中国DBA联盟(ACDU)成员,10余年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主,全网粉丝10万 擅长主流Oracle、MySQL、PG、高斯…

Formily 如何进行表单验证

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

安宝特方案 | AR眼镜:远程医疗的“时空折叠者”,如何为生命争夺每一分钟?

行业痛点:当“千里求医”遇上“资源鸿沟” 20世纪50年代,远程会诊的诞生曾让医疗界为之一振——患者不必跨越山河,专家无需舟车劳顿,一根电话线、一张传真纸便能架起问诊的桥梁。然而,传统远程医疗的局限也日益凸显&a…

使用git commit时‘“node“‘ 不是内部或外部命令,也不是可运行的程序

第一种: 使用git commit -m "xxx"时会报错,我看网上的方法是在命令行后面添加--no-verify:git commit -m "主题更新" --no-verify,但是不可能每次都添加。 最后解决办法是:使用git config --lis…

nodejs - vue 视频切片上传,本地正常,线上环境导致磁盘爆满bug

nodejs 视频切片上传,本地正常,线上环境导致磁盘爆满bug 原因: 然后在每隔一分钟执行du -sh ls ,发现文件变得越来越大,即文件下的mp4文件越来越大 最后导致磁盘直接爆满 排查原因 1、尝试将m3u8文件夹下的所有视…

【MySQL — 数据库基础】深入解析MySQL的聚合查询

1. 聚合查询 1.1 聚合函数 函数说明COUNT ( [DISTINCT] expr)返回查询到的数据的数量( 行数 )SUM ( [DISTINCT] expr)返回查询到的数据的总和,不是数字没有意义AVG ( [DISTINCT] expr)返回查询到的数据的平均值,不是数字没有意义MAX( [DISTINCT] expr)…

windows平台本地部署DeepSeek大模型+Open WebUI网页界面(可以离线使用)

环境准备: 确定部署方案请参考:DeepSeek-R1系列(1.5b/7b/8b/32b/70b/761b)大模型部署需要什么硬件条件-CSDN博客 根据本人电脑配置:windows11 + i9-13900HX+RTX4060+DDR5 5600 32G内存 确定部署方案:DeepSeek-R1:7b + Ollama + Open WebUI 1. 安装 Ollama Ollama 是一…

港中文腾讯提出可穿戴3D资产生成方法BAG,可自动生成服装和配饰等3D资产如,并适应特定的人体模型。

今天给大家介绍一种名为BAG(Body-Aligned 3D Wearable Asset Generation)的新方法,可以自动生成可穿戴的3D资产,如服装和配饰,以适应特定的人体模型。BAG方法通过构建一个多视图图像扩散模型,生成与人体对齐…

数据库 绪论

目录 数据库基本概念 一.基本概念 1.信息 2.数据 3.数据库(DB) 4.数据库管理系统(DBMS) 5.数据库系统(DBS) 二.数据管理技术的发展 1.人工管理阶段 2.文件系统阶段 3.数据库系统阶段 4.数据库管…

跨越边界,大模型如何助推科技与社会的完美结合?

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 概述 2024年,大模型技术已成为人工智能领域的焦点。这不仅仅是一项技术进步,更是一次可能深刻影响社会发展方方面面的变革。大模型的交叉能否推动技术与社会的真正融合?2025年…

kafka生产端之架构及工作原理

文章目录 整体架构元数据更新 整体架构 消息在真正发往Kafka之前,有可能需要经历拦截器(Interceptor)、序列化器(Serializer)和分区器(Partitioner)等一系列的作用,那么在此之后又会…

在 Windows 上使用 ZIP 包安装 MySQL 的详细步骤

以下是使用官方 ZIP 包在 Windows 上安装 MySQL 的详细步骤,确保能通过 mysql -uroot -p 成功连接。 步骤 1:下载 MySQL ZIP 包 访问 MySQL 官方下载页面: https://dev.mysql.com/downloads/mysql/选择 Windows (x86, 64-bit), ZIP Archive&…

记录 | WPF创建和基本的页面布局

目录 前言一、创建新项目注意注意点1注意点2 解决方案名称和项目名称 二、布局2.1 Grid2.1.1 RowDefinitions 行分割2.1.2 Row & Column 行列定位区分 2.1.3 ColumnDefinitions 列分割 2.2 StackPanel2.2.1 Orientation 修改方向 三、模板水平布局【Grid中套StackPanel】中…

SpringCloud - Nacos注册/配置中心

前言 该博客为Nacos学习笔记,主要目的是为了帮助后期快速复习使用 学习视频:7小快速通关SpringCloud 辅助文档:SpringCloud快速通关 源码地址:cloud-demo 一、简介 Nacos官网:https://nacos.io/docs/next/quickstar…

C++ 继承(1)

1.继承概念 我们平时有时候在写多个有内容重复的类的时候会很麻烦 比如我要写Student Teacher Staff 这三个类 里面都要包含 sex name age成员变量 唯一不同的可能有一个成员变量 但是这三个成员变量我要写三遍 太麻烦了 有没有好的方式呢? 有的 就是继承…

【C语言】传值调用与传址调用详解

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C语言 文章目录 💯前言💯传值调用1. 什么是传值调用?2. 示例代码:传值调用失败的情况执行结果: 3. 为什么传值调用无法修改外部变量? &#x1f4…

蓝桥杯C语言组:图论问题

蓝桥杯C语言组图论问题研究 摘要 图论是计算机科学中的一个重要分支,在蓝桥杯C语言组竞赛中,图论问题频繁出现,对参赛选手的算法设计和编程能力提出了较高要求。本文系统地介绍了图论的基本概念、常见算法及其在蓝桥杯C语言组中的应用&#…

windows通过网络向Ubuntu发送文件/目录

由于最近要使用树莓派进行一些代码练习,但是好多东西都在windows里或虚拟机上,就想将文件传输到树莓派上,但试了发现u盘不能简单传送,就在网络上找到了通过windows 的scp命令传送 前提是树莓派先开启ssh服务,且Window…