深度学习中的多头注意力机制:原理与实现解析

news2024/11/14 12:04:19

4. Multi-Head Attention

深度学习中的多头注意力机制:原理与实现解析

在自然语言处理和计算机视觉的任务中,多头注意力(Multi-Head Attention)已经成为Transformer模型中必不可少的组成部分。多头注意力机制不仅能够让模型关注到输入的不同方面,还能更好地捕获词语间复杂的上下文关系。今天,我们将深入解析多头注意力的原理与实现!


为什么需要多头注意力?

单一的注意力头只能捕获句子中的一种关系或模式,而在实际应用中,句子中的不同词语往往有复杂的关系。多头注意力通过并行多个注意力头,让模型能够关注到输入的多个不同层面,从而更全面地理解输入内容。每个头会从不同的角度捕捉句子中的依赖关系,有助于提升模型的表达能力和对上下文的理解。


多头注意力的工作原理

1. 生成 Q、K、V 矩阵

多头注意力机制的输入是三个矩阵:Query(查询)矩阵 QKey(键)矩阵 KValue(值)矩阵 V,每个矩阵都包含输入序列的信息:

  • Query(Q):代表要关注的内容
  • Key(K):输入特征标签,用于表示每个词的特征
  • Value(V):实际包含的内容信息

2. 多头注意力的计算步骤

假设我们有一个输入向量 x x x h h h 个注意力头,每个头的步骤如下:

  1. 线性变换:对输入向量 x x x 进行线性变换,生成 Q , K , V Q, K, V Q,K,V 三个矩阵。每个注意力头有自己的权重矩阵,这使得每个头都可以从不同的视角理解输入。

  2. 计算注意力权重:通过点积注意力计算每个 Query 和 Key 之间的相似度,用 softmax 得到注意力权重,公式如下:

    $$

    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V

    $$

    其中 d k d_k dk 是 Key 的维度,用于缩放,防止数值过大。

  3. 并行计算多个头:对每个头进行相同的计算。每个头的注意力权重不同,这使得每个头可以关注不同的上下文关系。

  4. 合并输出:将多个头的输出拼接,生成最终的多头注意力结果。通常通过线性变换将结果映射回原来的维度。


多头注意力公式

假设我们有 h h h 个注意力头,每个头的输出为 Attention i ( Q i , K i , V i ) \text{Attention}_i(Q_i, K_i, V_i) Attentioni(Qi,Ki,Vi) ,最终的多头注意力输出为:

$$

\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \cdot W^O

$$

其中:

  • head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
  • W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV 是每个头的线性变换矩阵。
  • W O W^O WO 是最终输出的线性映射矩阵,用于将拼接结果映射回原始维度。

自己实现多头注意力类

接下来我们通过代码实现一个简单的 MultiHeadAttention 类,以更好地理解多头注意力机制的实现细节。

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        # 确保嵌入维度能整除头数
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        # 定义 Q、K、V 的线性层
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N = query.shape[0]  # batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # 将 Q、K、V 分成多个头
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # 计算注意力得分
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** (1/2))
        attention = torch.softmax(energy, dim=3)

        # 计算注意力输出
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # 拼接头的输出,并通过最后的线性层
        out = self.fc_out(out)
        return out


代码解析

  • 初始化:定义了输入的维度、头数、每个头的维度,并创建了用于生成 Q、K、V 的线性层。
  • 分割多头:将输入 Q、K、V 按头数分割,使得每个头能独立计算注意力。
  • 计算注意力得分:通过点积计算 Q 和 K 之间的相似度,并使用 softmax 获得注意力权重。
  • 输出计算:将每个头的权重与 V 相乘,拼接各个头的输出,最后通过线性层映射到原始维度。

测试代码

我们可以通过以下测试代码验证 MultiHeadAttention 的输出是否正常。

embed_size = 256
heads = 8
seq_len = 10
x = torch.rand((3, seq_len, embed_size))  # 假设 batch size 为 3,序列长度为 10

multihead_attention = MultiHeadAttention(embed_size, heads)
output = multihead_attention(x, x, x)
print("多头注意力输出形状:", output.shape)

你会看到输出的形状为 (3, seq_len, embed_size),这与输入形状一致,验证了多头注意力的效果。


总结

  • 多头注意力是对单头注意力的扩展,可以让模型从多个角度捕获输入序列中的复杂关系。
  • 每个头独立生成 Q、K、V,并通过点积计算相似度,从而获得多样化的上下文信息。
  • 多头注意力在自然语言处理和计算机视觉任务中广泛应用,有助于模型更全面地理解输入数据。

希望通过这篇文章的讲解与代码示例,能帮助你理解多头注意力的原理与实现。如果有任何疑问,欢迎留言讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2238933.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【测试框架篇】单元测试框架pytest(3):用例执行参数详解

一、前言 上一篇内容介绍了用例编写的规则以及执行用例,执行用例时我们发现有些print输出内容,结果没有给我们展示,这是因为什么原因呢?接下来我们会针对这些问题进行阐述。 二、参数大全 我们可以在cmd中通过输入 pytest -h 或…

设计模式-七个基本原则之一-开闭原则 + SpringBoot案例

开闭原则:(SRP) 面向对象七个基本原则之一 对扩展开放:软件实体(类、模块、函数等)应该能够通过增加新功能来进行扩展。对修改关闭:一旦软件实体被开发完成,就不应该修改它的源代码。 要看实际场景,比如组内…

Android Room框架使用指南

Room框架使用指南 项目效果创建应用,配置Gradle1、在app Module的build.gradle配置kapt插件2、配置依赖:3、配置依赖包版本号创建实体类创建DAO1、DAO简介2、WordDao设计以及相关注解说明3、监听数据变化添加Room数据库1、Room数据库简介2、实现Room数据库实现存储库实现View…

前端开发中常用的包管理器(npm、yarn、pnpm、bower、parcel)

文章目录 1. npm (Node Package Manager)2. Yarn (Yarn Package Manager)3. pnpm4. Bower5. Parcel总结 前端开发中常用的包管理器主要有以下几个: 1. npm (Node Package Manager) 简介: npm 是 Node.js 的默认包管理器,也是最广泛使用的包…

C++builder中的人工智能(23):在现代C++ Windows上轻松录制声音

在这篇文章中,我们将探讨如何在现代C Windows上轻松录制声音。声音以波形和数字形式存在,其音量随时间变化。在C Builder中,使用Windows设备进行录音非常简单。要录制声音,在多设备应用程序中,必须使用FMX.Media.hpp头…

科目一汇总笔记2024

知识点,一天看一遍;提前一周即可;真实考试比“驾校宝典”模拟题简单。 1 知识点汇总 2 错题总结 增驾1轻 2中 3重 能见度 200 100 50 速度60 40 20 两条车道是:100 60 三条车道是:110 90 60 四条车道是:110 90 90 60 高速小车最高120其…

【详细】如何优雅地删除 Docker 容器与镜像

内容预览 ≧∀≦ゞ 镜像与容器的区别删除容器和镜像的具体步骤1. 删除容器步骤 1:查看当前运行的容器步骤 2:停止容器步骤 3:删除容器 2. 删除镜像步骤 1:查看镜像列表步骤 2:删除镜像 3. 删除所有容器和镜像 使用 1Pa…

华为eNSP:AAA认证(pap和chap)telnet/ssh

pap模式 一、拓扑图 二、配置过程 1、这个型号路由器是不带串口的,所以需要添加串口板卡 2、加入串行接口卡槽 右击路由,选择设置,将串口板卡拖动到路由器扩展槽,并开机即可 3、认证方路由器配置 [r8]aaa #进入aaa认证 [r8-a…

HCIP—快速生成树协议(RSTP)实验配置

一、回顾STP和STP的缺点和不足 1.STP的概述: STP(生成树协议)是一种用于在网络中防止产生环路的链路管理协议。 2.STP的作用: 解决二层环路,防止广播报文产生。但是网络拓扑收敛较慢,影响通信质量。 3…

qt QSyntaxHighlighter详解

1、概述 QSyntaxHighlighter是Qt文本处理框架中的一个强大工具,它专门用于实现文本编辑器中的语法高亮功能。通过自定义高亮规则,QSyntaxHighlighter可以实现对代码编辑器、富文本编辑器中的关键字、注释等内容的高亮显示。这一功能对于提升代码的可读性…

PyQt5 加载UI界面与资源文件

步骤一: 使用 Qt Designer 创建 XXX.ui文件 步骤二: 使用 Qt Designer 创建 资源文件 步骤三: Python文件中创建相关类, 使用 uic.loadUi(mainwidget.ui, self ) 加载UI文件 import sys from PyQt5 import QtCore, QtWidgets, uic from PyQt5.QtCore import Qt f…

国家级财经类211/985学科院校招收申请制硕士

国家级财经类211/985学科院校招收申请制硕士 ◎免试入学,边学边考,申硕便捷; ●1.5-2年制,无需辞职,远程学习; ◎考试方式灵活,可多次申考; ●申请考核制,学信网报名注…

Spring Boot - 扩展点 EnvironmentPostProcessor源码分析及真实案例

文章目录 概述EnvironmentPostProcessor 作用EnvironmentPostProcessor 实现和注册创建类并实现接口注册到 Spring Boot常见应用场景 源码分析1. EnvironmentPostProcessor 接口定义2. 扩展点加载流程3. 加载 EnvironmentPostProcessor 实现类4. EnvironmentPostProcessor 执行…

解决表格出现滚动条样式错乱问题

自定义表格出现滚动条时,会因为宽度不对等导致样式错乱; 解决思路: 监听表格数据的变化,当表格出现滚动条时,再调用更新宽度的方法updateWidth,去改变表格头部的宽度,最终保持表格头部和内容对…

.NET中通过C#实现Excel与DataTable的数据互转

在.NET框架中,使用C#进行Excel数据与DataTable之间的转换是数据分析、报表生成、数据迁移等操作中的常见需求。这一过程涉及到将Excel文件中的数据读取并加载至DataTable中,以便于利用.NET提供的丰富数据处理功能进行操作,同时也包括将DataTa…

albert模型实现微信公众号虚假新闻分类

项目源码获取方式见文章末尾! 600多个深度学习项目资料,快来加入社群一起学习吧。 《------往期经典推荐------》 项目名称 1.【基于CNN-RNN的影像报告生成】 2.【卫星图像道路检测DeepLabV3Plus模型】 3.【GAN模型实现二次元头像生成】 4.【CNN模型实现…

java的JJWT 0.91在jdk21中报错的解决方法

参考了很多其他人的办法,只有这种方式可以解决问题 JSON Web Token(缩写 JWT) 目前最流行、最常见的跨域认证解决方案,前端后端都需要会使用的东西 如果根据黑马的视频,导入了阿里云OSS的相关依赖,自然不会…

最高提升20倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!

文章来源|豆包大模型团队 强化学习(RL)对大模型复杂推理能力提升有关键作用,然而,RL 复杂的计算流程以及现有系统局限性,也给训练和部署带来了挑战。传统的 RL/RLHF 系统在灵活性和效率方面存在不足&#x…

云计算:定义、类型及对企业的影响

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 云计算:定义、类型及对企业的影响 云计算:定义、类型及对企业的影响 云计算:定义、类型及对企…

RSTP的配置

RSTP相对于STP在端口角色、端口状态、配置BPDU格式、配置BPDU的处理方式、快速收敛机制、拓扑变更机制和4种保护特性方面的详细改进说明: 端口角色: STP中定义了三种端口角色:根端口(Root Port)、指定端口&#xff0…