Flash attention入门

news2025/1/10 2:55:46

一、目录

  1. flash attention
  2. GPU运算流程
  3. flash attention 原理
  4. flash attention 与 standard attention 时间/内存 对比。
  5. flash attention 算法实现
  6. 比较flash attention 计算、memory-efficient attention 等不同内核下用时

二、实现

  1. flash attention
    目的: 提高运行速度,减少内存消耗。

  2. GPU运算流程
    见gpu 入门篇

  3. flash attention 原理
    3.1 原理:
    flashAtention其加速的原理是非常简单的,也是最基础和常见的系统性能优化的手段,即通过利用更高速的上层存储计算单元,减少对低速更下层存储器的访问次数,来提升模型的训练性能。在这里插入图片描述
    图片代表的为带宽大小与内存大小的关系,即从上面的数字可以看出SRAM的访问速率是HBM的10倍左右,然而其能承载的数据量却远远小于HBM。
    CPU 内存大小》GPU 高带宽内存>>GPU SRAM(静态内存)
    GPU SRAM速度>>GPU 高带宽 显存>>CPU 内存速度
    3.2. 创新点:将flashAttention 计算过程由HBM 转为SRAM 中,减少访问次数。
    3.3. 标准attention 计算方法 与flashAttention 计算方法
    标准attention计算:
    首先,从HBM中读取完整的Q和K矩阵(每个大小为N x d),计算点积得到相似度得分S(大小为N x N),需要进行O(Nd + N^2)次HBM访问。
    其次,计算注意力权重P(大小为N x N)时,需要对S进行softmax操作,这需要进行O(N^2)次HBM访问。
    最后,将注意力权重P和值向量V(每个大小为N x d)加权求和得到输出向量O(大小为N x d)时,需要进行O(Nd)次HBM访问。
    标准 Attention 算法的总HBM访问次数为O(Nd + N^2)
    flashAttention计算:
    将原始的注意力矩阵分解成更小的子矩阵,然后分别对这些子矩阵进行计算,只要这个子矩阵的大小可以在SRAM内存放,就可以在计算过程中只访问SRAM。
    计算过程中要尽量的利用SRAM进行计算,避免访问HBM操作。
    3.4. 什么时候使用HBM,什么时候使用SRAM?
    编程时,人为指定SRAM空间。

  4. flash attention 与 standard attention 时间/内存 对比。
    参考:https://zhuanlan.zhihu.com/p/638468472
    以 batch=32, seq_len=512, n_head=16,head_dim=64 为例,记录flash attention 与standard attention 时间/内存对比。在这里插入图片描述flash attention实现:

import torch
from xformers import ops as xops
import time
bs = 32
seq_len = 512
n_head = 16
head_dim = 64
query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")

flash_query_states = query_states.transpose(1, 2)
flash_key_states = key_states.transpose(1, 2)
flash_value_states = value_states.transpose(1, 2)
start_time = time.time()

#xformers 实现的注意力机制, 加速框架
flash_attn_output = xops.memory_efficient_attention(
    flash_query_states, flash_key_states, flash_value_states,
    attn_bias=xops.LowerTriangularMask()

)
print(f'flash attention time: {(time.time()-start_time)*1000} ms')
print(torch.cuda.max_memory_allocated("cuda:0")/1024**2)      #192M
print("=============================")
print(torch.cuda.memory_allocated("cuda:0")/1024**2)         #128M

standard attention 实现:

import torch
from xformers import ops as xops
import time
bs = 32
seq_len = 512
n_head = 16
head_dim = 64
query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
flash_query_states = query_states.transpose(1, 2)
flash_key_states = key_states.transpose(1, 2)
flash_value_states = value_states.transpose(1, 2)
start_time = time.time()
import math
import torch.nn as nn
attention_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool)).view(1, 1, seq_len, seq_len)
attention_mask = attention_mask.to(dtype=torch.float16).cuda()  # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float16).min           #数据类型
def standard_attention(query_states, key_states, value_states, attention_mask):
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
    attn_weights = attn_weights + attention_mask
    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2)
    return attn_output

start_time = time.time()
attn_output = standard_attention(query_states, key_states, value_states, attention_mask)

print(f'standard attention time: {(time.time()-start_time)*1000} ms')
#print(torch.allclose(attn_output, flash_attn_output, rtol=2e-3, atol=2e-3))   #判断两个张量是否接近相等(计算机计算的不精确性,完全相等的浮点数可能存在微小差异)

print(torch.cuda.max_memory_allocated("cuda:0")/1024**2)      #1128M
print("=============================")
print(torch.cuda.memory_allocated("cuda:0")/1024**2)         #136M
  1. flash attention 算法
    参考:https://blog.csdn.net/qinduohao333/article/details/131449876FlashAttention
    算法实现的关键在于以下三点:
    1 softmax的tiling展开,可以支持softmax的拆分并行计算,从而提升计算效率
    2 反向过程中的重计算,减少大量的显存占用,节省显存开销。
    3 通过CUDA编程实现fusion kernel
    参数了解:
    SRAM:静态显存。嵌入在GPU芯片上的SRAM存储器。
    HBM:高带宽内存。使得GPU能够更快地读取和写入数据。
    DRAM: 动态显存。嵌入在CPU芯片上的DARM存储器。
    所以:读写速度 SRAM>HBM>DRAM.在这里插入图片描述
  2. 比较flash attention 计算、memory-efficient attention 等不同内核下用时
    参考:https://blog.51cto.com/u_15293476/6131364
    用时比较: 内核下torch 实现>不指定内核下torch 实现> 内核下flash attention> 内核下 efficient attention.
import torch
import torch.nn.functional as F
from rich import print
from torch.backends.cuda import sdp_kernel    #内核计算
from enum import IntEnum
import torch.utils.benchmark as benchmark
device = "cuda" if torch.cuda.is_available() else "cpu"       #cudnn 需要使用gpu

# 超参数定义
batch_size = 64
max_sequence_len = 256
num_heads = 32
embed_dimension = 32
dtype = torch.float16

# 模拟 q k v
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

# 定义一个计时器:
def torch_timer(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# torch.backends.cuda中也实现了,这里拿出了为了好理解backend_map是啥
class SDPBackend(IntEnum):
    r"""
    Enum class for the scaled dot product attention backends.
    """
    ERROR = -1
    MATH = 0
    FLASH_ATTENTION = 1
    EFFICIENT_ATTENTION = 2

# 使用上下文管理器context manager来
# 其他三种方案,字典映射
backend_map = {
    SDPBackend.MATH: {               #启用pytorch 实现
        "enable_math": True,
        "enable_flash": False,
        "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {     #启用flashattention
        "enable_math": False,
        "enable_flash": True,
        "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {   #启用memory_efficient attention
        "enable_math": False,
        "enable_flash": False,
        "enable_mem_efficient": True}
}

# 基本版,不指定
print(f"基本对照方案 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# 基本对照方案 运行时间: 558.831 microseconds

#内核中运行
with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"math 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# math 运行时间: 1013.422 microseconds
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"flash attention 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported")
# flash attention 运行时间:  557.343 microseconds
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"Memory efficient 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported")
# Memory efficient 运行时间: 428.007 microseconds

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1678797.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NGM-SLAM:首创融合神经辐射场子图的3DGS-SLAM,问鼎SOTA!

论文标题: NGM-SLAM: Gaussian Splatting SLAM with Radiance Field Submap 论文作者: Mingrui Li, Jingwei Huang, Lei Sun Aaron, Xuxiang Tian, Tianchen Deng, Hongyu Wang 导读: 3DGS技术因其性能卓越而备受关注,3DGS-SLA…

GPT-4o 炸裂发布!你竟然还没用上?(附详细教程)

今天AI界的爆炸新闻非chatgpt-4o莫属,从早上到现在随处可见的文章推送,视频推送。 大家或多或少都有耳闻了,今天主要讲一讲我们普通人到底怎么用?如果不氪金行不行?我就想体验一下可不可以?带着问题往下看 …

Python 海龟画图(Turtle)命令大全

移动和绘制 forward() | fd() 使用语法: ​​turtle.forward(距离)​​ ​​turtle.fd(距离)​​ 参数说明: 距离 一个数字 (整数 或者 浮点) (注:单位是像素) 代码示例: import turtle turtle.forward(200) 效果: backward () | bk() | back() 使用语法: ​…

掏心经验分享,软考中项0基础入门篇!

想备考下半年中项(系统集成项目管理工程师)的朋友,不知道如何了解软考中项,今天给大家整理一篇关于我自己在备考软考时的一些考量和踩过的一些坑。(无广,放心看) 很多小伙伴总是听大家说软考中…

你是学会了还是学废了:Elasticsearch 7 集群拷贝到其它环境如何重置密码

欢迎您关注我的公众号【尚雷的驿站】 公众号:尚雷的驿站 CSDN :https://blog.csdn.net/shlei5580 墨天轮:https://www.modb.pro/u/2436 PGFans:https://www.pgfans.cn/user/home?userId4159 前言 本文描述了将生产ES集群打包拷贝…

线性模型之岭回归的用法

实战:使用岭回归模型 完整代码: import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegression from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split fro…

平芯微PW4056HH中文规格书

概述 PW4056HH 是一款完整的采用恒定电流/恒定电压的高压、大电流、单节锂离子电池线性充电 IC。充电电流可达 1A。输入 MAX 低工作电压 3.75V,降低充电功耗,提高效率。 PW4056HH 采用了内部 PMOS 架构,加上防反充电路,不需要外部…

Java开发大厂面试第04讲:深入理解ThreadPoolExecutor,参数含义与源码执行流程全解

线程池是为了避免线程频繁的创建和销毁带来的性能消耗,而建立的一种池化技术,它是把已创建的线程放入“池”中,当有任务来临时就可以重用已有的线程,无需等待创建的过程,这样就可以有效提高程序的响应速度。但如果要说…

Linux服务器lvm磁盘管理fdisk和df磁盘大小不同修改

服务器端由于硬盘是通过VCenter原来100G磁盘复制的虚拟机,复制完成后,原来100G的磁盘通过选择 磁盘重新复制出150G的磁盘,开机后发现还是原来的100G的磁盘,通过fdisk -l 查看有个sdb是150G, 但是已经划转的lvm盘只有100G, 通过df查看也是原来的100G: pvs查看pv里也是10…

【链路层和局域网】

文章目录 链路层和局域网网络节点的连接方式数据链路层和局域网链路层导论链路层:上下文链路层服务链路层在哪里实现?适配器通信错误检测奇偶校验校验和:CRC(循环冗余校验)多点访问链路和协议多路访问协议MAC&#xff…

立创EDA绘制PCB电路板

1、绘制好原理图后,点击设计---原理图转PCB,生成PCB文件 2、将元器件拖入电路板方框内,摆放布局并使用工具栏布线、放置过孔及丝印 3、然后顶层和底层铺铜 4、后面就可以生成制板文件发送嘉立创制板了。

基于国产LoRa的智慧农业解决方案--ASR6601、SX1278

我国《数字乡村发展战略纲要》明确指出“要推进农业数字化转型”,加快推广云计算、大数据、物联网、人工智能在农业生产经营管理中的运用。 然而,目前我国的农业数字化转型还面临着诸多挑战。我国整体农业机械化程度和自动化控制水平仍然较低。由于农田面…

ubuntu quota配置磁盘配额

安装quota工具:sudo apt-get install quota这条命令会安装quota工具&#xff0c;它用于在Linux系统中管理和强制执行磁盘配额。编辑用户quota:sudo edquota -u <username> /data这条命令会打开默认的文本编辑器&#xff0c;允许你为用户liushenshen在/data文件系统上设置…

三.Ubuntu安装MySql数据库

三.Ubuntu安装MySql数据库 1.首先进入Console,登录Ubuntu系统后,更新源,命令:apt update,如图所示。 安装MySQL命令:apt install mysql-server 执行期间按回车,进行下一步,执行过程如图所示: 选择yes或no,此步选择yes 安装完成。 2.提高MySQL安全性,该命令…

代码行数统计工具cloc

Release v2.00 AlDanial/cloc GitHub 代码量代码行数统计工具cloc的正确使用(windows平台亲测有效&#xff0c;本人踩过坑&#xff0c;文中提到&#xff01;)_cloc代码统计工具-CSDN博客

libssh C++封装之六(Dir)

1 概述 libssh是一个在客户端和服务器端实现SSHv2协议的多平台C库。使用libssh,您可以远程执行程序、传输文件、使用安全透明的隧道、管理公钥等等。本文描述的对libssh客户端功能的C++封装。 libssh下载地址 3 实现 3.5 Dir Dir类型管理远程路径,通过SFTP和Channel实现(有…

so-vits-svc:AI翻唱,语音克隆

前言 这个项目是为了让开发者最喜欢的动画角色唱歌而开发的&#xff0c;任何涉及真人的东西都与开发者的意图背道而驰。 项目地址&#xff1a;https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/README_zh_CN.md 安装 可以自行配置&#xff0c;应该也不难 …

做简单易用的GIS资源管理软件

在室外资源管理领域&#xff0c;采用基于GIS的解决方案已成为主流趋势&#xff0c;旨在实现资源的高效利用和管理。GIS技术结合资源对象的规划、定位和监控&#xff0c;为企业提供全面的管理方案&#xff0c;从而优化资源使用、提高运营效率和降低成本。 然而&#xff0c;许多资…

新手也能看懂的前端单元测试框架:Vitest

单元测试的概念及作用 1.什么是单元测试&#xff1f; 单元测试是测试中的一个重要环节&#xff0c;它针对软件中的最小可测试单元进行验证&#xff0c;通常是指对代码中的单个函数、方法或模块进行测试。 单元测试旨在确定特定部分代码的行为是否符合预期&#xff0c;通过针…