字节跳动提出高性能 transformer 推理库,获 IPDPS 2023 最佳论文奖

news2024/10/7 6:44:37

动手点关注

41e78f34afbcb501d14228be0552539c.gif

干货不迷路

字节跳动与英伟达, 加州大学河滨分校联合发表的论文 《ByteTransformer: A High-Performance Transformer Boosted for Variable-Length》在第 37 届 IEEE 国际并行和分布式处理大会(IPDPS 2023)中,从 396 篇投稿中脱颖而出,荣获了最佳论文奖。该论文提出了字节跳动的 GPU transformer 推理库——ByteTransformer。针对自然语言处理常见的可变长输入,论文提出了一套优化算法,这些算法在保证运算正确性的前提下,成功避免了传统实现中的冗余运算,实现了端到端的推理过程的大幅优化。另外, 论文中还手动调优了 transformer  中的 multi-head attention, layer normalization, activation 等核心算子, 将 ByteTransformer 的推理性提升至业界领先水平。与 PyTorch, TensorFlow, NVIDIA FasterTransformer, Microsoft DeepSpeed-Inference 等知名的深度学习库相比,ByteTransformer 在可变长输入下最高实现131%的加速。论文代码已开源。

ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs ( https://arxiv.org/abs/2210.03052 )

IPDPS: 并行和分布式计算方向计算机系统领域的旗舰会议。该会议专注于分享并讨论并行计算,分布式计算,大规模数据处理以及高性能计算等相关领域的最新研究进展。参与的专家学者来自世界各地的顶尖研究机构和企业,共同探讨该领域的创新发展和前沿技术。

code: https://github.com/bytedance/ByteTransformer

transformer 变长文本 padding free

transformer 在自然语言处理(NLP)中被广泛使用,随着 BERT、GPT-3 等大型模型的出现和发展,transformer 模型的重要性变得越来越突出。这些大型模型通常具有超过一亿个参数,需要大量的计算资源和时间进行训练和推理。因此,优化 transformer 性能变得非常重要。

现有的一些深度学习框架,如 Tensorflow,PyTorch,TVM 以及 NVIDIA TensorRT 等,要求输入序列长度相同,才能利用批处理加速 transformer 计算。然而,在实际场景中,输入序列通常是变长的,而零填充会引入大量的额外计算开销。有一些方法在 kernel launch 前对具有相似 seqlen 的输入分组,以最小化 padding,但无法实现 padding free。字节跳动 AML 团队先前提出的“effective transformer” [4],通过对输入的重排列,实现了 QKV projection 和 MLP 的 padding free,但 self attention 部分仍然需要 padding。

为了解决这个问题,字节跳动 AML 团队提出了 ByteTransformer,它实现了变长输入的 padding free 计算,并且实现了全面的 kernel fusion 以进一步提高性能。

11a5b32b65f477ccbe44b1ddea4fd407.png

图一: ByteTransformer 与其他工作 feature 对比

Remove padding 算法

这个算法源自字节跳动 AML 团队之前的工作 "effective transformer",在  NVIDIA 开源 FasterTransformer 中也有集成。ByteTransformer 同样使用该算法去除对 attention 外矩阵乘的额外计算。

算法步骤:

1)计算 attention mask 的前缀和,作为 offsets

2)根据 offsets 把输入张量从 [batch_size, seqlen, hidden_size] 重排列为 valid_seqlen, hidden_size] ,再参与后续的矩阵乘计算,实现 padding free

905d1ba2cd3db4ddd309119aad3d0c8e.png

图二:Remove padding 算法过程

FMHA (Fused Multi-Head Attention)

为了优化 attention 部分的性能,ByteTransformer 中实现了 fused multi-head attention 算子。对于 seqlen 长度, 以 384 为界划分为两种实现方式:

  • 对于短 seqlen, 因为可以把 QK 整行放在共享内存进行 softmax 操作, 通过手写 kernel 的方式实现, 矩阵乘通过调用 wmma 接口使用 TensorCore 保证高性能。

  • 对于长 seqlen, 因为共享内存大小限制,不能在一个手写 kernel 中完成所有操作。基于高性能的 CUTLASS[5] grouped GEMM, 分成两个 gemm kernel 实现, 并把 add_bias, softmax 等操作 fused 到 GEMM kernel 中。

CUTLASS grouped GEMM

NVIDIA 开发的 grouped GEMM 可以在一个 kernel 中完成多个独立矩阵乘问题的计算,利用这个性质可以实现 Attention 中的 padding free。

  • Attention 中的两次矩阵乘操作,都可以拆解为 batch_size x head_num 个独立的矩阵乘子问题。

  • 每个矩阵乘子问题,把问题大小传入到 grouped GEMM,其中 seqlen 传递真实的 valid seqlen 即可。

grouped GEMM 原理:kernel 中每个 threadblock (CTA) 固定 tiling size,每个矩阵乘子问题根据 problem size 和 tiling size,拆解为不同数量的待计算块,再把这些块平均分配到每个 threadblock 中进行计算。

4a6e3a05baf9be0c30e117b5899628ce.png

图三:grouped GEMM 原理示意图。每个子问题拆解为不同数量的块,再对这些块均匀分配,高效地实现单个 kernel 计算多个独立 GEMM 问题。

使用 grouped GEMM 实现 attention 时,由于子问题的数量 batch_size x head_num 通常较大,读取子问题参数会有不小的开销,因为从线程角度看,每个线程都需要遍历读取所有的子问题大小。

为了解决这个问题,ByteTransformer 对 grouped GEMM 中读取子问题参数进行了性能优化,使其可以忽略不计:

1)共享子问题参数。 对同一个输入,不同 head 的 valid seqlen 相同,problem size 也相同,通过共享使参数存储量从 batch_size x head_num 减少到 batch_size

2)warp prefetch. 原始实现中,每个 CUDA thread 依次读取所有的子问题 problem size,效率很低。改为一个 warp 内线程读取连续的 32 个子问题参数,然后通过 warp 内线程通信交换数据,每个线程的读取次数降低到 1/32

25f6411721fbf570c56e61f739d7a83f.png

图四:warp prefetch 示意图。每个 iteration 一个 warp 读取 32 个子问题 size

softmax fusion

为了进一步提高性能,把 Q x K 之后的 softmax 也 fuse 到矩阵乘算子中,相比单独的 softmax kernel 节省了中间矩阵的访存操作。

因为 softmax 需要对整行数据做归约,但因为共享内存大小的限制,一个 threadblock 内不能容纳整行数据,同时 threadblock 间的通信很低效,所以不能仅在 Q x K 的 epilogue 中完成整个 softmax 的操作。把 softmax 拆分成三步计算,分别 fuse 到 Q x K 的 epilogue 中, QK x V 的 prologue 中,以及中间再添加一个轻量的 kernel 做规约。

60f468f885b94fcb9f0647a293c57897.png

图五:softmax fusion 流程示意图。分为三步计算,大部分计算 fuse 到前后的 GEMM kernel 中

算法步骤:

1)partial reduction:Q x K 的 epilogue 中,每个 threadblock 内部规约,计算出 max 和 sum 两个值

2)full reduction:一个轻量级的 kernel,把每一行的 partial reduction 结果继续规约到整行的结果

3)element-wise op:修改了 CUTLASS 的代码,使其支持 prologue fusion,即在加载输入矩阵后,fuse 一些 element-wise 的操作。在 QK x V 的 prologue 中,读取当前行的规约结果,计算出 softmax 的最终结果,再参与后续的矩阵乘计算

性能数据

短 seqlen 手写 kernel 的性能

在 <= 384的短 seqlen 情况下,cuBLAS batch GEMM 相比 PyTorch MHA 性能提高 5 倍,而启用 zero padding 算法优化 softmax 后进一步提高 9% 的性能,ByteTransformer 的 MHA 把两个 batched GEMM 和中间的 softmax 完全 fuse 成一个 kernel,相比三种变体实现,平均加速分别为 617%、42% 和 30%

e0a8e2b19631e41babdb3d3cd7b1ff2b.png

图六:手写 attention kernel 性能对比。注:cuBLAS + zero padding 指对 softmax 的 zero padding

长 seqlen CUTLASS kernel 的性能

在 448~1024 seqlen 下, cuBLAS batched GEMM 比 PyTorch 的 MHA 性能提高了3倍,同时对 softmax 的 zero padding,进一步提高了 17% 的性能,通过引入高性能的 CUTLASS grouped GEMM 以及 softmax fusion,ByteTransformer 的 fused MHA 比变体 MHA 实现的性能提高了451%,110% 和 79%

2fba6c0752271ea98e66f986f9996fc2.png

图七:长 seqlen CUTLASS FMHA 性能对比

全面的 kernel fusion

除矩阵乘和 attention 的优化外,ByteTransformer 还对一些小的操作进行了全面的 kernel fusion,通过减少显存访问和 kernel launch 的开销,可以获得更极致的性能。

add-bias & LayerNorm fusion

矩阵乘之后的 add-bias 和 LayerNorm 操作,通过手写 kernel 的方式做 fusion,这部分操作在 seqlen 为 256 和 1024 的情况下分别占 10% 和 6% 的延迟,fused kernel 可以优化 61% 的性能,对单层 BERT transformer 的性能提升 3.2% (平均 seqlen 128 - 1024 的情况)。

GEMM & add-bias & GELU fusion

通过 CUTLASS fuse epilogue 的方式,把矩阵乘后的 add-bias 操作和 GELU activation 操作 fuse 到矩阵乘 kernel 中。add-bias 和 GELU 在 seqlen 为 256 和 1024 的情况下占比耗时分别为 7% 和 5%。把 add-bias 和 GELU 融合 GEMM 可以完美隐藏这部分的访存延迟,进一步使单层 transformer 性能提升 3.8%

变种 transformer 支持

目前,字节跳动 AML 团队已经在 GitHub 上开源了 ByteTransformer 的标准 BERT 实现(https://github.com/bytedance/ByteTransformer)。除此之外,字节内部版本还支持了许多 transformer 变种,比如 Deberta, Roformer,T5 等等。代码实现易于拓展,并且上述各种优化手段也可以方便地应用到变种 transformer 中。

更多性能数据

与其他 transformer 实现的端到端性能对比

实验配置:标准 BERT transformer,head size = 64, head number = 12, layer = 12, 平均 valid seqlen = 0.6 * 最大 seqlen,使用 A100 GPU 进行测试。比较在 seqlen=64~1024,batch_size=1,8,16 情况下的性能表现。相比 PyTorch JIT, TensorFlow XLA, Tencent TurboTransformer, Microsoft DeepSpeed-Inference 和 NVIDIA FasterTransformer,分别平均加速 87%, 131%, 138%, 74% 和 55%

  • 注:在论文投稿时,PyTorch 和 NVIDIA FasterTransformer 还没有集成 FlashAttention [6]

a830c0b910bdcb12a566f56cd10c7951.png

图八:各 transformer 实现的端到端性能对比

  1. 各优化手段影响拆解

    1. 与 ByteTransformer 自己的基线版本对比,开启各种优化后总体相对基线提升 60%。拆解各优化手段对性能的影响如下:

    2. add-bias & LayerNorm fusion 可以提高性能 3.2%

    3. 将 add-bias & GELU fuse 到 GEMM epilogue 可以进一步提高 3.8%

    4. 引入 remove padding 算法,性能提高 24%

    5. FMHA 额外提高 20%

2b97dcab91e0c7458fe7f3403c6ba805.png

图九:各优化手段性能提升拆解

  1. BERT-like 变种的性能对比

比较 ByteTransformer 在 ALBERT, DistilBERT 和 DeBERTa 这几种模型结构下与最先进的 DL 框架的性能表现。实验配置与标准 BERT 一致,平均 seqlen 为 0.6*最大 seqlen。

对于 ALBERT 和 DistilBERT,ByteTransformer 平均比 PyTorch、TensorFlow、Tencent TurboTransformer、DeepSpeed-Inference 和 NVIDIA FasterTransformer 分别快98%、158%、256%、93% 和 53%。对于 DeBERTa 模型,ByteTransformer 比PyTorch、TensorFlow 和 DeepSpeed 分别快 44%、243% 和 74%。

51c3dd75a6399931fdcbbced08528df8.png

图十:BERT-like 变种的性能对比。部分数据点缺失是因为对应框架不支持或无法成功运行

结论

ByteTransformer 是一种高效的 transformer 实现,它通过一系列优化手段,实现了在 BERT transformer 上的高性能表现。对于变长文本输入,相比其他 transformer 实现,ByteTransformer 具有明显的优势,实验中平均加速可达 50% 以上。适用于加速自然语言处理任务,提高模型训练与推理的效率。同时,ByteTransformer 也为其他研究者提供了一种高效的 transformer 实现方式。其优化手段和性能表现对于实际应用具有重要意义。

字节跳动 AML 团队介绍

AML 是字节跳动公司的机器学习中台,为抖音/今日头条/西瓜视频等业务提供推荐/广告/CV/语音/NLP的训练和推理系统。为公司内业务部门提供强大的机器学习算力,并在这些业务的问题上研究一些具有通用性和创新性的算法。同时,也通过火山引擎将一些机器学习/推荐系统的核心能力提供给外部企业客户。

更多 AML 岗位火热招聘中,欢迎扫描下方二维码,了解详情,进行投递

49470928b3bd3ef98eedabf9da7ce5b4.pngb0d224583a9229fd166f93675a785c85.png

引用:

[1] J. Fang, Y. Yu, C. Zhao, and J. Zhou, “Turbotransformers: an efficient gpu serving system for transformer models,” in Proceedings of the 26th ACM SIGPLAN Symposium on Principles and Practice of Parallel Programming, 2021, pp. 389–402.

[2] R. Y. Aminabadi, S. Rajbhandari, M. Zhang, A. A. Awan, C. Li, D. Li, E. Zheng, J. Rasley, S. Smith, O. Ruwase et al., “Deepspeed inference: Enabling efficient inference of transformer models at unprecedented scale,” arXiv preprint arXiv:2207.00032, 2022.

[3] NVIDIA, https://github.com/NVIDIA/FasterTransformer

[4] ByteDance https://github.com/bytedance/effective_transformer

[5] NVIDIA https://github.com/NVIDIA/cutlass

[6] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Re, “Flashattention: Fast ´ and memory-efficient exact attention with io-awareness,” arXiv preprint arXiv:2205.14135, 2022.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/652103.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

广东省高校人工智能产教融合院长研讨会召开,校企协同探索AI教育新范式

为深化产教融合、促进校企合作&#xff0c;着力推进人工智能产业和高校人才培养体系相融合&#xff0c;深入探讨校企合作、产教融合与课程建设规划等事宜&#xff0c;2023年6月9日下午&#xff0c;百度飞桨联合广东省计算机学会、华南理工大学计算机科学与工程学院、荔峰科技&a…

这世界好神奇,我们其实并不了解自己的身体

点击文末“阅读原文”即可参与节目互动 剪辑、音频 / 卷圈 运营 / SandLiu 卷圈 监制 / 姝琦 文案 / 粒粒 产品统筹 / bobo 场地支持 / 声湃轩北京站 这是一次突发奇想的天马行空&#xff0c;三个人猝不及防地坐下来大开脑洞&#xff0c;从诗词歌赋聊到人生哲学&#xff…

CVPR 2023 | 美团技术团队精选论文解读

本文精选了美团技术团队被CVPR 2023收录的8篇论文进行解读。这些论文既有自监督学习、领域自适应、联邦学习等通用学习范式方面的技术迭代&#xff0c;也涉及目标检测、跟踪、分割、Low-level Vision等典型视觉任务的性能&#xff0c;体现了美团在基础通用技术和垂直领域技术上…

HotSpot虚拟机对象探索与OutOfMemoryError异常

HotSpot虚拟机对象探索与OutOfMemoryError异常 1.HotSpot虚拟机对象探索 1.1对象的创建 不是一直有一个笑话,别人问程序员有没有对象,程序员会说我没有对象,但是我可以new一个出来 这里就可以判断他学过c或者java等语言 在java中对象的创建一般我们都是通过new来创建的,但…

MyBatis01

ORM&#xff1a;对象关系映射 O&#xff08;Object&#xff09;&#xff1a;Java虚拟机中的Java对象 R&#xff08;Relational&#xff09;&#xff1a;关系型数据库 M&#xff08;Mapping&#xff09;&#xff1a;将Java虚拟机中的Java对象映射到数据库表中一行记录&#xff0…

【王道·操作系统】第三章 内存管理

一、内存管理 1.1 内存的基础知识 内存可存放数据&#xff0c;程序执行前需要先放到内存中才能被CPU处理——缓和CPU与硬盘之间的速度矛盾内存地址从0开始&#xff0c;每个地址对应一个存储单元 按字节编址&#xff1a;每个存储单元大小为1字节(B)&#xff0c;即8个二进制位按…

【Spring Cloud系列】- Eureka使用详解

【Spring Cloud系列】- Eureka使用详解 文章目录 【Spring Cloud系列】- Eureka使用详解一、概述二、Eureka简介三、Eureka结构与作用Eureka结构图Eureka采用CS&#xff08;Client/Server,客户端/服务器&#xff09;架构&#xff0c;它包括以下两大组件 四、Eureka集群及与应用…

梁宁:为什么中国没有像 ChatGPT 和 Vision Pro 这样的创新产品?

6 月 10 日&#xff0c;产品战略专家梁宁和图灵联合创始人刘江围绕“ ChatGPT 真需求”主题进行直播对谈。 梁宁&#xff0c;产品战略专家&#xff0c;曾任湖畔大学产品模块学术主任&#xff0c;联想、腾讯高管&#xff0c;CNET集团副总裁。 工作经历横跨 BAT&#xff0c;与美团…

第14届蓝桥杯Scratch(中级)国赛真题解析2023.5.28

第14届蓝桥杯Scratch(中级)国赛真题解析2023.5.28 一:选择题(50分)第 1 题 单选题(10分) 运行以下程序后,角色说出的数是 ( C )。 *选择题严禁使用程序验证,选择题不答或答错都不扣分 A.150 B.200 C.300 D.600第 2 题 单选题(10分) 对以下程序效果描述完全正确的是 …

【JUC进阶】01. Synchroized实现原理

目录 1、前言 2、Synchronized使用 2.1、对象锁&#xff08;Instance Lock&#xff09; 2.2、类锁&#xff08;Class Lock&#xff09; 2.3、方法锁&#xff08;Method Lock&#xff09; 3、原理分析 3.1、monitor对象 3.2、monitorenter 3.3、monitorexit 3.4、对象…

库克和马斯克之后,比尔盖茨访华,美企认识到中国市场不可替代

微软创始人比尔盖茨已到达了中国&#xff0c;这是美国苹果CEO库克、特斯拉CEO马斯克之后访华的又一位美国重要企业家&#xff0c;那么美国企业家陆续访华是为了什么呢&#xff1f; 一、美企连遭挫折 苹果一季度的业绩显示营收、利润均下滑了个位数&#xff0c;但是苹果看到了隐…

AIGC 加持 Cocos,游戏开发需要几步?

近日&#xff0c;游戏行业知名的 B2B 大会 WN 2023 大会于土耳其首都伊斯坦布尔顺利举办。本次大会邀请了来自全球的游戏开发商、媒体、发行商、分发平台等行业决策者&#xff0c;共同探讨游戏行业未来发展态势&#xff0c;进一步拓展业务&#xff0c;并在世界范围内寻找新的合…

【力扣刷题 | 第十天】347.前k个高频元素 227 简单计算器

前言&#xff1a; 本篇将是最后一篇我们利用栈与队列来解决力扣问题&#xff0c;在下文我们将进入到数这一章&#xff0c;相对应的【夜深人静讲数据结构与算法】专栏中树也会及时更新。 347. 前 K 个高频元素 - 力扣&#xff08;LeetCode&#xff09; 给你一个整数数组 nums 和…

【JS】1714- 重学 JavaScript API - Geolocation API

❝ 前期回顾&#xff1a; 1. Page Visibility API 2. Broadcast Channel API 3. Beacon API 4. Resize Observer API 5. Clipboard API 6. Fetch API 7. Performance API 8. WebStorage API 9. WebSockets API 10. Fullscreen API ❞ 本文将深入探讨 Geolocation API 的概念、使…

华为OD机试真题 JavaScript 实现【关联子串】【2023Q1 100分】,附详细解题思路

一、题目描述 给定两个字符串str1和str2&#xff0c;str1进行排列组合只要有一个为str2的子串则认为str1是str2的关联子串&#xff0c;请返回子串在str2的起始位置&#xff0c;若不是关联子串则返回-1。 二、输入描述 qwe dsgfasgfwe 三、输出描述 -1 四、解题思路 读取…

009、体系架构之HTAP

HTAP HTAP技术传统的HTAP解决方案HATP的要求TiDB的HTAP架构TiDB的HTAP特性使用场景 MPP HTAP技术 传统的HTAP解决方案 HATP的要求 可扩展性 分布式事务分布式存储 同时支持OLTP与OLAP 同时支持行存和列存OLTP与OLAP业务隔离 实时性 行存与列存数据实时同步 TiDB的HTAP架构 …

Committer 迎新!这次是来自阿里云的同学

点击蓝字 关注我们 迎新&#xff01; 截至今天&#xff0c;Apache DolphinScheduler 项目在 GitHub 上的 Star 数已突破 10.6K&#xff0c;贡献者人数也突破了 470 人。社区的不断壮大&#xff0c;离不开每位 Contributor 的支持。 最近&#xff0c;Apache DolphinScheduler 又…

AI模型部署实战:利用CV-CUDA加速视觉模型部署流程

本文首发于公众号【DeepDriving】&#xff0c;欢迎关注。 CV-CUDA简介 随着深度学习技术在计算机视觉领域的发展&#xff0c;越来越多的AI算法模型被用于目标检测、图像分割、图像生成等任务中&#xff0c;如何高效地在云端或者边缘设备上部署这些模型是工程师迫切需要解决的问…

Android 13(T) - 智能指针

Android有一套自己的智能指针管理办法&#xff0c;并且将其运用在源码的各个角落&#xff0c;所以学习Media框架之前&#xff0c;我们有必要先了解下Android智能指针。 本节代码源自于Android 13(T)&#xff0c;参考 (aospxref.com) 1 概述 与智能指针相关的总共有5个类&#…

某小厂面试加答案(6.15)

看 Java 面试题就去 www.javacn.site 磊哥新推出《企业面经和答案》栏目&#xff0c;最近会持续更新&#xff0c;欢迎大家订阅此账号查看&#xff0c;或访问 www.javacn.site 查看。 面经来源于牛客&#xff0c;如下图所示&#xff1a; https://www.nowcoder.com/feed/main/det…