图解多头注意力机制:维度变化一镜到底

news2025/4/17 4:17:02



一、多头注意力机制概述

多头注意力(Multi-Head Attention)是Transformer模型的核心组件,其核心思想是通过 ‌并行处理多个子空间‌ 来捕捉序列中不同位置间的复杂依赖关系。主要特点:

  • 并行计算:将高维向量拆分为多个低维子空间
  • 多视角学习:每个注意力头关注不同特征模式
  • 高效性:矩阵运算高度可并行化

在这里插入图片描述

二、代码实现

1. pyTorch 实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        """
        Args:
            embed_dim: 词向量维度(如512)
            num_heads: 注意力头数量(如8)
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 每个头的维度(如512//8=64)
        
        assert self.head_dim * num_heads == embed_dim, "维度不可整除"
        
        # 定义线性变换层
        self.query = nn.Linear(embed_dim, embed_dim)  # Q矩阵
        self.key = nn.Linear(embed_dim, embed_dim)    # K矩阵
        self.value = nn.Linear(embed_dim, embed_dim)  # V矩阵
        self.out = nn.Linear(embed_dim, embed_dim)    # 输出层

    def transpose_for_scores(self, x):
        """拆分多头并调整维度顺序
        输入: [batch_size, seq_len, embed_dim]
        输出: [batch_size, num_heads, seq_len, head_dim]
        """
        new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
        x = x.view(*new_shape)  # 新增头维度
        return x.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]

    def forward(self, query, key, value, mask=None):
        """前向传播流程
        输入形状: [batch_size, seq_len, embed_dim]
        输出形状: [batch_size, seq_len, embed_dim]
        """
        batch_size = query.size(0)
        
        # 1. 线性变换
        Q = self.query(query)  # [N, seq, D]
        K = self.key(key)      # [N, seq, D]
        V = self.value(value)  # [N, seq, D]

        # 2. 拆分多头
        Q = self.transpose_for_scores(Q)  # [N, h, seq, d]
        K = self.transpose_for_scores(K)  # [N, h, seq, d] 
        V = self.transpose_for_scores(V)  # [N, h, seq, d]

        # 3. 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [N, h, seq_q, seq_k]
        scores /= math.sqrt(self.head_dim)  # 缩放
        
        # 4. 应用掩码(可选)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        # 5. 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)  # [N, h, seq_q, seq_k]
        
        # 6. 应用权重到Value
        out = torch.matmul(attn_weights, V)  # [N, h, seq_q, d]
        
        # 7. 合并多头
        out = out.permute(0, 2, 1, 3).contiguous()  # [N, seq_q, h, d]
        out = out.view(batch_size, -1, self.embed_dim)  # [N, seq, D]
        
        # 8. 输出层
        return self.out(out), attn_weights
2. tensorFlow实现
# TensorFlow (兼容TF2.x)

import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense

class MultiHeadAttention(Layer):
    def __init__(self, embed_dim, num_heads):
        """
        Args:
            embed_dim: 词向量维度(如512)
            num_heads: 注意力头数量(如8)
        """
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "维度不可整除"
        
        # 定义线性变换层
        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        self.output_dense = Dense(embed_dim)
        
    def split_heads(self, x, batch_size):
        """拆分多头并调整维度顺序
        输入: [batch_size, seq_len, embed_dim]
        输出: [batch_size, num_heads, seq_len, head_dim]
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, query, key, value, mask=None):
        batch_size = tf.shape(query)
        
        # 1. 线性变换
        Q = self.query_dense(query)  # [N, seq, D]
        K = self.key_dense(key)      # [N, seq, D]
        V = self.value_dense(value)  # [N, seq, D]
        
        # 2. 拆分多头
        Q = self.split_heads(Q, batch_size)  # [N, h, seq, d]
        K = self.split_heads(K, batch_size)  # [N, h, seq, d]
        V = self.split_heads(V, batch_size)  # [N, h, seq, d]
        
        # 3. 计算注意力分数
        matmul_qk = tf.matmul(Q, K, transpose_b=True)  # [N, h, seq_q, seq_k]
        scaled_attention_logits = matmul_qk / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
        
        # 4. 应用掩码(可选)
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # 添加极大负值
        
        # 5. 计算注意力权重
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        
        # 6. 应用权重到Value
        output = tf.matmul(attention_weights, V)  # [N, h, seq_q, d]
        
        # 7. 合并多头
        output = tf.transpose(output, perm=[0, 2, 1, 3])  # [N, seq_q, h, d]
        concat_attention = tf.reshape(output, (batch_size, -1, self.embed_dim))
        
        # 8. 输出层
        return self.output_dense(concat_attention), attention_weights

三、维度变化全流程详解

1. 参数设定
  • batch_size = 2
  • seq_len = 5
  • embed_dim = 512
  • num_heads = 8
  • head_dim = 512 // 8 = 64
2. 维度变化流程图
原始输入: [2, 5, 512]
    │
    ├─线性变换───────保持形状→ [2, 5, 512]
    │
    ├─拆分多头──────→ [2, 8, 5, 64]
    │                (拆分512为8个64维头)
    │
    ├─计算注意力分数──→ [2, 8, 5, 5]
    │                (每个头计算5x5的注意力矩阵)
    │
    ├─Softmax───────→ [2, 8, 5, 5]
    │                (最后一维归一化)
    │
    ├─应用权重到Value→ [2, 8, 5, 64]
    │                (每个头输出新的序列表示)
    │
    ├─合并多头───────→ [2, 5, 512]
    │                (拼接8个64维头恢复512维)
    │
    └─输出层────────→ [2, 5, 512]
3. 关键步骤维度变化

在这里插入图片描述

四、关键实现细节解析

1. 多头拆分与合并
# 拆分多头(核心代码)
new_shape = x.size()[:-1] + (num_heads, head_dim)
x = x.view(*new_shape).permute(0, 2, 1, 3)

# 合并多头(逆过程)
x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embed_dim)
  • 为什么要permute:将num_heads维度提前,便于后续矩阵乘法并行处理多个头
2. 注意力分数计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
  • 转置维度‌:将K的seq_len和head_dim维度交换,使矩阵乘法满足[seq_q, d] x [d, seq_k] → [seq_q, seq_k]
  • 缩放因子‌:防止点积结果过大导致softmax梯度消失
3. 掩码处理技巧

python

scores = scores.masked_fill(mask == 0, -1e9)
  • 作用‌:将填充位置(如)的注意力权重趋近于0
  • 为什么用-1e9‌:经过softmax后,exp(-1e9) ≈ 0

五、完整运行示例

# 测试用例
embed_dim = 512
num_heads = 8
model = MultiHeadAttention(embed_dim, num_heads)

# 生成测试数据
batch_size = 2
seq_len = 5
inputs = torch.randn(batch_size, seq_len, embed_dim)

# 前向传播
output, attn = model(inputs, inputs, inputs)

# 验证输出形状
print(output.shape)  # torch.Size([2, 5, 512])
print(attn.shape)    # torch.Size([2, 8, 5, 5])

六、总结与常见问题

1. 核心优势
  • 并行计算效率‌:通过矩阵运算同时处理所有位置和注意力头
  • 多视角学习‌:不同注意力头可关注语法、语义等不同特征
  • 长距离依赖‌:直接计算任意两个位置间的关联
2. FAQ
  • Q1:为什么需要多个注意力头?‌

  • A:类比CNN中多个卷积核,不同头可以捕捉不同类型的特征依赖

  • Q2:head_dim为什么要设置为embed_dim/num_heads?‌

  • A:保持总参数量不变,确保拆分前后的维度乘积相等(num_heads * head_dim = embed_dim)

  • Q3:permute之后为什么要调用contiguous()?‌

  • A:确保张量在内存中连续存储,避免后续view操作报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2316719.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[ISP] 人眼中的颜色

相机是如何记录颜色的,又是如何被显示器还原的? 相机通过记录RGB数值然后显示器显示RGB数值来实现颜色的记录和呈现。道理是这么个道理,但实际上各厂家生产的相机对光的响应各不相同,并且不同厂家显示器对三原色的显示也天差地别&…

解锁MySQL 8.0.41源码调试:Mac 11.6+CLion 2024.3.4实战指南

文章目录 解锁MySQL 8.0.41源码调试:Mac 11.6CLion 2024.3.4实战指南前期准备环境搭建详细步骤安装 CLion安装 CMake 3.30.5准备 MySQL 8.0.41 源码配置 CMake 选项构建 MySQL 项目 调试环境配置与验证配置 LLDB 调试器启动调试验证调试环境 总结与拓展 解锁MySQL 8…

关于xcode Project navigator/项目导航栏的一些说明

本文基于 xcode12.4 版本做说明 首先要明确一点,导航栏这里展示的并不是当前工程在电脑硬盘中的文件结构,它展示的是xxxxxx.xcodeproj/project.pbxproj文件(后文简.pbxproj文件)中的内容。我们在导航栏中的操作就是修改该文件,有些操作会修…

深度解析扣减系统设计:从架构到实践

背景 在当今数字化业务蓬勃发展的时代,扣减系统在众多业务场景中扮演着关键角色。无论是电商平台的库存扣减,还是金融领域的资金扣减、积分系统的积分扣减,一个高效、可靠且数据一致的扣减系统都是业务稳健运行的基石。本文将深入探讨扣减系…

视觉定位项目中可以任意修改拍照点位吗?

修改拍照点位不是那么简单 1. 背景2. 修改拍照点位意味着什么?3. 如何解决这个问题? 1. 背景 在视觉定位的项目中,会遇到这么一种情况:完成三步(9点标定,旋转中心标定,示教基准)之…

深度学习常用操作笔记

深度学习常用操作笔记 指令报错cannot import name Config from mmcvImportError: cannot import name print_log from mmcvImportError: cannot import name init_dist from mmengine.runnerWARNING: Retrying (Retry(total4, connectNone, readNone, redirectNone, statusNon…

C++学习内存管理

1.概念的介绍 总括: 1. 栈(Stack) 存储内容: 局部变量(包括函数参数、非静态局部变量)。 函数调用的上下文信息(如返回地址、寄存器状态等)。 特点: 内存由编译器自动…

git使用。创建仓库,拉取分支,新建分支开发

文章目录 安装 git自己新建仓库,进行代码管理合作开发的流程拉去主分支代码查看本地分支的状态查看远程分支查看远程的仓库信息本地分支切换切换并创建分支提交代码 made by NJITZX git 是一个版本控制工具,真正开发项目中是多个人开发一个项目的&#…

itsdangerous加解密源码分析|BUG汇总

这是我这两天的思考 早知道密码学的课就不旷那么多了 纯个人见解 如需转载,标记出处 目录 一、官网介绍 二、事例代码 源码分析: 加密函数dump源码使用的函数如下: 解密 ​编辑 ​编辑 关于签名: 为什么这个数字签名没有…

不像人做的题————十四届蓝桥杯省赛真题解析(上)A,B,C,D题解析

题目A:日期统计 思路分析: 本题的题目比较繁琐,我们采用暴力加DFS剪枝的方式去做,我们在DFS中按照8位日期的每一个位的要求进行初步剪枝找出所有的八位子串,但是还是会存在19月的情况,为此还需要在CHECK函数…

JavaScript 中 call 和 apply 的用法与区别

文章目录 前言一、 call 方法1.1 基本用法1.2 传递多个参数 二、apply 方法2.1 基本用法2.2 传递数组参数 三、call 和 apply 的区别四、实际应用场景4.1 借用方法4.2 继承与构造函数 五、总结 前言 在 JavaScript 中,call 和 apply 是两个非常重要的函数方法&…

面试系列|蚂蚁金服技术面【1】

哈喽,大家好!今天分享一下蚂蚁金服的 Java 后端开发岗位真实社招面经,复盘面试过程中踩过的坑,整理面试过程中提到的知识点,希望能给正在准备面试的你一些参考和启发,希望对你有帮助,愿你能够获…

使用傅里叶变换测量声卡的频率失真

文章目录 一、说明二、关于声卡的技术详述三、实验代码获取四、结论 一、说明 假如我希望使用我的声卡来模拟软件无线电,利用声音而不是射频信号。我的声卡能胜任这项任务吗?本文将研究一种技术来找出答案。另外,需要了解音频技术的读者也可…

【HTML5】01-HTML摆放内容

本文介绍HTML5摆放标签的知识点。 目录 1. HTML概念 2. HTML骨架 3. 标签的关系 4. 标题标签 5. 段落标签 6. 换行和水平线 7. 文本格式化标签 8. 图像标签 图像 - 属性 9. 路径 相对路径 绝对路径 10. 超链接标签 11. 音频标签 12. 视频标签 1. HTML概念 HTM…

内存管理:

我们今天来学习一下内存管理: 1. 内存分布: 我们先来看一下我们下面的图片: 这个就是我们的内存,我们的内存分为栈区,堆区,静态区,常量区; 我们的函数栈帧开辟消耗的内存就是我们…

设计模式使用Java案例

代码设计要有可维护性,可复用性,可扩展性,灵活性,所有要使用设计模式进行灵活设计代码 创建型 简单工厂模式(Simple Factory) 简单工厂模式(Simple Factory Pattern)是一种创建型…

模运算的艺术:从基础到高阶的算法竞赛应用

在算法竞赛中,模运算(取模运算)是一个非常重要的概念,尤其在处理大数、防止溢出、以及解决与周期性相关的问题时。C 中的模运算使用 % 运算符,但它的行为和使用场景需要特别注意。 1. 模运算的基本概念 模运算是指求一…

ST电机库电流采样 三电阻单ADC

一、概述 下图是三电阻采样的电路结构 其中流过三相系统的电流I1、I2、I3遵循以下关系: 因此,为了重建流过普通三相负载的电流,在我们可以用以上公式计算的情况下,只需要对三相中的两相进行采样即可。 STM32的ADC可以很灵活的配置成同步采集两路ADC数据,…

现代密码学 | 具有保密和认证功能的安全方案

1.案例背景 1.1 2023年6月,微软云电子邮件泄露 事件描述: 2023年6月,属于多家美国政府机构的微软云电子邮件账户遭到非法入侵,其中包括了多位高级政府官员的电子邮件。据报道,美国国务院的10个邮件账户中共有6万封电…

一款基于Python的从常规文档里提取图片的简单工具开发方案

一款基于Python的从常规文档里提取图片的简单工具开发方案 1. 环境准备 安装必需库 pip install python-docx PyMuPDF openpyxl beautifulsoup4 pillow pip install pdfplumber # PDF解析备用方案 pip install tk # Python自带,无需安装工具选择 开发环…