关于 UNFUSED_PADDED_MHA VS FUSED_MHA
- FUSED_MHA用了另一种kernel的执行方法(和添加链接描述相同,将在下一个section说明)
- UNFUSED_PADDED 的 KERNELS执行代码在 src/fastertransformer/kernels/unfused_attention_kernels.cu
enum class AttentionType {
UNFUSED_MHA,
UNFUSED_PADDED_MHA,
FUSED_MHA,
FUSED_PADDED_MHA
};
/* NOTE:
1. only swin-style relative position bias is supported currently
2. gpt-style (causal-mask) models support any-sequence-length fmha, so we don't need to call isValidSeqLen at run-time
3. bert/vit can also support any-seq-length fmha
*/
template<typename T>
AttentionType getAttentionType(size_t size_per_head,
const int sm,
const bool remove_padding,
const int max_seq_len,
const bool is_fuse = true,
const bool with_swin_relative_position_bias = false,
const bool causal_mask = false)
{
if (std::is_same<T, half>::value && is_fuse) {
// Bert/Vit
if (!causal_mask) {
if (!with_swin_relative_position_bias
&& (((sm == kSM_70 || sm == kSM_72) && size_per_head == 64)
|| ((sm == kSM_75 || sm == kSM_80 || sm == kSM_86)
&& (size_per_head == 64 || size_per_head == 32)))) {
return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;
}
else if (with_swin_relative_position_bias && (sm == kSM_75 || sm == kSM_80 || sm == kSM_86)
&& max_seq_len <= 256 && size_per_head == 32) {
return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;
}
}
// GPT and its variants
else {
// FMHA_ENABLE only affects gpt-style models (causal-mask)
char * fused_qkv = std::getenv("FMHA_ENABLE");
if (fused_qkv != nullptr && std::string(fused_qkv) == "ON") {
if ((sm == kSM_70 || sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)
&& (size_per_head == 32 || size_per_head == 40 || size_per_head == 64 || size_per_head == 80
|| size_per_head == 128 || size_per_head == 144 || size_per_head == 160 || size_per_head == 256)) {
return remove_padding ? AttentionType::FUSED_MHA : AttentionType::UNFUSED_PADDED_MHA;
}
}
}
}
- 如果想执行FUSED_MHA,需要将参数设置如下:
FUSED_MHA
- https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md
所以有关核函数的定义调用等还在forward部分:
https://github1s.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/bert/Bert.cc#L494
调用了FusedAttentionLayer的传播函数
传播函数的融合部分
Dispatcher_fp16为指向MHARunner类型的指针
实际上通过 .reset()实现了多态:
最终调用pimpl->run
指针pimpl对应的内部类的定义在
https://github1s.com/NVIDIA/FasterTransformer/blob/main/3rdparty/trt_fused_multihead_attention/qkvToContext.cu#L62