FasterTransformer 004 open_attention.h forward

news2024/11/20 3:25:51

initialize

在这里插入图片描述

forward()

  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/cuda/open_attention.h#L149-L217

在这里插入图片描述

  • 使用cuBLAS库执行矩阵乘法运算,并对cublasGemmEx()进行三个单独的调用。这些操作包括将属性核与输入张量相乘,并添加偏差项,从而生成查询、键和值矩阵。
    在这些矩阵相乘之后,该函数使用sqrtf()函数计算缩放因子,并将查询、键和值矩阵以及该缩放器传递给另一个名为multiHeadAttr_nofuse_kernelLauncher()的函数。此函数可能会使用额外的计算将多头注意力应用于查询、键和值矩阵,以生成输出矩阵param_.attr_out。
    最后,forward()函数捕获执行过程中抛出的任何运行时错误,并重新抛出它们。

cublasGemmEx *3

check_cuda_error(cublasGemmEx(param_.cublas_handle, 
        CUBLAS_OP_N, CUBLAS_OP_N, 
        n, m, k, 
        &alpha, 
        param_.attr_kernel_Q, AType_, n, 
        param_.from_tensor, BType_, k, 
        &beta, 
        query_buf_, CType_, n, 
        computeType_, 
        static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));

      check_cuda_error(cublasGemmEx(param_.cublas_handle, 
        CUBLAS_OP_N, CUBLAS_OP_N,
        n, m, k, 
        &alpha, 
        param_.attr_kernel_K, AType_, n, 
        param_.to_tensor, BType_, k, 
        &beta, 
        key_buf_, CType_, n, 
        computeType_, 
        static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));

      check_cuda_error(cublasGemmEx(param_.cublas_handle, 
        CUBLAS_OP_N, CUBLAS_OP_N, 
        n, m, k,
        &alpha,
        param_.attr_kernel_V, AType_, n, 
        param_.to_tensor, BType_, k, 
        &beta, 
        value_buf_, CType_, n, 
        computeType_, 
        static_cast<cublasGemmAlgo_t>(cublasAlgo_[0])));

cublasGemmEx

cublasGemmEx is a function from the NVIDIA cuBLAS library that performs a generalized matrix multiplication operation (GEMM) on two matrices A and B, and accumulates the result into a third matrix C.

The “Ex” suffix in the function name indicates that this is an extended version of the basic cublasGemm function, which allows for more advanced features such as data type casting, tensor operations, and tensor cores support.

The function signature for cublasGemmEx is as follows:

cublasStatus_t cublasGemmEx(cublasHandle_t handle,
                            cublasOperation_t transa,
                            cublasOperation_t transb,
                            int m,
                            int n,
                            int k,
                            const void *alpha,
                            const void *A,
                            cudaDataType_t Atype,
                            int lda,
                            const void *B,
                            cudaDataType_t Btype,
                            int ldb,
                            const void *beta,
                            void *C,
                            cudaDataType_t Ctype,
                            int ldc,
                            cudaDataType_t computeType,
                            cublasGemmAlgo_t algo);

Here’s a brief overview of the input parameters:

  • handle: A handle to the cuBLAS library context.
  • transa and transb: Transpose operation to be performed on matrices A and B respectively before the GEMM operation.
  • m, n, and k: The dimensions of matrices A, B, and C, respectively.
  • alpha: Scalar used to scale the product of matrices A and B.
  • A, B, and C: Pointers to the device memory storing the matrices A, B, and C.
  • Atype, Btype, and Ctype: Data types of matrices A, B, and C, respectively.
  • lda, ldb, and ldc: The leading dimensions of matrices A, B, and C, respectively.
  • beta: Scalar used to scale matrix C before accumulation.
  • computeType: The data type used for intermediate computations.
  • algo: The algorithm used for the GEMM operation.

The function returns a cublasStatus_t value indicating whether the operation was successful or if an error occurred.

multiHeadAttr_nofuse_kernelLauncher

declaration in “open_attention.h”

  void multiHeadAttr_nofuse_kernelLauncher(
      cudaStream_t stream,
      cublasHandle_t handle,
      DataType_* Q,
      const DataType_* bias_Q,
      DataType_* K,
      const DataType_* bias_K,
      DataType_* V,
      const DataType_* bias_V,
      const DataType_* attr_mask,
      DataType_* dst,
      const int batch_size,
      const int seq_len,
      const int head_num,
      const int size_per_head,
      const DataType_ scaler);

define in “open_attention.cu”

  • 定义了一个模板函数和两个特化模板。程序编译时会匹配然后编译。
template void OpenMultiHeadAttention<OperationType::FP32>::multiHeadAttr_nofuse_kernelLauncher(
      cudaStream_t stream,
      cublasHandle_t handle,
      float* Q,
      const float* bias_Q,
      float* K,
      const float* bias_K,
      float* V,
      const float* bias_V,
      const float* attr_mask,
      float* dst,
      const int batch_size,
      const int seq_len,
      const int head_num,
      const int size_per_head,
      const float scaler);

template void OpenMultiHeadAttention<OperationType::HALF>::multiHeadAttr_nofuse_kernelLauncher(
      cudaStream_t stream,
      cublasHandle_t handle,
      __half* Q,
      const __half* bias_Q,
      __half* K,
      const __half* bias_K,
      __half* V,
      const __half* bias_V,
      const __half* attr_mask,
      __half* dst,
      const int batch_size,
      const int seq_len,
      const int head_num,
      const int size_per_head,
      const __half scaler);
}//namespace cuda
template<OperationType OpType_>
void OpenMultiHeadAttention<OpType_>::multiHeadAttr_nofuse_kernelLauncher(
      cudaStream_t stream,
      cublasHandle_t cublas_handle,
      DataType_* Q,
      const DataType_* bias_Q,
      DataType_* K,
      const DataType_* bias_K,
      DataType_* V,
      const DataType_* bias_V,
      const DataType_* attr_mask,
      DataType_* dst,
      const int batch_size,
      const int seq_len,
      const int head_num,
      const int size_per_head,
      const DataType_ scaler)
{

    int m = batch_size * seq_len;
    int k = head_num * size_per_head;

    dim3 grid;
    dim3 block;

    if(OpType_ == OperationType::FP32)
    {
      const int word_per_block = 1;
      assert(k <= 1024);
      assert(m / word_per_block * 3 <= 65536);

      dim3 grid(m / word_per_block * 3);
      dim3 block(k);
      add_QKV_bias<DataType_><<<grid, block, 0, stream>>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, v_buf_,
          batch_size, seq_len, head_num, size_per_head, word_per_block);
    }
    else
    {
      const int word_per_block = 1;
      grid.x = batch_size * seq_len / word_per_block;
      block.x = head_num * size_per_head * word_per_block / 2;

      add_QKV_bias<DataType_><<<grid, block, 0, stream>>>(Q, bias_Q, K, bias_K, V, bias_V, q_buf_, k_buf_, 
      v_buf_, batch_size, seq_len, head_num, size_per_head / 2, word_per_block);
    }

    DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f;
    
    check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
      CUBLAS_OP_T, CUBLAS_OP_N,
      seq_len, seq_len, size_per_head,
      &alpha,
      k_buf_, AType_, size_per_head, seq_len * size_per_head,
      q_buf_, BType_, size_per_head, seq_len * size_per_head,
      &beta,
      qk_buf_, CType_, seq_len, seq_len * seq_len,
      batch_size * head_num,
      computeType_,
      static_cast<cublasGemmAlgo_t>(cublasAlgo_[1])));

    if(seq_len <= 32)
      block.x = 32;
    else if(seq_len > 32 && seq_len <= 64)
      block.x = 64;
    else if(seq_len > 64 && seq_len <= 128)
      block.x = 128;
    else if(seq_len > 128 && seq_len <= 256)
      block.x = 256;
    else if(seq_len > 256 && seq_len <= 512)
      block.x = 512;
    else
      block.x = 1024;

    if(batch_size * head_num <= 120)
    {
      grid.x = batch_size * head_num * seq_len;
      softmax_kernel_v2<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler); 
    }
    else
    {
      grid.x = batch_size * head_num;
      softmax_kernel<DataType_><<<grid, block, 0, stream>>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler); 
    }

    check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
      CUBLAS_OP_N, CUBLAS_OP_N,
      size_per_head, seq_len, seq_len,
      &alpha,
      v_buf_, AType_, size_per_head, seq_len * size_per_head,
      qk_buf_, BType_, seq_len, seq_len * seq_len,
      &beta,
      transpose_dst_, CType_, size_per_head, seq_len * size_per_head,
      batch_size * head_num,
      computeType_,
      static_cast<cublasGemmAlgo_t>(cublasAlgo_[2])));

/* for half2 only */
    if(OpType_ == OperationType::HALF)
    {
      const int seq_per_block = 4;
      grid.x = batch_size * head_num * seq_len / seq_per_block;
      block.x = seq_per_block * size_per_head / 2;

      assert(grid.x * seq_per_block == batch_size * head_num * seq_len);

      transpose<DataType_><<<grid, block, 0, stream>>>(transpose_dst_, dst, 
          batch_size, seq_len, head_num, size_per_head / 2);
    }
    else
    {
      const int seq_per_block = 1;
      grid.x = batch_size * head_num * seq_len / seq_per_block;
      block.x = seq_per_block * size_per_head;
      transpose<DataType_><<<grid, block, 0, stream>>>(transpose_dst_, dst, 
          batch_size, seq_len, head_num, size_per_head);
    }
}

cublasGemmTridedBatchedEx

cublasGemmTridedBatchedEx是cuBLAS库中的一个函数,用于执行跨步分批矩阵乘法。它获取多组输入矩阵,并对每组并行执行相同的矩阵乘法运算,将结果存储回内存。函数名称中的“Ex”表示这是基本“cublasGemmTriedBatched”函数的扩展版本,其中包括用于指定数据类型、比例因子和其他设置的附加选项。
在深度学习的背景下,“cublasGemmTridedBatchedEx”通常用于在Transformer神经网络中执行多头自注意操作。此操作涉及将查询、键和值矩阵集相乘以产生注意力分数,注意力分数用于对值进行加权并计算最终输出。有效地执行这些矩阵乘法对于在大型Transformer模型中实现高性能至关重要。

Attention & Transformer

在这里插入图片描述
在这里插入图片描述

  • FlashAttention:一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法(但非传统attention的近似方法,是数学等价的)
    https://www.bilibili.com/video/BV1SW4y1X7kh/?
    https://zhuanlan.zhihu.com/p/618533434

CG

cuda inline

  • 您可能会注意到的第一件事是__inline__声明。这可能是不必要的。它告诉编译器 将此函数的整个代码放在该点 调用它的位置,而不是导致跳转发生。 这使得编译后的代码运行得更快。另一方面 NVCC知道使许多(也许是大多数)设备功能内联 已经没有问了,所以我们的要求可能是多余的。 我把它放进去,这样我就可以谈论它了。如果函数 更长,也许默认情况下它不会内联,但是 如果您编写该函数主要是为了可读性,则 使其内联可能对您很重要。在这种情况下, 您可以获得更具可读性的代码,而不会牺牲速度来跳转。
  • 在这里插入图片描述

layernorm

  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/bert_encoder_transformer.h#L208-L281 中调用了add_bias_input_layernorm_kernelLauncher方法。

  • 一个帮助理解multiHeadAttr_nofuse_kernelLauncher的例子

#include <stdio.h>
#include <iostream>
#include <typeinfo>

// using OperationType = int;
enum class OperationType{FP32, HALF};


template<OperationType op>class First{};
template< OperationType N,template<OperationType>class XXX >
class Second;

template< OperationType N,template<OperationType>class XXX >
class Second{
    public:
        XXX<N> b; 
        Second(){ 
            std::cout<<"NNN"; 
        } 
};

// template<template<OperationType> class MultiHeadAttention_>
// class BertEncoderTransformerTraits<OperationType::FP32, MultiHeadAttention_>
template< template<OperationType>class XXX >
class Second<OperationType::FP32,XXX>{
    public:
        XXX<OperationType::FP32> b; 
        Second(){ 
            std::cout<<"SSSSS"; 
        } 
};

// template< template<OperationType>class XXX >
// class Second<OperationType::HALF,XXX>{
//     public:
//         XXX<OperationType::HALF> b; 
//         Second(){ 
//             std::cout<<"HALF"; 
//         } 
// };




int main()
{
    printf("Hello World\n");
    // Second<OperationType::FP32,First> *second = new  Second<OperationType::FP32,First>();
    Second<OperationType::HALF,First> *second = new  Second<OperationType::HALF,First>();
    return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/641063.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

<Linux开发>驱动开发 -之-内核定时器与中断

&#xff1c;Linux开发&#xff1e;驱动开发 -之-内核定时器与中断 交叉编译环境搭建&#xff1a; &#xff1c;Linux开发&#xff1e; linux开发工具-之-交叉编译环境搭建 uboot移植可参考以下&#xff1a; &#xff1c;Linux开发&#xff1e; -之-系统移植 uboot移植过程详…

Linux系统下安装Kubernetes(超详细。。。)

一、安装Kubernetes前的准备 1.1 准备Hosts文件 &#xff08;注意&#xff0c;请根据Linux虚拟机的IP地址&#xff0c;修改以下命令后再执行&#xff09; cat >>/etc/hosts<<EOF 192.168.100.146 deploy EOF 1.2 检查虚拟机的hostname cat /etc/hostname验证…

Charles抓包配置

这里写目录标题 一、Windows抓包配置1、Help-SSL Proxying-install Charles Root Certificate2、安装并导入证书&#xff0c;按下方各图完成证书导入后&#xff0c;正常情况下&#xff0c;会显示该证书没有问题。3、SSL证书过期解决办法a、可在windows的设置中搜索证书关键字&a…

c++ nlohmann/json 及修改json文件中个别关键字

(2条消息) nlohmann json使用_nlohmann::json_蜗牛单行道的博客-CSDN博客json为JavaScript object notation 是一种数据格式&#xff0c;逐渐替换掉了传统的xml 。json数据格式的属性名称和字符串值需要用双引号引起来&#xff0c;用单引号或者不用引号会导致读取数据错误。jso…

Django-初

文章目录 一、Django框架介绍二、后台管理第一步:项目的创建与运行第二步:应用的创建和使用第三步: 项目的数据库模型第四步: 启用后台Admin站点管理 三、前台管理第一步: URLconf 路由管理第二步: 视图函数处理业务逻辑第三步: 模板管理实现好看的HTML页面&#xff08;可参考菜…

网络计算模式期末复习(一)

C/S架构 C/S架构即客户端/服务端架构。客户端包含一个或多个在用户电脑上运行的程序&#xff0c;客户端程序发送请求和从服务器接收的数据。服务器端主要提供数据管理、数据共享、数据及系统维护和并发控制等。 B/S架构 B/S架构即浏览器/服务器架构&#xff0c;是随着Intern…

图片上添加贴纸怎么做?这几种方法很简单

在图片上添加贴纸是一种非常实用的图片编辑技巧&#xff0c;通过添加贴纸&#xff0c;图片可以变得更加生动有趣&#xff0c;吸引人们的眼球。贴纸可以是各种形状、颜色和大小&#xff0c;从而丰富图片的视觉效果。例如&#xff0c;在一张风景照片中添加一只卡通动物的图案&…

python中golbal的使用

简介 global关键字定义了一种在局部定义全局变量的方法 python中变量分为全局变量和局部变量&#xff0c;局部变量也叫做内部变量内部变量只能被内部使用&#xff0c;无法被其他函数或者对象使用 使用 简单使用 def fn():global fn_varfn_var "Hello World"fn1()…

为什么网红餐饮都做不长久?如何解决网红餐饮店所面临的问题?

随着社交媒体的兴起&#xff0c;网红餐饮在近年来越来越受到人们的关注。这些网红餐饮通常有着独特的装修风格、口味或者服务方式&#xff0c;吸引了大量的消费者前来体验。然而&#xff0c;有越来越多的网红餐饮因为各种原因而不得不倒闭&#xff0c;这引发了人们对于网红餐饮…

cajviewer怎么转换成pdf格式,分享几个方法给大家!

CAJViewer是一款常用的文献阅读软件&#xff0c;它主要用于打开和阅读中国知网等数据库中的CAJ格式文件。然而&#xff0c;有时候我们可能需要将这些CAJ文件转换为PDF格式&#xff0c;以便更方便地与他人分享或者进行打印。本文将介绍两到三种将CAJViewer文件转换为PDF格式的方…

华为OD机试真题2022Q4 A + 2023 B卷(JavaJavaScript)

大家好&#xff0c;我是哪吒。 五月份之前&#xff0c;如果你参加华为OD机试&#xff0c;收到的应该是2022Q4或2023Q1&#xff0c;这两个都是A卷题。 5月10日之后&#xff0c;很多小伙伴收到的是B卷&#xff0c;那么恭喜你看到本文了&#xff0c;抓紧刷题吧。B卷新题库正在更…

建站记录1:开通阿里云,购买域名,安装宝塔+LAMP系统

个人建站&#xff1a; 因为宝塔系统&#xff08;https://www.bt.cn&#xff09;&#xff0c;可以方便的部署zblog 彩色背景 什么是LAMP&#xff1f; Linux Apache PHP MySQL LAMP 是指Linux&#xff08;操作系统&#xff09; Apache &#xff08;HTTP 服务器&#xff09;…

batch_size对精确度和损失的影响

1 问题 在深度学习的学习过程中&#xff0c;模型性能对batchsize虽然没有学习率那么敏感&#xff0c;但是在进一步提升模型性能时&#xff0c;batch_size就会成为一个非常关键的参数。 batch_size对精度和损失的影响研究。 batch_size [,32,64,128&#xff0c;256] 不同batch_…

镕铭微电子VPU 极致降本增效实践

当前视频行业环境下&#xff0c;硬件芯片的机遇与挑战并存&#xff0c;如何使得硬件芯片产品及方案设计更好地贴近用户、服务用户及满足用户更深层次需求&#xff1f;本次LiveVideoStackCon 2022 北京站邀请到镕铭微电子解决方案架构总监——蔡媛Amy&#xff0c;为大家介绍镕铭…

【熬夜送书 | 第五期】清华社赞助 | 《MySQL系列丛书》

MySQL是什么? MySQL是一种关系型数据库管理系统&#xff0c;由瑞典MySQL AB公司开发。MySQL是最流行的关系型数据库管理系统之一&#xff0c;在WEB应用方面&#xff0c;MySQL是最好的RDBMS(Relational Database Management System:关系数据库管理系统)应用软件之一。 MySQL有…

Arduino esp32 环境配置以及避坑指南

目录 环境配置安装 IDE下载固件 项目测试疑难解答micropython 固件冲突问题 环境配置 安装 IDE 参考文献&#xff1a;CSDN 首先下载 Arduino IDE 请注意&#xff0c;一定要选择 1.8 版本的&#xff0c;千万别用 2.0版本&#xff01;&#xff01;&#xff01; 建议直接下载 win…

通过向量回归、随机森林回归、线性回归和K-最近邻回归将预测结果绘制成图表进行展示

文章目录 表格部分数据如下运行效果如下代码解析完整代码附件 表格部分数据如下 附件里会给出全部数据链接 运行效果如下 代码解析 import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib.font_manager import FontPropertiesfont FontP…

webpack自动化打包webpack-dev-server

在前面的章节中我们每次改完要打包的资源文件&#xff0c;和配置文件都是是输入npx webpack命令手动打包的&#xff0c;那么有没有什么办法可以监听到我们代码的改动&#xff0c;在保存时就自动打包呢&#xff1f; 答案是当然有&#xff0c;不然哪些框架的脚手架是怎么实现保存…

Redis命令-数据结构String类型和Hash类型

1. String类型 字符串类型&#xff0c;Redis中最简单的存储类型 底层都是字节数组形式存储&#xff0c;只不过是编码方式不同&#xff1b; 字符串类型的最大空间不能超过512m&#xff1b; SET/GET/MSET/MGET使用示例&#xff1a; INCR使用示例&#xff1a; INCRBY自增并指定步长…

CSS粘性定位 - 它的真正工作原理!

本文首发于微信公众号&#xff1a;大迁世界, 我的微信&#xff1a;qq449245884&#xff0c;我会第一时间和你分享前端行业趋势&#xff0c;学习途径等等。 更多开源作品请看 GitHub https://github.com/qq449245884/xiaozhi &#xff0c;包含一线大厂面试完整考点、资料以及我的…