FlashSequence: SORA视频生成长序列任务训练解决方案

news2025/1/15 16:32:48

作者:黄奕桐、沈雯婷、艾宝乐、王昂、九丰

摘要

我们提出了长序列训练方案 FlashSequence 并集成在 PAI-TorchAcc (阿里云机器学习平台开发的Pytorch上的大模型训练加速框架)中,该方案能够支持SORA类超长序列模型的高效训练。在两机 16 卡 A100 上,FlashSequence 能够训练 1M 的长序列模型,并达到了 51.7%的 MFU,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。

一、横空出世的 SORA

SORA 介绍

SORA 是一个文生视频的模型,可以根据输入的文本生成对应的视频。

图1: SORA,这个图的核心部分来自:https://openai.com/research/video-generation-models-as-world-simulators。加上了text encoder和DiT blocks。

SORA 在训练时输入的视频可以看成是若干帧图像, 通过visual encoder得到spatial tempral patches 并 flatten成一维作为transformer tokens,同时输入的文本通过 text encoder 生成 embed,两者送入 diffusion transformer (DiT) 进行训练。

DiT 模型

图2: DiT模型的网络结构,来自:https://arxiv.org/abs/2212.09748

图 2 是 DiT 模型的网络结构,可以看到,DiT 模型和 LLAMA 等 LLM 模型的结构上大体上相同,DiT 模型多了对于输入 latent 的 patchify 处理(转换为 LLM 模型需要的 tokens 输入)、在 DecoderLayer 中增加了与文本输入的交互等。基本上 LLM 有的 Multi-Head Self-Attention、Pointwise Feadforward(MLP)结构,DiT 模型也有。

从整体上来看,DiT 模型与 LLM 模型的区别不大,从计算上来看,主要的计算量还是在 Attention 和 MLP 部分,区别在于 Attention 部分通常不会使用 casual mask,导致计算量较大;从显存上看,文本交互的部分会引入额外的显存使用。

训练需求

和LLM只有一个模型不同的是,文生视频模型由多个模型组成,包括:对文本进行编码的 text encoder、对视频进行编码的 visual encoder、DiT、用于推理的和 DiT 模型相同大小的 EMA 模型等。一般来说,text encoder为一个LM模型,预计在几 B 左右;visual encoder 通常为包含conv的VAE,参数量较小;DiT 为一个中小规模的模型,在 1B ~ 30B 左右。其中text encoder, visual encoder 通常为 pretrained 模型,在一些场景下无需进行训练。整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。

同时,与一般的 LLM 模型不同的是,文生视频模型的输入 token 数在几十K到几M之间,因此文生视频模型训练的核心挑战是对长序列的高效支持

二、 workload 分析

计算量

对于 text encoder 和 DiT 模型,主要计算量为L(\alpha bsh^2+\beta bs^2h),其中 L 表示 decoder layer 层数、b 表示 micro batch size,s 表示 sequence length,h 表示 hidden dim size,\alpha表示 linear 层的系数,\beta表示 attention 层的系数(在一般的非 causal attention 上前向是 4,后向是 8)。在 s 比较大时,如几十K到几M 时,其他的算子如 element wise 算子在这里可以忽略。text encoder 的 s 通常比较固定的,且远小于视频输入的 token 数。对于 VAE 模型,其主要计算量为一系列卷积操作,在长序列场景,计算量一般小于 DiT 模型。

因此,在视频输入的 token 数 s 比较大时,如几十K到几M 时,整个训练的计算量绝大部分在 DiT 模型上。同时,在 s 比较大的时候,attention 部分的计算量会按照 token 数平方增长,而 linear 层的增长只是线性增长,因此,attention 部分的计算量会逐渐成为整个训练过程中的瓶颈。在 1M 场景,attention 部分的计算时间可以达到 E2E 时间的 95%。

显存使用

如前所述,文生视频整体的模型参数量在 10B ~ 60B 左右,与一般的 LLM 模型差别不大。在使用 7B 的 text encoder 和 7B 的 DiT 模型时,模型常驻的显存大概在 130GB(包含各个模型的参数和 DiT 部分的 optimizer state)。

在文生视频的训练过程中,通常只会训练 DiT 部分,而 text encoder 和 VAE 部分通常不会进行训练,所以显存使用主要在 DiT 的训练部分。text encoder 和 VAE 部分的显存主要考虑临时的 tensor 使用会不会导致 OOM。

在目前的 LLM 模型中,通常会使用 FlashAttention 进行性能和显存优化。在使用了 FlashAttention 之后,DiT 部分的显存使用为\gamma Lbsh,其中 \gamma为一个 Decoder Layer 的显存使用系数,取决于 DiT 模型的实现,这个值会有所变化,但是由于 text 的输入,通常会比普通的 LLM 模型大,例如,一般的 LLM 模型可以是 34,而 DiT 模型会达到 60 ~ 70。在 tokens 数为 1M 、micro batch size 为 1 的场景下,7B 的 DiT 模型总的 activation 显存使用可以达到 8000 GB 以上。

可以看到,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这也为后续显存优化提供了参考。

三、FlashSequence 长序列训练方案

FlashSequence

基于 workload 分析,我们提出了 FlashSequence 这一解决方案:

  • 分布式策略:

  • 为了切分中小规模模型的参数,FlashSequence 使用了 FSDP 这一分布式策略,同时,FlashSequence 在 FSDP 外面嵌套使用了 DP 提升多机拓展性。

  • 为了切分长序列训练场景下的 activation,FlashSequence 使用了 context parallel 对 sequence 维度进行切分,同时,FlashSequence 提出 2D context parallel 的方案减少 context parallel 跨机的通信开销。我们还去除了使用 context parallel 之后带来的冗余重复计算。

  • 显存优化策略:

  • FlashSequence 通过使用 CPU offloading 将 activation offload 到 CPU 内存上减少显存,同时极大减少了 gradient checkpoint(GC)带来的额外重算计算量。CPU offloading 的策略在长序列场景下数据传输时间能够和计算时间完全 overlap,相比 GC 能够在不影响 E2E 时间的情况下减少显存。

  • 为了避免 CPU 内存 OOM 和减少一部分 offloading 时间,FlashSequence 使用了 selective GC,selective GC 会优先选择显存计算比高的部分。

  • FlashSequence 还使用了 PyTorch expandable allocator 解决长序列场景下显存碎片过多的问题。

分布式策略

整体思路

在模型中存在两种类型的 tensor,一种是参数相关的包括模型参数、optimizer state、gradients,另一种是 activation。由于长序列场景下参数和 activation 都是不可忽视的,为了避免 OOM,我们需要同时切分参数和 activation。例如,常驻的参数和 optimizer state 可以达到 130GB,而 activation 在 1M 场景下可以达到 8000GB 以上。

参数切分

在参数的切分方面,我们存在多种选择,比如 TP、PP、FSDP 等,但是由于模型本身规模不是特别大,同时考虑到计算和通信的 overlap 情况,FlashSequence 选择了 FSDP 这种参数切分策略。不同于 TP 和 PP,FSDP 的通信除了第一个 layer 的 allgather 之外都能和计算 overlap,没有和计算 overlap 的通信时间在 FSDP 较小的情况下通常可以忽略。虽然 TP 也可以同时切分 activation,但是 TP 会引入无法 overlap 的通信,同时 PP 需要比较大的 gradient accumulation steps 才能掩盖 bubble。

activation 切分

对于 activation,DiT blocks 输入的 shape 为 [batch, sequence, hidden_dim]。由于长序列场景 activation 非常大,所以 micro batch size 通常为 1,这一维度无法切分。在 sequence 维度的切分目前存在 context parallel 如DeepSpeed-Ulysses 和 Ring Attention,以及 Megatron 的 sequence parallel。在 hidden_dim 的维度的切分主要是 Megatron 的 tensor parallel(以切分 weight 的方式实现对 activation 的 hidden dim 维度的切分)。纯粹的 tensor parallel 在 layer norm 等部分还是需要全量的 tensor,这一点在长序列场景是不可接受的。通常目前的主流做法是 Megatron 的 TP-SP 切分方式,这种切分方式和 context parallel 一样可以完整切分 layer 内的 activation。

对于 TP-SP 的切分方式,通信量为\frac{16Lsbh(t-1)}{t},其中t为 TP-SP 的数目,L 为 layer 数、s 为 sequence、b 为 micro batch size、h 为 hidden dim。对于 context parallel,以 DeepSpeed-Ulysses 为例,通信量为\frac{8\Psi(t-1)}{t}+\frac{16Lsbh(t-1)}{t^2},其中\Psi为模型参数量,前面一项是对模型参数的 all reduce 通信,后面一项是对 self attn 的 q、k、v、out 的 alltoall 通信。对于DeepSpeed-Ulysses,模型参数的 all reduce 可以被计算 overlap(类似 DDP),而后面不能 overlap 的通信小于 TP-SP 的切分方式。

从上面的对比可以看出,DeepSpeed-Ulysses 不能 overlap 的通信理论上是小于 TP-SP 的(即使考虑 TP-SP 后向通信可以 overlap)。同时,我们使用 FSDP 切分参数之后也不再需要 TP 对模型参数进行切分。在这种场景下面,FSDP 只是一种切分模型参数的分布式策略,其数据并行的含义被弱化了,不再是开启多少 FSDP 读取多少不同的数据样本,只需要保证 context parallel 的一个 group 内读取相同的数据即可。

综上所述,FSDP+context parallel 的方式优于 TP-SP 的切分方式。同时context parallel 还可以使用 Ring Attention 的方式进一步减少不能 overlap 的通信。

FSDP+DP

在文生视频这种中小规模模型的场景下,FSDP 不需要开很大就可以避免 OOM,在 7B 及以下规模,使用 FSDP=8 就足够满足显存使用需求,同时还能使用高速的机内带宽进行通信。

在更多的卡数下,FSDP 的拓展性会存在一些问题,为了避免这些问题,FlashSequence 进行了 DP 和 FSDP 的嵌套,在外层使用 DP,在内层使用 FSDP。虽然 FSDP 和 DP 的通信都能被计算 overlap,但是 DP 的通信量小于 FSDP,同时 DP 只在计算时间更长的后向进行通信,所以,DP 相比于 FSDP 拥有更好的多机拓展性。在使用了 DP+FSDP 的组合之后,不只能满足参数切分的需求,同时提升了多机的拓展性。

Context Parallel

context parallel 的好处是只在 attention 部分和 transformer 模型之后引入了额外通信,在其他的部分比如 MLP 均不需要额外的通信,而且 gradients 的同步使用 DP+FSDP 就可以完成。同时,在 context parallel 的作用域之内 activation 和计算可以被均匀切分。

目前的 context parallel 都是在一开始就对 sequence 维度进行切分。唯一的区别在于 attention 部分的处理,DeepSpeed-Ulysses 会将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,而 RingAttention 会依然保留 sequence 维度的切分对 attention 的计算进行特殊处理。

DeepSpeed-Ulysses

图3:DeepSpeed-Ulysses,来自:https://arxiv.org/abs/2309.14509

如图 3 所示,DeepSpeed-Ulysses 会对 q、k、v 分别进行 all to all,将 sequence 维度的切分转换为 head 维度的切分再进行 attention 的计算,然后再对 attention 的输出进行 all to all,将 head 维度的切分转换回 sequence 维度的切分。由于 attention 的计算在 head 维度是并行的,所以这样操作之后不需要对 attention 的计算进行额外处理。可以看到,DeepSpeed-Ulysses 切分的是 head 维度,所以这使得DeepSpeed-Ulysses 的并行数目最多开到 head 的大小。

单个 layer 内 DeepSpeed-Ulysses 的通信和计算对比为:\frac{16sbh(t-1)}{t^2B}:\frac{\alpha bsh^2+\beta bs^2h}{tF}= \frac{16(t-1)F}{t(\alpha h+\beta s)B},其中 F 为 GPU 计算 FLOPS,B 为 alltoall 通信带宽。可以看到,随着 s 的变大,DeepSpeed-Ulysses 的通信占比会逐渐降低,最终达到一个可以忽略的程度,在 seq len = 256K 单机 8 卡的场景下,DeepSpeed-Ulysses 的通信时间在 E2E 的时间占比已经低于 1%,在 seq len=64K 的场景下也只有 2%~ 3%。但是,在涉及到跨机通信时,DeepSpeed-Ulysses 的通信开销由于机间通信带宽较低会变得不可忽视。在 256K 的场景下 2 机 16 卡会达到 10%以上。

Ring Attention

图4:Ring Attention,来自:https://arxiv.org/abs/2310.01889
 

如图 4 所示,Ring Attention 的实现过程中会保持 sequence 维度的切分。Ring Attention 会以 ring 的方式发送和接收其他 device 上的 k 和 v,同时计算本地的 q、k、v 分块的 attention,对输出进行一些矫正保证正确性。这种方式可以使得计算和通信能够 overlap 起来。

Ring Attention 计算和通信 overlap 的理论条件是:考虑前向的一个小的 Attention,通信量为 k 和 v:4bsh,计算量为:4bs^2h,所以计算能够掩盖通信的条件为:4bs^2h/F \ge 4bsh/B \implies s \ge F/B,其中 F 为 GPU 计算 FLOPS,B 为 send/recv 通信带宽。在实际运行过程中,还需要考虑 Flash Attention 的计算利用率和 send/recv 的带宽利用率,根据机器和算子性能的不同,在涉及跨机通信时,在 A100 上面下单卡需要 24K 的序列长度才能 overlap。

可以看到,Ring Attention 的优势是通信能够和计算 overlap,但是需要保证 s 切分后单 GPU 卡上的句子长度满足 overlap 条件。

2D context parallel

对于 context parallel,由于只有 attention 部分存在通信,所以我们需要考虑的只是 attention 部分的处理。在 attention 部分,activation 的 shape 为 [batch, sequence, heads, head_dim],由于维度的大小关系,在这其中 sequence、heads 和 head_dim 是可以进行切分的 sequence 和 heads 的切分分别代表了 Ring Attention 和 Ulysses。head_dim 维度由于是矩阵乘的 contracted 维度,这种维度的切分一般不可避免会引入无法 overlap 的 allreduce 或者 allgather 等通信算子,这会使得通信量大于 Ulysses 的 alltoall。

除此之外,我们还可以同时切分 sequence 维度和 heads 维度。在这种情况下,我们只需要进行一部分通信量较少的 alltoall 通信将一部分 sequence 维度转换为 head 维度,同时,针对剩余的 sequence 维度的切分,可以使用可以 overlap 的 send/recv 通信进行处理。由于 alltoall 的跨机性能较差同时 send/recv 的通信时间可以被计算 overlap,FlashSequence 让外层 alltoall 的通信使用机内的 nvlink 进行通信,内层的 send/recv 使用机间带宽进行通信。我们称这种 context parallel 为 2D context parallel。

2D context parallel 相比 DeepSpeed-Ulysses 可以减少没有 overlap 的 alltoall 时间,相比 Ring Attention 可以在单机 tokens 数较小时减少 send/recv 的次数和 attention 的计算时间,使得 send/recv 和计算可以 overlap。这种设计在 context parallel 涉及跨机通信时会显著减少没有和计算 overlap 的通信时间在 E2E 中的占比,在 seqlen = 256K、2 机 16 卡的场景可以将 DeepSpeed-Ulysses 的通信时间从 10%以上减少到低于 1%。

分布式策略的冗余计算优化

在上面我们提到使用 context parallel 对 sequence 维度进行切分,但是这个切分是存在边界的,一般情况下我们会在 activation 的 shape 转换为 transformer 需要的 shape 之后(比如 DiT 模型的 patchify 之后)才对sequence 维度进行切分。由于 context parallel 需要 group 内的 device 读取相同的数据,这就会导致从 dataloader 读取样本到 sequence 维度切分之间在 group 内的 device 进行的是相同的计算。这一部分在 SORA 模型中通常是 visual encoder 和 text encoder 模型,分别负责对视频和文本进行编码。这些计算在中小长度的序列长度下占比比较高,取决于具体模型实现和序列长度,可以达到 20%甚至 70%。

为了为了去除这一部分的冗余计算,我们可以让 context parallel group 内的 device 读取不同的数据,在需要 sequence 维度切分时进行一个 context parallel 大小的 loop 遍历,依次对前面不同 device 读取的数据进行 broadcast ,使得 transformer 的部分输入的数据一样。这样处理之后,VAE+text encoder 的时间占比会减少到之前的 1/context parallel size,带来 E2E 性能提升。

显存优化策略

使用分布式策略可以进行模型参数和 activation 的切分以减少显存,但是分布式策略的切分会引入通信开销,在更多卡参与切分时,这些开销会逐渐变得不可忽视。例如 activation 在 1M 场景下可以达到 8000GB 以上,使用 80GB 的 GPU 就需要至少 100 张卡,这是不可接受的。因此,我们还需要一些显存优化策略来进一步减少显存。

在目前的实践中,gradient checkpoint(GC)是较为常见的策略,GC 的重点是选择合适的重算部分以减少额外的计算开销。CPU offloading 在 DeepSpeed 中通常是对参数进行 offload,但是在长序列场景,我们发现 CPU offloading 在 activation 上相比 GC 也能带来明显的性能提升。显存碎片在长序列场景也会经常遇到,经常会出现 PyTorch reserve 了 10 几 GB 的显存却无法分配一个几百 MB 的 tensor,进而导致 OOM。

Selective GC

gradient checkpoint(GC)的思想是在前向过程中不保留 activation,在后向时重新运行一次前向生成 activation。在使用 GC 的过程中,最主要的问题是选择好重算的部分。目前主流的做法是对整个 decoder layer 进行 GC(full GC)或者对 Attention 部分进行 GC(Megatron selective GC)。

但是,如前所述,在长序列场景,attention 部分占据了绝大部分的计算,重算 attention 的开销很大。同时,与较小序列不同的是,在长序列场景,MLP 部分也是可以考虑进行 GC 的,在 1M 场景,MLP 的 E2E 占比已经低于 5%,重算的开销较小。

FlashSequence 会优先选择显存计算比高的部分。按照模型中的算子 FLOPS以及算子节省的显存量,FlashSequence 会选择依次节省显存收益较大的部分。

CPU Offloading

CPU Offloading 的思想是将部分 tensor 从显存传输到 CPU 内存上,在需要时再 prefetch 回来。在 DeepSpeed 中,这一技术通常只在参数上使用,这是因为之前的场景 offload activation 会有比较大的 PCIe 传输开销。然而,在长序列场景,如上面所述,相比于计算量按照 token 数 s 平方增长,显存量是按照 token 数 s 线性增长的,这就使得在长序列场景,计算的时间会逐渐超过 offload activation 的 PCIe 传输时间。在 64K 场景,offload 一层 decoder layer activation 的 PCIe 传输时间可能需要 2 ~ 3 层 layer 的计算进行 overlap,而在超过 256K 的场景,offload 一层 decoder layer activation 的 PCIe 传输时间仅需一层 layer 的计算就可以 overlap。在不同的模型下,这个 overlap 的 layer 的数目会有所区别,但是随着序列长度 s 的增长,最终都会达到一个可用的状态,比如在 64K 上使用 offloading 就可以无损减少多层 decoder layer 的 activation 显存占用。

以一层 decoder layer 的 activation 作为 offload 的粒度,offload 一层 decoder layer 可以达到和 GC 一层 decoder layer 类似的显存减少量,同时在长序列场景,offloading 的传输时间能够被计算时间 overlap,相当于在 E2E 性能无损的情况下减少了显存,相比于 GC 能够减少额外的计算开销。

虽然 offloading 在长序列场景拥有比 GC 更好的性能表现,offloading 本身也存在一些问题:

  1. 较短的序列长度需要多层 layer 的计算才能 overlap 传输时间,当然这个在更长的序列长度上不是问题。

  2. offloading 需要使用 CPU 的 pinned 内存,而 CPU 的内存虽然有 1TB ~ 2TB,但是在长序列场景,8 张卡的 offloading 所需要使用的内存总量会很快超过 CPU 的内存。这可以通过结合部分 selective GC 进行解决。

  3. offloading 会和跨机通信(RDMA 也会使用一部分 PCIe 资源)竞争 PCIe,这种影响在前向计算中比较明显。但是在使用了 DP+FSDP 和 2D context parallel 的组合之后,大部分通信都是使用 nvlink,机间通信也能够被计算 overlap,所以对 E2E 的性能影响不大。

基于上述问题,FlashSequence 在优先 CPU offloading 的同时使用 selective GC,避免 CPU 内存 OOM 的同时减少重算的 FLOPS。

显存碎片

在长序列场景,一个 tensor 的显存使用可以达到几百 MB 甚至几 GB,在这种场景下,PyTorch 的 caching allocator 会导致比较多的显存碎片。

图5: caching allocator的显存分配情况

图6: expandable allocator的显存分配情况

图 5 是某个长序列场景 OOM 时的显存使用情况,其中空白部分是还没分配但是被 PyTorch reserve 的显存。这个 OOM 本来是不应该出现的,因为这个时候请求分配的 tensor 只需要 500 多 MB 的显存,而 PyTorch reserve 的未分配显存有 7.5GB。但是因为 PyTorch reserve 的未分配显存都是不连续的(大的空白是 200 多 MB),所以导致了 OOM。这个显存碎片问题在更长的序列场景会更加常见,有时候可以达到 10 ~ 20GB 的显存碎片。

在 PyTorch 的 2.2 及以上版本,引入了expandable 的 allocator,这个 allocator 可以在有更大显存分配请求的情况下拓展已有的空闲显存块,进而减少原始 caching allocator 的显存碎片。从图 6 中可以看到显存碎片低于 1GB。在大部分长序列场景下,expandable allocator 的显存碎片都比 caching allocator 的小,同时在我们场景下性能基本没有变化。

计算优化

FlashAttention 优化

FlashAttention是DiT 模型中attention部分的常用优化手段,FlashAttention的前向计算量为4bs^2hFLOPS,后向的计算量是10bs^2hFLOPS(FlashAttention 在后向存在部分重算),如前文对计算量的分析,随着序列长度的增长,attention部分在端到端的训练时间中甚至占比到95%以上。因此,FlashAttention的计算性能,也成为整个训练任务最为dominate的部分。

我们对TriDao版本的FlashAttention2在不同序列长度下的性能做了A100上的kernel的性能测试,以batch-size=1, hidden-dim=128为例。通过图 7 和图 8 的性能数字,可以看到:

1.  序列长度和N_HEADS同时太大(N_CTX>=512K and N_HEADS>=8)或太小(N_CTX<=4K 或 N_CTX<=32K and N_HEADS<=4),都会造成一定程度的性能损失。这个性能数字也可以指导对N_HEADS和N_CTX的分片;

2.  根据FlashAttention的性能,我们可以预估一次迭代的训练时长。如1M下,按照前向227TFLOPS/s,后向191TFLOPS/s计算,一个layer计算batch-size=1, hidden-size=4096的FlashAttention计算时间为4bs^2h \times 10^{-12}/227+10bs^2h\times 10^{-12}/191,约5min,L个layer的FlashAttention的用#n_gpus并行计算时间为5Lmin/n_gpus。

图7: FlashAttention在A100上面不同序列长度和HEAD大小下前向的性能

图8: FlashAttention在A100上面不同序列长度和HEAD大小下后向的性能

不同的 Attention 实现

前文中我们预估了1M序列长度下FlashAttention的计算时间,可以看到由于序列长度平方项的计算量的存在,一轮迭代的时间在分钟级别,导致模型训练的速度非常慢。有很多降低Attention二次方序列长度计算量的工作,提升Transformer效率的工作,主要包含以下几种:Linear Attention、Sparse Attention、Mamba、Compress Memory。这些算法由于计算量的减少,在模型的效果方面和原始Transformer存在差异,后续我们将对这方面的工作进行探索,并集成到系统中。

  • Linear Attention:

线性Attention的核心思想是用kernel函数代替softmax,然后通过矩阵乘法的结合律,将序列长度维度的两层循环减少为一层循环,从而将Attention的计算量从序列长度的平方项降低为线性项。线性Attention减少计算量也会造成模型的精度损失,取决于核函数的设计。

相关的工作有Transformers are RNNs, RMKW, Linformer, Lightning, DiJiang等。

  • Sparse Attention:

Sparse Attention通过让每个token对应的向量只跟部分token对应的向量(可见域)计算相关度,使Attention矩阵计算变得稀疏。在Sparse Attention中,如何选择有相关性的元素进行计算,成为影响模型精度的关键。相关工作探索了固定可见域,如OpenAI 2019的Sparse Transformers,与输入数据相关的可见域,如ICLR 2024的Transformer-VQ等多种算法。

  • Mamba

Mamba基于状态空间模型(SSM, State space models),结合了RNN(表达隐藏状态和输入关系,去掉非线性激活函数)和CNN(并行训练),并根据输入动态调整模型的选择性参数,包括当前输入和历史状态信息对输出的影响系数、无关信息的过滤参数等,并通过硬件感知算法来优化计算效率,将Transoformer的计算效率变成序列长度线性相关。

相关的工作包括Mamba,Mamba在视觉模型上的应用如ZigMA, ICLR 24 Diffusion SSM等。

  • Compress Memory

Compress Memory是一种将长序列切分为一个个segment,将历史segment的信息编码到一个固定大小的memory中,将当前segment的attention和memory信息concat到一起,从而将计算复杂度降低到1/n_segments,而memory占用的空间为固定大小,以此方式可以计算无限长度的序列的attention。相关工作如Infini-attention,ICAE等。

四、实验

图9: 不同sequence length和context parallel(CP)下的MFU

为了衡量我们提出的 FlashSequence 的解决方案,我们以纯 Ulysses+FSDP 并使用 full GC 作为 baseline,其中 full GC 的 layer 数依据显存使用决定,少于模型 layer 数。图 9 展示了在 A100 上不同sequence length和context parallel(CP)下的MFU,可以看到,FlashSequence 的 MFU 相比 baseline 平均提高了 11.75%,性能平均提升了 23.3%。同时,FlashSequence 的方案在长序列场景可以获得和 FlashAttention 接近的 MFU,比如在 1M、CP=16 的场景下,FlashSequence 的 MFU 为 51.7%,接近占据 E2E 95%时间的 FlashAttention 53.5%的 MFU。

五、总结与展望

PAI-TorchAcc(Torch Accelerator)是阿里云机器学习平台开发的Pytorch上的大模型训练加速框架。PAI-TorchAcc 通过进行分布式优化、计算优化、显存优化等,为包括 SORA 模型在内的Pytorch上的模型提供高效训练支持。

目前,FlashSequence 已经集成到 PAI-TorchAcc 中,并在开源的 DiT 模型上验证了效果。此外,由于 SORA 所使用的 DiT 类模型结构与 LLM 模型基本类似, FlashSequence 也可以应用在大部分长序列训练场景。后续我们会陆续开源这些工作。

同时,目前长序列的最主要瓶颈都在 FlashAttention 的计算上面,如何优化 FlashAttention 的计算将成为长序列场景下的主要问题。由于 FlashAttention 的计算量按照 token 数平方增长,未来更可能的优化方向是探索计算量更低的 Attention 实现比如线性的 Attention,同时低精度如 FP8 训练、稀疏训练等也都是一些可以探索的方向。

【招聘】最后,如果你对大模型训练加速技术感兴趣,欢迎加入到我们的团队中。目前研究型实习生和社招都在火热招聘中,欢迎投递简历到研究型实习生 - 基于负载与硬件特性协同的大模型训练加速技术研究 或邮箱 wenting.swt@alibaba-inc.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1810653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解密Prompt系列31. LLM Agent之从经验中不断学习的智能体

前言 Agent智能体的工作流可以简单分成两种&#xff1a;一种是固定的静态工作流&#xff0c;一种是智能体自主决策的动态工作流。 静态流程的Agent举几个例子&#xff0c;例如新闻热点追踪推送Agent&#xff0c;每日新论文摘要总结Agent&#xff0c;它们的优点是可控&#xf…

SpringSecurity6从入门到实战之初始用户如何存储到内存(依旧源码级别讲解,耐心看完会有收获)

SpringSecurity6从入门到实战之初始用户如何存储到内存 文接上回,根据登录表单的提交最终得知用户相关信息存储在内存中.那么SpringSecurity是如何在项目启动时将用户信息存储到内存中的呢? 这里我们还是先回到SpringBoot加载配置的地方 UserDetailServiceAutoConfigutation 类…

【PowerDesigner】创建和管理CDM之使用实体间关系

目录 &#x1f30a;1. PowerDesigner简介 &#x1f30d;1.1 常用模型文件 &#x1f30d;1.2 PowerDesigner使用环境 &#x1f30a;2. 创建和管理CDM &#x1f30d;​​​​​​2.1 新建CDM &#x1f30d;2.2 使用实体间关系 &#x1f30c;a. 使用联系 &#x1f30c;b. …

2024年智能医疗与生物医药国际会议(ICIHB 2024)

2024 International Conference on Intelligent Healthcare and Biopharmaceuticals 【1】大会信息 会议简称&#xff1a;ICIHB 2024 大会地点&#xff1a;中国珠海 会议官网&#xff1a;www.icihb.com 投稿邮箱&#xff1a;icihbsub-paper.com 【2】会议简介 2024年智能医…

CISA网络安全事件应急手册

《Cybersecurity Incident & Vulnerability Response Playbooks》是美国CISA&#xff08;Cybersecurity and Infrastructure Security Agency&#xff0c;网络安全和基础设施安全局&#xff09;于2021年11月份发布的指导手册&#xff0c;是基于FCEB&#xff08;Federal Civ…

硬核新品!M4E EDU民航考培一体无人机

天途上新啦&#xff01; 应我国民用无人机首项强制性国家标准《民用无人驾驶航空器系统安全要求》&#xff0c;天途对现有小型无人机训练机的飞控、电池、感知避障和电子围栏等软硬件全面升级设计&#xff0c;严格按国标GB42590-2023规范生产。 M4E EDU四轴多旋翼无人机是天途…

浅谈word格式:.doc和.docx的优缺点及区别

.doc和.docx是两种最为常见的文档格式&#xff0c;它们在多个方面存在着显著的区别。首先&#xff0c;从版本角度来看&#xff0c;.doc是Microsoft Office Word 2003及之前版本的保存类型&#xff0c;而.docx则是Word 2007及之后版本的保存类型。这一区别直接影响了文档在不同版…

【递归、搜索与回溯】穷举vs暴搜vs深搜vs回溯vs剪枝

穷举vs暴搜vs深搜vs回溯vs剪枝 1.全排列2.子集 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f603; 管他什么深搜、回溯还是剪枝&#xff0c;画出决…

深圳市萨科微半导体有限公司

深圳市萨科微半导体有限公司凭借碳化硅、氮化镓等新材料、功率器件设计加工环节的先进工艺、高效管理和快速扩大生产规模&#xff0c;不断降低产品价格、提高市场的占有率&#xff0c;受到了世界各地客户的认可。萨科微具有高性能高可靠集成电路的独立研发能力和多年技术储备&a…

Craig Federighi 和 John Giannandrea 在 WWDC 上谈论苹果智能技术

WWDC 主题演讲结束后&#xff0c;苹果公司的克雷格-费德里吉&#xff08;Craig Federighi&#xff09;和约翰-吉安南德雷亚&#xff08;John Giannandrea&#xff09;坐下来&#xff0c;更深入地讨论了苹果智能公司在人工智能方面所做的努力&#xff0c;包括该公司是如何训练模…

在AWS上运行的EKS Elastic Kubernetes Service 创建集群Cluster,Node group, Nodes

1. 前提条件 AWS Account: https://aws.amazon.com/free/Installing KubeCtl CLI https://docs.aws.amazon.com/eks/latest/userguide/getting-started-eksctl.htmlEKS Cluster RoleIAM Role for Node GroupVPCEC2 Key Pair which can be used to SSH to the worker nodesAWS …

深入理解Vue3.js响应式系统基础逻辑

如果您觉得这篇文章有帮助的话&#xff01;给个点赞和评论支持下吧&#xff0c;感谢~ 作者&#xff1a;前端小王hs 阿里云社区博客专家/清华大学出版社签约作者/csdn百万访问前端博主/B站千粉前端up主 此篇文章是博主于2022年学习《Vue.js设计与实现》时的笔记整理而来 书籍&a…

cad导入su线条不在一个平面怎么办?

解决CAD导入sketchup线条不是共面问题&#xff0c;需要考虑到各个步骤如下&#xff1a; 1&#xff09;检查CAD文件。首先要检查CAD文件&#xff0c;确保线条是连接在一起的&#xff0c;并且看看有没有多余的线&#xff0c;以及是否有子线段没有合并&#xff0c;如果有会导致导入…

AdroitFisherman模块测试日志(2024/6/10)

测试内容 测试AdroitFisherman分发包中SHAUtil模块。 测试用具 Django5.0.3框架&#xff0c;AdroitFisherman0.0.31 项目结构 路由设置 总路由 from django.contrib import admin from django.urls import path,include from Base64Util import urls urlpatterns [path(ad…

猫狗识别(超详细版)(py代码)

猫狗识别&#xff08;一&#xff09; 一、图像识别 1.导入必要的库: import torchimport numpy as npimport torchvisionfrom os import pathfrom torchvision import datasets, modelsimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataL…

【NUCLEO-G071RB】009——HAL库-显示编译时间

NUCLEO-G071RB&#xff1a;009——HAL库-显示编译时间 编译时间设计目标程序修改运行测试 编译时间 这里的编译时间指的是烧录文件的编译时间&#xff0c;它由编译环境的日期和时间共同决定。 设计目标 1、获取编译时间&#xff0c;默认是ASC码格式 2、将编译时间转换为HEX …

哈尔滨等保如何做?

哈尔滨等保测评是确保信息系统安全稳定运行的重要一环&#xff0c;它涉及到对业务、资产、安全技术和安全管理的全面调研和评估。本文将详细阐述哈尔滨等保测评的实施步骤和注意事项&#xff0c;帮助读者更好地理解和执行等保测评工作。 首先&#xff0c;我们需要明确等保测评的…

新品发布 | 捷云等保一体机2.0全新上市,助力中小企业破解等保难题

等保2.0时代&#xff0c;随着网络威胁不断复杂化和组织化&#xff0c;作为网络安全“弱势群体”的中小企业&#xff0c;等保建设工作正面临着安全意识、管理、人才、资金捉襟见肘等问题&#xff0c;主要体现在以下两个方面&#xff1a; 等保建设流程复杂 中小企事业单位缺乏专…

条件概率的理解

P(A)表示A的先验概率 P(B)表示B的先验概率 P(A | B)表示在B发生的情况下&#xff0c;A的条件概率 P(B | A)表示在A发生的情况下&#xff0c;B的条件概率 先验概率是在进行实验之前基于当前知识对结果概率的最佳合理评估。后验概率是在考虑了新信息后&#xff0c;事件发生的修正…

行为树BehaviorTree

主要依托于BehaviorTree.CPP进行介绍。 1 基本概念 1.1 是什么与用来做什么 官网 https://www.behaviortree.dev/docs/learn-the-basics/BT_basics Unlike a Finite State Machine, a behavior Tree is a tree of hierarchical nodes that controls the flow of execution o…