DeepSeek开源周首日:发布大模型加速核心技术可变长度高效FlashMLA 加持H800算力解码性能狂飙升至3000GB/s

news2025/2/26 3:10:23

FlashMLA的核心技术特性包括对BF16精度的全面支持,以及采用块大小为64的页式键值缓存(Paged KV Cache)系统,实现更精确的内存管理。在性能表现方面,基于CUDA12.6平台,FlashMLA在H800SXM5GPU上创下了显著成绩:在内存受限场景下达到3000GB/s的处理速度,在计算受限场景下则实现580TFLOPS的算力水平。

1. 核心功能与特性

  • 性能提升
    FlashMLA在H800 SXM5 GPU(CUDA 12.6)上表现亮眼:

    • 内存受限场景下带宽达3000 GB/s
    • 计算受限场景下算力峰值达580 TFLOPS(BF16精度)
  • 关键技术优化

    • 变长序列处理:针对自然语言处理中的动态序列长度优化,提升长文本推理效率。
    • 分页KV缓存:块大小为64的分页机制,减少显存碎片化,提升内存利用率。
    • BF16支持:通过低精度计算降低内存占用,同时保持模型性能。
  • MLA架构创新
    相比传统注意力机制,MLA通过低秩压缩技术将每次查询的KV缓存量减少93.3%,显著降低推理时的显存需求,尤其适合长上下文场景。


2. 技术背景与意义

  • 解决行业痛点
    Transformer模型在长序列推理时面临KV缓存膨胀问题,导致显存占用高、硬件成本攀升。FlashMLA通过MLA架构和并行解码设计,将推理成本降低约80-90%,同时支持更高吞吐量

  • 开源生态价值
    FlashMLA开源代码库(GitHub链接)整合了FlashAttention-2/3和CUTLASS的技术实现,为开发者提供可复现的优化方案,加速AGI技术迭代。


3. 应用场景与部署

  • 适用场景

    • 大语言模型(LLM)推理加速,如对话AI、实时翻译、长文本生成等。
    • 需要低延迟、高吞吐的工业级NLP任务。
  • 部署要求

    • 硬件:Hopper架构GPU(如H800/H100)
    • 软件:CUDA 12.3+、PyTorch 2.0+

4. 对行业的影响

  • 成本革命
    DeepSeek通过MLA技术将模型训练和推理成本压缩至行业标杆水平。例如,其V3模型的训练成本仅600万美元(未含研发投入),而MLA的推理优化进一步降低商业化门槛。

  • 算力效率提升
    结合MoE(混合专家模型)架构和多Token预测技术,DeepSeek在单位算力下实现更高性能,推动行业从“堆算力”向“优化算法”转型。

  • 开源竞争格局
    此次开源被视为对Meta Llama、Mistral等项目的直接挑战,可能加速闭源与开源模型的性能差距缩小。


FlashMLA的发布标志着DeepSeek在高效计算领域的技术领先地位,其开源策略或将重塑大模型开发范式,推动更多低成本、高性能AI应用的涌现。

5.快速开始

安装

可以使用以下命令进行安装:

python setup.py install
基准测试

运行以下命令进行基准测试:

python tests/test_flash_mla.py
使用示例

在Python中可以这样使用:

from flash_mla import get_mla_metadata, flash_mla_with_kvcache

tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
        tile_scheduler_metadata, num_splits, causal=True,
    )
    ...

6.核心代码的详细解释

以下是对 FlashMLA/flash_mla/flash_mla_interface.py 文件中:

get_mla_metadata 函数

def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
  • 功能:该函数用于获取MLA(Multi-Head Attention)的元数据。
  • 参数
    • cache_seqlens:一个形状为 (batch_size)torch.Tensor,数据类型为 torch.int32,表示缓存的序列长度。
    • num_heads_per_head_k:整数类型,其值等于 seq_len_q * num_heads_q // num_heads_k
    • num_heads_k:整数类型,表示 num_heads_k 的值。
  • 返回值
    • tile_scheduler_metadata:形状为 (num_sm_parts, TileSchedulerMetaDataSize)torch.Tensor,数据类型为 torch.int32
    • num_splits:形状为 (batch_size + 1)torch.Tensor,数据类型为 torch.int32
  • 实现细节:该函数直接调用 flash_mla_cuda 模块中的 get_mla_metadata 函数,并将输入参数传递给它,然后返回该函数的结果。

flash_mla_with_kvcache 函数

def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse
  • 功能:该函数用于执行带有键值缓存(KVCache)的MLA操作。
  • 参数
    • q:形状为 (batch_size, seq_len_q, num_heads_q, head_dim)torch.Tensor,表示查询张量。
    • k_cache:形状为 (num_blocks, page_block_size, num_heads_k, head_dim)torch.Tensor,表示键缓存张量。
    • block_table:形状为 (batch_size, max_num_blocks_per_seq)torch.Tensor,数据类型为 torch.int32,表示块表。
    • cache_seqlens:形状为 (batch_size)torch.Tensor,数据类型为 torch.int32,表示缓存的序列长度。
    • head_dim_v:整数类型,表示 v 的头维度。
    • tile_scheduler_metadata:形状为 (num_sm_parts, TileSchedulerMetaDataSize)torch.Tensor,数据类型为 torch.int32,由 get_mla_metadata 函数返回。
    • num_splits:形状为 (batch_size + 1)torch.Tensor,数据类型为 torch.int32,由 get_mla_metadata 函数返回。
    • softmax_scale:可选的浮点数,表示在应用softmax之前对 QK^T 进行缩放的比例,默认为 1 / sqrt(head_dim)
    • causal:布尔类型,表示是否应用因果注意力掩码,默认为 False
  • 返回值
    • out:形状为 (batch_size, seq_len_q, num_heads_q, head_dim_v)torch.Tensor,表示输出张量。
    • softmax_lse:形状为 (batch_size, num_heads_q, seq_len_q)torch.Tensor,数据类型为 torch.float32,表示softmax的对数和指数(LogSumExp)。
  • 实现细节
    • 如果 softmax_scale 未提供,则将其设置为 q 张量最后一个维度的平方根的倒数。
    • 调用 flash_mla_cuda 模块中的 fwd_kvcache_mla 函数,传递相应的参数,并将返回的结果赋值给 outsoftmax_lse
    • 最后返回 outsoftmax_lse

这些函数主要是作为Python接口,调用底层的CUDA实现(flash_mla_cuda 模块)来完成MLA操作和元数据的获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2306137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

01 冲突域和广播域的划分

目录 1、冲突域和广播域的划分 1.1、冲突域 1.2、广播域 1.3、对比总结 1.4、冲突域与广播域个数计算例题 2、交换机和路由器的结构 2.1、交换机的结构 2.2、路由器的结构 1、冲突域和广播域的划分 1.1、冲突域 冲突域是指网络中可能发生数据帧冲突的物理范围。当多…

nodejs npm install、npm run dev运行的坎坷之路

1、前面的种种都不说了,好不容易运行起来oap-portal项目,运行idm-ui项目死活运行不起来,各种报错,各种安装,各种卸载nodejs,卸载nvm,重装,都不好使。 2、甚至后来运行npm install会…

大型装备故障诊断解决方案

大型装备故障诊断解决方案 方案背景 在全球航空工业迅猛发展的背景下,我国在军用和民用飞机自主研发制造领域取得了显著成就。尤其是在国家大力支持下,国内飞机制造企业攻克了诸多关键技术难题,实现了从设计研发到生产制造再到售后保障的完整…

反向代理模块kfj

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…

Python Seaborn库使用指南:从入门到精通

1. 引言 Seaborn 是基于 Matplotlib 的高级数据可视化库,专为统计图表设计。它提供了更简洁的 API 和更美观的默认样式,能够轻松生成复杂的统计图表。Seaborn 在数据分析、机器学习和科学计算领域中被广泛使用。 本文将详细介绍 Seaborn 的基本概念、常用功能以及高级用法,…

Android之APP更新(通过接口更新)

文章目录 前言一、效果图二、实现步骤1.AndroidManifest权限申请2.activity实现3.有版本更新弹框UpdateappUtilDialog4.下载弹框DownloadAppUtils5.弹框背景图 总结 前言 对于做Android的朋友来说,APP更新功能再常见不过了,因为平台更新审核时间较长&am…

JVM生产环境问题定位与解决实战(二):JConsole、VisualVM到MAT的高级应用

生产问题定位指南:几款必备的可视化工具 引言 在上一篇文章中,详细的介绍了JDK自带的一系列命令行工具,,如jps、jmap、jstat、jstack以及jcmd等,这些工具为排查和诊断Java虚拟机(JVM)问题提供…

力扣3102.最小化曼哈顿距离

力扣3102.最小化曼哈顿距离 题目 题目解析及思路 题目要求返回移除一个点后的最小的最大曼哈顿距离 最大最小值的题一般直接想到二分 本题有一个简单办法就是利用切比雪夫距离 当正方形转45,即边上点**( x , y ) -> (x y , y - x)时,两点间max(…

国标28181协议在智联视频超融合平台中的接入方法

一. 国标28181介绍 国标 28181 协议全称是《安全防范视频监控联网系统信息传输、交换、控制技术要求》,是国内视频行业最重要的国家标准,目前有三个版本: 2011 年:推出 GB/T 28181-2011 版本,为安防行业的前端设备、平…

【学习笔记】LLM+RL

文章目录 1 合成数据与模型坍缩(model collapse),1.1 递归生成数据与模型坍缩1.2 三种错误1.3 理论直觉1.4 PPL指标 2 基于开源 LLM 实现 O1-like step by step 慢思考(slow thinking),ollama,streamlit2.1…

【论文精读】YOLO-World:实时开放词汇目标检测

论文地址: YOLO-World: Real-Time Open-Vocabulary Object Detection 源代码:YOLO-World 摘要 YOLO系列检测器因其高效性和实用性而被广泛认可。然而,它们依赖于预定义和训练过的物体类别,这限制了其在开放场景中的适用性。为了…

【AI时代】可视化训练模型工具LLaMA-Factory安装与使用

文章目录 安装训练使用 安装 官方地址:https://github.com/hiyouga/LLaMA-Factory 创建虚拟环境 conda create -n llama-factory conda activate llama-factory安装 git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory pip in…

将产品照片(form.productPhotos)转为 JSON 字符串发送给后端

文章目录 1. 前端 form.productPhotos 的当前处理a. 组件绑定b. 当前发送逻辑 2. 如何将 form.productPhotos 转为 JSON 字符串发送给后端a. 修改前端 save() 方法b. 确保 esave API 支持接收字符串 基于你提供的 identify-form.vue 代码,我将分析如何将产品照片&a…

【科研绘图系列】R语言绘制小提琴图、散点图和韦恩图(violin scatter plot Venn)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载画图1画图2画图3画图4画图5画图6画图7参考介绍 【科研绘图系列】R语言绘制小提琴图、散点图和韦恩图(violin & scatter plot & Venn) 加载R包 library…

kotlin 知识点一 变量和函数

在Kotlin中定义变量的方式和Java 区别很大,在Java 中如果想要定义一个变 量,需要在变量前面声明这个变量的类型,比如说int a表示a是一个整型变量,String b表 示b是一个字符串变量。而Kotlin中定义一个变量,只允许在变量…

solidity之Foundry安装配置(一)

一门面向合约的高级编程语言,主要用来编写以太坊只能合约。 Solidity受C语言,Python和js影响,但为编译成为以太坊虚拟机字节码在EVM上执行,很多特性和限制都和EVM相关。 Solidity 是静态类型语言,支持继承、库、自定义…

PHP-create_function

[题目信息]: 题目名称题目难度PHP-create_function2 [题目考点]: create_function ( string args , string args , string code )[Flag格式]: SangFor{wWx5dEGHHhDUwmST4bpXwfjSzq43I6cz}[环境部署]: docker-compose.yml文件或者docker …

FFmpeg 是什么?为什么?怎么用?

摘要:本文介绍了 FFmpeg,一个功能强大的开源多媒体处理工具,广泛应用于视频和音频文件的处理。FFmpeg 支持多种多媒体格式,能够实现视频编码/解码、格式转换、裁剪、合并、音频提取、流媒体处理等功能。本文详细阐述了 FFmpeg 的主…

云计算及其他计算

云计算知识思维导图:https://kdocs.cn/l/cpl2Kizx7IyC 云计算的核心判断标准通常基于美国国家标准与技术研究院(NIST)的定义,并结合实际应用场景。以下是判断一个服务是否为云计算的关键标准,以及对应的服务类型&#…

前端Toast提示快速入门

White graces:个人主页 🙉专栏推荐:Java入门知识🙉 🐹今日诗词:十年一觉扬州梦,赢得青楼薄幸名🐹 ⛳️点赞 ☀️收藏⭐️关注💬卑微小博主🙏 ⛳️点赞 ☀️收藏⭐️关注&#x1f4…