【LLM加速】注意力优化(基于位置/内容的稀疏注意力 | flashattention)

news2025/1/10 20:37:32

note

(1)近似注意力:

  • Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。
  • Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。

(2)在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存并消耗了大量的计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型面临的重要问题。

  • 方法一:从近似注意力出发,旨在减少注意力计算和内存需求,提出了稀疏近似、低秩近似等方法。
  • 方法二:从计算加速设备本身的特性出发,研究如何更好地利用硬件特性对Transformer 中的注意力层进行高效计算。

(3)FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵。

文章目录

  • note
  • 一、近似注意力
    • 1. 基于位置的稀疏注意力机制
    • 2. 基于内容的稀疏注意力机制
      • (1)Routing Transformer:使用聚类
      • (2)Reformer:使用LSH哈希
  • 二、计算加速
    • 1. GPU硬件基础知识
    • 2. flashattention
    • 3. 多查询注意力MQA
      • (1)MHA和MQA的区别
      • (2)MHA和MQA的具体代码
      • (3)使用矩阵乘法matmul广播实现参数共享
      • (4)tgi框架中的MQA
  • Reference

一、近似注意力

对一些训练好的Transformer 结构中的注意力矩阵进行分析时发现,其中很多是稀疏的,因此可以通过限制Query-Key 对的数量来降低计算复杂度。这类方法称为稀疏注意力(SparseAttention)机制。可以将稀疏化方法进一步分成基于位置的和基于内容信息的两类。

1. 基于位置的稀疏注意力机制

基于位置的稀疏注意力机制的基本类型如下图,主要包含如下五种类型:全局注意力(Global Attention)、带状注意力(Band Attention)、膨胀注意力(Dilated Attention)、随机注意力(Random Attention)、局部块注意力(Block Local Attention)。

这些注意力机制的区别主要在于它们如何选择序列中的元素来计算注意力权重,这直接影响计算复杂度、处理长距离依赖的能力以及对不同类型任务的适用性。每种注意力机制的关键区别和特点:

  1. 全局注意力(Global Attention):

    • 关键特点:在计算每个位置的注意力时,考虑序列中的所有其他位置。
    • 优点:能够捕获全局依赖性,理论上可以处理任意距离的关系。
    • 缺点:计算复杂度高,随序列长度的平方增长,不适合处理长序列。
  2. 带状注意力(Band Attention):

    • 关键特点:仅在每个位置的一个固定宽度的带内计算注意力权重,通常集中在序列的对角线附近。
    • 优点:减少了计算量,适合捕获局部依赖性。
    • 缺点:可能忽略重要的长距离依赖。
  3. 膨胀注意力(Dilated Attention):

    • 关键特点:通过引入膨胀因子来间隔地选择序列中的元素进行注意力计算,从而覆盖更广的范围。和CNN中的Dilated Conv类似,通过增加空隙以获取更大的感受野
    • 优点:在降低计算复杂度的同时,能够捕获更远的依赖性。
    • 缺点:可能不如全局注意力在捕捉所有长距离依赖上有效。
  4. 随机注意力(Random Attention):

    • 关键特点:随机选择序列中的位置来计算注意力权重。即通过随机采样,提升非局部的交互。
    • 优点:显著降低计算需求,引入随机性可能帮助模型探索更多的依赖关系。
    • 缺点:随机性可能导致忽略一些关键的依赖关系。
  5. 局部块注意力(Block Local Attention):

    • 关键特点:将序列分割成多个块,在这些局部块内计算注意力权重。使用多个不重叠的块Block来限制信息交互。
    • 优点:大幅降低计算复杂度,适合处理长序列。
    • 缺点:如果不允许跨块计算,则可能忽略块间的依赖关系。

总结来说,这些注意力机制通过不同的策略平衡计算复杂度和模型的捕获依赖能力。选择哪种注意力机制取决于特定任务的需求,例如处理长序列数据时可能更倾向于使用带状、膨胀、随机或局部块注意力机制,而在不那么受限于计算资源的情况下,全局注意力可能是最好的选择,因为它能够捕获全局依赖性。

在这里插入图片描述
下面给出带状注意力的栗子:

# query-shape: [bs, seq_len, emb_dim]
def band_attention(query, key, value, band_width):
    """
    Args:
        query, key, value: standard attention inputs
        band_width: The width of the band around the diagonal to compute attention.
    Returns:
        Tensor: The output of the attention mechanism.
    """
    batch_size, seq_len, d_k = query.size()
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    # Create a mask to zero out attention scores outside the band
    idxs = torch.arange(seq_len).unsqueeze(0).to(query.device)
    mask = (idxs - idxs.transpose(0, 1)).abs().ge(band_width).to(scores.dtype)
    scores.masked_fill_(mask, float('-inf'))

    attention = F.softmax(scores, dim=-1)
    output = torch.matmul(attention, value)
    return output

# 测试的case
def band_attention_test():
    import torch
    # 假设输入数据的维度
    batch_size = 2
    seq_length = 10
    embed_size = 128
    heads = 8
    # 生成随机数据作为输入
    values = torch.rand(batch_size, seq_length, embed_size)
    keys = torch.rand(batch_size, seq_length, embed_size)
    queries = torch.rand(batch_size, seq_length, embed_size)
    # 定义带宽
    band_width = 3
    # 使用相同的随机数据输入
    band_attention_output = band_attention(queries, keys, values, band_width)
    # Band Attention Output Shape: torch.Size([2, 10, 128])
    print("Band Attention Output Shape:", band_attention_output.shape)

可以看到上面的mask矩阵确实是带状的:
在这里插入图片描述

现有的稀疏注意力机制,通常是基于上述五种基于位置的稀疏注意力机制的复合模式,下图给出了一些典型的稀疏注意力模型:

  • star-transformer:使用带状注意力和全局注意力的组合,只包括一个全局注意力节点和宽度为3的带状注意力,其中任意两个非相邻节点通过一个共享的全局注意力连接,而相邻节点则直接相连。
  • longformer:将上层中的一些带状注意力头部替换为具有扩张窗口的注意力,在增加感受野同时不增加计算量
  • ETC(Extended Transformer Construction):利用带状注意力和外部全局节点注意力(External Global-node Attention)的组合。ETC 稀疏注意力还包括一种掩码机制来处理结构化输入,并采用对比预测编码(Contrastive Predictive Coding,CPC)进行预训练。
  • BigBird:使用带状和全局注意力,还使用额外的随机注意力来近似全连接注意力,此外还揭示了稀疏编码器和稀疏解码器的使用可以模拟任何图灵机

在这里插入图片描述

2. 基于内容的稀疏注意力机制

基于内容的稀疏注意力机制根据输入数据创建稀疏注意力,其中一种很简单的方法是选择和给定查询 (Query) 有很高相似度的键 (Key)。

(1)Routing Transformer:使用聚类

(1)Routing Transformer采用K-means 聚类方法,针对Query和Key进行聚类,类中心向量集合为 { μ i } i = 1 k \left\{\boldsymbol{\mu}_i\right\}_{i=1}^k {μi}i=1k ,其中k 是类中心的个数。每个Query 只与其处在相同簇 (Cluster) 下的Key 进行交互。中心向量采用滑动平均的方法进行更新:
μ ~ ← μ ~ + ( 1 − λ ) ( ∑ i : μ ( q i ) = μ q i + ∑ j : μ ( k j ) = μ k j ) c μ ← λ c μ + ( 1 − λ ) ∣ μ ∣ μ ← μ ~ c μ \begin{gathered} \widetilde{\boldsymbol{\mu}} \leftarrow \tilde{\boldsymbol{\mu}}+(1-\lambda)\left(\sum_{i: \mu\left(\boldsymbol{q}_i\right)=\mu} \boldsymbol{q}_i+\sum_{j: \mu\left(\boldsymbol{k}_j\right)=\mu} \boldsymbol{k}_j\right) \\ c_\mu \leftarrow \lambda c_\mu+(1-\lambda)|\mu| \\ \mu \leftarrow \frac{\widetilde{\boldsymbol{\mu}}}{c_\mu} \end{gathered} μ μ~+(1λ) i:μ(qi)=μqi+j:μ(kj)=μkj cμλcμ+(1λ)μμcμμ

(2)Reformer:使用LSH哈希

(2)Reformer 则采用局部敏感哈希 (Local-Sensitive Hashing,LSH) 的方法为每个Query 选择Key-Value 对。其主要思想是使用LSH 函数对Query 和Key 进行哈希计算,将它们划分到多个桶内,以提升在同一个桶内的Query 和Key 参与交互的概率。假设 b b b 是桶的个数,给定一个大小为 [ D k , b / 2 ] [D k , b / 2] [Dkb/2] 的随机矩阵 R R R , LSH 函数的定义为:
h ( x ) = arg ⁡ max ⁡ ( [ x R ; − x R ] ) h(\boldsymbol{x})=\arg \max ([\boldsymbol{x} R ;-\boldsymbol{x} R]) h(x)=argmax([xR;xR])

如果 h q i = h k j h \boldsymbol{q}_i=h \boldsymbol{k}_j \quad hqi=hkj 时, q i \boldsymbol{q}_i qi 才可以与相应的Key-Value对进行交互。

二、计算加速

1. GPU硬件基础知识

NVIDIA GPU中的内存(显存)按照它们物理上是在GPU芯片内部还是板卡RAM存储芯片上,决定了它们的速度、大小以及访问限制。GPU显存分为:

  • 全局内存(Global memory)
  • 本地内存(Local memory)
  • 共享内存(Shared memory,SRAM)
  • 寄存器内存(Register memory)
  • 常量内存(Constant memory)
  • 纹理内存(Texture memory)

在这里插入图片描述

全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM)位于板卡RAM存储芯片上,该部分内存容量很大。全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100中全局内存有80GB空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低

共享内存和寄存器位于GPU芯片上,因此容量很小,并且共享内存只有在同一个GPU线程块(Thread Block)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100中每个GPU线程块在流式多处理器(Stream Multi-processor,SM)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。

根据自注意力机制的原理,在GPU中进行计算时,传统的方法还需要引入两个中间矩阵 S 和 P 并存储到全局内存中。具体计算过程如下:
S = Q × K , P = Softmax ⁡ ( S ) , O = P × V \boldsymbol{S}=\boldsymbol{Q} \times \boldsymbol{K}, \boldsymbol{P}=\operatorname{Softmax}(\boldsymbol{S}), \boldsymbol{O}=\boldsymbol{P} \times \boldsymbol{V} S=Q×K,P=Softmax(S),O=P×V

按照上述计算过程,需要:

  • 首先从全局内存中读取矩阵 Q Q Q K K K ,并将计算好的矩阵 S S S再写入全局内存
  • 之后再从全局内存中获取矩阵 S S S ,计算Softmax得到矩阵 P P P 再写入全局内存
  • 之后读取矩阵 P P P 和矩阵 V V V ,计算得到矩阵 O O O

这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多 ,因此计算效率越来越多地受到全局内存访问的瓶颈。

2. flashattention

FlashAttention就是通过利用GPU硬件中的特殊设计,针对全局内存和共享存储的I/O速度的不同,尽可能地避免HBM中读取或写入注意力矩阵。

FlashAttention目标是尽可能高效地使用SRAM来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax函数,并且后向传播中不能存储中间注意力矩阵

在这里插入图片描述

FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存消耗的方法。

FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入,存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内存中读取中间注意力矩阵的标准方法更快。

虽然大幅减少了全局内存的访问量,重新计算也导致FLOP(FLOPS指标,Floating Point Operations per Second 指每秒浮点运算次数) 增加,但其运行的速度更快且使用的内存更少。

在这里插入图片描述

3. 多查询注意力MQA

多查询注意力(Multi Query Attention)是多头注意力的一种变体。其特点是,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数,因此键和值的矩阵仅有一份,这大幅减少了显存占用,使其更高效。
在这里插入图片描述

由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。文献研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约5% 的原始训练数据量就可以达到不错的效果。

包括Falcon[64]、SantaCoder[65]、StarCoder[66] 在内的很多模型都采用了多查询注意力机制。

(1)MHA和MQA的区别

MHA 和 MQA 之间的区别主要在于建立 Wqkv Layer 上(如下代码)。在MQA中,除了query向量还保存8个头,key和value向量都只剩下1个【公共头】,即前面说的所有head之间共享一份key和value参数。

# Multi Head Attention
self.Wqkv = nn.Linear(                        # 【关键】Multi-Head Attention 的创建方法
    self.d_model, 
    3 * self.d_model,                         # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(                # 【关键】每个 tensor 都是 (1, 512, 768)
    3, 
    dim=2
)
  
# Multi Query Attention
self.Wqkv = nn.Linear(                                # 【关键】Multi-Query Attention 的创建方法
    d_model,
    d_model + 2 * self.head_dim,                      # 只创建 query 的 head 向量,所以只有 1 个 d_model
    device=device,                                    # 而 key 和 value 不再具备单独的头向量
)
query, key, value = qkv.split(                        # query -> (1, 512, 768)
    [self.d_model, self.head_dim, self.head_dim],     # key   -> (1, 512, 96)
    dim=2                                             # value -> (1, 512, 96)
)

(2)MHA和MQA的具体代码

其中MultiheadAttentionMultiQueryAttention类完整的代码如下。

class MultiheadAttention(nn.Module):
 
     def __init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
        ):
         """
        Multi Head init func.
 
        Args:
            d_model (int): hidden state size, e.g. 768
            n_heads (int): 设定的注意力头数, e.g. 8
            device (str): _description_
        """
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
     
         self.Wqkv = nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model,
             3 * self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
        )                                            # (d_model, 3 * d_model)
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
 
     def forward(
         self,
         x
    ):
         """
        forward func.
 
        Args:
            x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
 
        Returns:
            _type_: _description_
        """
         qkv = self.Wqkv(x)                            # (1, 768, 3 * 768)
 
         query, key, value = qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3,
             dim=2
        )    
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads
        )                                             # (1, 512, 768)
 
         return self.out_proj(context), attn_weights, past_key_value
 
 
 class MultiQueryAttention(nn.Module):
     """Multi-Query self attention.
 
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
 
     def __init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] = None,
    ):
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
         self.head_dim = d_model // n_heads
 
         self.Wqkv = nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
 
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
         self.out_proj._is_residual = True  # type: ignore
 
     def forward(
         self,
         x,
    ):
         qkv = self.Wqkv(x)                                           # (1, 512, 960)
 
         query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],             # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
        )
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
        )
 
         return self.out_proj(context), attn_weights, past_key_value

(1)初始化函数 __init__

  • __init__(self, d_model: int, n_heads: int, device: Optional[str] = None): 这是类的初始化函数,用于创建类的实例时初始化其属性。它接受三个参数:模型的维度 d_model、注意力头的数量 n_heads,以及设备 device(可选),用于指定模块运行的硬件(CPU或GPU)。
  • self.d_model = d_modelself.n_heads = n_heads: 这两行代码将传入的模型维度和头的数量保存为类的属性。
  • self.head_dim = d_model // n_heads: 计算每个头的维度,即将模型维度均分到每个头上。
  • self.Wgkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device): 创建一个线性层 Wgkv,用于生成查询(Q)、键(K)和值(V)。这个线性层的输出维度是 d_model + 2 * self.head_dim,意味着查询的维度保持为 d_model,而键和值的维度为 self.head_dim。这种设计减少了模型参数,因为它没有为键和值分别创建额外的线性变换。
  • self.attn_fn = scaled_multihead_dot_product_attentionself.out_proj = nn.Linear(self.d_model, self.d_model, device=device): 定义了一个注意力函数 attn_fn 和一个输出投影层 out_projattn_fn 负责计算多头点积注意力,而 out_proj 用于将注意力机制的输出转换回原始输入的维度。

(2)前向传播函数 forward

  • def forward(self, X): 定义了前向传播函数,它接收一个输入张量 X
  • gkv = self.Wgkv(X): 首先,输入通过 Wgkv 线性层,产生了合并的查询、键、值矩阵。
  • query, key, value = gkv.split([self.d_model, self.head_dim, self.head_dim], dim=2): 然后,将 gkv 拆分为查询、键和值三部分。注意拆分的维度与 Wgkv 层的输出设计相匹配。
  • context, attn_weights, past_key_value = self.attn_fn(query, key, value, self.n_heads, multiquery=True): 使用定义的注意力函数计算注意力,multiquery=True 参数指示使用多查询注意力机制。
  • return self.out_proj(context), attn_weights, past_key_value: 最后,将注意力的输出通过 out_proj 投影层,然后将结果、注意力权重和过去的键值对返回。

(3)使用矩阵乘法matmul广播实现参数共享

其中注意上面的scaled_multihead_dot_product_attention函数就是实现刚才说的一份key和value参数让多个头使用,使用矩阵乘法matmul进行广播,实现参数共享。

def scaled_multihead_dot_product_attention(
        query,
        key,
        value,
        n_heads,
        multiquery=False,
    ):
    q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
    v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                    # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery
    
    attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
    attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)
 
    out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
    out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)
 
    return out, attn_weight, past_key_value

(4)tgi框架中的MQA

具体还可以参考tgi框架中的MQA代码:

class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.

    Using torch or triton attention implementation enables user to also use
    additive bias.
    """

    def __init__(self, config, prefix, weights):
        super().__init__()
        attn_impl = config.attn_config["attn_impl"]
        self.attn_impl = config.attn_config["attn_impl"]
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.d_model = config.d_model
        d_model = config.d_model
        self.n_heads = config.n_heads
        self.softmax_scale = config.attn_config["softmax_scale"]
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout_p = config.attn_config["attn_pdrop"]
        # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
        self.Wqkv = TensorParallelColumnLinear.load(
            config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
        )
        fuse_splits = (d_model, d_model + self.head_dim)
        if self.qk_ln:
            raise NotImplementedError("qk_ln not supported")
        if self.attn_impl == "flash":
            self.attn_fn = flash_attn_fn
        elif self.attn_impl == "triton":
            self.attn_fn = triton_flash_attn_fn
            if verbose:
                warnings.warn(
                    "While `attn_impl: triton` can be faster than `attn_impl: flash` "
                    + "it uses more memory. When training larger models this can trigger "
                    + "alloc retries which hurts performance. If encountered, we recommend "
                    + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
                )
        elif self.attn_impl == "torch":
            self.attn_fn = scaled_multihead_dot_product_attention
            if torch.cuda.is_available() and verbose:
                warnings.warn(
                    "Using `attn_impl: torch`. If your model does not use `alibi` or "
                    + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
                    + "we recommend using `attn_impl: triton`."
                )
        else:
            raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
        self.out_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.out_proj",
            weights=weights,
            bias=not config.no_bias,
        )
        # self.out_proj._is_residual = True

    def forward(
        self,
        x,
        past_key_value=None,
        attn_bias=None,
        attention_mask=None,
        is_causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        (query, key, value) = qkv.split(
            [self.d_model, self.head_dim, self.head_dim], dim=2
        )
        key_padding_mask = attention_mask
        if self.qk_ln:
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        (context, attn_weights, past_key_value) = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weights,
            multiquery=True,
        )
        return (self.out_proj(context), attn_weights, past_key_value)

Reference

[1] https://github.com/huggingface/text-generation-inference
[2] LLM 加速技巧:Muti Query Attention
[3] 训练模型算力的单位:FLOPs、FLOPS、Macs 与 估算模型(FC, CNN, LSTM, Transformers&&LLM)的FLOPs
[4] FlashAttention 的速度优化原理是怎样的?
[5] FlashAttention图解(如何加速Attention)
[6] flashattention论文:https://arxiv.org/pdf/2205.14135.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1522964.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

源于一区| 改善性能的5种高效而小众的变异策略,一键调用 (Matlab)

基于群体的优化算法在达到迭代后期时种群多样性往往会速降,进化将陷入停滞,而许多算法本身并没有突变机制,一旦受到局部最优值的约束,就很难摆脱这些约束。它还将减少种群多样性,减缓收敛速度。 变异策略可以增加种群…

2025武忠祥考研数学,视频百度网盘+基础全程课程PDF

“得数学者的天下”,25考研首先要开始的就是数学复习,而数学复习首先要开始的必然是高数! 很多同学选择了跟着武忠祥老师学习高数,但是具体要怎么学?用什么书?怎么刷题?快来看看以 下的武忠祥…

GenAI开源公司汇总

主要分类如下: 1. 基础模型:这些是机器学习和AI的核心模型提供商,它们提供基础的算法和技术支持。 2. 模型部署与推断:提供云服务和计算资源,帮助用户部署和运行AI模型。 3. 开发者工具:支持AI/ML的开发…

【01】htmlcssgit

01-前端干货-html&css 防脱发神器 一图胜千言 使用border-box控制尺寸更加直观,因此,很多网站都会加入下面的代码 * {margin: 0;padding: 0;box-sizing: border-box; }颜色的 alpha 通道 颜色的 alpha 通道标识了色彩的透明度,它是一个 0~1 之间的取值,0 标识完全…

开发指南013-国际化-后台部分

平台底层做了国际化处理。开发时候根据项目性质,决定是否采用国际化,但是底层所需资源必须包含(一些底层例如登录校验都做了对应处理)。平台先支持中文简体、中文繁体、英文、日文,必要时可以随时扩展其他语言。 国际化…

单片机FLASH深度解析和编程实践(上)

本篇文章主要针对单片机FLASH编程和FLASH基本原理进行学习分享。以STM32单片机作为实例进行编程实训。 关于FLASH操作的相关寄存器及编程,大家可以参考下一篇文章: 单片机FLASH深度解析和编程实践(下)-CSDN博客 目录 一、STM32编程方式 二、…

Linux批量注释

1.注释行 1.按ctrlv进入块选择模式 ,然后上下键选中需要注释的行 2.按shifti(也就是大写I) 然后输入// 或 # 3.按ESC键 2.取消注释行 1.按ctrlv进入块选择模式, 然后上下键选中需要取消注释的行 2.然后按d

QT C++ QButtonGroup应用

//QT 中,按钮数量比较少,可以分别用各按钮的信号和槽处理。 //当按钮数量较多时,用QButtonGroup可以实现共用一个槽函数,批量处理,减少垃圾代码, //减少出错。 //开发平台:win10QT6.2.4 MSVC…

面向控制台编程?Java的GUI开发

记得之前刚开始学习Java,按部就班去阅读《Java核心技术》这本书的时候,总是听别人提起,java swing那一章不用看了。然后直到对着控制台编程了半年,回来捡起了Swing图形界面,跟着网上搞了坦克大战的游戏,总觉…

【蓝桥杯选拔赛真题38】C++判断数字 第十四届蓝桥杯青少年创意编程大赛 算法思维 C++编程选拔赛真题解析

目录 C判断数字 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 七、推荐资料 C判断数字 第十四届蓝桥杯青少年创意编程大赛C选拔赛真题 一、题目要求 1、编程实现 给定一个正整数N(100≤N<100000)…

从零开始搭建游戏服务器 第二节 Actor模型与应用

目录 复习本节内容正文什么是Actor模型如何应用创建Actor基类创建RootActor创建AkkaContext创建ConnectActorManager和ConnectActor生成actor并发送消息给它 课后作业结尾 复习 上一节我们使用gradle构建了一个多模块系统。 并且在登录服启动了Netty服务&#xff0c;监听confi…

字符串的模式匹配算法

一、朴素模式匹配算法 二、KMP算法 三、KMP求Next数组 四、KMP求NextVal数组

Nacos与Eureka的使用与区别

Nacos与Eureka的使用与区别 单体架构&#xff1a;优点缺点 分布式架构需要考虑的问题&#xff1a;微服务企业需求 认识SpringCloud服务的拆分与远程调用微服务调用方式 Eureka提供者和消费者架构搭建Eureka服务注册服务发现 Ribbon负载均衡饥饿加载总结 Nacos注册中心Nacos安装…

Task-balanced distillation for object detection用于

Task-balanced distillation for object detection用于目标检测的任务平衡蒸馏 摘要 主流的目标检测器通常由分类和回归两个子任务组成&#xff0c;由两个并行头部实现。这种经典的设计范式不可避免的导致分类得分和定位质量&#xff08;IOU&#xff09;之间的空间分布不一致…

鸿蒙开发实现弹幕功能

鸿蒙开发实现弹幕功能如下&#xff1a; 弹幕轮播组件&#xff1a;BannerScroll import type { IDanMuInfoList, IDanMuInfoItem } from ../model/DanMuData //定义组件 Component export default struct BannerScroll {//Watch 用来监视状态数据的变化&#xff0c;包括&#…

软考79-上午题-【面向对象技术3-设计模式】-结构型设计模式02

一、组合模式 1-1、意图 将对象组合成树型结构&#xff0c;以表示"部分-整体"的层次结构。Composite使得用户对单个对象和组 合对象的使用具有一致性。 示例&#xff1a;对象&#xff1a;文件、文件夹 1-2、结构 Component 为组合中的对象声明接口&#xff1b;在适…

CKA认证之Etcd备份与恢复

题目介绍&#xff1a; 资料参考&#xff1a; https://kubernetes.io/zh-cn/docs/tasks/administer-cluster/configure-upgrade-etcd 解题&#xff1a; 1、备份 #参考模板列出 etcdctl 可用的各种选项。 #例如&#xff0c;你可以通过指定端点、证书和密钥来制作快照&#xff0…

Android 开发 地图 polygon 显示信息

问题 Android 开发 地图 polygon 显示信息 详细问题 笔者进行Android项目开发&#xff0c;接入高德地图绘制区域后&#xff0c;需要在指定区域&#xff08;位置&#xff09;内显示文本信息&#xff0c;如何实现 实现效果 解决方案 代码 import com.amap.api.maps.model.T…

Day47-http和www基础2

Day47-http和www基础2 1.Linux下实践观察HTTP协议通信过程1.1 DNS解析过程1.2 tcp三次握手1.3 发起HTTP请求报文阶段1.4 请求报文阶段1.5 响应报文部分1.6 分析URL&#xff0c;DNS解析1.7 HTTPS连接建立的过程1.8 请求报文阶段1.8.1 referer1.8.2 如何解决防盗链&#xff1f; 2…

【机器学习】分类模型的评价方法

&#x1f33b;个人主页&#xff1a;相洋同学 &#x1f947;学习在于行动、总结和坚持&#xff0c;共勉&#xff01; #学习笔记# 目录 一、混淆矩阵&#xff08;Confusion Matrix&#xff09; 二、评估指标&#xff08;Evaluation metrics&#xff09; 1.正确率(accuracy) …