MultiHeadAttention多头注意力机制的原理

news2024/9/22 13:38:23

MultiHeadAttention多头注意力作为Transformer的核心组件,其主要由多组自注意力组合构成。

1. self-Attention自注意力机制

在NLP任务中,自注意力能够根据上下文词来重新构建目标词的表示,其之所以被称之为注意力,在于从上下文词中去筛选目标词更需要关注的部分,比如"他叫小明","他"这个词更应该关注"小明"这个上下文。

上图提示了一个输入为两个单词[Thinking, Matchines]的序列在经过自注意力构建后的变换过程:

  • 通过Embeding层,两个单词的one-hot向量转换为embedding向量X=[x1, x2]
  • 通过三组矩阵运算得到query、key、value值,这三组矩阵的输入都是原来同一个输入向量[x1,x2],这也是被称之为自注意力的原因。

       \\ Q=\begin{bmatrix} q_1\\ q_2 \end{bmatrix}_{2\times d_q}=\begin{bmatrix} x_1\\ x_2 \end{bmatrix}_{2\times d_x} *W^Q_{d_x \times d_q}\\ K=\begin{bmatrix} k_1\\ k_2 \end{bmatrix}_{2\times d_k}=\begin{bmatrix} x_1\\ x_2 \end{bmatrix}_{2\times d_x} *W^K_{d_x \times d_k}\\ V=\begin{bmatrix} v_1\\ v_2 \end{bmatrix}_{2\times d_v}=\begin{bmatrix} x_1\\ x_2 \end{bmatrix}_{2\times d_x} *W^K_{d_x \times d_v}

  • 计算query、key间的相似度得分,为了提升计算效率,此处采用缩放点积注意力,其需要query、key向量的维度是相等的,并且都满足零均值和单位方差,此时得分表示:

       \\ score(q, k)=\frac{q\cdot k }{\sqrt{d_k}}\\ Score(Q, K)_{2\times 2}=\begin{bmatrix} s_{11} & s_{12}\\ s_{21} & s_{22}\end{bmatrix}_{2 \times 2}=\frac{1}{\sqrt{d_k}}\begin{bmatrix} q1 && q1\\ q2 && q2 \end{bmatrix}_{2\times d_q}\begin{bmatrix} k1 & k2 \\ k1 & k2 \end{bmatrix}_{d_q \times 2}

  • 对相似度得分矩阵求softmax进行归一化(按axis=1维进行),在实际中由于进行transformer中的输入序列要求是定长的,因此会有补余向量,此时这里softmax会有一个掩蔽操作,将补余部分都置为0。

       softmax(\begin{bmatrix} \xrightarrow[]{s_{11} \ s_{12}}\\ \xrightarrow[]{s_{21} \ s_{22}}\end{bmatrix})=\begin{bmatrix} p_{11} & p_{12}\\ p_{21} & p_{22} \end{bmatrix}

  • 乘以value向量得到输出z:

      Z=\begin{bmatrix} z_1 \\ z_2 \end{bmatrix}= \begin{bmatrix} p_{11} & p_{12}\\ p_{21} & p_{22} \end{bmatrix}\begin{bmatrix} v_1 \\ v_2 \end{bmatrix}

完成了结构的分析后,接下来,我们考虑一个新的问题,为什么自注意力机制会有效?通过三组矩阵Q、K、V我们获得了原来的输入三种不同表征形式,其通过query-key的比较来衡量目标词和上下文词的相似性关联,通过value来提取词的本质特征,最终通过自注意力机制,我们建立了结合上下文信息的词的新特征向量,其本质是特征提取器。

2. MultiHeadAttention多头注意力机制

多头注意力是多组自注意力构件的组合,上文已经提到自注意力机制能帮助建立包括上下文信息的词特征表达,多头注意力能帮忙学习到多种不同类型的上下文影响情况,比如"今天阳光不错,适合出去跑步",在不同情景下,"今天"同"阳光"、"跑步"的相关性是不同,特别是头越多,越有利于捕获更大更多范围的相关性特征,增加模型的表达能力。

 上图描述了多头注意力的处理过程,其实际上将多个自注意机制的产出再经过参数矩阵得到一个新输出。我们将上述自注意步骤引入多头情况,介绍如何通过矩阵来计算,其由3组自注意力组合,输入为2个单词的序列。

  • query、key、value表征向量的计算

       \\ \begin{bmatrix} q^1_1 & q^2_1 & q^3_1\\ q^1_2 & q^2_2 & q^3_2\\ \end{bmatrix}_{2 \times 3d_q}=\begin{bmatrix} x_1 \\ x_2 \end{bmatrix}_{2 \times d_x}W^Q_{d_x \times 3d_q}\\ \begin{bmatrix} k^1_1 & k^2_1 & k^3_1\\ k^1_2 & k^2_2 & k^3_2\\ \end{bmatrix}_{2 \times 3d_k}=\begin{bmatrix} x_1 \\ x_2 \end{bmatrix}_{2 \times d_x}W^K_{d_x \times 3d_k}\\ \begin{bmatrix} v^1_1 & v^2_1 & v^3_1\\ v^1_2 & v^2_2 & v^3_2\\ \end{bmatrix}_{2 \times 3d_v}=\begin{bmatrix} x_1 \\ x_2 \end{bmatrix}_{2 \times d_x}W^V_{d_x \times 3d_v}

  • 计算query、key间的相似度得分

       score(Q,K)=\begin{bmatrix} s^1_{11} & s^1_{12}\\ s^1_{21} & s^1_{22}\\s^2_{11} & s^2_{12}\\s^2_{21} & s^2_{22}\\s^3_{11} & s^3_{12}\\s^3_{21} & s^3_{22} \end{bmatrix}_{3\cdot 2 \times 2}=\frac{1}{\sqrt{d_q}}\begin{bmatrix} q^1_1 & q^1_1 \\ q^1_2 & q^1_2 \\ q^2_1 & q^2_1\\ q^2_2 & q^2_2 \\ q^3_1 & q^3_1 \\ q^3_2 & q^3_2\end{bmatrix}_{3\cdot 2 \times 2}\begin{bmatrix} k^1_1 & k^1_2 \\ k^1_1 & k^1_2 \\ k^2_1 & k^2_2\\ k^2_1 & k^2_2 \\ k^3_1 & k^3_2 \\ k^3_1 & k^3_2\end{bmatrix}_{3\cdot 2 \times 2}

  • 对相似度得分矩阵求softmax
  • 乘以value向量得到各自注意力模块输出,并乘以输出权重矩阵得到最终输出矩阵O,其最终还是得到了多头注意力的输出,其d_o为输出词向量维度,如果其维度等于输入词向量维度时,输出和输入的尺度是一致的,因此多头注意力机制本质仍是特征抽取器。

        \\ Z=\begin{bmatrix} z^1_1 \\ z^1_2\\z^2_1 \\ z^2_2\\z^3_1 \\ z^3_2 \end{bmatrix}= \begin{bmatrix} p^1_{11} & p^1_{12}\\ p^1_{21} & p^1_{22} \\ p^2_{11} & p^2_{12}\\ p^2_{21} & p^2_{22} \\ p^3_{11} & p^3_{12}\\ p^3_{21} & p^3_{22} \end{bmatrix}\begin{bmatrix} v_1 \\ v_2 \\ v_1 \\ v_2 \\ v_1 \\ v_2 \end{bmatrix}\\ O_{2\times d_o}=\begin{bmatrix} z^1_1 & z^2_1 & z^3_1 \\ z^1_2 & z^2_2 & z^3_2 \end{bmatrix}_{2\times 3d_z}W^o_{3d_z \times d_o}

3. MultiHeadAttention多头注意力机制的代码

上图左侧为单点积注意力Dot-Product Attention组件的结构(当Q,K,V为同一输入时称之为自注意力),右侧为多个单注意力组件组成多头注意力,以下是paddle的实现代码:

class DotProductAttention(nn.Layer):
    def __init__(self, query_size, query_vec_size, hidden_vec_size, **kwargs):
        """
        input:
            query_size: int, 序列长度(词数)
            query_vec_size: int, 词向量长度
            hidden_vec_size: int, 输出词向量长度
        """
        super(DotProductAttention, self).__init__(**kwargs)
        self.query_size = query_size
        self.query_vec_size = query_vec_size
        self.hidden_vec_size = hidden_vec_size
        # 线性变换层
        self.W_q = nn.Linear(query_vec_size, hidden_vec_size)
        self.W_k = nn.Linear(query_vec_size, hidden_vec_size)
        self.W_v = nn.Linear(query_vec_size, hidden_vec_size)

    def forward(self, queries, keys, values, valid_lens):
        """
        input:
            queries,keys,values: tensor([batch_size, query_size, query_vec_size]), 输入
            valid_lens: tensor([batch_size]), 序列中有效长度
        output:
            output: tensor([batch_size, query_size, hidden_vec_size]), 输出
        """ 
        #1. Linear: queries, keys, values的线性变换, out shape: [batch_size, query_size, hidden_vec_size]
        queries = self.W_q(queries)
        keys = self.W_k(keys)
        values = self.W_v(values)
        # 2. score, shape: [batch_size, query_size, query_size]
        scores = paddle.bmm(queries, keys.transpose((0, 2, 1))) / math.sqrt(self.hidden_vec_size)
        scores = scores.reshape([-1, self.query_size])
        # 3. mask, shape: [batch_size * query_size, query_size]
        mask = paddle.arange(self.query_size, dtype=paddle.float32)[None, :] < paddle.repeat_interleave(valid_lens, self.query_size)[:, None]
        scores[~mask] = float(-1e6)
        # 4. softmax [batch_size, query_size, query_size]
        scores = scores.reshape([-1, self.query_size, self.query_size])
        scores = nn.functional.softmax(scores, axis=-1)
        # 5. output [batch_size, query_size, query_size] * [batch_size, query_size, hidden_vec_size]
        return paddle.bmm(scores, values)

在实际中,多个单注意力组件的计算可以通过同一矩阵进行并行计算,如第2节所描述,以下完成最终多头注意力的代码,可以看出其同单注意力的代码几乎差不多:

class MultiHeadAttention(nn.Layer):
    def __init__(self, query_size, query_vec_size, hidden_vec_size, output_vec_size, head_num, **kwargs):
        """
        input:
            query_size: int, 序列长度(词数)
            query_vec_size: int, 词向量长度
            hidden_vec_size: int, 变换层词向量长度
            output_vec_size: int, 输出层词向量长度
            head_num: int, 头数
        """
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.query_size = query_size
        self.query_vec_size = query_vec_size
        self.hidden_vec_size = hidden_vec_size
        self.output_vec_size = output_vec_size
        self.head_num = head_num
        # 线性变换层
        self.W_q = nn.Linear(query_vec_size, hidden_vec_size * head_num)
        self.W_k = nn.Linear(query_vec_size, hidden_vec_size * head_num)
        self.W_v = nn.Linear(query_vec_size, hidden_vec_size * head_num)
        self.W_o = nn.Linear(hidden_vec_size * head_num, output_vec_size)

    def forward(self, queries, keys, values, valid_lens):
        """
        input:
            queries,keys,values: tensor([batch_size, query_size, query_vec_size]), 输入
            valid_lens: tensor([batch_size]), 序列中有效长度
        output:
            output: tensor([batch_size, query_size, hidden_vec_size]), 输出
        """ 
        #1. Linear: queries, keys, values的线性变换, out shape: [batch_size, query_size, hidden_vec_size * head_num]
        queries = self.W_q(queries)
        keys = self.W_k(keys)
        values = self.W_v(values)
        # 2. score, shape: [batch_size * head_num, query_size, query_size]
        queries = queries.reshape([-1, self.query_size, self.hidden_vec_size, self.head_num])\
                         .transpose((0, 3, 1, 2))\
                         .reshape([-1, self.query_size, self.hidden_vec_size])
        keys = keys.reshape([-1, self.query_size, self.hidden_vec_size, self.head_num])\
                   .transpose((0, 3, 1, 2))\
                   .reshape([-1, self.query_size, self.hidden_vec_size])
        values = values.reshape([-1, self.query_size, self.hidden_vec_size, self.head_num])\
                   .transpose((0, 3, 1, 2))\
                   .reshape([-1, self.query_size, self.hidden_vec_size])
        scores = paddle.bmm(queries, keys.transpose((0, 2, 1))) / math.sqrt(self.hidden_vec_size)
        scores = scores.reshape([-1, self.query_size])
        # 3. mask, shape: [batch_size * head_num * query_size, query_size]
        mask = paddle.arange(self.query_size, dtype=paddle.float32)[None, :] < paddle.repeat_interleave(valid_lens, self.query_size * self.head_num)[:, None]
        scores[~mask] = float(-1e6)
        # 4. softmax [batch_size, query_size * head_num, query_size]
        scores = scores.reshape([-1, self.head_num, self.query_size, self.query_size])
        scores = nn.functional.softmax(scores, axis=-1)
        # 5. output [batch_size, query_size, head_num * hidden_vec_size]
        z = paddle.bmm(scores.reshape([-1, self.query_size, self.query_size]), values)
        z = z.reshape([-1, self.head_num, self.query_size, self.hidden_vec_size]).transpose((0, 2, 1, 3))
        # 6. output linear
        return self.W_o(z.reshape([-1, self.query_size, self.head_num * self.hidden_vec_size]))

4. 为什么要用注意力机制

按论文Attention Is All You Need的观点,上图为 self-attention同cnn、rnn在复杂度上的比较,其中n是指序列的长度,d是序列词向量的维度,k表示卷积核的大小,我们为什么将self-attention作为序列数据特征编码或解码器,主要基于三点理由:

  • Complexity per Layer 计算量:计算量是指每层的计算量,可以看出当词向量的维度大于序列长度时,self-attention的计算量是要更小的。
  • Sequential Operations 并行实现:self-attention完全可以通过矩阵运算来实现并行计算。
  • Maximum Path Length 上下文依赖特征获取难易:比如序列中第1个词同最后1个词,上下文依赖特征获取难易是指要获取这两个词相互特征需要经过多少步,对于RNN网络需要经过序列长度的步,而对于CNN同卷积核和层数有关,对于self-attention只需要1步就能建立第1个词同最后1个词的关系。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Spring6】| Spring6集成MyBatis3.5

目录 一&#xff1a;Spring6集成MyBatis3.5 第一步&#xff1a;准备数据库表 第二步&#xff1a;IDEA中创建一个模块&#xff0c;并引入依赖 第三步&#xff1a;基于三层架构实现&#xff0c;所以提前创建好所有的包 第四步&#xff1a;编写pojo 第五步&#xff1a;编写m…

【Redis数据库】异地公网远程登录连接Redis教程

文章目录 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接 转发自CSDN远程穿透的文章&#xff1a;公网远程连接R…

Java阶段二Day05

Java阶段二Day05 文章目录 Java阶段二Day05截至此版本可实现的流程图为V14UserControllerClientHandlerDispatcherServletHttpServletResponseHttpServletRequest V15DispatcherServletHttpServletResponseHttpServletRequest V16HttpServletRequestHttpServletResponse 反射JA…

SpringCloud整合AOP做日志管理

目录 1、前置知识2、步骤2.1、依赖2.2、自定义注解&#xff0c;用于注解式AOP2.3、定制切面类2.4、测试 1、前置知识 切面&#xff08;Aspect&#xff09;&#xff1a;官方的抽象定义为“一个关注点的模块化&#xff0c;这个关注点可能会横切多个对象”&#xff0c;在本例中&a…

超详细Redis入门教程——Redis命令(上)

前言 本文小新为大家带来 超详细Redis入门教程——Redis命令&#xff08;上&#xff09; 相关知识&#xff0c;具体内容包括Redis 基本命令&#xff0c;Key 操作命令&#xff0c;String 型 Value 操作命令&#xff0c;Hash 型 Value 操作命令&#xff0c;List 型 Value 操作命令…

快速搭建外卖配送服务:利用外卖系统源码实现

外卖配送服务已经成为了现代消费者生活的一部分&#xff0c;它不仅方便了消费者的用餐需求&#xff0c;也给商家提供了新的销售渠道&#xff0c;同时也为外卖配送员提供了更多的就业机会。为了满足这个市场的需求&#xff0c;外卖系统源码应运而生。 外卖系统源码是一个集成了…

第一章:数、式、方程与方程组

1.实数 1.内容概述 1.了解实数分类2.数轴3.相反数和倒数4.绝对值5.算数平方根相关概念及有关计算2.实数分类 3.实数的基本概念 1.数轴:规定原点、正方向和单位长度的直线叫做数轴2.相反数:绝对值相同而符号相反的两个数,互称相反数3.倒数:1除以任何数的商,我们叫做倒数,0…

超市购物系统【GUI/Swing+MySQL】(Java课设)

系统类型 Swing窗口类型Mysql数据库存储数据 使用范围 适合作为Java课设&#xff01;&#xff01;&#xff01; 部署环境 jdk1.8Mysql8.0Idea或eclipsejdbc 运行效果 本系统源码地址&#xff1a;https://download.csdn.net/download/qq_50954361/87682510 更多系统资源库…

Jenkins ssh windows 部署 java程序

版权说明&#xff1a; 本文由博主keep丶原创&#xff0c;转载请保留该段内容在文章头部。 原文地址&#xff1a; https://blog.csdn.net/qq_38688267/article/details/130203785 文章目录 前言实现步骤1. windows下载安装ssh2. windows 安装 winsw2.1 下载 winsw2.2 配置winsw2…

Linux 0.11启动过程分析(一)

Linux 0.11 系列文章 Linux 0.11启动过程分析&#xff08;一&#xff09; Linux 0.11 fork 函数&#xff08;二&#xff09; Linux0.11 缺页处理&#xff08;三&#xff09; Linux0.11 根文件系统挂载&#xff08;四&#xff09; Linux0.11 文件打开open函数&#xff08;五&…

[oeasy]python0132_变量含义_meaning_声明_declaration_赋值_assignment

变量定义 回忆上次内容 上次回顾了一下历史 python 是如何从无到有的看到 Guido 长期的坚持和努力 编程语言的基础都是变量声明 python是如何声明变量的呢&#xff1f; 变量 想要定义变量首先明确什么是变量 变量就是数值能变的量英文名称 variable 计算机在内存中分配出…

SpringBoot Starter 作用及原理

本文会以 mybatis 为例&#xff0c;通过对比 mybatis-spring 和 mybatis-spring-boot-starter 代码示例&#xff0c;了解 Starter 的作用。并对 mybatis-spring-boot-starter 进行简单剖析&#xff0c;了解 Starter 原理。 下面还有投票&#xff0c;一起参与进来吧&#x1f44d…

DataEase看中国 - 中国影星“成龙”电影票房数据分析

背景介绍 说起成龙&#xff0c;我们并不陌生&#xff0c;著名的动作明星。以武打动作片出道&#xff0c;凭借动作片《红番区》打入好莱坞&#xff0c;该片打破北美外语片票房纪录。 目前&#xff0c;由成龙、郭麒麟等主演的新片《龙马精神》正在公映&#xff0c;电影《…

【每日一练】JAVA算法求柱状图中最大的矩形面积

文章目录 前言题目分析算法实战1、创建算法方法2、创建测试用例3、查看测试结果 写在最后 前言 作为一名以JAVA语言为主的搬砖人&#xff0c;学习掌握好函数语法很重要&#xff0c;但是算法也是需要掌握的。今天我们就分享一个求柱状图中最大的矩形面积的题目&#xff0c;这个…

torch.utils.data.DataLoader中的next(iter(train_dataloader))

在做实验时&#xff0c;我们常常会使用用开源的数据集进行测试。而Pytorch中内置了许多数据集&#xff0c;这些数据集我们常常使用DataLoader类进行加载。 如下面这个我们使用DataLoader类加载torch.vision中的FashionMNIST数据集。 from torch.utils.data import DataLoader …

数据结构入门(C语言)顺序表的增删查改

目录 前言1. 顺序表的概念2. 动态顺序表2.1 顺序表的初始化与销毁2.2 顺序表的尾插容量检查2.3 顺序表的尾删2.4 顺序表的头插2.5 顺序表的头删2.6 固定位置的插入2.7 固定位置的删除2.8 查找和打印2.9 修改元素主函数部分(菜单) 结语 前言 本章会用C语言来描述数据结构中的顺…

协同运力、算力、存力,加速迈向智能世界

2023年4月20日&#xff0c;华为在HAS2023期间举办“迈向智能世界”主题论坛&#xff0c;吸引了来自全球的分析师、专家学者及媒体与会。会上&#xff0c;华为ICT战略与Marketing总裁彭松发表了“持续技术创新&#xff0c;加速迈向智能世界”的主题演讲。 华为ICT战略与Marketin…

zabbix监控linux主机

1.本实验使用centos7主机&#xff0c;IP地址为10.1.60.115&#xff0c;firewalld和selinux服务已关闭 2.下载zabbix yum源(与zabbix server用一样的版本) rpm -Uvh https://repo.zabbix.com/zabbix/5.0/rhel/7/x86_64/zabbix-release-5.0-1.el7.noarch.rpm 3.安装zabbix客户…

玛雅水上乐园|玩趣系列作品集

玛雅水上乐园曾经是一座历史悠久的玛雅金字塔&#xff0c;曾用于宗教和水上航行&#xff0c;被废弃了 3000 多年。现在&#xff0c;01a1 工作室已将其改造成一个令人兴奋的旅游景点&#xff0c;在这里你可以享受美食和饮料&#xff0c;享受日光浴&#xff0c;并结交新朋友。所以…

从零学习SDK(8)SDK的集成和部署

选择使用SDK与其他平台和服务进行集成和部署的好处有&#xff1a; 简化开发流程&#xff0c;节省时间和成本&#xff0c;无需从零开始编写复杂的代码逻辑。 保证功能的稳定性和兼容性&#xff0c;避免出现各种潜在的错误和问题。 享受SDK提供方的技术支持和更新&#xff0c;获…