Transformer 中自注意力机制的 一些细节理解

news2024/9/29 7:29:17

摘自知乎博主https://www.zhihu.com/question/362131975/answer/2182682685?utm_oi=78365163782144

作者:月来客栈
链接:https://www.zhihu.com/question/362131975/answer/2182682685

1. 多头注意力机制原理

1.1 动机

首先让我们先一起来看看作者当时为什么要提出Transformer这个模型?需要解决什么样的问题?现在的模型有什么样的缺陷?

1.1.1 面临问题

现在主流的序列模型都是基于复杂的循环神经网络或者是卷积神经网络构造而来的Encoder-Decoder模型,并且就算是目前性能最好的序列模型也都是基于注意力机制下的Encoder-Decoder架构。

由于传统的Encoder-Decoder架构在建模过程中,下一个时刻的计算过程会依赖于上一个时刻的输出,而这种固有的属性就限制了传统的Encoder-Decoder模型就不能以并行的方式进行计算。

所以作者会不停的提及这些传统的Encoder-Decoder模型

最近的工作通过因式分解技巧和条件计算显着提高了计算效率,同时还提高了后者的模型性能。然而,顺序计算的基本限制仍然存在。

1.1.2 解决思路

因此,作者首次提出了一种全新的Transformer架构来解决这一问题,如图1-2所示。当然,Transformer架构的优点在于它完全摈弃了传统的循环结构,取而代之的是只通过注意力机制来计算模型输入与输出的隐含表示,而这种注意力的名字就是大名鼎鼎的自注意力机制(self-attention),也就是图1-2中的Multi-Head Attention模块。


总体来说,所谓自注意力机制就是通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。最终,Transformer架构就是基于这种的自注意力机制而构建的Encoder-Decoder模型。

解释这句话

  1. “所谓自注意力机制就是通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重”

    • 自注意力机制是一种计算方法,用来判断句子中每个词在表示句子含义时的重要性。
    • 在这个过程中,自注意力机制会计算句子中每个词对其他词的“关注度”(也就是注意力权重),这是一种通过数学运算(如点积、归一化等)来完成的。
  2. “然后再以权重和的形式来计算得到整个句子的隐含向量表示”

    • 这些注意力权重会被用于计算每个词的隐含表示(也就是词向量),并通过加权求和的方式整合起来,得到整个句子的表示。
    • 隐含向量是句子的数值化表示,用于之后的模型运算和预测。
  3. “最终,Transformer架构就是基于这种的自注意力机制而构建的Encoder-Decoder模型”

    • Transformer是一种神经网络架构,它的核心部分就是自注意力机制。
    • Transformer模型包括编码器(Encoder)和解码器(Decoder),通过自注意力机制来处理和理解输入数据(如句子)并生成输出(如翻译、文本生成等)。

总结来说,这句话的意思是:自注意力机制是通过计算句子中各个词的重要性来生成整个句子的表示,Transformer模型就是基于这种机制构建的用于处理语言任务的模型。

1.2 技术手段

自注意力机制 ,然后再来探究整体的网络架构。

1.2.1 什么是self-Attention

        就是“Scaled Dot-Product Attention“。注意力机制可以描述为将query和一系列的key-value对映射到某个输出的过程,而这个输出的向量就是根据query和key计算得到的权重作用于value上的权重和(Value加权和)。 需要结合Transformer的解码过程,才能更好地理解。

## 7. ScaledDotProductAttention
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        ## 输入进来的维度分别是 [batch_size x n_heads x len_q x d_k]  K: [batch_size x n_heads x len_k x d_k]  V: [batch_size x n_heads x len_k x d_v]
        ##首先经过matmul函数得到的scores形状是 : [batch_size x n_heads x len_q x len_k]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)

        ## 然后关键词地方来了,下面这个就是用到了我们之前重点讲的attn_mask,把被mask的地方置为无限小,softmax之后基本就是0,对q的单词不起作用
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        #一横行做softmax
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

从上图和代码可以看出,自注意力机制的核心过程就是通过Q和K计算得到注意力权重;然后再作用于V得到整个权重和输出。具体的,对于输入Q、K和V来说,其输出向量的计算公式为:

  • Q(Query)的维度通常是\text{batch size} \times \text{sequence length} \times d_q   
  • K(Key)和 V(Value)维度类似
  • 其中,batch size 是一次输入的数据批量大小,sequence length 是输入序列的长度(即词的数量), d_q是每个词的 Query 向量的维度。d_k是每个词的 Key 向量的维度。d_v是每个词的 Value 向量的维度.
  • 多头之后,QKV 张量的维度 如下所示
  • 代码解释
  • K.transpose(-1, -2) 通过转置操作将 K 的最后两个维度对调,使得它的列数与 Q 的列数相同,从而能够进行点积计算。

为什么要进行缩放? 去掉方差的影响

是因为通过实验作者发现,对于较大的d_k来说,在完成^{^{^{^{}}}}QK^{T}后将会得到很大的值,而这将导致在经过sofrmax操作后产生非常小的梯度,不利于网络的训练。 (内积 逐元素相乘相加 ,dk维度越大 结果可能就偏大)

最后,文章中还通过随机变量的分析来解释为什么点积结果会变得很大:

借鉴了 (65 封私信 / 81 条消息) transformer的细节到底是怎么样的? - 知乎 (zhihu.com)

中的第二个回答
链接:https://www.zhihu.com/question/362131975/answer/3039107481
 

 我们拿 Q 矩阵中的任意一列 q, K 矩阵中的任意一行 k 出来

翻译:

内积qk的期望

E\left ( q_{i} \times k_{j} \right )=E\left ( q_{i} \right )\times E\left ( k_{j} \right ) =0\times 0=0

所以 根据计算公式可得   内积qk的期望:

E\left ( q\odot k \right )=E(\sum\left ( q_{i}\odot k _{j}\right ) )=0+0+...+0=0

内积qk的方差

 证明参考 https://www.jianshu.com/p/a97ece3459dc

https://www.cnblogs.com/yuyuanliu/p/15968716.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2052256.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IP SSL证书快速申请教程

在互联网安全领域中,SSL证书是比较普遍的传输数据加密方式之一。SSL证书通过建立加密通道,确保客户端与服务器之间传输的数据不被第三方窃取或篡改。而大多数SSL证书,如单域名SSL证书、多域名SSL证书以及通配符SSL证书,在申请时必…

颇为实用的现代化开源数据表格GristCore

GristCore:用Grist,让数据自动化,让工作更智能。 - 精选真开源,释放新价值。 概览 Grist-core项目是Grist的心脏,是一个创新的在线数据协作平台,它突破了传统电子表格的局限,引入了先进的自动化…

宋仕强论道之效率与成本的关系

宋仕强论道之效率与成本的关系中说,效率于企业的意义重大,一是技术发展和应用带来效率提高,农耕文明与工业时代分别以铁制农具应用和电气化为标志。在现阶段,人工智能(AI)是目前最有效的新技术,…

Linux找回root密码,帮助指令

目录 找回root密码 帮助指令 man获得帮助指令 help指令 应用实例 找回root密码 进入开机界面,输入e进入编辑界面。 在指定位置输入init/bin/sh 再输入ctrlx进入单用户模式。 最后输入passwd修改密码。 帮助指令 man获得帮助指令 man ls后可以看到很多指令以…

【leetcode】相交链表-25-1

方法:遍历 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode(int x) : val(x), next(NULL) {}* };*/ class Solution { public:ListNode *getIntersectionNode(ListNode *headA, ListNode *headB) {…

Gene_processing_system-v2.0使用之环境变量配置

Gene_processing_system-v2.0环境变量配置 在D盘路径解压上述文件《Gene_processing_system-v2.0.zip》,解压后,对内置Python3.9环境变量进行配置。操作如下: 环境变量配置 第一步:复制python3.9路径值,复制路径值为…

【MySQL】数据的基本操作(CRUD)

系列文章目录 例如:第一章 数据库基础 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 系列文章目录前言一、对数据简单操作新增操作 (create)查询操作(select)模糊查询分页查询修改操作(update)删除操…

C 开源库之cJSON

cJSON简介 CJSON库是一个用于解析和生成JSON数据的C语言库。 它提供了一组函数,使得在C语言中操作JSON数据变得简单而高效。 您可以使用CJSON库来解析从服务器返回的JSON数据,或者将C语言数据结构转换为JSON格式以进行传输。 cJSON 使用 官网地址&…

JAVA同城找搭子同城交友系统小程序源码

🌈【同城搭子交友系统】—— 遇见你的城市小确幸✨ 👭 城市喧嚣中的温暖邂逅 在繁忙的都市生活中,你是否常常感到孤单,渴望有那么几个志同道合的朋友,一起探索这座城市的每一个角落?🏙️ 同城…

【机器学习-监督学习】逻辑斯谛回归

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈Python机器学习 ⌋ ⌋ ⌋ 机器学习是一门人工智能的分支学科,通过算法和模型让计算机从数据中学习,进行模型训练和优化,做出预测、分类和决策支持。Python成为机器学习的首选语言,…

【前端 23】用Less优化你的CSS书写体验

用Less优化你的CSS书写体验 Less(Leaner Style Sheets)是一种CSS预处理器,它扩展了CSS的功能,引入了变量、嵌套规则、混合(Mixins)、运算等特性,使得CSS编写更加灵活和强大。下面,我…

【虚拟化】KVM命令行安装linux虚拟机

目录 ​一、KVM概述 1.1 KVM是什么 1.2 KVM和QEMU的关系 1.3 kvm相关安装包及其作用 二、安装KVM 三、命令行创建虚拟机并安装CentOS7 四、报错处理 4.1 问题1 4.2 问题2 一、KVM概述 1.1 KVM是什么 KVM(Kernel-based Virtual Machine, 即内核级虚拟机) 是一个开源的系…

热门好用骨传导耳机怎么挑选?推荐这五款值得入手的骨传导耳机

近两年来,骨传导运动蓝牙耳机在运动领域内日益流行。与传统耳机相比,它的显著优势是能够保持双耳开放,不会堵塞耳道,消除了入耳式耳机可能引起的不适感。此外还能避免运动时耳内出汗可能导致的各种卫生和健康问题。很多人就问了&a…

3.4交换机端口安全配置的方法和步骤

一、设置端口安全性 switchport port-security 二、设置某端口的安全Mac地址 switchport port-security mac-address <mac 地址> 三、设置端口允许通过的最多mac地址数量 switchport port-security maximum<数量> 默认为1,通常最多1024个 四、检测到违规时的…

leetcode 堆栈(栈+优先队列)——java实现

java创建堆栈和操作函数 Queue<String> queue new LinkedList<String> ();//队列定义 Deque<String> stack new LinkedList<String>();//堆栈 队列方法&#xff1a; queue.offer(e) null queue.poll() 返回移除的值 queue.peek() 堆栈方法&#xff1…

从零开始学cv-8:直方图操作进阶:直方图匹配,局部直方图均衡化,彩色直方图均衡化

文章目录 一&#xff0c;简介二、直方图匹配三、局部直方图均衡化四、彩色直方图均衡化4.1 rgb彩色直方图均衡化4.2 ycrb 彩色直方图均衡化 一&#xff0c;简介 在上一篇文章中&#xff0c;我们探讨了直方图的基本概念&#xff0c;并详细讲解了如何利用OpenCV来查看图像直方图…

MATLAB 大场景建筑物点云提取方法实现(75)

MATLAB 大场景建筑物点云提取方法实现(75) 一、算法介绍二、算法实现1.代码2.效果展示总结一、算法介绍 本章手动实现了一种建筑物点云提取方法,可以对室外的大规模场景点云中的建筑物进行有效提取,下面是实现的效果和具体的实现方法,直接复制粘贴代码即可使用, 二、算…

【基础算法总结】多源 BFS_多源最短路问题

多源 BFS_多源最短路问题 1.多源 BFS_多源最短路问题2.01 矩阵3.飞地的数量4.地图中的最高点5.地图分析 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&…

springboot纹理生成图片系统--论文源码调试讲解

第2章 程序开发技术 2.1 MySQL数据库 开发的程序面向用户的只是程序的功能界面&#xff0c;让用户操作程序界面的各个功能&#xff0c;那么很多人就会问&#xff0c;用户使用程序功能生成的数据信息放在哪里的&#xff1f;这个就需要涉及到数据库的知识了&#xff0c;一般来说…

Maven继承和聚合特性

目录 Maven继承关系 1.继承概念 父POM 子模块 2.继承机制 3.示例 4.继承作用 背景 需求 5.注意事项 Maven聚合关系 1. 定义与概念 2. 实现方式 3. 特性与优势 4. 示例 5. 注意事项 Maven继承关系 1.继承概念 Maven 继承是指在 Maven 的项目中&#xff0c;定义…