LLM - Transformer 的 Q/K/V 详解

news2024/11/16 18:00:11

目录

一.引言

二.传统 Q/K/V

三.Transformer Q/K/V

- Input Query

- Q/K/V 获取

- Q/K 相似度计算

- 注意力向量

- Multi Head

四.代码测试

- 初始化

- Attention

- Main

五.总结


一.引言

Transformer 的输入是我们的一个 query 句子,例如 "我爱中国",但是 Transformer 处理时却 1 生 3 得到了 Q/K/V,下面我们从传统机器学习和 Transformer 两个角度看一下 Q/K/V 从哪里来,去哪里去。

二.传统 Q/K/V

对于传统机器学习的 Attention 而言,Q/K 分别代表 Query 和 Keys 即我们候选物品与另一个物品的向量,通过 matmul(Q,K) 即可计算二者的相似度,最后 softmax 归一化乘上 valuse 即可得到。MatMul 后的 Scala 是除以当前 sqrt(dim) 的缩放操作,防止乘积结果太大。

最常见的就是电商场景,例如我们浏览了一件羽绒服,这个就是我们的 Q,而我们历史浏览过多个商品例如鞋子、包包、大衣,这些都是 K,通过 matmul(Q,K) 可以计算出羽绒服与我们浏览的 N 个商品每个商品的相似度,归一化后的值有大有小,从而代表当前商品 K 对 Q 向量的贡献度衡量,进行加权求和即可得到候选集 Q 在浏览记录 K 下的注意力向量表征。

三.Transformer Q/K/V

Transformer 场景下,我们的输入只有 text = "我爱中国",此时不存在上面类似候选商品 Q 和历史浏览商品 K 的关系,怎门办,也好办,Q/K 都是自己就可以,也就是 Self-Attention,下面我们看下如何操作。

- Input Query

对于给定的句子 text,其长度为 L,以 "我想吃酸菜鱼" 为例,其 seq_len = 6, dim = [L x 768]

这里我们可以构建 Embedding 层,将 text 使用 Tokenizer 分词后,得到其对应的 Embedding 表征。除此之外,我们定义 Attention Module,同时根据传入的 Head 数计算拆分后的 head_dim。

class BitDDDAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(BitDDDAttention, self).__init__()

        self.embed_dim = embed_dim  # embedding 维度
        self.num_heads = num_heads  # head 数量
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads

        # 构建 Q/K/V 向量以及最后的全连接 MLP
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

- Q/K/V 获取

这里 Q/K/V 都是通过一层 Linear 映射得到的,torch 的话就是 nn.Linear,映射后的 Q/K/V 维度不变,依然是 L x 768:

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
        query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

初始化 Q/K/V 的转换 Linear 矩阵,并根据 head_num 与 head_dim 转换,转换后的维度为:

(bsz, num_heads, seq_len, head_dim)

- Q/K 相似度计算

Attention(Q,K_i,V_i) = softmax(\frac{Q^TK_i}{\sqrt{d_k}})V_i

相似度计算考虑如上 Scale-Dot-Attention 公式,Q 的维度为 L x 768,K 转置后为 768 x L,二者 MatMul 得到 L x L 的方阵,其中 (0,0) 位置代表 "我想吃酸菜鱼" 中 "我" 对 V 中 "我" 字向量的注意力权重,以此类推,整个结果表示 "我想吃酸菜鱼" 中每个字对其他字的注意力,当然也包括自己,所以这种操作也被称为自注意力。

        # Compute the attention scores
        attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
        attention_probs = torch.softmax(attention_scores, dim=-1)

Q Dim: (bsz, num_heads, seq_len, head_dim)

K.T Dim: (bsz, num_heads, head_dim, seq_len)

相乘后得到 (bsz, num_heads, seq_len, seq_len) 的注意力权重。

Tips:

这里做了除以 sqrt(dim) 的操作,主要是为了缩小点积范围,确保 softmax 梯度稳定性,而为什么需要做 softmax 一是为了保证权重的非负性,同时增加非线性操作。第 row 行归一化的权重,就是第 row 个元素对其他元素的注意力权重。

- 注意力向量

经过 softmax 后的 LxL 的注意力权重矩阵分别与 LxDim 的 Value 相乘,最终得到自注意力后的矩阵。第一行代表 "我" 对 "我想吃酸菜鱼" 中每个注意力权重在 "我" 向量上的加权平均结果,以此类推,最后得到 seq_len 即 L 个加权平均的表征。

        attention_output = torch.matmul(attention_probs, value)
                                .permute(0, 2, 1, 3)
                                .contiguous()
                                .view(batch_size, seq_len, self.embed_dim)                                                                                                                          

计算后通过 view 将多个 head 的向量合并,恢复为原始的 embed_dim,此时维度依然为 

(bsz, seq_len, head_num x head_dim)

- Multi Head

实践场景中很多都使用 multi-head Attention 即多头注意力,我们可以理解为把原始向量进行拆分,将其多个部分分别进行上面的 Attention 操作。比如还是 "我想吃酸菜鱼",我们得到的 Q/K/V 都是 L x 768 维,此时令 head = 4,则 L x 768 会切分为 4 份 L x 192,分别将 L x 192 做 Self-Attention 后再 Concat,即可再次得到 L x 768。Multi-Head 的思想意在学习向量不同位置对语义的不同表征。

四.代码测试

- 初始化

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer


class BitDDDAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(BitDDDAttention, self).__init__()

        self.embed_dim = embed_dim  # embedding 维度
        self.num_heads = num_heads  # head 数量
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads

        # 构建 Q/K/V 向量以及最后的全连接 MLP
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

定义 head_num、head_dim 以及 Q/K/V 的 Linear 线性层。

- Attention

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Split the embedding into num_heads and reshape to (batch_size, num_heads, seq_len, head_dim)
        query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Compute the attention scores
        attention_scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / self.head_dim ** 0.5
        attention_probs = torch.softmax(attention_scores, dim=-1)

        # Apply the attention weights to the value
        attention_output = torch.matmul(attention_probs, value).permute(0, 2, 1, 3).contiguous().view(batch_size,
                                                                                                      seq_len,
                                                                                                      self.embed_dim)

        # Apply a linear layer to the output
        x = self.fc_out(attention_output)
        return x

执行 Scale-Dot-Attention,Multi-Head 可以看作是分开做了多个 Self-Attention。

- Main

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    pretrained_bert = BertModel.from_pretrained('bert-base-uncased')

    input_texts = ["This is a test sentence.", "Here is another test sentence."]
    input_ids = [tokenizer.encode(text, add_special_tokens=True, max_length=10, padding='max_length', truncation=True,
                                  return_tensors='pt') for text in input_texts]
    input_ids = torch.cat(input_ids, dim=0)  # Concatenate and add batch dimension

    with torch.no_grad():
        embedded_output = pretrained_bert(input_ids)[0]  # Get the output of the BERT model
        print(embedded_output.size())  # Output shape should be (2, 10, embedding_dim)

        embed_dim = embedded_output.size(-1)
        num_heads = 4
        model = BitDDDAttention(embed_dim, num_heads)

        output = model(embedded_output)
        print(output.size())  # Output shape should be (2, 10, embed_dim)

输入两条样本,batch 为 2,seq_len 为 10,最终调用 Attention 进行 forward 即可得到一次 Attention 的结果,维度为: torch.Size([2, 10, 768])。

五.总结

本节内容基于 Torch 实现了 MultiHeadAttention,由于 Q/K/V 都是来自于同一句话,故其名为自注意力即 Self-Attention,如果想要类似推荐算法中的不同商品比较,则 Q 是一句话,KV 换成另一句话即可。这里有一点需要注意,我们在 K/V 中调整任意两个字的位置,其对最终的计算没有影响,即在注意力机制中是没有位置信息的,这也是为什么 LLM 要引入 Position Embedding 的原因。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1415165.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二叉搜索树,穷举(全排列)

力扣230.二叉搜索树中第K小的元素 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeN…

Doris安装部署文档

简介 网址:Apache Doris: Open-Source Real-Time Data Warehouse - Apache Doris Apache Doris 是一个基于 MPP 架构的高性能、实时的分析型数据库,以极速易用的特点被人们所熟知,仅需亚秒级响应时间即可返回海量数据下的查询结果&#xff…

mac配置L2TP连接公司内网

1. 打开系统设置 2. 打开网络 3. 点击网络页面其他服务右下角三个点,添加VPN配置中的L2TP 4. 配置VPN,服务器填写公司的服务器ip,共享密钥没有可以随便填写 5. 打开终端编辑文件 sudo vim /etc/ppp/opt…

Java笔记 --- 三、方法引用

三、方法引用 概述 分类 引用静态方法 引用成员方法 本类中注意,静态方法中没有this,需要创建本类的对象 引用构造方法 其他的调用方式 使用类名引用成员方法 引用数组的构造方法

spring整合mybatis的底层原理

spring整合mybatis的底层原理 原理: FactoryBean的自定义对象jdk动态代理Mapper接口对象 一、手写一个spring集成mybatis 目录结构: 1.1 入口类 public class Test {public static void main(String[] args) {AnnotationConfigApplicationContext co…

Java基于 SpringBoot+Vue 的高校心理教育辅导系统的研究与实现

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

智能工厂4G无线设备预测维护云端联动的DI、AI、DO混合信号处理单元

在现代工业智能化进程中,一款集成了丰富I/O接口并能与各大云平台无缝对接的智能设备显得尤为重要。比如最近推出的这款创新产品,它集合了8路数字输入通道,涵盖了干湿节点的识别功能,适用于多种开关量信号的读取;同时&a…

B端系统优化:用好卡片式设计,效果立竿见影翻倍。

B端系统中,卡片式设计是一种常见的界面设计风格,它将信息和功能组织成一系列独立的卡片,每个卡片通常包含一个特定的信息块或功能模块。 一、卡片式设计的特点包括: 模块化和可重用性:卡片可以被独立设计和开发&#…

svg 属性详解:填充与边框

svg 属性详解:填充与边框 1 颜色和透明度2 填充规则 fill-rule3 边框样式3.1 stroke-width3.2 stroke-linecap3.3 stroke-linejoin3.4 stroke-dasharray 1 颜色和透明度 图像都有颜色,svg 中可以使用属性 fill 和 stroke 来修改图形的颜色。fill 属性设置…

第十六讲_HarmonyOS应用程序包介绍

HarmonyOS应用程序包介绍 1. 应用程序包概述1.1 多 Module 设计的好处1.2 Module 的类型 2. 应用程序包结构2.1 应用的配置文件2.2 资源目录 3. 应用程序编译后包结构 1. 应用程序包概述 官方推荐基于Stage模型开发HarmonyOS应用程序,一个应用可以包含一个或多个Mo…

计算机基础之微处理器简介

微处理器 微处理器定义 微型计算机的CPU也被称为微处理器,是将运算器、控制器和高速缓存集成在一起的超大规模集成电路芯片,是计算机的核心部件。能完成取指令、执行指令,以及与外界存储器和逻辑部件交换信息等操作。 微处理器发展 CPU从…

研发日记,Matlab/Simulink避坑指南(七)——数据溢出钳位Bug

文章目录 前言 背景介绍 问题描述 分析排查 解决方案 总结归纳 前言 见《研发日记,Matlab/Simulink避坑指南(二)——非对称数据溢出Bug》 见《研发日记,Matlab/Simulink避坑指南(三)——向上取整Bug》 见《研发日记,Matlab/Simulink避坑…

【CSS】实现鼠标悬停图片放大的几种方法

1.背景图片放大 使用css设置背景图片大小100%&#xff0c;同时设置位置和过渡效果&#xff0c;然后使用&#xff1a;hover设置当鼠标悬停时修改图片大小&#xff0c;实现悬停放大效果。 <!DOCTYPE html> <html lang"en"> <head><meta charset…

C++大学教程(第九版)7.19 将7.10节vector对象的例子转换成array对象

文章目录 题目代码运行截图 题目 (将7.10节vector 对象的例子转换成array 对象)将图7.26中 vector 对象的例子转换成使用array 对象。请消除任何 vector 对象仅有的特性。 分析&#xff1a; vector对象独有的特性&#xff1a; 1.vector对象长度可变 2.长度不同的vector对象可…

基于springboot校友社交系统源码和论文

校友社交系统提供给用户一个校友社交信息管理的网站&#xff0c;最新的校友社交信息让用户及时了解校友社交动向,完成校友社交的同时,还能通过论坛中心进行互动更方便。本系统采用了B/S体系的结构&#xff0c;使用了java技术以及MYSQL作为后台数据库进行开发。系统主要分为系统…

为什么选择快速应用开发:提高业务响应速度与竞争力的关键

如今&#xff0c;企业想要持续蓬勃发展&#xff0c;就需要具备快速满足客户期望的能力。无论是十几年历史的重要市场占有者推出新的APP&#xff0c;还是在疫情期间从线下转向线上电商营销&#xff0c;企业都需要主动适应市场。随着为客户提供新的服务方式&#xff0c;员工也需要…

以前年度资产价值导入以及汇率自动计算解决方案

文章目录 1 Introduction2 Code3 Summary 1 Introduction In the example we will finish ABLDT function and modify asset value . 2 Code DATA: key TYPE bapi1022_key,generaldata TYPE bapi1022_feglg001,generaldatax TYPE bapi1…

品牌突围|内容营销「共创公式」全面讲解

为什么品牌要扎根小红书&#xff1f;除了种草投放&#xff0c;品牌还能做些什么&#xff1f; 在小红书&#xff0c;迎接消费者共创的时代&#xff0c;激活品牌营销的无限潜能。 拥抱多元 在新机遇中预见未来 2023年&#xff0c;各大社交媒体平台涌现出了许多热点&#xff0c…

keil使用教程

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据 总结 前言 例如&#xff1a;随着人工智能的不断发展&#xff0c;机器学习这门技术也越来越重…

jupyter python笔记杂乱

问题产生的原因: 无法执行sess.run()的原因是tensorflow版本不同导致的&#xff0c;tensorflow版本2.0无法兼容版本1.0 解决办法: tf.compat.v1.disable_eager_execution() 确保tf’2能运行tf1的代码 notebok打开指定文件夹 直接解决