【大模型训练】Flash Attention详解

news2024/9/20 6:41:21

文章目录

  • 前言
  • 预备知识
  • FlashAttention1
    • 传统Attention计算方式
    • FlashAttention1的基本原理
      • 除去Softmax操作的分块计算
      • Softmax分块计算
      • Attention分块计算
  • FlashAttention2
  • 参考资料

前言

FlashAttention系列工作,是一种加速注意力计算方法,目前已经应用在:GPT-3、Falcon2(阿联酋大模型)、Llama2、Megatron-LM、GPT-4等流行LLM上。并且FlashAttention2已经集成到了pytorch2.0中,可以很便捷的调用。

1. FlashAttention动机Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length.,可以看出由于Transformer中self-attention 的时间和内存复杂度是序列长度的二次方,所以序列过长时,算法速度会变慢,需要消耗很高的内存,导致低效的

2. FlashAttention主要贡献

  • FlashAttention利用底层硬件的内存层次知识,例如GPU的内存层次结构,来提高计算速度和减少内存访问开销
  • 核心原理是通过将输入分块,并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作
  • FlashAttention减少了内存读写量,从而实现了2-4倍的计算加速

预备知识

1. GPU 内存层次结构

GPU 内存层次结构 包含多种不同大小和速度的内存形式,内存容量越小,读写速度越快。以A100 GPU为例,主要有两种类型,如下图所示:

在这里插入图片描述

  • 高带宽内存 (HBM):也就是我们常说的GPU显存,A100具有 40-80GB HBM,带宽为 1.5-2.0TB/s
  • SRAM:位于GPU片上,每个 108 个流式多处理器(SM, Streaming Multiprocessor)都有 192KB 片上 SRAM,带宽估计约为 19TB/s
  • 二者的位置分布如下图所示,其中HBM在VRAM部门,而SRAM在GPU内部:

在这里插入图片描述

可以看到,片上 SRAM 比 HBM 快一个数量级,但内存容量小很多数量级。随着计算相对于内存速度变得更快,内存 (HBM) 访问越来越成为操作瓶颈。因此,利用快速 SRAM 变得更加重要。

2. GPU执行过程

GPU 有大量线程(threads )来执行操作(称为内核 Kernel)。每个Kernel将输入从 HBM 加载到寄存器和 SRAM,进行计算,然后将输出写入 HBM。

3. 性能特点

根据计算和内存访问的平衡,操作可以分为计算限制或内存限制。这通常通过算术强度来衡量,即内存访问的每个字节的算术运算数量。

  • 计算限制:操作所花费的时间由算术运算的数量决定,而访问 HBM 的时间要少得多。典型的例子是大内部维度的矩阵乘法,以及大量通道的卷积
  • 内存限制:操作所花费的时间由内存访问次数决定,而计算所花费的时间要少得多。示例包括大多数其他操作:逐元素(激活、dropout)和归约(求和、softmax、批量归一化、层归一化)

FlashAttention1

在这里插入图片描述

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
paper,code

传统Attention计算方式

传统的Attention计算过程如下:

在这里插入图片描述

  • 首先 Q , K , V Q,K,V Q,K,V矩阵计算好,放在HBM中
  • 接着,为了计算QK注意力得分,将 Q , K Q,K Q,K从HBM中取出来写入SRAM,然后计算 S = Q K T S=QK^T S=QKT,再把 S S S从SRAM写入HBM
  • 然后,从HBM加载 S S S到SRAM,计算 P = S o f t m a x ( S ) P=\rm{Softmax}(\it{S}) P=Softmax(S),然后再把 P P P从SRAM写入HBM
  • 最后,从HBM加载 P , V P,V P,V到SRAM,计算最终的输出 O = P V O=PV O=PV,然后把 O O O写入HBM

可以看到,这个过程中存在多次HBM和SRAM之间的读写操作。同时由于Attention的计算方式,导致中间的临时变量 S , P S,P S,P的参数量和输入序列长度的平方成正比:

在这里插入图片描述

因此,在训练时,长序列输入在计算Attention时会产生更大参数量的临时变量,会占用更大显存空间,导致更多的访问消耗。也就是说,Attention操作主要是内存限制问题,通信时间是制约计算效率的主要因素。

FlashAttention1的基本原理

因此,FlashAttention主要的思想,就是减少通信时间,也就是减少IO操作,使得计算尽可能多的访问片上的SRAM,尽可能少的访问片外的HBM。

  • 通过分块计算,融合多个操作,减少中间结果缓存
  • 反向传播时,重新计算中间结果(类似于梯度检查点的原理)

在这里插入图片描述

除去Softmax操作的分块计算

在计算Attention主要有两个临时变量 S , P S,P S,P,FlashAttention的分块计算,使得不需要存储这两个临时变量,而是直接在SRAM计算得到部分最终结果 O O O,从而减少了内存访问开销。这里我们先忽略Softmax操作,因为在分块计算,他比较麻烦。

这里假设 Q , K , V Q,K,V Q,K,V矩阵的大小为 ( 4 , 3 ) (4, 3) (4,3),那么FlashAttention的分块计算过程如下,由于矩阵乘法的性质,每次分块计算得到的结果,都是最终结果矩阵中的一部分值:
在这里插入图片描述

Softmax分块计算

下面,我们来解决Softmax这个麻烦的操作,首先来回顾Softmax的计算公式:

softmax ⁡ ( { x 1 , … , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N softmax({x1,,xN})={j=1Nexjexi}i=1N

但是,如果数据类型为FP16,那么最大可以表示为65536,因此当 x i x_i xi为12时, e 12 = 162754 e^{12}=162754 e12=162754,超过了FP16所能表示的最大值。因此,我们需要使用Safe_Softmax方法来避免这个问题:

m = max ⁡ ( x i ) softmax ⁡ ( { x 1 , … , x N } ) = { e x i / e m ∑ j = 1 N e x j / e m } i = 1 N = { e x i − m ∑ j = 1 N e x j − m } i = 1 N \begin{aligned} & m=\max \left(x_i\right) \\ & \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i} / e^m}{\sum_{j=1}^N e^{x_j} / e^m}\right\}_{i=1}^N=\left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \end{aligned} m=max(xi)softmax({x1,,xN})={j=1Nexj/emexi/em}i=1N={j=1Nexjmexim}i=1N

也就是在计算Softmax之前,先对输入数据进行归一化处理。此时计算Softmax的流程为:

x = [ x 1 , … , x N ] m ( x ) : = max ⁡ ( x ) p ( x ) : = [ e x 1 − m ( x ) , … , e x N − m ( x ) ] l ( x ) : = ∑ i p ( x ) i softmax ⁡ ( x ) : = p ( x ) l ( x ) \begin{aligned} & x=\left[x_1, \ldots, x_N\right] \\ & m(x):=\max (x) \\ & p(x):=\left[e^{x_1-m(x)}, \ldots, e^{x_N-m(x)}\right] \\ & l(x):=\sum_i p(x)_i \\ & \operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} x=[x1,,xN]m(x):=max(x)p(x):=[ex1m(x),,exNm(x)]l(x):=ip(x)isoftmax(x):=l(x)p(x)

接下来看分块处理时,假设这里分为两块处理,我们首先需要在每个块内找到最大值(使用临时变量来保存),做归一化处理:

x = [ x 1 , … , x N , … x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … x 2 N ] m ( x 1 ) p ( x 1 ) l ( x 1 ) m ( x 2 ) p ( x 2 ) l ( x 2 ) \begin{aligned} & x=\left[x_1, \ldots, x_N, \ldots x_{2 N}\right] \\ & x^1=\left[x_1, \ldots, x_N\right] \quad x^2=\left[x_{N+1}, \ldots x_{2 N}\right] \\ & m\left(x^1\right) \quad p\left(x^1\right) \quad l\left(x^1\right) \quad m\left(x^2\right) \quad p\left(x^2\right) \quad l\left(x^2\right) \end{aligned} x=[x1,,xN,x2N]x1=[x1,,xN]x2=[xN+1,x2N]m(x1)p(x1)l(x1)m(x2)p(x2)l(x2)

然后再计算数据的全局最大值,并且更新 p ( x ) , l ( x ) p(x),l(x) p(x),l(x),最后计算得到输入数据的Softmax值:

  • 由于全局最大值,一定是各个块内最大值中的一个
  • 因此在更新 p ( x ) , l ( x ) p(x),l(x) p(x),l(x),只需要乘以每个块最大值相对于全局最大值的差值的指数,就可以了
    m ( x ) : = max ⁡ ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) softmax ⁡ ( x ) : = p ( x ) l ( x ) \begin{aligned} &\begin{aligned} & m(x):=\max \left(m\left(x^1\right), m\left(x^2\right)\right) \\ & p(x):=\left[e^{m\left(x^1\right)-m(x)} p\left(x^1\right), e^{m\left(x^2\right)-m(x)} p\left(x^2\right)\right] \\ & l(x):=e^{m\left(x^1\right)-m(x)} l\left(x^1\right)+e^{m\left(x^2\right)-m(x)} l\left(x^2\right) \end{aligned}\\ &\operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} m(x):=max(m(x1),m(x2))p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)]l(x):=em(x1)m(x)l(x1)+em(x2)m(x)l(x2)softmax(x):=l(x)p(x)

Attention分块计算

最后,我们来看一下FlashAttention的完整计算流程:
在这里插入图片描述

FlashAttention2

在这里插入图片描述
相比于FlashAttention1的改进:

  • 减少了非矩阵乘法计算,可以利用Tensor Core加速计算
  • 调整了内外训练方式,改为 Q 为外层循环,KV 为内层循环,进一步减少HBM读写,增加了并行度
  • 如果一个Block处于矩阵上三角部分(Mask机制),则不进行attention计算,进一步优化了计算效率

参考资料

  • [1] https://www.bilibili.com/video/BV1UT421k7rA/?share_source=copy_web&vd_source=79b1ab42a5b1cccc2807bc14de489fa7
  • [2] https://zhuanlan.zhihu.com/p/676655352

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决TensorFlow-GPU安装错误:Python版本兼容性与环境配置问题

创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力! 从错误信息中可以看到,tensorflow-gpu 安装时出现了 packaging.requirements.InvalidRequirement 错误,具体是因为解析 Python 版本时出现了问题。这通常是由于环…

OpenAI全新发布o1模型:开启 AGI 的新时代

OpenAI全新发布o1模型:开启 AGI 的新时代 欢迎关注【youcans的AGI学习笔记】原创作品 2024年9月13日,OpenAI新模型o1 正式发布。o1 在测试化学、物理和生物学专业知识的基准 GPQA-diamond 上,全面超过了人类博士专家。 OpenAI 宣称&#xff…

CANFD芯片应用中关键功能和性能指标分析

CAN FD芯片通信速率高达5Mbps,需要线缆少传输距离较远,在汽车、工业、宇航、能源等领域应用越来越广。 1)汽车工业:汽车内部电子系统日益复杂,需要高速、可靠的数据传输来确保车辆的安全和性能。CAN FD通信提供了更高…

R数据对象快速保存与读取:qs包

qs:R对象的快速序列化 qs是一个R语言包,使用qs可以快速地从磁盘中保存和读取对象。** 它的主要目的是替换R中的saveRDS和readRDS函数,提供了一个更加快速而完整的数据读写方法。 ** 受到fst的启发,qs通过lz4/zstd库使用了类似的块…

人工智能和机器学习:探讨人工智能和机器学习的最新发展、应用、挑战和未来趋势

人工智能和机器学习是当前科技领域的热点话题,其最新发展、应用、挑战和未来趋势备受关注。 最新发展: 人工智能和机器学习技术在近年来得到了快速发展,尤其是深度学习技术的广泛应用。例如,深度学习在图像识别、语音识别、自然语…

docker入门安装及使用

docker概述 docker是一种容器技术,它提供了标准的应用镜像(包含应用和应用多需要的依赖),因此,我们可以非常轻松的在docker中安装应用,安装好的应用相当于一个独立的容器 如下图所示,为docker中…

机器学习文献|基于循环细胞因子特征,通过机器学习算法预测NSCLC免疫治疗结局

今天我们一起学习一篇最近发表在Journal for immunotherapy of cancer (IF 10.9)上的文章,Machine learning for prediction of immunotherapeutic outcome in non-small-cell lung cancer based on circulating cytokine signatures[基于循环…

制证书、制电子印章、签章 -- 演示程序说明

ofd签章系统涉及证书的制作、电子印章制作、签章、验章等环节。关于ofd签章原理,本人写过多篇文章进行了阐述; 见文章《ofd板式文件 电子签章实现方法》、《一款简单易用的印章设计工具》、《签章那些事 -- 让你全面了解签章的流程》。 为了进一步加深对签章过程的理…

基于Spring Security OAuth2认证中心授权模式扩展

介绍 Spring Security OAuth2 默认实现的四种授权模式在实际的应用场景中往往满足不了预期。 需要扩展如下需求: 手机号短信验证码登陆微信授权登录 本次主要通过继承Spring Security OAuth2 抽象类和接口,来实现对oauth2/token接口的手机号短信的认证…

GD32F4开发 -- FATFS移植

之前已经讲了 GD32F4开发 – FATFS文件系统 现在将其一直到我的工程。 一、移植 在工程里创建FATFS文件夹。 移植正点原子 实验39 FATFS实验里的代码。 移植完后如下图: 注意:ffconf.h文件,找到对应宏并按照需求修改。 二、创建 FATFS 分…

最新中科院预警名单发布,多本高分区期刊被标记“On hold”(附20-24年所有名单)

2024年2月,期刊分区表团队发布2024年度《国际期刊预警名单 》。 最新版的《国际期刊预警名单》共有24本期刊,较23年版本的28本减少了4本,全部预警期刊当中,医学类数量最多,达11本。期刊JOURNAL OF BIOMATERIALS AND T…

高效率免费创作文章,4款ai写作生成器来帮忙

高效率免费创作文章,这对于每个创作者来说是非常不错的方法,即能提高创作效率,而且还能节省文章创作成本,但是想要高效率免费创作我们就需要找到相应的ai写作生成器来帮忙。因为如果是人工创作文章就需要耗费时间成本与人力成本的…

在pycharm终端中运行pip命令安装模块时,出现了“你要如何打开这个文件”弹出窗口,是什么状况?

这种情况发生在Windows系统上,当在PyCharm终端中运行pip命令安装模块时,如果系统无法确定要使用哪个程序打开该文件,就会出现“你要如何打开这个文件”弹出窗口。 解决方法是: 选择“查找一个应用于此文件”的选项。在弹出的窗口…

C++与C语言的区别

前言 本文主要用C语言和C做对比来学习C,便于个人理解。C包含C语言,是对C语言的扩展,在C中,支持C语言的语法使用,C是C语言的超集 一、C与C语言的区别 C语言简单高效,适合低级系统编程和硬件相关的开发。…

揭秘Web3新纪元:算力共享平台如何重塑数字世界的力量源泉

目录 一、Web3:算力共享的新舞台 二、技术革新:解锁算力的无限潜能 三、应用场景:算力如何改变世界 四、未来展望:算力共享的无尽可能 在区块链技术的浪潮中,Web3.0的曙光正引领我们迈向一个前所未有的数字时代。而在这场变革的洪流中,基于Web3的算力共享平台犹如一股…

Redis集群_主从复制

Redis集群基本概念 在实际项目中,一般不会只在一台机器上部署redis服务器,因为单台redis服务器不能满足高并发的压力,另外如果该服务器或者redis失效,整个系统就可能崩溃项目里一般会用主从复制的模式来提升性能,用集…

“精装朋友圈”的年轻人,开始在40度高温买羽绒服

文 | 螳螂观察 作者 | 如意 人生一世,苦了自己也不能苦朋友圈。 这届的年轻人,无论人生有多“毛坯”,都有一个一生要强的朋友圈,而且“装修”朋友圈还有一套哲学,信奉图片精修,排版讲究,文案…

OpenAI o1 Review 大模型PHD水平数理推理能力 OpenAI o1 vs GPT4o vs Gemini vs Claude

1. 介绍 OpenAI昨天发布了o1推理优化的大模型,利用了CoT (Chain of Thought) 思维链推理机制,提升了针对数学/物理/编程/逻辑等复杂问题的推理能力。OpenAI官方网站评测 OpenAI o1大模型对比GPT4o的数学、编程能力有显著提升。我们利用DeepNLP的AI Stor…

2024.9.13 Python与图像处理新国大EE5731课程大作业,SIFT 特征和描述符,单应性矩阵透视变换

1.SIFT特征点和描述符 import cv2 import numpy as np import matplotlib.pyplot as plt # read image img cv2.imread(im01.jpg,cv2.IMREAD_COLOR) gray cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) plt.imshow(gray,plt.cm.gray)提取图片,以灰度图像输出 #SIFT sift…

【免费分享】OpenHarmony鸿蒙物联网开发板资料包一网打尽,附教程/视频/项目/源码...

想要深入学习鸿蒙设备开发及鸿蒙物联网开发吗?现在机会来了!我们为初学者们准备了一份全面的资料包,包括原理图、教程、视频、项目、源码等,所有资料全部免费领取,课程视频可试看(购买后看完整版&#xff0…