From self-attention 2 flash-attention 数学原理与 cuda 实现优化

news2025/1/17 15:20:31

self attension 是transformer 编码器和解码器中共同的一个计算环节,在整个transformer 网络体系中耗费的算力比例占主导。所以节省self attention 的正向和反向的计算时间,就可以加速 transormer 的训练和推理过程。

1,self attention 的数学提炼

两个矩阵乘法,加入一个列向的softmax

input   矩阵: \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbf{R}^{N \times d}

output 矩阵:\mathbf{O} \in \mathbf{R}^{N \times d}

 

\mathbf{self\ attention\ algorithm:}

        step1:        \mathbf{S} = \mathbf{Q}*\mathbf{K}^t

        step2:        \mathbf{P} = \mathbf{softmax_{column}(S)}

        step3:        \mathbf{O} = \mathbf{P}*\mathbf{V}

2,cpu 实现self attention

这里的数据类型使用了 float,实际网络中一般采用 fp16,数学过程是相同的;

cpu_self_attention.cpp

#include <stdio.h>
#include <string.h>

#include "cpu_gemm.h"
#include "utils.h"
#include "soft_max.h"
//all matrices are row major.

void cpu_self_attention(float* Q, int ldq,
						float* K, int ldk,
						float* V, int ldv,
						float* S, int lds,
						float* P, int ldp,
						float* O, int ldo,
						int N, int d)
{
	gemm_nt(Q, ldq, K, ldk, S, lds, N, N, d);// S = Q*K^t     (NxN) = (Nxd) * (dxN)
					printf("\nS =\n");	print_matrix(S, N, N, lds);
	soft_max_column(P, ldp, S, lds, N, N);// P(NxN) = softmax(S(NxN))
					printf("\nP =\n");	print_matrix(S, N, N, lds);
	gemm_nn(P, ldp, V, ldv, O, ldo, N, d, N);// O = P*V     (Nxd) = (NxN) * (Nxd)
}

cpu_gemm.cpp

#include "cpu_gemm.h"

void gemm_nn(float *A, int lda,		//A(M x K) rowMj
	     	 float *B, int ldb,		//B(K x N) rowMj
	     	 float *C, int ldc,		//C(M x N) rowMj
	      	 int M,
			 int N,
			 int K)
{
	for(int i=0; i<M; i++)
	{
		for(int j=0; j<N; j++)
		{
			float sigma = 0.0;

			for(int k=0; k<K; k++)
			{
				sigma += A[i*lda + k] * B[k*ldb + j];
			}

			C[i*ldc + j] = sigma;
		}
	}
}

void gemm_nt(float *A, int lda,		//A(M x K) rowMj
	     	 float *B, int ldb,		//B(N x K) rowMj
	     	 float *C, int ldc,		//C(M x N) rowMj
	      	 int M,
			 int N,
			 int K)
{
	for(int i=0; i<M; i++)
	{
		for(int j=0; j<N; j++)
		{
			float sigma = 0.0;

			for(int k=0; k<K; k++)
			{
				sigma += A[i*lda + k] * B[k + j*ldb];
			}

			C[i*ldc + j] = sigma;
		}
	}
}

cpu_softmax_column.cpp

这里使用的是未数值优化的方式,直接按照原始公式计算:

#include "soft_max.h"
void soft_max_column(float *P, int ldp, float* S, int lds, int M, int N)//P = softmax(S)  P(i,j) = exp(S(i,j))/sigma(exp(S(r,j)));  r=0,1,..,n-1 ;
{
    for(int j=0; j<N; j++){
        float sigma = 0.0f;

        for(int i=0; i<M; i++){
            sigma += exp(S[i*lds + j])
        }

        for(int i=0; i<M; i++){
            P[i*ldp + j] = S[i*lds + j]/sigma;
        }
    }
}

3, gpu 实现 self attention 正向

cuda 实现上述过程:

gpu_self_attention.cu

gpu_gemm.cu

gpu_softmax_column.cu

4,为什么不需要gpu 实现self attention 反向

融合上述过程

5, gpu 实现 flash attention 反向

融合算子

数学原理

cuda 实现

挖坑,未完待续 。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1804417.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习笔记——路由网络基础——环回接口(loopback)

6、环回接口(loopback) (1)定义 环回接口(loopback) &#xff1a;是一种虚拟的接口&#xff0c;是一种纯软件性质的虚拟接口&#xff0c;模拟一个单独的网段。 Loopback等于在设备中模拟另外不同的网络&#xff0c;实现不需要物理接口连接设备&#xff0c;依然可以模拟的功能…

MobileNetV4实战:使用 MobileNetV4实现图像分类任务(二)

文章目录 训练部分导入项目使用的库设置随机因子设置全局参数图像预处理与增强读取数据设置Loss设置模型设置优化器和学习率调整策略设置混合精度&#xff0c;DP多卡&#xff0c;EMA定义训练和验证函数训练函数验证函数调用训练和验证方法 运行以及结果查看测试完整的代码 在上…

了解Synchronized对象头?

1、对象头的结构 Java对象存储在内存中结构为&#xff1a; 对象头&#xff08;Header&#xff09;&#xff1a;实例数据&#xff08;Instance Data&#xff09;&#xff1a;定义类中的成员属性对齐填充字节&#xff08;Padding&#xff09;&#xff1a;由于HotSpot虚拟机的自…

高通SDX12:Voice Over USB 功能调试

一、功能概述及使用环境 Linux PC 作为上位机,内置 SLIC基于高通 SDX12 平台的设备作为从设备,通过USB连接到 Linux PC 上,在 PC 上枚举 UAC 设备从设备进行 MO/MT Call 时,上位机使用 arecord 进行录音,音频数据通过 USB 传至上位机,上位机停止录音后再使用 aplay 进行播…

经典文献阅读之--Online Monocular Lane Mapping(使用Catmull-Rom样条曲线完成在线单目车道建图)

0. 简介 对于单目摄像头完成SLAM建图这类操作&#xff0c;对于自动驾驶行业非常重要&#xff0c;《Online Monocular Lane Mapping Using Catmull-Rom Spline》介绍了一种仅依靠单个摄像头和里程计生成基于样条的在线单目车道建图方法。我们提出的技术将车道关联过程建模为一个…

【STM32】ucOS-III多任务程序

【STM32】uc/OS-III多任务程序 文章目录 【STM32】uc/OS-III多任务程序STM32F103C8T6移植uC/OS-III基于HAL库超完整详细过程与相关实验实验任务实验过程一、 uC/OS-III源码下载二、 建立STM32CubeMX工程三、 复制uC/OS-III文件到工程文件夹四、 添加工程组件和头文件路径五、修…

【中颖】SH79F9202 串口通信

头文件 uart.h #ifndef UART_H #define UART_H#include "SH79F9202.h" #include "LCD.h" #include "timer2.h" #include "timer5.h" #include "cpu.h" #include "key.h" #include "io.h" #include &qu…

【C++】深入理解decltype和decltype(auto)

深入理解decltype和decltype&#xff08;auto&#xff09; 一、decltype语法介绍二、decltype的推导规则1. expr不加括号2. expr加上括号 三、关于decltype的CV属性推导四、 decltype(auto) 的使用 一、decltype语法介绍 decltype关键字是C11新标准引入的关键字&#xff0c;它…

向量数据库是什么?

向量数据库是什么&#xff1f; 随着人工智能和机器学习技术的迅猛发展&#xff0c;向量数据库作为一种新型数据库引起了广泛关注。向量数据库专门用于存储和查询高维向量数据&#xff0c;是在大规模数据检索和相似性搜索领域的重要工具。 向量数据库的定义 向量数据库是一种…

心链13---主页切换功能 + loading特效 + 导航栏完善 + 队伍页接口修改

心链 — 伙伴匹配系统 直接取出所有用户&#xff0c;依次和当前用户计算分数&#xff0c;取 TOP N&#xff08;54 秒&#xff09; 优化方法&#xff1a; 切忌不要在数据量大的时候循环输出日志&#xff08;取消掉日志后 20 秒&#xff09;Map 存了所有的分数信息&#xff0c;占…

上位机图像处理和嵌入式模块部署(f407 mcu和其他mcu品类的选择)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 很多朋友读书的时候学的是stm32&#xff0c;工作中用的也是stm32。这本来问题不大&#xff0c;但是过去两三年的经历告诉我们&#xff0c;mcu的使用…

Polar Web【中等】反序列化

Polar Web【中等】反序列化 Contents Polar Web【中等】反序列化思路&探索EXPPHP生成PayloadGET传递参数 运行&总结 思路&探索 一个经典的反序列化问题&#xff0c;本文采用PHP代码辅助生成序列字符串的方式生成 Payload 来进行手动渗透。 打开站点&#xff0c;分析…

Python编程基础4

模块&#xff1a;模块支持从逻辑上组织Python代码&#xff0c;当代码量变得非常大的时候&#xff0c;最好把代码分成一些有组织的代码段。代码片段相互间有一定的联系&#xff0c;可能是一个包含数据成员和方法的类、函数、变量。 搜索路径&#xff1a;模块的导入需要一个叫做‘…

构建智能汽车新质生产力丨美格智能亮相2024高通汽车技术与合作峰会

近日&#xff0c;以“我们一起&#xff0c;驭风前行”为主题的2024高通汽车技术与合作峰会在无锡国际会议中心隆重举行。作为高通公司的战略合作伙伴&#xff0c;美格智能受邀全程参与此次汽车技术与合作峰会。在峰会现场&#xff0c;美格智能产品团队隆重展示了多款基于高通平…

Wireshark自定义Lua插件

背景&#xff1a; 常见的抓包工具有tcpdump和wireshark&#xff0c;二者可基于网卡进行抓包&#xff1a;tcpdump用于Linux环境抓包&#xff0c;而wireshark用于windows环境。抓包后需借助包分析工具对数据进行解析&#xff0c;将不可读的二进制数转换为可读的数据结构。 wires…

VUE封装-自定义权限控制指令

在实际开发中&#xff0c;会遇到很多的权限控制、资源位的场景&#xff0c;其实就是用来控制某个组件的展示与否&#xff0c;可以是一个按钮、一个报表、一个TAB页面等 例如下图&#xff0c;我想通过当前登录的用户控制谷歌的这个logo显示与否 因为设计到的权限、资源位控制比…

摆脱Jenkins - 使用google cloudbuild 部署 java service 到 compute engine VM

在之前 介绍 cloud build 的文章中 初探 Google 云原生的CICD - CloudBuild 已经介绍过&#xff0c; 用cloud build 去部署1个 spring boot service 到 cloud run 是很简单的&#xff0c; 因为部署cloud run 无非就是用gcloud 去部署1个 GAR 上的docker image 到cloud run 容…

GUI编程-01

组件 窗口 弹窗 面板 文本框 列表框 按钮 图片 监听事件 鼠标 键盘事件 破解工具 Java提供了丰富的图形用户界面&#xff08;Graphics User Interface&#xff0c;GUI&#xff09;的类库&#xff0c;基于这些类库可以编写窗口程序。 Java关于图形界面的类库主要放在…

【Redis学习笔记05】Jedis客户端(string、list、set)

Jedis客户端 1. 命令 1.1 String类型 1.1.1 常见命令 SET命令 语法&#xff1a;SET key value [EX seconds | PX milliseconds] [NX|XX] 说明&#xff1a;将string类型的value值设置到指定key中&#xff0c;如果之前该key存在&#xff0c;则会覆盖原先的值&#xff0c;原先…

数染色体 算法 python源码

效果图如下&#xff1a; 原图&#xff1a; 完整代码&#xff1a; import cv2 import numpy as np from skimage import measure import randomimage cv2.imread(113.jpg, cv2.IMREAD_GRAYSCALE)blurred_img cv2.GaussianBlur(image, (5, 5), 0)_, binary_image cv2.thresho…