chatGLM2中的Multi Query Attention

news2025/1/18 3:22:28

目录

原理简介

代码实现和耗时比较

总结分析

          近期一直在玩大模型,对中文支持比较好的就是清华的chatGLM,目前chatGLM由v1升级到了chatGLM2。在gihub上介绍信息如下:

 试用了一下,效果和速度确实有所提升。

 这个得益于chatGLM2应用了许多优化的技术,介绍中有提到过的FlashAttention技术、Multi Query Attention(MQA)技术和int4量化等等。其中MQA技术是对Multi head  Attention(MHA)的一种优化实现,加快了技术速度的同时也保证了效果下降的不厉害。

原理简介

       MQA最早是出现在2019年谷歌的一篇论文Fast Transformer Decoding: One Write-Head is All You Need,之所以没有关注到,是因为之前很少做文本生成,解码序列长度也没有现阶段大模型的要求那么高。MQA的思想其实比较简单(如果对MHA比较熟悉的话),论文中给出的描述如下:

论文的意思是:MQA和MHA除了不同的attention head共享一份keys和values权重之外,其他的都是一样的。现有4个head的attention,每个head分别进行softmax(QK)V注意力计算,那么这样设置的MHA和MQA示意图如下所示:

 

 可以看到MHQ和MQA的不同之处仅仅在于每个头共享相同的K、V权重而Q不同享。

模型效果论文对比如下:

 推理速度上生成一个token时MHA和MQA的encoder分别耗时1.7us和1.5us,而decoder分别46us和3.8us,说明decoder上MQA比MHA快很多。另外在效果上MQA的PPL(越小越好)有所上升,BLEU(越大越好)有所下降,换句话说就是效果有所下降。

代码实现和耗时比较

参考了huggingface的transformers包中的bertselfattention源码实现了一版MHA和MQA,代码如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import math
import torch.nn as nn
import torch
from tqdm import tqdm
import time
class MiltiHeadSelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(0.1)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self,hidden_states):
        mixed_query_layer = self.query(hidden_states)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return context_layer


class MultiQuerySelfAttention(nn.Module):
    def __init__(self, num_attention_heads, hidden_size):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.attention_head_size)
        self.value = nn.Linear(hidden_size, self.attention_head_size)

        self.dropout = nn.Dropout(0.1)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self,hidden_states):
        # hidden_states (B, L, D)
        mixed_query_layer = self.query(hidden_states)
        # query_layer  (B, h, L, d)
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 每个key、value head参数都是一样的,只计算一次
        key = self.key(hidden_states)
        #key_layer  (B, 1, L, d)
        key_layer = key.unsqueeze(1)
        value = self.value(hidden_states)
        # value_layer  (B, 1, L, d)
        value_layer = value.unsqueeze(1)

        # key_layer  (B, 1, d, L)
        key_layer = key_layer.transpose(-1, -2)
        #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L)
        attention_scores = torch.matmul(query_layer, key_layer)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d)
        context_layer = torch.matmul(attention_probs, value_layer)
        #(B, h, L, d) => (B, L, h, d)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # (B,L, h*d) => (B,L,D)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        # (B,L, h*d) => (B,L,D)
        context_layer = context_layer.view(new_context_layer_shape)
        return context_layer




if __name__ == '__main__':
    seed = 100
    num_attention_heads, hidden_size = 32, 4096
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    device = "cuda:0"

    embeddings = torch.randn(5, 128, hidden_size).to(device)

    multiquery = MultiQuerySelfAttention(num_attention_heads, hidden_size).to(device)
    print(multiquery)
    total = 0
    for name, param in multiquery.named_parameters():
        if len(param.size()) == 2:
            total += param.shape[0] * param.shape[1]
        else:
            total += param.shape[0]
    print(f"multiquery parameters {total}")
    count = 100
    start = time.time()
    for _ in tqdm(range(count),ncols=50):
        input = embeddings.clone()
        for _ in range(100):
            for i in range(24):
                ouput = multiquery(input)
            input = torch.cat([input,ouput[:,-1:,:]],dim=1)
    end = time.time()
    print(f"multiquery time total cost {round(end - start, 8)} mean cost {round((end - start) / count, 8)}")


    multihead = MiltiHeadSelfAttention(num_attention_heads, hidden_size).to(device)
    print(multihead)
    total = 0
    for name, param in multihead.named_parameters():
        if len(param.size()) == 2:
            total += param.shape[0] * param.shape[1]
        else:
            total += param.shape[0]
    print(f"multihead parameters {total}")
    count = 100
    start = time.time()
    for _ in tqdm(range(count) ,ncols=50):
        input = embeddings.clone()
        for _ in range(100):
            for i in range(24):
                ouput = multihead(input)
            input = torch.cat([input, ouput[:, -1:, :]], dim=1)
    end = time.time()
    print(f"multihead time total cost {round(end-start,8)} mean cost {round((end-start)/count,8)}")

实现中主要借助矩阵计算的broadcast机制(自动广播机制)并行计算、就不用自己来实现每个头单独计算然后进行cat操作,效率比较高。模拟chatGLM2的设置,hidden_size = 4096、num_heads =32,num_layers=24输入一个维度为(5,128,4096)的向量进行文本解码,生成100个token,耗时对比如下:

 生成100个token时,MQA解码平均耗时2.7826秒,MHA解码平均耗时6.4796秒,简单来看MQA在decoder解码加速了一倍。从模型结构来看原始的MHA一层5034W参数,而MQA只有1783W参数,还是通过压缩参数量来实现显存占用的减少以及推理时间的减少。

总结分析

显存占用和推理耗时减小是显而易见的,因为参数量减少了。至于效果变化得很小,只能说多头attention机制中的多头其实并不是一定,之前的bert模型有人探索了改变head头数目,也会保持效果变化不大。在大模型这,可能只需要有不同的head采用不同的query向量,kv一样来保证每个头提取到不同的特征就够了。

什么时候使用MQA有效呢?

1、采用attention的模型,模型规模越大,那么收益就约明显。

2、decoder生成任务相比较encoder任务收益明显大很大,其实decoder生成任务的收益来源于每一次softmax(QK)V注意力计算微小耗时差异的累积,一次生成任务要生成许多个token,一个token需要经历模型结构层数次的softmax(QK)V注意力的计算。

参考文章

Fast Transformer Decoding: One Write-Head is All You Need

ChatGLM2-6B

 huggingface / transformers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pnpm + workspace + changesets

pnpm workspace changesets 构建你的 monorepo 工程 什么是monorepo? 什么是 monorepo?以及和 multirepo 的区别是什么? 关于这些问题,在之前的一篇**介绍 lerna** 的文章中已经详细介绍过,感兴趣的同学可以再回顾下。 简而…

nacos学习积累

官方文档:https://nacos.io/zh-cn/docs/quick-start.html 1 注册中心简介 注册中心对比和选型:Zookeeper、Eureka、Nacos、Consul和ETCD 如果消费者直接连接的提供者。这样做的问题是,若提供者出现宕机,或消费者存在高并发情况&…

SSRF漏洞

前言 作者简介:不知名白帽,网络安全学习者。 博客主页:不知名白帽的博客_CSDN博客-网络安全,CTF,内网渗透领域博主 网络安全交流社区:https://bbs.csdn.net/forums/angluoanquan 目录 SSRF漏洞原理 产生CSRF的函数 SSRF中常见手…

MySQL原理探索——27 主库出问题了,从库怎么办

在前面的第24、25和26篇文章中,介绍了 MySQL 主备复制的基础结构,但这些都是一主一备的结构。 大多数的互联网应用场景都是读多写少,因此你负责的业务,在发展过程中很可能先会遇到读性能的问题。而在数据库层解决读性能问题&#…

Atlassian Bamboo Enterprise Crack

Atlassian Bamboo Enterprise Crack Bamboo Server是专业组织如持续集成、安装和运输的选择。 从代码到安装的连续运输。 在一个工作流中集中发布自动生成、测试和发布。 构建:专注于编码,依靠Bamboo作为自己的CI,构建一个主机!创建多阶段构建…

3.8.cuda运行时API-使用cuda核函数加速yolov5后处理

目录 前言1. Yolov5后处理2. 后处理案例2.1 cpu_decode2.2 gpu_decode 总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。 本次课程学习精简…

【*2200线段树Pushup】CF1567 E

Problem - E - Codeforces 题意&#xff1a; 思路&#xff1a; 维护这些信息即可 Code&#xff1a; #include <bits/stdc.h>#define int long longusing namespace std;const int mxn2e510; const int mxe2e510; const int mod1e97; const int Inf1e18;struct info{in…

【C语言】gcc编译时报错 fatal error: stdio.h: 没有那个文件或目录

零、问题 在Ubuntu20.04.6中使用GCC编译一个HelloWorld代码时遇到如下问题&#xff1a; 首先确认了&#xff0c;自己单词没有拼写错。 然后再检查GCC的版本&#xff0c;确实没问题&#xff1a; 我用的是Ubuntu20.04.6的版本。 壹、解决 没有标准的头文件需要安装build-es…

和鲸社区数据分析每周挑战【第九十七期:技术博客文本分析】

和鲸社区数据分析每周挑战【第九十七期&#xff1a;技术博客文本分析】 文章目录 和鲸社区数据分析每周挑战【第九十七期&#xff1a;技术博客文本分析】一、背景描述二、数据说明三、问题描述四、数据导入五、数据探索性分析六、对文章标题进行文本分类预测1、数据预处理2、逻…

C++万字自学笔记

[TOC] 一、 C基础 C的IDE有CLion、Visual Studio、DEV C、eclipse等等&#xff0c;这里使用CLion进行学习。 0. C初识 0.1 第一个C程序 编写一个C程序总共分为4个步骤 创建项目创建文件编写代码运行程序 #include <iostream>int main() {using namespace std;cout…

新手如何快速安装电脑监控软件?

越来越多的管理者选择使用电脑监控软件&#xff0c;许多新手不知道具体怎样安装&#xff0c;本期将为大家介绍下具体的安装流程。 电脑监控软件购买之后&#xff0c;会提供网址和账号密码&#xff0c;登录后需要先添加员工信息&#xff0c;有三种方法&#xff1a; &#xff0…

Android性能优化(bin启动优化)

我们平时会在android里面写个bin程序来干点活&#xff0c;但是有时候我们会发现很奇怪的现象&#xff0c;我明明很早就启动这个bin了&#xff0c;但是过了很久bin程序的main函数才被调用~。这个是为啥呢&#xff1f;主要有2个原因&#xff1a; 一.bin程序依赖的so库太多&#…

steam搬砖项目,csgo游戏搬砖熟练操作后,可以月入过万~

科思创业汇 大家好&#xff0c;这里是科思创业汇&#xff0c;一个轻资产创业孵化平台。赚钱的方式有很多种&#xff0c;我希望在科思创业汇能够给你带来最快乐的那一种&#xff01; 网上创业创造了一批赚钱的人&#xff0c;年收入从几十万到几百万不等&#xff0c;营业额从几…

基于springboot房屋租赁管理系统

开发工具&#xff1a;IDEA&#xff0c;jdk1.8 服务器&#xff1a;tomcat9.0 数据库&#xff1a;mysql5.7 前端&#xff1a;jsp、bootstrap 技术&#xff1a; springbootmybatis-plus 系统主要分前台和后台&#xff0c;分租客、房东、管理员三个角色 系统功能介绍说明&…

nodejs 高级编程-通信

一、通信基本原理 通信必要条件 主机之间需要有传输介质主机上必须有网卡设备主机之间需要协商网络速率 二、网络通讯方式 常见的通讯方式 交换机通讯路由器通讯 如何建立多台主机互连&#xff1f; 如何定位局域网中的其他主机&#xff1f; 通过Mac地址来唯一标识一台主机…

hcip笔记---ospf的LSA限制和不规则区域

有关ACL&#xff1a;例如&#xff1a;1.1.1.0 0.0.0.255这个网段以及后面跟随的通配符&#xff0c;通配符和反掩码长得很像&#xff0c;同时都是用0标识不可变&#xff0c;1标识可变&#xff0c;但反掩码里的1和0必须连续出现&#xff0c;而通配符则不需要遵循这个规则&#xf…

深入思考Sui的独特性如何构建出跨时代的产品

近日&#xff0c;我们与Mysten Labs产品总监Janet Wu面对面探讨了Web3的产品开发过程&#xff0c;了解了她对Sui上最激动人心的产品用例的看法&#xff0c;以及她对该行业未来的展望。 您能简单介绍一下在Mysten Labs担任产品总监意味着什么吗&#xff1f; 对我而言&#xff…

0基础学习VR全景平台篇 第59篇:专业版功能-跨账号复制

功能位置示意 一、本功能将用在哪里&#xff1f; 跨账号复制&#xff0c;是指将本账号中已发布的VR漫游作品一键复制给其他账号使用。 复制成功后&#xff0c;其他账号中也会生成同样的作品以及获得相关的全景、音频、图片、视频等素材。 并且原作品和复制品可以独立编辑&am…

K8s为什么需要calico? calico 原理深入理解.

文章目录 为什么需要calico&#xff1f;-网络插件”千千万”&#xff0c;为何k8s要用calicocalico的架构calico Pod 跨node通信tunl0 的作用&#xff1f;为什么所有pod的默认网关都是169.254.1.1 &#xff1f;什么是ARP 代理&#xff1f;jksj BGP模式的calico工作原理calico BG…

vue3 报错解决:找不到模块‘xxx.vue’或其相应的类型声明。(Vue 3 can not find module)

src下面建立一个xx.d.ts的文件 declare module *.vue {import { ComponentOptions } from vueconst componentOptions: ComponentOptionsexport default componentOptions }