Flash Attention 的优点以及Softmax 归一化系数解释

news2024/9/26 1:27:01
文章:FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness

 

原始Attention 计算使用gpu存储标准流程

涉及两个gpu存储器:

1)SRAM(static Random Access Memory):静态随机存取存储器

2)HBM(High Bandwidth Memory):高带宽存储器

具体流程

Q,K,V 初始存储在HBM中,S=Q(K的转置),P= softmax(S),O = PV

1)将Q,K从HBM取到SRAM当中,计算S,将S 放到HBM当中。

2)将S从HBM取到SRAM当中,计算P,将P 放到HBM当中。

3P)将S从HBM取到SRAM当中,计算O,将O 放到HBM当中。

FlashAttention旨在避免从 HBM(High Bandwidth Memory)中读取和写入注意力矩阵,这需要做到:

  1. 目标一:在不访问整个输入的情况下计算softmax函数的缩减;
  2. 目标二:在后向传播中不能存储中间注意力矩阵.

优化1:

FlashAttention如何实现在不访问整个输入的情况计算softmax大的缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和V做矩阵乘之前,需要让 Q、K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减

优化2:

在后向传播中不存储中间注意力矩阵,以Flash Attention所提供的算法为例,通过对比标准Attention算法在实现过程中,标准Attention算法的实现需要将计算过程中的S、P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型,因此Flash Attention就提出了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。

在Flash Attention的前向计算算法中我们可以看出,Flash Attention算法并没有将S、P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。

1. 存储softmax归一化的系数的含义

在Flash Attention中,"只存储softmax归一化的系数"是指在进行反向传播时,不直接存储整个注意力矩阵,而是只存储用于softmax归一化的系数。这样做的目的是为了减少存储需求,从而节省内存。

在传统的注意力机制中,我们需要计算并存储一个N^2的注意力矩阵,其中N是输入序列的长度。这个矩阵存储了序列中每个元素对其他所有元素的注意力权重。然而,这种方法的存储需求随着序列长度的增加而呈平方级增长,对内存的需求非常大。

Flash Attention通过只存储softmax归一化的系数,避免了存储整个注意力矩阵,从而大大减少了内存需求。这些系数足够用于在反向传播过程中计算梯度,而无需参考完整的注意力矩阵。

2. Flash Attention的优点

Flash Attention的主要优点是它可以显著减少内存需求,同时也加速了计算。这使得模型能够处理更长的序列,或者在内存有限的设备上运行。此外,Flash Attention还保持了与传统注意力机制相同的表现力,因为它仍然能够模拟序列中元素之间的所有对应关系。

Softmax归一化系数解析

1. Softmax归一化

Softmax是一种常用的归一化函数,它可以将一组任意实数转换为一组在(0,1)区间内的实数,且这组实数的总和为1。这使得softmax函数的输出可以被解释为一组概率分布。

2. Softmax归一化系数

在softmax函数中,"归一化系数"通常指的是用于将每个输入值转换为概率的那个系数。具体来说,假设我们有一组输入值x1, x2, ..., xn,那么对应的softmax归一化系数就是1/Z,其中Z是所有输入值经过指数运算后的和,即Z = exp(x1) + exp(x2) + ... + exp(xn)。这个归一化系数保证了softmax函数的输出值的总和为1。

在某些情况下,为了节省存储空间和计算资源,我们可能只会存储这个归一化系数,而不是存储整个softmax函数的输出。例如,在上文提到的Flash Attention中,就只存储了这个归一化系数,而不是存储整个注意力矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1151637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python Django 之全局配置 settings 详解

文章目录 1 概述1.1 Django 目录结构 2 常用配置:settings.py2.1 注册 APP:INSTALLED_APPS2.2 模板路径:TEMPLATES2.3 静态文件:STATICFILES_DIRS2.4 数据库:DATABASES2.5 允许访问的主机:ALLOWED_HOSTS 1 …

【C】柔性数组

柔性数组 也许你从来没有听说过柔性数组(flexible array)这个概念,但是它确实是存在的。 C99 中,结构中的最后一个元素允许是未知大小的数组,这就叫做『柔性数组』成员。 例如: 柔性数组的特点 结构中的柔性数组成员前…

tab切换样式

一、样式一 <view class"tabs"><view class"tab" :class"{active: active index}" v-for"(item, index) in tabs" click"active index">{{item.label}}</view></view> data(){return {active: 0…

Codeforces Round 905 (Div. 3)

目录 A. Morning B. Chemistry C. Raspberries D. In Love E. Look Back F. You Are So Beautiful G1. Dances (Easy version) G2. Dances (Hard Version) A. Morning time limit per test 2 seconds memory limit per test 256 megabytes input standard input …

重磅新闻-国内首家八类网线认证分析仪上市了

伴随USA对国内某些敏感企业的非常不友好&#xff0c;设置层层障碍&#xff0c;技术堡垒。使得一些网线基础制造研发、线缆线束厂、汽车生产生产厂、军工用途的线缆品质的认证、以及相关高校的研发受到了不同的程度的阻碍。重磅消息&#xff0c;国内首家八类网线认证测仪-维信仪…

Vue前后端分离的低代码开发框架

目录 项目简介 平台特性 1.构架特性 2.功能特性 技术栈 1.后端技术栈 2.前端技术栈 2.1 Vue2技术栈 2.2 Vue3技术栈 3.数据库支持 部署方式 项目简介 JNPF开发平台是一个基于SpringBootVue3的全栈开发平台&#xff0c;采用微服务、前后端分离架构。前后端封装了上千…

【多线程相关其一】Python并发编程

1.为什么要引入并发编程 场景1&#xff1a;一个网络爬虫&#xff0c;按顺序爬取花了1个小时&#xff0c;采用并发下载减少到20分钟。 场景2&#xff1a;一个APP应用&#xff0c;优化前每次打开页面需要3秒&#xff0c;采用异步并发提升到每次200毫秒 引入并发&#xff0c;就是…

Centos系统安装阿里云盘+简单使用

GitHub地址&#xff1a; 1、安装阿里云盘 wget https://github.com/tickstep/aliyunpan/releases/download/v0.2.7/aliyunpan-v0.2.7-linux-amd64.zip unzip aliyunpan-v0.2.7-linux-amd64.zip mv aliyunpan-v0.2.7-linux-amd64 /usr/local/aliyunpan ln -s /usr/local/aliyu…

安装Jdk 报错 ,Java SE Development Kit 8 Update 202(64-bit)安装完毕之前,向导被中断

具体原因没有找到&#xff0c;估计是由于jdk 没有删干净导致的&#xff0c;我的处理方法是&#xff0c;将 Java的注册表全然后手动安装 Jdk和导入注册表&#xff08;在同事那里获取jdk文件 压缩包&#xff0c;并将 java的注册表导出&#xff0c;放在自己电脑上使用。&#xff0…

动手学深度学习——第七次学

LeNet&#xff08;LeNet-5&#xff09;由两个部分组成&#xff1a; 卷积编码器和全连接层密集块 卷积把高宽不断变小&#xff0c;把通道数逐渐增多&#xff0c;&#xff08;最后高宽会变成&#xff0c;通道会变得很大&#xff0c;然后做全连接进行输出&#xff09;通道信息可以…

Leetcode—1488.避免洪水泛滥【中等】

2023每日刷题&#xff08;十四&#xff09; Leetcode—1488.避免洪水泛滥 算法思想 将晴天的日期全部记录在set<int> sun中使用unordered_map<int, int> lakeRainy来记录每个湖泊上一次下雨的日期遇到晴天时先不用管抽哪个湖当下雨时&#xff0c;湖泊已经装满水时…

QT5 通过 webview2 加载网页

官方文档参考&#xff1a;https://learn.microsoft.com/zh-cn/microsoft-edge/webview2/get-started/win32 Webview2依赖的头文件和库 头文件主要为&#xff1a;WebView2和WixLibrary&#xff0c;存储在include/external 库主要为&#xff1a;WebView2LoaderStatic.lib和W…

C++——类和对象(构造函数与析构函数)

构造函数与析构函数 本章思维导图&#xff1a; 注&#xff1a;本章思维导图对应的Xmind文件和.png文件都已导入到”资料“中 1. 构造函数 以前&#xff0c;我们写一个Date类一般是这么写的&#xff1a; class Date { public :void Init(int year, int month, int day){_year…

Unity Animator cpu性能测试

测试案例&#xff1a; 场景中共有4000个物体&#xff0c;挂在40个animtor 上&#xff0c;每个Animator控制100个物体的动画。 使用工具&#xff1a; Unity Profiler. Unity 版本&#xff1a; unity 2019.4.40f1 测试环境&#xff1a; 手机 测试过程&#xff1a; 没有挂…

解读电力系统中的GPS北斗卫星同步时钟系统

随着电力系统的快速发展,变电站中的各类系统 &#xff1a;计算机监控系统、水情测报系统、视频监控系统 状态监测系统 生产信息管理系统等&#xff0c;各类装置&#xff1a;继电保护装置、故障录波装置、PMU装置、事件顺序记录SOE功能越来越强大&#xff0c;需要采集、记录的数…

CSS3背景样式

在CSS 2.1中&#xff0c;background属性的功能还无法满足设计的需求&#xff0c;为了方便设计师更灵活地设计需要的网页效果&#xff0c;CSS3在原有background基础上新增了一些功能属性&#xff0c;可以在同一个对象内叠加多个背景图像&#xff0c;可以改变背景图像的大小尺寸&…

LeetCode热题100 48.旋转图像

题目描述 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,2,3],[4,5,6],[7,8,9…

2022年09月 Python(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

Python等级考试&#xff08;1~6级&#xff09;全部真题・点这里 一、单选题&#xff08;共25题&#xff0c;每题2分&#xff0c;共50分&#xff09; 第1题 运行以下代码&#xff0c;结果输出的是&#xff1f;&#xff08; &#xff09; means[Thank,You] print(len(means))A…

Android开发知识学习——TCP / IP 协议族

文章目录 学习资源来自&#xff1a;扔物线TCP / IP 协议族TCP连接TCP 连接的建立与关闭TCP 连接的建立为什么要三次握手&#xff1f; TCP 连接的关闭为什么要四次挥手&#xff1f; 为什么要⻓连接&#xff1f; 常见面试题课后题 学习资源来自&#xff1a;扔物线 TCP / IP 协议…