辅导男朋友转算法岗的第2天|self Attention与kv cache

news2025/1/11 15:45:30

文章目录

    • 公式
    • KV Cache
    • MHA、MQA、GQA
  • 面试题

公式

$ \text{Output} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \times V$ 复杂度是O( n 2 n^2 n2)

KV Cache

推理阶段最常用的缓存机制,用空间换时间。

原理:

在进行自回归解码的时候,新生成的token会加入序列,一起作为下一次解码的输入。

由于单向注意力的存在,新加入的token并不会影响前面序列的计算,因此可以把已经计算过的每层的kv值保存起来,这样就节省了和本次生成无关的计算量。

通过把kv值存储在速度远快于显存的L2缓存中,可以大大减少kv值的保存和读取,这样就极大加快了模型推理的速度。

分别做一个k cache和一个v cache,把之前计算的k和v存起来

以v cache为例:

在这里插入图片描述

存在的问题:存储碎片化

解决方法:page attention(封装在vllm里了)

MHA、MQA、GQA

Multi-Head Attention、Multi-Query Attention、Group-Query Attention

目的:优化KV Cache所需空间大小

原理是共享k和v,但是使用MQA效果会差一些,于是又出现了GQA这种折中的办法

在这里插入图片描述

面试题

为什么除以 d k \sqrt{d_k} dk

压缩softmax输入值,以免输入值过大,进入了softmax的饱和区,导致梯度值太小而难以训练。

Multihead的好处

1、每个head捕获不同的信息,多个头能够分别关注到不同的特征,增强了表达能力。多个头中,会有部分头能够学习到更高级的特征,并减少注意力权重对角线值过大的情况。

比如部分头关注语法信息,部分头关注知识内容,部分头关注近距离文本,部分头关注远距离文本,这样减少信息缺失,提升模型容量。

2、类似集成学习,多个模型做决策,降低误差

decoder-only模型在训练阶段和推理阶段的input有什么不同?

  • 训练阶段:模型一次性处理整个输入序列,输入是完整的序列,掩码矩阵是固定的上三角矩阵。
  • 推理阶段:模型逐步生成序列,输入是一个初始序列,然后逐步添加生成的 token。掩码矩阵需要动态调整,以适应不断增加的序列长度,并考虑缓存机制。

手撕必背-多头注意力

逐头计算

import torch.nn as nn
class MultiHeadAttentionScores(nn.Module):

    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiHeadAttentionScores, self).__init__()
        self.num_attention_heads = num_attention_heads # 8,16, 32, 64
        
        # Create a query, key, and value projection layer
        # for each attention head.  W^Q, W^K, W^V
        self.query_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.key_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])
        
        self.value_layers = nn.ModuleList([
            nn.Linear(hidden_size, attention_head_size) 
            for _ in range(num_attention_heads)
        ])

    def forward(self, hidden_states):
        # Create a list to store the outputs of each attention head
        all_attention_outputs = []

        for i in range(self.num_attention_heads): # i.e. 8
            query_vectors = self.query_layers[i](hidden_states)
            key_vectors = self.key_layers[i](hidden_states)
            value_vectors = self.value_layers[i](hidden_states)
            
            # softmax(Q&K^T)*V
            attention_scores = torch.matmul(query_vectors, key_vectors.transpose(-1, -2))
            # attention_scores combined with softmax--> normalized_attention_score
            attention_outputs = torch.matmul(attention_scores, value_vectors)
            all_attention_outputs.append(attention_outputs)

        return all_attention_outputs

矩阵运算

import torch
import torch.nn as nn

class MultiHeadAttentionScores(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, attention_head_size):
        super(MultiHeadAttentionScores, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_head_size
        self.hidden_size = hidden_size
        
        self.query = nn.Linear(hidden_size, num_attention_heads * attention_head_size)
        self.key = nn.Linear(hidden_size, num_attention_heads * attention_head_size)
        self.value = nn.Linear(hidden_size, num_attention_heads * attention_head_size)

    def forward(self, hidden_states):
        batch_size = hidden_states.size(0)
        
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        value_layer = self.value(hidden_states)
        
        query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        attention_outputs = torch.matmul(attention_probs, value_layer)
        
        attention_outputs = attention_outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.num_attention_heads * self.attention_head_size)
        
        return attention_outputs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1717864.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SG7050EEN差分晶体振荡器:为5G路由器提供卓越的时钟源

随着5G技术的快速发展,5G路由器作为连接高速网络的重要设备,正迅速普及。为了确保5G路由器在高宽带和低延迟的网络环境中表现出色,选择一款高性能的晶体振荡器至关重要。爱普生推出的SG7050EEN差分晶体振荡器,以其高精度、低相位噪…

K210视觉识别模块学习笔记2:固件的下载升级_官方数字识别例程导入方法

今日开始学习K210视觉识别模块:固件的下载升级_官方数字识别例程导入方法 主要学习如何升级固件库,在哪下载固件库,以及如何在TF卡正确导入官方例程: 亚博智能的K210视觉识别模块...... 本次最终目的是正确导入官方的数字识别例程&#xff0…

Python 之SQLAlchemy使用详细说明

目录 1、SQLAlchemy 1.1、ORM概述 1.2、SQLAlchemy概述 1.3、SQLAlchemy的组成部分 1.4、SQLAlchemy的使用 1.4.1、安装 1.4.2、创建数据库连接 1.4.3、执行原生SQL语句 1.4.4、映射已存在的表 1.4.5、创建表 1.4.5.1、创建表的两种方式 1、使用 Table 类直接创建表…

小程序使用Canvas设置文字竖向排列

在需要使用的js页面引入js文件,传入对应参数即可 /** * 文本竖向排列 */ function drawTextVertical(context, text, x, y) {var arrText text.split();var arrWidth arrText.map(function (letter) {return 26; // 字体间距,需要自定义可以自己加参数,根据传入参数进行…

飞凌嵌入式FET3568/3568J-C核心板现已适配OpenHarmony4.1

近日,飞凌嵌入式为FET3568/3568J-C核心板适配了OpenHarmony4.1系统,新系统的加持使核心板在兼容性、稳定性与安全性等方面都得到进一步提升,不仅为FET3568/3568J-C核心板赋予了更强大的功能,也为开发者们提供了更加广阔的创新空间…

WordPress中借助Table of Contents Plus+Widget Options插件,实现仅在文章侧边栏显示文章目录的功能

本文转自博主的个人博客:https://blog.zhumengmeng.work,欢迎大家前往查看。 原文链接:点我访问 序言:今天心血来潮,写了一篇文章,忽然发现自己的文章极少有目录,这对于长文章的阅读来说是十分不利的&#…

vivado 时序约束

时间限制 以下ISE设计套件时序约束可以表示为XDC时序约束 Vivado设计套件。每个约束描述都包含一个UCF示例和 等效的XDC示例。 在未直接连接到边界的网络上创建时钟时,UCF和XDC不同 的设计(如端口)。在XDC中,当在上定义带有create…

微信小程序发送订阅消息

小程序后台。订阅消息里面,新建一个消息模板 小程序代码,登录后,弹出订阅信息 requestSubscribeMessage: function () {wx.requestSubscribeMessage({tmplIds: [-323232-32323], // 替换为你的模板IDsuccess(res) {// 用户订阅结果console.l…

【康耐视国产案例】Nvidia/算能+智能AI相机:用AI驱动 | 降低电动车成本的未来之路

受环保观念影响、政府激励措施推动与新能源技术的发展,消费者对电动汽车(EV)的需求正在不断增长,电动汽车已经成为了未来出行方式的重要组成部分。然而,电动汽车大规模取代燃油汽车的道路还很漫长。最大的障碍就是电动汽车的售价相对过高。尽…

day2数据结构

双链表的插入 循环链表,判断循环链表是否为空 指向的是自己 仅设表尾指针的循环链表合并 代码举例 删除线性表的最小值,并由函数返回删除的值,空的位置,由最后一个元素填补,若表为空显示出错信息 &L 因为L会发生…

深入理解flask规则构建与动态变量应用

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、Flask规则基础 1. 静态规则与动态规则 2. 规则语法与结构 三、动态变量应用…

AI作画算法原理

1.概述 AI作画算法的原理相当复杂,涉及多个领域的知识,包括计算机视觉、机器学习和神经网络等。我们从以下几个方面来描述AI作画算法的基本原理。 2. 数据准备 在数据准备方面,AI作画算法通常需要大量的图像数据作为训练样本。可以是各种各…

52-QSplitter类QDockWidget类

一 QSplitter类 Qt提供QSplitter(QSplitter)类来进行分裂布局&#xff0c;QSplitter派生于QFrame。 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow>class MainWindow : public QMainWindow {Q_OBJECTpublic:MainWindow(QWidget *parent nullptr);~…

【深度强化学习】如何平衡cpu和gpu来加快训练速度(实录)

文章目录 问题抛出问题展示 问题探索参考&#xff1a;如何平衡cpu和gpu来加快训练速度呢&#xff1f; 解决问题实现逻辑&#xff1a;PPO算法示例&#xff1a;偷懒改法&#xff1a;第三处修改再次修改--24.5.22 不偷懒改法修改总结1 最终成绩&#xff08;不是&#xff09;附加赛…

Python中的 Lambda 函数

大家好&#xff0c;在 Python 编程的世界里&#xff0c;有一种功能强大却不常被提及的工具&#xff0c;它就是 Lambda 函数。这种匿名函数在 Python 中拥有着令人惊叹的灵活性和简洁性&#xff0c;却常常被许多开发者忽视或者只是将其当作一种附加功能。Lambda 函数的引入&…

Windows系统WDS+MDT网络启动自动化安装

Windows系统WDS+MDT网络启动自动化安装 适用于在Windows系统上WDS+MDT网络启动自动化安装 1. 安装准备 1.下载windows server 2019、windows 10 pro的ISO文件,并安装好windows server 2019 2.下载windows 10 2004版ADK及镜像包 1.1 安装平台 Windows 111.2. 软件信息 软件…

【Python】解决Python报错:IndexError: pop from empty list

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

提高倾斜摄影三维模型OSGB格式轻量化

提高倾斜摄影三维模型OSGB格式轻量化 倾斜摄影三维模型以其高精度和真实感受在城市规划、建筑设计和虚拟漫游等领域发挥着重要作用。然而&#xff0c;由于其庞大的数据量和复杂的几何结构&#xff0c;给数据存储、传输和可视化带来了挑战。为了提高倾斜摄影三维模型的性能和运行…

C/C++中互斥量(锁)的实现原理探究

互斥量的实现原理探究 文章目录 互斥量的实现原理探究互斥量的概念何为原子性操作原理探究 互斥量的概念 ​ 互斥量&#xff08;mutex&#xff09;是一种同步原语&#xff0c;用于保护多个线程同时访问共享数据。互斥量提供独占的、非递归的所有权语义&#xff1a;一个线程从成…

Docker管理工具Portainer忘记admin登录密码

停止Portainer容器 docker stop portainer找到portainer容器挂载信息 docker inspect portainer找到目录挂载信息 重置密码 docker run --rm -v /var/lib/docker/volumes/portainer_data/_data:/data portainer/helper-reset-password生成新的admin密码&#xff0c;使用新密…