flash attention 参数(笔记)

news2025/1/23 17:50:37

目录

一、flash attention官方

  1.1、flash attention安装

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

  2.2、flash_attn_varlen_kvpacked_func

  2.3、flash_attn_varlen_func

  ​​​​​​​2.4、flash_attn_with_kvcache

  2.5、flash_attn_func


一、flash attention官方

版本: flash-attn  2.5.7

flash-attention/flash_attn/flash_attn_interface.py at main · Dao-AILab/flash-attention · GitHubFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py

  1.1、flash attention安装(下文1.2节)

Trl SFT: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)_unsloth 多卡-CSDN博客文章浏览阅读812次,点赞18次,收藏25次。第三 参考官方命令: https://github.com/Dao-AILab/flash-attention。第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致。第二 安装好 c++ g++ ninja。_unsloth 多卡https://blog.csdn.net/qq_16555103/article/details/137677561

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_qkvpacked_func 用于计算可变长序列的注意力输出,其中 query、key 和 value 已被打包成一个张量。

该函数的主要作用是:

  1. 高效计算注意力输出:通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。

  2. 支持变长序列:函数通过 cu_seqlens 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  5. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式和变长序列输入。

def flash_attn_varlen_qkvpacked_func(
    qkv,  # query、key 和 value 张量,形状为 (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起。例如,如果批次大小为 2,序列长度分别为 3 和 5,头数为 4,head 维度为 64,则 qkv 的形状为 (8, 3, 4, 64)
    cu_seqlens,  # 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,用于索引 qkv。例如,如果批次大小为 2,序列长度分别为 3 和 5,则 cu_seqlens 为 [0, 3, 8]
    max_seqlen,  # 批次中序列的最大长度,整数值。例如,如果批次中的序列长度分别为 3 和 5,则 max_seqlen 为 5
    dropout_p=0.0,  # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。通常在训练时使用一个小于 1 的值,如 0.1,而在评估时设置为 0.0
    softmax_scale=None,  # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)。如果 headdim 为 64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
    causal=False,  # 是否应用因果注意力掩码(用于自回归建模)。如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在语言模型等自回归任务中很有用
    window_size=(-1, -1),  # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 (2, 2),则查询位置 i 只能关注位置 [i-2, i+2] 范围内的键
    alibi_slopes=None,  # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads 为 4,则 alibi_slopes 可以是形状为 (4,) 的张量
    deterministic=False,  # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。通常在评估时不需要使用确定性实现,因为评估时不需要计算梯度
    return_attn_probs=False,  # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
    """
    dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。

    如果 query、key 和 value 已经打包成一个张量,调用这个函数会比调用 flash_attn_varlen_func 更快,因为反向传播避免了显式拼接 query、key 和 value 的梯度,从而减少了内存复制和计算量。

    对于多查询注意力(MQA)和分组查询注意力(GQA),请参见 flash_attn_varlen_kvpacked_func 和 flash_attn_varlen_func。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在 [i - window_size[0], i + window_size[1]] 范围内(包括边界)的键。

    参数说明:
    qkv: (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起
    cu_seqlens: (batch_size + 1,),数据类型为 torch.int32,表示批次中每个序列的累积长度,用于索引 qkv
    max_seqlen: 整数,批次中序列的最大长度
    dropout_p: 浮点数,dropout 概率
    softmax_scale: 浮点数,在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)
    causal: 布尔值,是否应用因果注意力掩码(用于自回归建模)
    window_size: (left, right),整数元组,如果不是 (-1, -1),则实现滑动窗口局部注意力
    alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32 张量,一种用于注意力分数偏置的方法
    deterministic: 布尔值,是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的
    return_attn_probs: 布尔值,是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确

    返回值:
    out: (total, nheads, headdim),注意力层的输出
    softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数)
    S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留)
    """

    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv,  # query、key 和 value 张量
        cu_seqlens,  # 序列的累积长度
        max_seqlen,  # 批次中序列的最大长度    
        dropout_p,  # dropout 概率
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果注意力掩码
        window_size,  # 滑动窗口大小,用于实现局部注意力
        alibi_slopes,  # 注意力分数偏置方法
        deterministic,  # 是否使用确定性反向传播实现
        return_attn_probs,  # 是否返回注意力概率(仅用于测试)
    )
  2.2、flash_attn_varlen_kvpacked_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_kvpacked_func 用于计算变长序列的注意力输出,其中 key 和 value 已被打包成一个张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。

该函数的主要作用是:

  1. 高效计算注意力输出:通过将 key 和 value 打包成一个张量作为输入,避免了显式连接 K、V 的梯度,从而提高了计算效率。

  2. 支持变长序列:函数通过 cu_seqlens_q 和 cu_seqlens_k 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。

  5. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  6. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入和 MQA/GQA 等特性,在自然语言处理等领域具有广泛的应用。

def flash_attn_varlen_kvpacked_func(
    q,  # query 张量,形状为 (total_q, nheads, headdim),例如 (1024, 16, 64),其中 total_q=1024 是批次中 query token 的总数,nheads=16 是注意力头数,headdim=64 是每个注意力头的维度
    kv,  # key-value 张量,形状为 (total_k, 2, nheads_k, headdim),例如 (2048, 2, 8, 64),其中 total_k=2048 是批次中 key token 的总数,2 表示 key 和 value 被打包在一起,nheads_k=8 是 key 和 value 的注意力头数,headdim=64 是每个注意力头的维度。注意,nheads_k 可以小于 nheads,这支持了多查询注意力(MQA)和分组查询注意力(GQA)的用法
    cu_seqlens_q,  # query 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 10, 20, 32, 42],表示批次中有 4 个序列,第一个序列长度为 10,第二个序列长度为 10,第三个序列长度为 12,第四个序列长度为 10。这些累积长度用于索引 q 张量
    cu_seqlens_k,  # key 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 15, 30, 45, 60],用于索引 kv 张量
    max_seqlen_q,  # 批次中 query 序列的最大长度,例如 12,表示批次中最长的 query 序列长度为 12
    max_seqlen_k,  # 批次中 key 序列的最大长度,例如 15,表示批次中最长的 key 序列长度为 15
    dropout_p=0.0,  # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。在训练(training)时,可以设置一个小于 1 的值,例如 0.1,表示将有 10% 的神经元被随机丢弃
    softmax_scale=None,  # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim),例如如果 headdim=64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126,这是一种常见的缩放方式,可以使注意力分数的方差保持在合理范围内,防止出现较大的梯度
    causal=False,  # 是否应用因果注意力掩码(用于自回归建模),如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在诸如语言模型等自回归任务中非常有用
    window_size=(-1, -1),  # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 i, 关注窗口为 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键
    alibi_slopes=None,  # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads=16,则 alibi_slopes 可以是形状为 (16,) 的张量,每个元素对应一个注意力头的偏置斜率。如果设置了该参数,则会对查询 i 和键 j 之间的注意力分数加上一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|),这种偏置可以鼓励模型关注更靠近查询的键,或者更远离查询的键,具体取决于 alibi_slopes 的值
    deterministic=False,  # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。在评估(evaluation)时,通常不需要使用确定性实现,因为评估时不需要计算梯度
    return_attn_probs=False,  # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
    """
    dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。

    如果 K 和 V 已经打包成一个张量,调用这个函数会比调用 flash_attn_func 更快,因为反向传播避免了显式拼接 K 和 V 的梯度,从而减少了内存复制和计算量。

    支持多查询注意力(MQA)和分组查询注意力(GQA),只需将 KV 的头数设置为少于 Q 的头数即可。注意,Q 的头数必须能被 KV 的头数整除。
    例如,如果 Q 有 6 个头,而 K 和 V 有 2 个头,那么 Q 的头 0、1、2 将注意到 K 和 V 的头 0,而 Q 的头 3、4、5 将注意到 K 和 V 的头 1。
    这种机制使得模型可以在不同的头组之间共享计算资源,从而提高计算效率。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。这种掩码模式确保了在自回归(auto-regressive)建模中,查询只能关注之前的输出,无法关注未来的输出。
    例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0 
        1 0
        1 1
    如果掩码的一行全为零,则该查询的输出也将为零,因为它无法关注任何有效的键。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键。
    例如,如果 seqlen_q = 3, seqlen_k = 5, window_size = (1, 2),则查询位置 0 只能关注键位置 [0, 3],查询位置 1 只能关注键位置 [1, 4],查询位置 2 只能关注键位置 [2, 4]。
    这种局部注意力机制可以显著提高计算效率,尤其是在处理长序列时,因为它减少了需要计算的注意力分数的数量。但同时也会牺牲一些表达能力,因为查询无法关注整个序列。
    """

    # 返回值:
    # out: (total, nheads, headdim),注意力层的输出,例如 (1024, 16, 64)
    # softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数),例如 (4, 16, 12)
    # S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留),例如 (4, 16, 12, 15)

    return FlashAttnVarlenKVPackedFunc.apply(
        q,  # query 张量
        kv,  # 打包的 key-value 张量
        cu_seqlens_q,  # query 序列的累积长度
        cu_seqlens_k,  # key 序列的累积长度 
        max_seqlen_q,  # 批次中 query 序列的最大长度
        max_seqlen_k,  # 批次中 key 序列的最大长度
        dropout_p,  # dropout 概率
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果注意力掩码
        window_size,  # 滑动窗口大小,用于实现局部注意力
        alibi_slopes,  # 注意力分数偏置方法
        deterministic,  # 是否使用确定性反向传播实现
        return_attn_probs,  # 是否返回注意力概率(仅用于测试)
    )

"""
  一个具体的 GQA 示例:
    假设我们有一个批次,其中包含 4 个序列,每个序列的长度分别为 10、10、12 和 10。
    我们希望使用 GQA,其中 query 有 6 个头(nheads=6),而 key 和 value 有 2 个头(nheads_k=2)。
    每个注意力头的维度为 64(headdim=64)。

    在这种情况下,各个参数的值如下:

    q: (total_q, nheads, headdim) = (42, 6, 64)
       total_q = 10 + 10 + 12 + 10 = 42,是批次中所有序列的 token 总数
       nheads = 6,表示 query 有 6 个注意力头
       headdim = 64,表示每个注意力头的维度为 64

    kv: (total_k, 2, nheads_k, headdim) = (42, 2, 2, 64) 
        total_k = 10 + 10 + 12 + 10 = 42,与 total_q 相同,因为每个序列的 query 长度和 key 长度是相同的
        2 表示 key 和 value 被打包在一起
        nheads_k = 2,表示 key 和 value 有 2 个注意力头
        headdim = 64,表示每个注意力头的维度为 64

    cu_seqlens_q: [0, 10, 20, 32, 42],表示批次中每个序列的累积长度,用于索引 q
                  第一个序列的长度为 10,第二个序列的长度为 10,第三个序列的长度为 12,第四个序列的长度为 10

    cu_seqlens_k: [0, 10, 20, 32, 42],与 cu_seqlens_q 相同,因为每个序列的 query 长度和 key 长度是相同的

    max_seqlen_q: 12,批次中最长的 query 序列长度
    max_seqlen_k: 12,与 max_seqlen_q 相同,因为每个序列的 query 长度和 key 长度是相同的

    在这个 GQA 设置下,query 的 6 个头将被分成 2 组,每组 3 个头,分别关注 key 和 value 的 2 个头:
    - query 的头 0 1 2 将关注 key 和 value 的头 0
    - query 的头 3 4 5 将关注 key 和 value 的头 1
 

    这种分组机制允许不同的 query 头关注不同的 key 和 value 子空间,提高了计算效率和表达能力。
    同时,通过减少 key 和 value 的头数,可以显著降低计算和存储开销。
"""
  2.3、flash_attn_varlen_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_func 用于计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。

该函数的主要作用是:

  1. 高效计算注意力输出:通过分开输入 query、key 和 value 张量,可以有效利用计算资源进行注意力计算。

  2. 支持变长序列:函数通过 cu_seqlens_q 和 cu_seqlens_k 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。

  5. 支持分块稀疏注意力:可以通过提供 block_table 参数来启用分块稀疏注意力,进一步提高计算效率。

  6. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  7. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入、MQA/GQA 和分块稀疏注意力等特性,在自然语言处理等领域具有广泛的应用。

def flash_attn_varlen_func(
    q, # 输入的 query 张量,形状为 (total_q, nheads, headdim),其中 total_q 是批量中所有查询token的总数,nheads 是注意力头数,headdim 是每个注意力头的维度
    k, # 输入的 key 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
    v, # 输入的 value 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
    cu_seqlens_q, # 批量中每个查询序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 q 中索引相应的位置
    cu_seqlens_k, # 批量中每个 key 序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 k 和 v 中索引相应的位置
    max_seqlen_q, # 批量中最大查询序列长度
    max_seqlen_k, # 批量中最大 key 序列长度
    dropout_p=0.0, # Dropout 率,在评估(evaluation)时应设置为 0.0
    softmax_scale=None, # softmax 缩放因子,默认为 1 / sqrt(headdim)
    causal=False, # 是否应用因果注意力掩码,用于自回归(auto-regressive)建模
    window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限制上下文窗口
    alibi_slopes=None, # 用于添加注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32
    deterministic=False, # 是否使用确定性反向传播实现,比非确定性实现稍慢但使用更多内存,前向传播始终是确定性的
    return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能不具有正确缩放
    block_table=None # 可选的块表,用于分块稀疏注意力
):
    """
    解决问题: 计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。支持多查询注意力(MQA)和分组查询注意力(GQA)。

    注意事项:
    - dropout_p 应在评估时设置为 0.0。
    - 支持多查询注意力(MQA)和分组查询注意力(GQA),通过将 K、V 的注意力头数设置为少于 Q 的注意力头数来实现。Q 的注意力头数必须能被 K、V 的注意力头数整除。
      例如,如果 Q 有 6 个注意力头,K、V 有 2 个注意力头,那么 Q 的头 0、1、2 将关注 K、V 的头 0,Q 的头 3、4、5 将关注 K、V 的头 1。
    - 如果 causal=True,因果掩码将与注意力矩阵的右下角对齐。
      例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
      如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
      如果掩码的一行全为零,输出也将为零。
    - 如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位于 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内的 key。
    - 可以通过提供 block_table 参数来启用分块稀疏注意力。

    返回值:
    - out: 注意力层的输出张量,形状为 (total, nheads, headdim),其中 total = total_q。
    - softmax_lse [可选,如果 return_attn_probs=True]: 每行的 QK^T * scaling 的 logsumexp 值,形状为 (batch_size, nheads, seqlen),即 softmax 归一化因子的对数。
    - S_dmask [可选,如果 return_attn_probs=True]: softmax 的输出,可能具有不同的缩放,形状为 (batch_size, nheads, seqlen, seqlen),还编码了 Dropout 模式(负值表示该位置被丢弃,非负值表示该位置被保留)。
    """
    return FlashAttnVarlenFunc.apply(
        q, # 输入的 query 张量
        k, # 输入的 key 张量
        v, # 输入的 value 张量
        cu_seqlens_q, # 批量中每个查询序列的累积长度
        cu_seqlens_k, # 批量中每个 key 序列的累积长度
        max_seqlen_q, # 批量中最大查询序列长度
        max_seqlen_k, # 批量中最大 key 序列长度
        dropout_p, # Dropout 率
        softmax_scale, # softmax 缩放因子
        causal, # 是否应用因果注意力掩码
        window_size, # 用于实现滑动窗口局部注意力
        alibi_slopes, # 用于添加注意力分数偏置
        deterministic, # 是否使用确定性反向传播实现
        return_attn_probs, # 是否返回注意力概率
        block_table, # 可选的块表,用于分块稀疏注意力
    )
  2.4、flash_attn_with_kvcache

该函数 flash_attn_with_kvcache 用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。它是一个高效的注意力计算函数,可以在推理时加速序列生成任务。

函数的主要特点和技术如下:

  1. 支持更新 key 和 value 缓存:如果提供了新的 k 和 v 张量,函数会将它们的值原地更新到 k_cache 和 v_cache 中。这对于增量解码非常有用,可以一次性完成缓存更新和注意力计算。

  2. 旋转位置嵌入 (Rotary Position Embedding):如果提供了 rotary_cos 和 rotary_sin,函数会对 qk 应用旋转位置嵌入。旋转位置嵌入是一种编码序列位置信息的方法,可以提高注意力模型在长序列任务中的性能。

  3. 因果注意力掩码 (Causal Attention Mask):如果设置 causal=True,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。

  4. 滑动窗口局部注意力 (Sliding Window Local Attention):如果设置 window_size != (-1, -1),函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。

  5. 多查询和分组查询注意力 (MQA/GQA):函数支持将 q 的头数设置为 kv 头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。

  6. 分块 key/value 缓存:如果提供了 block_table,函数会将 k_cache 和 v_cache 视为分页缓存,支持高效的缓存管理。

  7. 注意力分数偏置 (Alibi Slopes):如果提供了 alibi_slopes,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。

  8. CUDA 内核加速:函数的核心计算由 CUDA 内核 flash_attn_cuda.fwd_kvcache 完成,提供高性能的并行计算能力。

总的来说,flash_attn_with_kvcache 函数集成了多种先进的注意力计算技术,可以高效地

def flash_attn_with_kvcache(
    q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
    k_cache, # key 缓存张量,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim) 或 (num_blocks, page_block_size, nheads_k, headdim)
    v_cache, # value 缓存张量,形状与 k_cache 相同
    k=None, # 可选的新 key 张量,形状为 (batch_size, seqlen_new, nheads_k, headdim)
    v=None, # 可选的新 value 张量,形状与 k 相同
    rotary_cos=None, # 可选的旋转位置嵌入余弦值,形状为 (seqlen_ro, rotary_dim / 2)
    rotary_sin=None, # 可选的旋转位置嵌入正弦值,形状与 rotary_cos 相同
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, # 缓存序列长度,可以是整数或张量
    cache_batch_idx: Optional[torch.Tensor] = None, # 缓存批次索引张量,形状为 (batch_size,)
    block_table: Optional[torch.Tensor] = None, # 可选的块表张量,形状为 (batch_size, max_num_blocks_per_seq)
    softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
    causal=False, # 是否进行因果注意力掩码
    window_size=(-1, -1), # 滑动窗口大小,(-1, -1)表示无限上下文窗口
    rotary_interleaved=True, # 是否交错旋转位置嵌入
    alibi_slopes=None, # 可选的注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
    num_splits=0, # 将 key/value 沿序列维度分割的数量,0 表示自动确定
):
    """
    该函数用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。

    如果 k 和 v 不为 None,则会将它们的新值原地更新到 k_cache 和 v_cache 中。这对于增量解码很有用:
    你可以传入上一步的缓存 key/value,并使用当前步的新 key/value 进行更新,然后使用更新后的缓存进行注意力计算,所有操作都在一个内核中完成。

    如果你传入了 k/v,你必须确保缓存足够大,可以容纳新的值。例如,KV 缓存可以预先分配最大序列长度,并使用 cache_seqlens 跟踪每个序列在批次中的当前长度。

    如果传入了 rotary_cos 和 rotary_sin,则会应用旋转位置嵌入。key @k 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果是因果注意力或局部注意力(即 window_size != (-1, -1)),则查询 @q 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果既不是因果注意力也不是局部注意力,则查询 @q 将只在索引 cache_seqlens 处被 rotary_cos 和 rotary_sin 旋转(即我们认为 @q 中的所有标记都位于 cache_seqlens 位置)。

    该函数支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。注意 Q 中的头数量必须能被 KV 中的头数量整除。
    例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
    如果掩码行全为 0,则输出也将为 0。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。

    注意:不支持反向传播。

    参数:
        q: (batch_size, seqlen, nheads, headdim)
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) 如果没有 block_table,
            或 (num_blocks, page_block_size, nheads_k, headdim) 如果有 block_table (即分页 KV 缓存)
            page_block_size 必须是 256 的倍数。
        v_cache: 形状与 k_cache 相同
        k [可选]: (batch_size, seqlen_new, nheads_k, headdim)。如果不为 None,我们将它与 k_cache 拼接,从 cache_seqlens 指定的索引开始。
        v [可选]: (batch_size, seqlen_new, nheads_k, headdim)。与 k 类似。
        rotary_cos [可选]: (seqlen_ro, rotary_dim / 2)。如果不为 None,我们将对 k 和 q 应用旋转位置嵌入。只在 k 和 v 被传入时适用。rotary_dim 必须能被 16 整除。
        rotary_sin [可选]: (seqlen_ro, rotary_dim / 2)。与 rotary_cos 类似。
        cache_seqlens: int 或 (batch_size,), 数据类型 torch.int32。KV 缓存的序列长度。
        block_table [可选]: (batch_size, max_num_blocks_per_seq), 数据类型 torch.int32。
        cache_batch_idx: (batch_size,), 数据类型 torch.int32。用于索引 KV 缓存的索引。如果为 None,我们假设批次索引为 [0, 1, 2, ..., batch_size - 1]。如果索引不是唯一的,并且提供了 k 和 v,那么缓存中更新的值可能来自任何重复的索引。
        softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
        causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
        window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
        rotary_interleaved: bool。只在传入 rotary_cos 和 rotary_sin 时适用。如果为 True,旋转位置嵌入将组合维度 0 & 1、2 & 3 等。如果为 False,旋转位置嵌入将组合维度 0 & rotary_dim / 2、1 & rotary_dim / 2 + 1(即 GPT-NeoX 风格)。
        alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
        num_splits: int。如果大于 1,则将 key/value 沿序列维度分割成这么多块。如果 num_splits == 1,我们不分割 key/value。如果 num_splits == 0,我们使用启发式方法自动确定分割数量。除非你知道你在做什么,否则不要更改这个参数。

    返回:
        out: (batch_size, seqlen, nheads, headdim)。
    """
    # 确保 k_cache 和 v_cache 的最后一维是连续的
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    
    # 如果 q、k 或 v 的最后一维不是连续的,则对它们进行连续化
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    
    # 如果没有指定 softmax 缩放系数,则使用默认值 1 / sqrt(headdim)
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    
    # 如果 cache_seqlens 是整数,则将其转换为张量
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
        cache_seqlens = maybe_contiguous(cache_seqlens)
    
    # 确保 cache_batch_idx 和 block_table 是连续的
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
    block_table = maybe_contiguous(block_table)
    
    # 调用 CUDA 内核计算注意力输出
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
        cache_batch_idx,
        block_table,
        alibi_slopes,
        None,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        rotary_interleaved,
        num_splits,
    )
    
    return out
  2.5、flash_attn_func

该函数 flash_attn_func 用于计算注意力层的输出,是一个高效的注意力计算函数,支持多种先进的注意力技术。

函数的主要特点和技术如下:

  1. 多查询和分组查询注意力 (MQA/GQA):函数支持将 q 的头数设置为 kv 头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。

  2. 因果注意力掩码 (Causal Attention Mask):如果设置 causal=True,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。

  3. 滑动窗口局部注意力 (Sliding Window Local Attention):如果设置 window_size != (-1, -1),函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。

  4. 注意力分数偏置 (Alibi Slopes):如果提供了 alibi_slopes,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。

  5. 确定性反向传播 (Deterministic Backward Pass):函数支持使用确定性反向传播实现,虽然稍慢但使用更多内存,正向传播始终是确定性的。

  6. 返回注意力概率:如果设置 return_attn_probs=True,函数会返回注意力概率,但这只用于测试,返回的概率可能由于缩放问题而不准确。

  7. CUDA 内核加速:函数的核心计算由 CUDA 内核 FlashAttnFunc.apply 完成,提供高性能的并行计算能力。

总的来说,flash_attn_func 函数集成了多种先进的注意力计算技术,可以高效地计算注意力层的输出,支持各种用途和优化方式。

def flash_attn_func(
    q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
    k, # key 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
    v, # value 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
    dropout_p=0.0, # dropout 概率,评估时应设置为 0.0
    softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
    causal=False, # 是否应用因果注意力掩码,例如用于自回归建模
    window_size=(-1, -1), # 滑动窗口大小,(-1, -1) 表示无限上下文窗口
    alibi_slopes=None, # 注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
    deterministic=False, # 是否使用确定性反向传播实现,稍慢但使用更多内存,正向传播始终是确定性的
    return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能由于缩放问题而不准确
):
    """
    该函数用于计算注意力层的输出。

    支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。
    注意 Q 中的头数量必须能被 KV 中的头数量整除。
    例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。
    例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
    如果掩码行全为 0,则输出也将为 0。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。
    位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。

    参数:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float。dropout 概率。
        softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
        causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
        window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
        alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
        deterministic: bool。是否使用确定性反向传播实现,稍慢但使用更多内存。正向传播始终是确定性的。
        return_attn_probs: bool。是否返回注意力概率。这个选项仅用于测试。返回的概率可能由于缩放问题而不准确。

    返回:
        out: (batch_size, seqlen, nheads, headdim)。
        softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen)。每行 QK^T * 缩放系数的 logsumexp (例如,softmax 归一化因子的对数)。
        S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen)。softmax 的输出(可能有不同的缩放)。它还编码了 dropout 模式(负值表示该位置被丢弃,非负值表示被保留)。
    """
    # 调用 FlashAttnFunc 类的 apply 方法计算注意力输出
    return FlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
    )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1639040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1. 深度学习笔记--神经网络中常见的激活函数

1. 介绍 每个激活函数的输入都是一个数字,然后对其进行某种固定的数学操作。激活函数给神经元引入了非线性因素,如果不用激活函数的话,无论神经网络有多少层,输出都是输入的线性组合。激活函数的意义在于它能够引入非线性特性&am…

魔方阵(C语言)

一、魔方阵规律; 8 1 6 3 5 7 4 9 2 魔方阵中各数的排列规律如下: (1)将1放在第1行中间一列。 (2)从2开始直到nn止,各数依次按此规则存放:每一个数存放的行比前一个数的行数减1,列数加1(例如上…

Go 语言基础(二)【数组、切片、指针、map、struct】

1、数组 特别需要注意的是:在 Go 语言中,数组长度也是数组类型的一部分!所以尽管元素类型相同但是长度不同的两个数组,它们的类型并不相同。 1.1、数组的初始化 1.1.1、通过初始化列表{}来设置值 var arr [3]int // int类型的数…

HSDB使用教程

HSDB:Hostspot Debugger,JVM内置的工具,用于深入分析JVM运行时的内部状态 启动HSDB java -cp D:/tools/jdk-1.8/lib/sa-jdi.jar sun.jvm.hotspot.HSDB 获取进程id jps 连接到指定进程 查找类 通过查询查找对象 输入查询语句 select d from …

Linux 学习 --- 编辑 vi 命令

1、vi 基本概念(了解) 基本上 vi 可以分为三种状态,分别是命令模式 (command mode)、插入模式 (Insert mode) 和底行模式 (last line mode),各模式的功能区分如下: 命令行模式 command mode)  控制屏幕光标的移动&a…

「笔试刷题」:字母收集

一、题目 描述 有一个 𝑛∗𝑚 的矩形方阵,每个格子上面写了一个小写字母。 小红站在矩形的左上角,她每次可以向右或者向下走,走到某个格子上就可以收集这个格子的字母。 小红非常喜欢 "love" 这四个字母。…

kubernetes中使用ELK进行日志收集

目录 一、需要收集哪些日志 1、kubernetes集群的系统组件日志 2、应用日志 二、日志收集方案ELK 1、收集日志:Logstash 2、存储日志:Elasticsearch 3、展示日志:Kibana 三、安装elk 1、下载安装包 2、创建用户并切换到新用户 3、上…

【Web】CTFSHOW 中期测评刷题记录(1)

目录 web486 web487 web488 web489 web490 web491 web492 web493 web494 web495 web496 web497 web498 web499 web500 web501 web502 web503 web505 web506 web507 web508 web509 web510 web486 扫目录 初始界面尝试文件包含index.php&am…

ubuntu与redhat的不同之处

华子目录 什么是ubuntu概述 ubuntu版本简介桌面版服务器版 安装部署部署后的设置设置root密码关闭防火墙启用允许root进行ssh登录更改apt源安装所需软件 网络配置Netplan概述配置详解配置文件DHCP静态IP设置设置 软件安装方法apt安装软件作用常用命令配置apt源 deb软件包安装概…

基于React实现B站评论区

今天继续来学习一下React,使用React实现B站评论区,如下图: 在使用React开发类似B站评论区的功能时,我们需要考虑以下几个关键点来构建一个基本的评论系统: 1. 设计组件结构 首先,设计组件结构是关键。至少…

【数据结构】:链表的带环问题

🎁个人主页:我们的五年 🔍系列专栏:数据结构 🌷追光的人,终会万丈光芒 前言: 链表的带环问题在链表中是一类比较难的问题,它对我们的思维有一个比较高的要求,但是这一类…

ThreeJS:Mesh网格与三维变换

Mesh网格 ThreeJS中,Mesh表示基于以三角形为多边形网格(polygon mesh)的物体的类,同时也作为其它类的基类。 通过Mesh网格,我们可以组合Geometry几何体与Material材质属性,在3D世界中,定义一个物体。例如:之…

Unity ParticleSystem 入门

概述 在项目的制作过程成,一定少不了粒子系统的使用吧,如果你想在项目粒子效果,那这部分的内容一定不要错过喔!我添加了理解和注释更好理解一点! 这次的内容比较多,右侧有目录,可以帮助快速导…

高中数学:三角函数公式汇总及推导

一、定义 常用三角函数值 参考: 三角函数定义 二、基本三角函数及相互关系 sinx cosx tanx cscx secx cotx 函数间相互关系 参考: cosx、sinx、tanx的函数图像与性质 secx、cscx、cotx函数图像及相关关系 三、诱导公式 口诀:奇变…

通信接口——时钟和信号

前言 所有接口只要抓住三个核心点就能分清:时钟同步和异步,时钟的来源,信号的传输方向。 一、时钟同步和异步 接口之间的交互方式存在多种形式,如果按照是否有公共时钟CLK的参与,可以分为同步传输和异步传输。 同步&…

C语言——队列的实现

队列按照先进先出(FIFO,First In First Out)的原则管理数据。这意味着最先进入队列的元素会被最先移出,类似于排队等候服务的情况。队列通常有两个主要操作:入队(enqueue),将元素添加…

DRF返回值源码分析

DRF返回值源码分析 1 返回值 在视图中定义finalize_response方法(也可以用来判断是否异常) 自定义异常 配置文件 # settings.py REST_FRAMEWORK {"EXCEPTION_HANDLER": utils.handlers.exception_handler # 自定义的exceptional_handler路…

Sarcasm detection论文解析 |利用对话语境进行讽刺分析

论文地址: 论文地址:Sarcasm Analysis Using Conversation Context | Computational Linguistics | MIT Press github地址:https://github.com/debanjanghosh/sarcasm_context Alex-Fabbri/deep_learning_nlp_sarcasm: code for deep learnin…

基于springboot实现公司日常考勤系统项目【项目源码+论文说明】

基于springboot实现公司日常考勤系统演示 摘要 目前社会当中主要特征就是对于信息的传播比较快和信息内容的安全问题,原本进行办公的类型都耗费了很多的资源、传播的速度也是相对较慢、准确性不高等许多的不足。这个系统就是运用计算机软件来完成对于企业当中出勤率…

debug的基本使用

1.简介   首先看下IDEA中Debug模式下的界面。 如下是在IDEA中启动Debug模式,进入断点后的界面,我这里是Windows,可能和Mac的图标等会有些不一样。就简单说下图中标注的8个地方: ① 以Debug模式启动服务,左边的一个按…