探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

news2024/12/24 11:16:10

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): 
    def  __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # Indicates the number of heads for the queries
        self.n_heads_q = args.n_heads
        # Indiates how many times the heads of keys and value should be repeated to match the head of the Query
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # Indicates the dimentiona of each head
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

    def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape #(B, 1, dim)
        # Apply the wq, wk, wv matrices to query, key and value
        # (B, 1, dim) -> (B, 1, H_q * head_dim)
        xq = self.wq(x)
        # (B, 1, dim) -> (B, 1, H_kv * head_dim)
        xk = self.wk(x)
        xv = self.wv(x)

        # (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        # (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # Apply the rotary embeddings to the keys and values
        # Does not chnage the shape of the tensor
        # (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)
        xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)
        xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)

        # Replace the enty in the cache for this token
        self.cache_k[:batch_size, start_pos:start_pos + seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv

        # Retrive all the cached keys and values so far
        # (B, seq_len_kv, H_kv, head_dim)
        keys = self.cache_k[:batch_size, 0:start_pos + seq_len]
        values = self.cache_v[:batch_size, 0:start_pos+seq_len] 

        # Repeat the heads of the K and V to reach the number of heads of the queries
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        # (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)
        scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)
        output = torch.matmul(scores, values)

        # (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()
        output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))
        return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1627704.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

栈和队列OJ——括号匹配问题,用队列实现栈,用栈实现队列,设计循环队列

题目1——括号匹配问题 题目来源. - 力扣(LeetCode) 思路——辅助栈法 括号匹配问题是一个经典的计算机科学问题,常用于检查一个字符串中的括号是否正确匹配。这包括各种括号,如小括号“()”,大括号“{}”&#xff0…

Eagle for Mac:强大的图片管理工具

Eagle for Mac是一款专为Mac用户设计的图片管理工具,旨在帮助用户更高效、有序地管理和查找图片资源。 Eagle for Mac v1.9.2中文版下载 Eagle支持多种图片格式,包括JPG、PNG、GIF、SVG、PSD、AI等,无论是矢量图还是位图,都能以清…

你的网站还在使用HTTP? 免费升级至HTTPS吧

如果您的网站还在使用老的http协议,可以申请一个免费的SSL证书升级至https! 具体步骤如下: 1 申请免费SSL证书 根据你的需求选择合适的SSL证书类型,如单域名证书,多域名证书、通配符证书 登录免费供应商JoySSL官网&…

18 JavaScript学习:错误

JavaScript错误 JavaScript错误通常指的是在编写JavaScript代码时发生的错误。这些错误可能是语法错误、运行时错误或逻辑错误。以下是对这些错误的一些常见分类和解释: 语法错误: 这类错误发生在代码编写阶段,通常是由于代码不符合JavaScrip…

排队叫号取号投屏语音播报小程序开源版开发

排队叫号取号投屏语音播报小程序开源版开发 多场景排队叫号系统,支持大屏幕投屏,语音播报叫号,可用于餐厅排队取餐、美甲店排队取号、排队领取、排队就诊、排队办理业务等诸多场景,助你轻松应对各种排队取号叫号场景。 功能特性…

IBM SPSS Statistics for Mac v27.0.1中文激活版:强大的数据分析工具

IBM SPSS Statistics for Mac是一款功能强大的数据分析工具,为Mac用户提供了高效、精准的数据分析体验。 IBM SPSS Statistics for Mac v27.0.1中文激活版下载 该软件拥有丰富的统计分析功能,无论是描述性统计、推论性统计,还是高级的多元统计…

C++进阶--智能指针

智能指针的概念 智能指针是C中的一个重要概念,用于管理动态分配的对象内存。它是一个类模板,通过封装原始指针,并在对象生命周期结束时自动释放内存,从而避免了内存泄漏和资源管理的繁琐工作。 C标准库提供了多种常见的智能指针…

MySQL常见问题与解决方案详述

MySQL:常见问题与解决方案详述 作为一款广泛使用的开源关系型数据库管理系统,MySQL对于初学者来说既充满吸引力又充满挑战。本文将列举初学者在使用MySQL过程中可能遇到的一些典型问题,并提供详细的解决方案,配以图片辅助说明&am…

Visual Studio 对 C++ 头文件和模块的支持

在 C 编程领域,头文件和模块的管理有时候确实比较令人头疼。但是,有许多工具和功能可以简化此过程,提高效率并减少出错的可能性。下面是我们为 C 头文件和模块提供的几种工具的介绍。 构建明细 通过菜单栏 Build > Run Build Insights&a…

Eudic欧路词典for Mac:专业英语学习工具

Eudic欧路词典for Mac,作为专为Mac用户设计的英语学习工具,凭借其简捷高效的特点,成为众多英语学习者不可或缺的助手。 Eudic欧路词典for Mac v4.6.4激活版下载 这款词典整合了多个权威词典资源,如牛津、柯林斯、朗文等&#xff0…

低代码技术的全面应用:加速创新、降低成本

引言 在当今数字化转型的时代,企业和组织面临着不断增长的应用程序需求,以支持其业务运营和创新。然而,传统的软件开发方法通常需要大量的时间、资源和专业技能,限制了企业快速响应市场变化和业务需求的能力。在这样的背景下&…

杰发科技AC7840——CAN通信简介(7)_波形分析

参考: CAN总线协议_stm32_mustfeng-GitCode 开源社区 0. 简介 隐形和显性波形 整帧数据表示 1. 字节描述 CAN数据帧标准格式域段域段名位宽:bit描述帧起始SOF(Start Of Frame)1数据帧起始标志,固定为1bit显性(b0)仲裁段dentify(ID)11本数…

HarmonyOS开发案例:【 自定义弹窗】

介绍 基于ArkTS的声明式开发范式实现了三种不同的弹窗,第一种直接使用公共组件,后两种使用CustomDialogController实现自定义弹窗,效果如图所示: 相关概念 [AlertDialog]:警告弹窗,可设置文本内容和响应回…

视频输入c++ 调用 libtorch推理

1、支持GPU情况 libtorch 支持GPU情况比较奇怪,目前2.3 版本需要在链接器里面加上以下命令,否则不会支持gpu -INCLUDE:?ignore_this_library_placeholderYAHXZ 2 探测是否支持 加一个函数看你是否支持torch,不然不清楚,看到…

数据库和表创建练习

一丶要求 1.创建一个数据库db_classes 2 创建一行表db_hero 3. 将四大名著中的常见人物插入这个英雄表 二丶创建db_classes一个数据库, 使用数据库默认的字符集 create database db_classes; 三丶创建一行表db_hero 1.先切换到我们创建的db_classes;数据库中 use db_class…

HTTP的MIME 类型(2024-04-27)

1、简介 MIME (Multipurpose Internet Mail Extensions) 是描述消息内容类型的标准,用来表示文档、文件或字节流的性质和格式。 MIME 消息能包含文本、图像、音频、视频以及其他应用程序专用的数据。 浏览器通常使用 MIME 类型(而不是文件扩展名&…

打包的意义 作用等前端概念集合 webpack基础配置等

基础网页是什么? 在学校最基础的三剑客 原生JS CSS H5就可以开发静态网页了 对于浏览器而言也能识别这些基础的文件和语法,真正的所见即所得,非常直接。 为什么要使用框架库? 对于常用的前端框架而言,无论是Vue Rea…

【面试经典 150 | 回溯】组合

文章目录 写在前面Tag题目来源解题思路方法一:回溯 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对于本题涉及到的数据结构等内容进行回顾…

企业微信开发

侧边栏开发 企业内应用 创建应用 录入必要信息 配置 网页授权及JS-SDK 需要按照提示,把认证的txt暴露出来,能够访问即可。 下图为认证成功的截图 配置侧边栏工具栏 录入页面名称(tab页展示名)、页面URL 配置授权可信ip 用于…

boa交叉编译(移植到arm)

参考:CentOS7 boa服务器的搭建和配置-CSDN博客 以下操作在宿主机/编译平台操作: 1. 先执行[参考]1到3、 4.2、4.3、4.4、4.5 2. 修改MakeFile # 由以下: CC gcc CPP gcc -E # 改为: CC arm-linux-gnueabihf-gcc CPP arm-l…