目录
一、flash attention官方
1.1、flash attention安装
二、flash attention 常见函数
2.1、flash_attn_varlen_qkvpacked_func
2.2、flash_attn_varlen_kvpacked_func
2.3、flash_attn_varlen_func
2.4、flash_attn_with_kvcache
2.5、flash_attn_func
一、flash attention官方
版本: flash-attn 2.5.7
flash-attention/flash_attn/flash_attn_interface.py at main · Dao-AILab/flash-attention · GitHubFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py
1.1、flash attention安装(下文1.2节)
Trl SFT: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)_unsloth 多卡-CSDN博客文章浏览阅读812次,点赞18次,收藏25次。第三 参考官方命令: https://github.com/Dao-AILab/flash-attention。第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致。第二 安装好 c++ g++ ninja。_unsloth 多卡https://blog.csdn.net/qq_16555103/article/details/137677561
二、flash attention 常见函数
2.1、flash_attn_varlen_qkvpacked_func
输入前、输出后需要使用unpad、pad
该函数
flash_attn_varlen_qkvpacked_func
用于计算可变长序列的注意力输出,其中 query、key 和 value 已被打包成一个张量。该函数的主要作用是:
高效计算注意力输出:通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。
支持变长序列:函数通过
cu_seqlens
参数接收每个序列的累积长度,可以有效处理变长序列的情况。支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。
提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。
返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。
该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。
该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式和变长序列输入。
def flash_attn_varlen_qkvpacked_func(
qkv, # query、key 和 value 张量,形状为 (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起。例如,如果批次大小为 2,序列长度分别为 3 和 5,头数为 4,head 维度为 64,则 qkv 的形状为 (8, 3, 4, 64)
cu_seqlens, # 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,用于索引 qkv。例如,如果批次大小为 2,序列长度分别为 3 和 5,则 cu_seqlens 为 [0, 3, 8]
max_seqlen, # 批次中序列的最大长度,整数值。例如,如果批次中的序列长度分别为 3 和 5,则 max_seqlen 为 5
dropout_p=0.0, # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。通常在训练时使用一个小于 1 的值,如 0.1,而在评估时设置为 0.0
softmax_scale=None, # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)。如果 headdim 为 64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
causal=False, # 是否应用因果注意力掩码(用于自回归建模)。如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在语言模型等自回归任务中很有用
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 (2, 2),则查询位置 i 只能关注位置 [i-2, i+2] 范围内的键
alibi_slopes=None, # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads 为 4,则 alibi_slopes 可以是形状为 (4,) 的张量
deterministic=False, # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。通常在评估时不需要使用确定性实现,因为评估时不需要计算梯度
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
"""
dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。
如果 query、key 和 value 已经打包成一个张量,调用这个函数会比调用 flash_attn_varlen_func 更快,因为反向传播避免了显式拼接 query、key 和 value 的梯度,从而减少了内存复制和计算量。
对于多查询注意力(MQA)和分组查询注意力(GQA),请参见 flash_attn_varlen_kvpacked_func 和 flash_attn_varlen_func。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在 [i - window_size[0], i + window_size[1]] 范围内(包括边界)的键。
参数说明:
qkv: (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起
cu_seqlens: (batch_size + 1,),数据类型为 torch.int32,表示批次中每个序列的累积长度,用于索引 qkv
max_seqlen: 整数,批次中序列的最大长度
dropout_p: 浮点数,dropout 概率
softmax_scale: 浮点数,在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)
causal: 布尔值,是否应用因果注意力掩码(用于自回归建模)
window_size: (left, right),整数元组,如果不是 (-1, -1),则实现滑动窗口局部注意力
alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32 张量,一种用于注意力分数偏置的方法
deterministic: 布尔值,是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的
return_attn_probs: 布尔值,是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确
返回值:
out: (total, nheads, headdim),注意力层的输出
softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数)
S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留)
"""
return FlashAttnVarlenQKVPackedFunc.apply(
qkv, # query、key 和 value 张量
cu_seqlens, # 序列的累积长度
max_seqlen, # 批次中序列的最大长度
dropout_p, # dropout 概率
softmax_scale, # softmax 缩放因子
causal, # 是否应用因果注意力掩码
window_size, # 滑动窗口大小,用于实现局部注意力
alibi_slopes, # 注意力分数偏置方法
deterministic, # 是否使用确定性反向传播实现
return_attn_probs, # 是否返回注意力概率(仅用于测试)
)
2.2、flash_attn_varlen_kvpacked_func
输入前、输出后需要使用unpad、pad
该函数
flash_attn_varlen_kvpacked_func
用于计算变长序列的注意力输出,其中 key 和 value 已被打包成一个张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。该函数的主要作用是:
高效计算注意力输出:通过将 key 和 value 打包成一个张量作为输入,避免了显式连接 K、V 的梯度,从而提高了计算效率。
支持变长序列:函数通过
cu_seqlens_q
和cu_seqlens_k
参数接收每个序列的累积长度,可以有效处理变长序列的情况。支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。
支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。
提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。
返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。
该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。
该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入和 MQA/GQA 等特性,在自然语言处理等领域具有广泛的应用。
def flash_attn_varlen_kvpacked_func(
q, # query 张量,形状为 (total_q, nheads, headdim),例如 (1024, 16, 64),其中 total_q=1024 是批次中 query token 的总数,nheads=16 是注意力头数,headdim=64 是每个注意力头的维度
kv, # key-value 张量,形状为 (total_k, 2, nheads_k, headdim),例如 (2048, 2, 8, 64),其中 total_k=2048 是批次中 key token 的总数,2 表示 key 和 value 被打包在一起,nheads_k=8 是 key 和 value 的注意力头数,headdim=64 是每个注意力头的维度。注意,nheads_k 可以小于 nheads,这支持了多查询注意力(MQA)和分组查询注意力(GQA)的用法
cu_seqlens_q, # query 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 10, 20, 32, 42],表示批次中有 4 个序列,第一个序列长度为 10,第二个序列长度为 10,第三个序列长度为 12,第四个序列长度为 10。这些累积长度用于索引 q 张量
cu_seqlens_k, # key 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 15, 30, 45, 60],用于索引 kv 张量
max_seqlen_q, # 批次中 query 序列的最大长度,例如 12,表示批次中最长的 query 序列长度为 12
max_seqlen_k, # 批次中 key 序列的最大长度,例如 15,表示批次中最长的 key 序列长度为 15
dropout_p=0.0, # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。在训练(training)时,可以设置一个小于 1 的值,例如 0.1,表示将有 10% 的神经元被随机丢弃
softmax_scale=None, # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim),例如如果 headdim=64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126,这是一种常见的缩放方式,可以使注意力分数的方差保持在合理范围内,防止出现较大的梯度
causal=False, # 是否应用因果注意力掩码(用于自回归建模),如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在诸如语言模型等自回归任务中非常有用
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 i, 关注窗口为 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键
alibi_slopes=None, # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads=16,则 alibi_slopes 可以是形状为 (16,) 的张量,每个元素对应一个注意力头的偏置斜率。如果设置了该参数,则会对查询 i 和键 j 之间的注意力分数加上一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|),这种偏置可以鼓励模型关注更靠近查询的键,或者更远离查询的键,具体取决于 alibi_slopes 的值
deterministic=False, # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。在评估(evaluation)时,通常不需要使用确定性实现,因为评估时不需要计算梯度
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
"""
dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。
如果 K 和 V 已经打包成一个张量,调用这个函数会比调用 flash_attn_func 更快,因为反向传播避免了显式拼接 K 和 V 的梯度,从而减少了内存复制和计算量。
支持多查询注意力(MQA)和分组查询注意力(GQA),只需将 KV 的头数设置为少于 Q 的头数即可。注意,Q 的头数必须能被 KV 的头数整除。
例如,如果 Q 有 6 个头,而 K 和 V 有 2 个头,那么 Q 的头 0、1、2 将注意到 K 和 V 的头 0,而 Q 的头 3、4、5 将注意到 K 和 V 的头 1。
这种机制使得模型可以在不同的头组之间共享计算资源,从而提高计算效率。
如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。这种掩码模式确保了在自回归(auto-regressive)建模中,查询只能关注之前的输出,无法关注未来的输出。
例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
0 0
0 0
0 0
1 0
1 1
如果掩码的一行全为零,则该查询的输出也将为零,因为它无法关注任何有效的键。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键。
例如,如果 seqlen_q = 3, seqlen_k = 5, window_size = (1, 2),则查询位置 0 只能关注键位置 [0, 3],查询位置 1 只能关注键位置 [1, 4],查询位置 2 只能关注键位置 [2, 4]。
这种局部注意力机制可以显著提高计算效率,尤其是在处理长序列时,因为它减少了需要计算的注意力分数的数量。但同时也会牺牲一些表达能力,因为查询无法关注整个序列。
"""
# 返回值:
# out: (total, nheads, headdim),注意力层的输出,例如 (1024, 16, 64)
# softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数),例如 (4, 16, 12)
# S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留),例如 (4, 16, 12, 15)
return FlashAttnVarlenKVPackedFunc.apply(
q, # query 张量
kv, # 打包的 key-value 张量
cu_seqlens_q, # query 序列的累积长度
cu_seqlens_k, # key 序列的累积长度
max_seqlen_q, # 批次中 query 序列的最大长度
max_seqlen_k, # 批次中 key 序列的最大长度
dropout_p, # dropout 概率
softmax_scale, # softmax 缩放因子
causal, # 是否应用因果注意力掩码
window_size, # 滑动窗口大小,用于实现局部注意力
alibi_slopes, # 注意力分数偏置方法
deterministic, # 是否使用确定性反向传播实现
return_attn_probs, # 是否返回注意力概率(仅用于测试)
)
"""
一个具体的 GQA 示例:
假设我们有一个批次,其中包含 4 个序列,每个序列的长度分别为 10、10、12 和 10。
我们希望使用 GQA,其中 query 有 6 个头(nheads=6),而 key 和 value 有 2 个头(nheads_k=2)。
每个注意力头的维度为 64(headdim=64)。
在这种情况下,各个参数的值如下:
q: (total_q, nheads, headdim) = (42, 6, 64)
total_q = 10 + 10 + 12 + 10 = 42,是批次中所有序列的 token 总数
nheads = 6,表示 query 有 6 个注意力头
headdim = 64,表示每个注意力头的维度为 64
kv: (total_k, 2, nheads_k, headdim) = (42, 2, 2, 64)
total_k = 10 + 10 + 12 + 10 = 42,与 total_q 相同,因为每个序列的 query 长度和 key 长度是相同的
2 表示 key 和 value 被打包在一起
nheads_k = 2,表示 key 和 value 有 2 个注意力头
headdim = 64,表示每个注意力头的维度为 64
cu_seqlens_q: [0, 10, 20, 32, 42],表示批次中每个序列的累积长度,用于索引 q
第一个序列的长度为 10,第二个序列的长度为 10,第三个序列的长度为 12,第四个序列的长度为 10
cu_seqlens_k: [0, 10, 20, 32, 42],与 cu_seqlens_q 相同,因为每个序列的 query 长度和 key 长度是相同的
max_seqlen_q: 12,批次中最长的 query 序列长度
max_seqlen_k: 12,与 max_seqlen_q 相同,因为每个序列的 query 长度和 key 长度是相同的
在这个 GQA 设置下,query 的 6 个头将被分成 2 组,每组 3 个头,分别关注 key 和 value 的 2 个头:
- query 的头 0 1 2 将关注 key 和 value 的头 0
- query 的头 3 4 5 将关注 key 和 value 的头 1
这种分组机制允许不同的 query 头关注不同的 key 和 value 子空间,提高了计算效率和表达能力。
同时,通过减少 key 和 value 的头数,可以显著降低计算和存储开销。
"""
2.3、flash_attn_varlen_func
输入前、输出后需要使用unpad、pad
该函数
flash_attn_varlen_func
用于计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。该函数的主要作用是:
高效计算注意力输出:通过分开输入 query、key 和 value 张量,可以有效利用计算资源进行注意力计算。
支持变长序列:函数通过
cu_seqlens_q
和cu_seqlens_k
参数接收每个序列的累积长度,可以有效处理变长序列的情况。支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。
支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。
支持分块稀疏注意力:可以通过提供
block_table
参数来启用分块稀疏注意力,进一步提高计算效率。提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。
返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。
该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。
该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入、MQA/GQA 和分块稀疏注意力等特性,在自然语言处理等领域具有广泛的应用。
def flash_attn_varlen_func(
q, # 输入的 query 张量,形状为 (total_q, nheads, headdim),其中 total_q 是批量中所有查询token的总数,nheads 是注意力头数,headdim 是每个注意力头的维度
k, # 输入的 key 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
v, # 输入的 value 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
cu_seqlens_q, # 批量中每个查询序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 q 中索引相应的位置
cu_seqlens_k, # 批量中每个 key 序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 k 和 v 中索引相应的位置
max_seqlen_q, # 批量中最大查询序列长度
max_seqlen_k, # 批量中最大 key 序列长度
dropout_p=0.0, # Dropout 率,在评估(evaluation)时应设置为 0.0
softmax_scale=None, # softmax 缩放因子,默认为 1 / sqrt(headdim)
causal=False, # 是否应用因果注意力掩码,用于自回归(auto-regressive)建模
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限制上下文窗口
alibi_slopes=None, # 用于添加注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32
deterministic=False, # 是否使用确定性反向传播实现,比非确定性实现稍慢但使用更多内存,前向传播始终是确定性的
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能不具有正确缩放
block_table=None # 可选的块表,用于分块稀疏注意力
):
"""
解决问题: 计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。支持多查询注意力(MQA)和分组查询注意力(GQA)。
注意事项:
- dropout_p 应在评估时设置为 0.0。
- 支持多查询注意力(MQA)和分组查询注意力(GQA),通过将 K、V 的注意力头数设置为少于 Q 的注意力头数来实现。Q 的注意力头数必须能被 K、V 的注意力头数整除。
例如,如果 Q 有 6 个注意力头,K、V 有 2 个注意力头,那么 Q 的头 0、1、2 将关注 K、V 的头 0,Q 的头 3、4、5 将关注 K、V 的头 1。
- 如果 causal=True,因果掩码将与注意力矩阵的右下角对齐。
例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 掩码)为:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
0 0
0 0
0 0
1 0
1 1
如果掩码的一行全为零,输出也将为零。
- 如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位于 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内的 key。
- 可以通过提供 block_table 参数来启用分块稀疏注意力。
返回值:
- out: 注意力层的输出张量,形状为 (total, nheads, headdim),其中 total = total_q。
- softmax_lse [可选,如果 return_attn_probs=True]: 每行的 QK^T * scaling 的 logsumexp 值,形状为 (batch_size, nheads, seqlen),即 softmax 归一化因子的对数。
- S_dmask [可选,如果 return_attn_probs=True]: softmax 的输出,可能具有不同的缩放,形状为 (batch_size, nheads, seqlen, seqlen),还编码了 Dropout 模式(负值表示该位置被丢弃,非负值表示该位置被保留)。
"""
return FlashAttnVarlenFunc.apply(
q, # 输入的 query 张量
k, # 输入的 key 张量
v, # 输入的 value 张量
cu_seqlens_q, # 批量中每个查询序列的累积长度
cu_seqlens_k, # 批量中每个 key 序列的累积长度
max_seqlen_q, # 批量中最大查询序列长度
max_seqlen_k, # 批量中最大 key 序列长度
dropout_p, # Dropout 率
softmax_scale, # softmax 缩放因子
causal, # 是否应用因果注意力掩码
window_size, # 用于实现滑动窗口局部注意力
alibi_slopes, # 用于添加注意力分数偏置
deterministic, # 是否使用确定性反向传播实现
return_attn_probs, # 是否返回注意力概率
block_table, # 可选的块表,用于分块稀疏注意力
)
2.4、flash_attn_with_kvcache
该函数
flash_attn_with_kvcache
用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。它是一个高效的注意力计算函数,可以在推理时加速序列生成任务。函数的主要特点和技术如下:
支持更新 key 和 value 缓存:如果提供了新的
k
和v
张量,函数会将它们的值原地更新到k_cache
和v_cache
中。这对于增量解码非常有用,可以一次性完成缓存更新和注意力计算。旋转位置嵌入 (Rotary Position Embedding):如果提供了
rotary_cos
和rotary_sin
,函数会对q
、k
应用旋转位置嵌入。旋转位置嵌入是一种编码序列位置信息的方法,可以提高注意力模型在长序列任务中的性能。因果注意力掩码 (Causal Attention Mask):如果设置
causal=True
,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。滑动窗口局部注意力 (Sliding Window Local Attention):如果设置
window_size != (-1, -1)
,函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。多查询和分组查询注意力 (MQA/GQA):函数支持将
q
的头数设置为k
、v
头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。分块 key/value 缓存:如果提供了
block_table
,函数会将k_cache
和v_cache
视为分页缓存,支持高效的缓存管理。注意力分数偏置 (Alibi Slopes):如果提供了
alibi_slopes
,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。CUDA 内核加速:函数的核心计算由 CUDA 内核
flash_attn_cuda.fwd_kvcache
完成,提供高性能的并行计算能力。总的来说,
flash_attn_with_kvcache
函数集成了多种先进的注意力计算技术,可以高效地
def flash_attn_with_kvcache(
q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
k_cache, # key 缓存张量,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim) 或 (num_blocks, page_block_size, nheads_k, headdim)
v_cache, # value 缓存张量,形状与 k_cache 相同
k=None, # 可选的新 key 张量,形状为 (batch_size, seqlen_new, nheads_k, headdim)
v=None, # 可选的新 value 张量,形状与 k 相同
rotary_cos=None, # 可选的旋转位置嵌入余弦值,形状为 (seqlen_ro, rotary_dim / 2)
rotary_sin=None, # 可选的旋转位置嵌入正弦值,形状与 rotary_cos 相同
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, # 缓存序列长度,可以是整数或张量
cache_batch_idx: Optional[torch.Tensor] = None, # 缓存批次索引张量,形状为 (batch_size,)
block_table: Optional[torch.Tensor] = None, # 可选的块表张量,形状为 (batch_size, max_num_blocks_per_seq)
softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
causal=False, # 是否进行因果注意力掩码
window_size=(-1, -1), # 滑动窗口大小,(-1, -1)表示无限上下文窗口
rotary_interleaved=True, # 是否交错旋转位置嵌入
alibi_slopes=None, # 可选的注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
num_splits=0, # 将 key/value 沿序列维度分割的数量,0 表示自动确定
):
"""
该函数用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。
如果 k 和 v 不为 None,则会将它们的新值原地更新到 k_cache 和 v_cache 中。这对于增量解码很有用:
你可以传入上一步的缓存 key/value,并使用当前步的新 key/value 进行更新,然后使用更新后的缓存进行注意力计算,所有操作都在一个内核中完成。
如果你传入了 k/v,你必须确保缓存足够大,可以容纳新的值。例如,KV 缓存可以预先分配最大序列长度,并使用 cache_seqlens 跟踪每个序列在批次中的当前长度。
如果传入了 rotary_cos 和 rotary_sin,则会应用旋转位置嵌入。key @k 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
如果是因果注意力或局部注意力(即 window_size != (-1, -1)),则查询 @q 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
如果既不是因果注意力也不是局部注意力,则查询 @q 将只在索引 cache_seqlens 处被 rotary_cos 和 rotary_sin 旋转(即我们认为 @q 中的所有标记都位于 cache_seqlens 位置)。
该函数支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。注意 Q 中的头数量必须能被 KV 中的头数量整除。
例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。
如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
0 0
0 0
0 0
1 0
1 1
如果掩码行全为 0,则输出也将为 0。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。
注意:不支持反向传播。
参数:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) 如果没有 block_table,
或 (num_blocks, page_block_size, nheads_k, headdim) 如果有 block_table (即分页 KV 缓存)
page_block_size 必须是 256 的倍数。
v_cache: 形状与 k_cache 相同
k [可选]: (batch_size, seqlen_new, nheads_k, headdim)。如果不为 None,我们将它与 k_cache 拼接,从 cache_seqlens 指定的索引开始。
v [可选]: (batch_size, seqlen_new, nheads_k, headdim)。与 k 类似。
rotary_cos [可选]: (seqlen_ro, rotary_dim / 2)。如果不为 None,我们将对 k 和 q 应用旋转位置嵌入。只在 k 和 v 被传入时适用。rotary_dim 必须能被 16 整除。
rotary_sin [可选]: (seqlen_ro, rotary_dim / 2)。与 rotary_cos 类似。
cache_seqlens: int 或 (batch_size,), 数据类型 torch.int32。KV 缓存的序列长度。
block_table [可选]: (batch_size, max_num_blocks_per_seq), 数据类型 torch.int32。
cache_batch_idx: (batch_size,), 数据类型 torch.int32。用于索引 KV 缓存的索引。如果为 None,我们假设批次索引为 [0, 1, 2, ..., batch_size - 1]。如果索引不是唯一的,并且提供了 k 和 v,那么缓存中更新的值可能来自任何重复的索引。
softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
rotary_interleaved: bool。只在传入 rotary_cos 和 rotary_sin 时适用。如果为 True,旋转位置嵌入将组合维度 0 & 1、2 & 3 等。如果为 False,旋转位置嵌入将组合维度 0 & rotary_dim / 2、1 & rotary_dim / 2 + 1(即 GPT-NeoX 风格)。
alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
num_splits: int。如果大于 1,则将 key/value 沿序列维度分割成这么多块。如果 num_splits == 1,我们不分割 key/value。如果 num_splits == 0,我们使用启发式方法自动确定分割数量。除非你知道你在做什么,否则不要更改这个参数。
返回:
out: (batch_size, seqlen, nheads, headdim)。
"""
# 确保 k_cache 和 v_cache 的最后一维是连续的
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
# 如果 q、k 或 v 的最后一维不是连续的,则对它们进行连续化
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
# 如果没有指定 softmax 缩放系数,则使用默认值 1 / sqrt(headdim)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
# 如果 cache_seqlens 是整数,则将其转换为张量
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
# 确保 cache_batch_idx 和 block_table 是连续的
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
# 调用 CUDA 内核计算注意力输出
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
block_table,
alibi_slopes,
None,
softmax_scale,
causal,
window_size[0],
window_size[1],
rotary_interleaved,
num_splits,
)
return out
2.5、flash_attn_func
该函数
flash_attn_func
用于计算注意力层的输出,是一个高效的注意力计算函数,支持多种先进的注意力技术。函数的主要特点和技术如下:
多查询和分组查询注意力 (MQA/GQA):函数支持将
q
的头数设置为k
、v
头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。因果注意力掩码 (Causal Attention Mask):如果设置
causal=True
,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。滑动窗口局部注意力 (Sliding Window Local Attention):如果设置
window_size != (-1, -1)
,函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。注意力分数偏置 (Alibi Slopes):如果提供了
alibi_slopes
,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。确定性反向传播 (Deterministic Backward Pass):函数支持使用确定性反向传播实现,虽然稍慢但使用更多内存,正向传播始终是确定性的。
返回注意力概率:如果设置
return_attn_probs=True
,函数会返回注意力概率,但这只用于测试,返回的概率可能由于缩放问题而不准确。CUDA 内核加速:函数的核心计算由 CUDA 内核
FlashAttnFunc.apply
完成,提供高性能的并行计算能力。总的来说,
flash_attn_func
函数集成了多种先进的注意力计算技术,可以高效地计算注意力层的输出,支持各种用途和优化方式。
def flash_attn_func(
q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
k, # key 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
v, # value 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
dropout_p=0.0, # dropout 概率,评估时应设置为 0.0
softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
causal=False, # 是否应用因果注意力掩码,例如用于自回归建模
window_size=(-1, -1), # 滑动窗口大小,(-1, -1) 表示无限上下文窗口
alibi_slopes=None, # 注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
deterministic=False, # 是否使用确定性反向传播实现,稍慢但使用更多内存,正向传播始终是确定性的
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能由于缩放问题而不准确
):
"""
该函数用于计算注意力层的输出。
支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。
注意 Q 中的头数量必须能被 KV 中的头数量整除。
例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。
如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。
例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
0 0
0 0
0 0
1 0
1 1
如果掩码行全为 0,则输出也将为 0。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。
位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。
参数:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float。dropout 概率。
softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
deterministic: bool。是否使用确定性反向传播实现,稍慢但使用更多内存。正向传播始终是确定性的。
return_attn_probs: bool。是否返回注意力概率。这个选项仅用于测试。返回的概率可能由于缩放问题而不准确。
返回:
out: (batch_size, seqlen, nheads, headdim)。
softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen)。每行 QK^T * 缩放系数的 logsumexp (例如,softmax 归一化因子的对数)。
S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen)。softmax 的输出(可能有不同的缩放)。它还编码了 dropout 模式(负值表示该位置被丢弃,非负值表示被保留)。
"""
# 调用 FlashAttnFunc 类的 apply 方法计算注意力输出
return FlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)