机器学习深度学习——自注意力和位置编码(数学推导+代码实现)

news2025/1/17 5:52:10

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——注意力分数(详细数学推导+代码实现)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

自注意力和位置编码

  • 引入
  • 自注意力
    • 多头注意力
    • 基于多头注意力实现自注意力
  • 比较CNN、RNN和self-attention
    • 结论
    • 剖析——CNN
    • 剖析——RNN
    • 剖析——self-attention
    • 总结
  • 位置编码
    • 绝对位置信息
    • 相对位置信息
  • 小结

引入

在深度学习中,经常使用CNN和RNN对序列进行编码。有了自注意力之后,我们将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为自注意力(self-attention)。下面将使用自注意力进行序列编码。

import math
import torch
from torch import nn
from d2l import torch as d2l

自注意力

给定一个由词元组成的序列:
x 1 , . . . , x n 其中任意 x i ∈ R d x_1,...,x_n\\ 其中任意x_i∈R^d x1,...,xn其中任意xiRd
该序列的自注意力输出为一个长度相同的序列:
y 1 , . . . , y n 其中 y i = f ( x i , ( x 1 , x 1 ) , . . . , ( x n , x n ) ) ∈ R d y_1,...,y_n\\ 其中y_i=f(x_i,(x_1,x_1),...,(x_n,x_n))∈R^d y1,...,yn其中yi=f(xi,(x1,x1),...,(xn,xn))Rd
自注意力就是这样,任意的xi都是既当key,又当value,还当query。
下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,张量形状为(批量大小,时间步数目或词元序列长度,d)。输出与输入的张量形状相同。
而在此之前,简单讲解下多头注意力,接着基于多头注意力实现自注意力。

多头注意力

当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系。因此允许注意力机制组合使用查询、键和值的不同子空间表示是有益的。
因此,与其只使用一个注意力池化,我们可以独立学习得到h组不同的线性投影来变换查询、键和值。然后,这h组变换后的查询、键和值将并行地送到注意力池化中。最后将这h个注意力池化的输出拼接在一起,并通过另一可以学习的线性投影进行变换,来产生最终输出。这就是多头注意力(multihead attention),如下图所示:
在这里插入图片描述
而多头注意力的实现过程通常使用的是缩放点积注意力来作为每一个注意力头,我们设定:
p q = p k = p v = p o / h p_q=p_k=p_v=p_o/h pq=pk=pv=po/h
值得注意的是,如果将查询、键和值的线性变化的输出数量设置为:
p q h = p k h = p v h = p o p_qh=p_kh=p_vh=p_o pqh=pkh=pvh=po
就可以并行计算h个头,下面代码中的po是通过num_hiddens指定的。

代码如下:

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

基于多头注意力实现自注意力

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

可以输出验证一下:

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)

输出结果:

torch.Size([2, 4, 100])

比较CNN、RNN和self-attention

首先看这个图:
在这里插入图片描述
接下来进行CNN、RNN以及self-attention三个架构的比较,首先这三个架构目标都是要将n个词元组成的序列映射到另一个长度相同的序列,其中的每个输入词元或输出词元都由d维向量表示。我们的比较将基于计算的复杂性、顺序操作和最大路径长度,先给出结论再进行剖析解释。
我们首先要知道,顺序操作会妨碍并行计算,而任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系。

结论

计算复杂度并行度最大路径长度
CNNO(knd2)O(n)O(n/k)
RNNO(nd2)O(1)O(n)
self-attentionO(n2d)O(n)O(1)

剖析——CNN

考虑一个卷积核大小为k的卷积层,由于序列长度是n,输入和输出的通道数量都是d,所以卷积层的计算复杂度为O(knd2)。而如上图所示,可以看出CNN网络是分层的,因此会有O(1)个顺序操作,那么这代表着通道可以并行执行n个词元,那么并行度就是O(n)。
上图中可以看出k=3,因为这样刚好就使得x1和x5处于这个卷积核大小为3的双层卷积神经网络的感受野内。因此最大的路径长度一定是不会超过n/k的,下标为n的也会因为卷积核被限制到一个感受野内,因此可以知道最大路径长度为O(n/k)。

剖析——RNN

当更新RNN的隐状态时,d×d权重矩阵和d维隐状态的乘法计算复杂度为O(d2),再加上序列长度为n,因此RNN的计算复杂度为O(nd2),由上图也可以看出n个序列的顺序操作是没办法并行化的,则并行度为O(1),最大路径长度是O(n)(可以理解成当我们要组合y1和yn的时候,这时候长度为n)。

剖析——self-attention

查询、键、值都是n×d矩阵。计算过程为:n×d矩阵乘以d×n矩阵,之后得到的n×n矩阵再乘以n×d矩阵,因此自注意力有O(n2d)的计算复杂度。而上图展示了自注意力的强大,O(n)的并行度显而易见,同时最大路径长度是O(1),因为他们可以任意组合。

总结

总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。
但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加位置编码来注入绝对的或相对的位置信息。
位置编码可以通过学习得到也可以直接固定得到,下面讲解基于正弦函数和余弦函数的固定位置编码。
假设输入表示X∈Rn×d包含一个序列中n个词元的d维嵌入表示。位置编码使用相同形状的位置嵌入矩阵P∈Rn×d输出X+P,矩阵第[i,2j](偶数列)和[i,2j+1](奇数列)列上的元素为:
p i , 2 j = s i n ( i 1000 0 2 j / d ) , p i , 2 j + 1 = c o s ( i 1000 0 2 j / d ) p_{i,2j}=sin(\frac{i}{10000^{2j/d}}),\\ p_{i,2j+1}=cos(\frac{i}{10000^{2j/d}}) pi,2j=sin(100002j/di),pi,2j+1=cos(100002j/di)
看起来很奇怪,在后面讲解的时候就能看出来了,先定义一个类来实现它:

#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

我们可以进行打印图像,可以清晰看到6、7列比8、9列频率高,而6与7(8与9同理)由于正余弦函数的相位交替,而导致偏移量不同。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
d2l.plt.show()

运行结果:
在这里插入图片描述

绝对位置信息

其实就是二进制了,想象一下0-7的二进制表示是各不相同的,而且容易知道:较高比特位的交替频率低于较低比特位(而使用三教函数的话输出的是浮点数,显然会更省空间)。

相对位置信息

除了捕获绝对位置信息之外,上述的位置编码还允许模型学习得到输入序列中相对位置信息。这是因为对于任何确定的位置偏移σ,位置i+σ处的位置编码可以线性投影位置i处的位置编码来表示。
用数学来表示:
令 w j = 1 / 1000 0 2 j / d ,对于任何确定的位置偏移 σ : [ c o s ( σ w j ) s i n ( σ w j ) − s i n ( σ w j ) c o s ( σ w j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ c o s ( σ w j ) s i n ( i w j ) + s i n ( σ w j ) c o s ( i w j ) − s i n ( σ w j ) s i n ( i w j ) + c o s ( σ w j ) c o s ( i w j ) ] = [ s i n ( ( i + σ ) w j ) c o s ( ( i + σ ) w j ) ] ——积化和差 = [ p i + σ , 2 j p i + σ , 2 j + 1 ] 令w_j=1/10000^{2j/d},对于任何确定的位置偏移σ:\\ \begin{bmatrix} cos(σw_j)&sin(σw_j)\\ -sin(σw_j)&cos(σw_j) \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =\begin{bmatrix} cos(σw_j)sin(iw_j)+sin(σw_j)cos(iw_j)\\ -sin(σw_j)sin(iw_j)+cos(σw_j)cos(iw_j) \end{bmatrix}\\ =\begin{bmatrix} sin((i+σ)w_j)\\ cos((i+σ)w_j) \end{bmatrix}——积化和差\\ =\begin{bmatrix} p_{i+σ,2j}\\ p_{i+σ,2j+1} \end{bmatrix} wj=1/100002j/d,对于任何确定的位置偏移σ[cos(σwj)sin(σwj)sin(σwj)cos(σwj)][pi,2jpi,2j+1]=[cos(σwj)sin(iwj)+sin(σwj)cos(iwj)sin(σwj)sin(iwj)+cos(σwj)cos(iwj)]=[sin((i+σ)wj)cos((i+σ)wj)]——积化和差=[pi+σ,2jpi+σ,2j+1]
2×2投影矩阵不依赖于任何位置的索引i。

小结

1、在自注意力中,查询、键和值都来自同一组输入。
2、卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。
3、为了使用序列的顺序信息,可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/888491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

导读-Linux简介

Linux简介 ​ 总所周知,计算机系统包含硬件和软件两部分。硬件部分被称为裸机,主要包括中央处理器(CPU)、内存、外存和各种外部设备。软件部分主要包括系统软件和应用软件两部分。系统软件包括操作系统、汇编语言、编译程序、数据…

Leetcode Top 100 Liked Questions(序号53~74)

53. Maximum Subarray 题意:一个数组,找到和最大的子串 我的思路 我记得好像On的动态规划来做的?但是想不起来了,先死做,用的前缀和——TLE超时 那就只能想想dp怎么做了 假设dp[i]表示的是以 i 为右端点的最大的…

在医疗行业数字孪生能做些什么?

数字孪生技术随着发展正在多行业遍地开花,在之前的文章中也为大家介绍过数字孪生的行业应用,今天带大家一起探讨一下数字孪生在医疗行业的表现。其实数字孪生在医疗行业已有很多应用案例,从医疗诊断到手术模拟,再到药物研发&#…

android内存分析工具记录,请利用好最后2个神器

相机见证了java内存暴增和native持续增长的问题,因此这里记录一下使用的工具情况,方便后续继续使用 一、java 内存 如果是java层的内存可以直接借助leakCanary工具,配置也很简单,直接在build.gradle中添加依赖即可: …

Java语言怎么输出有颜色的字符串呢?

在Java中,我们应该如何输出有颜色的文字字符串呢? 目录 一、使用方法 二、举例说明 三、常见的颜色及其对应的ANSI转义序列 一、使用方法 在Java中,可以使用ANSI转义序列来改变输出文本的颜色。 二、举例说明 (1&#xff…

【mysql报错解决】MySql.Data.MySqlClient.MySqlException (0x80004005)或1366

场景:c#使用mysql数据库执行数据库迁移,使用了新增inserter的语句,然后报错 报错如下: 1.MySql.Data.MySqlClient.MySqlException (0x80004005): Incorrect string value: ‘\xE6\x9B\xB4\xE6\x94\xB9…’ for column ‘Migratio…

LVS 负载均衡集群

集群 集群(Cluster)是一组相互连接的计算机或服务器,它们通过网络一起工作以完成共同的任务或提供服务。集群的目标是通过将多台计算机协同工作,提高计算能力、可用性、性能和可伸缩性,适用于大量高并发的场景。 集群…

安科瑞变电所运维平台在电力系统中应用分析

摘要:现代居民生活、工作对电力资源的需求量相对较多,给我国的电力产业带来了良好的发展机遇与挑战。探索电力系统基本构成, 将变电运维安全管理以及相应的设备维护工作系统性开展,能够根据项目实践工作要求,将满足要求…

【讨论】视频监控集中存储方案如何做?

视频监控集中存储是指将多个视频监控摄像头所捕捉到的视频信号集中存储于一个中央设备,这个中央设备可以是服务器、网络存储设备或其他专用设备。通过集中存储,可以避免因为存储设备分散而导致的管理不便和难以有效地管理和检索视频数据,同时…

iTOP-RK3568开发板ubuntu环境下安装Eclipse

eclipse 是使用 Java 语言开发的,一个 Java 应用程序,这意味着 eclipse 只能运行在 Java虚拟机上。倘若没有安装 JDK(Java Development Kit),即使在 ubuntu 上安装了 eclipse,也不能运行,所以要…

软文发布问题解答:高效宣传与推广指南

以下是一秒推小编针对软文发布的20个常见问题及回答: 1. 什么是软文? 答:软文是指用文学手法、写作技巧撰写的宣传文章,以实现对特定受众的陈述、说明和推销。 2. 发布软文的目的是什么? 答:发布软文的目…

奥威BI数据可视化工具:360度呈现数据,告别枯燥表格

随着企业数据量的不断增加,如何有效地进行数据分析与决策变得越来越重要。奥威BI数据可视化工具作为一款强大的数据分析工具,在帮助企业深入挖掘数据价值方面具有显著优势。 奥威BI数据可视化工具是一款基于数据仓库技术的数据分析工具,具有…

RTT(RT-Thread)ADC设备(RTT保姆级介绍)

目录 ADC设备 前言 ADC相关参数说明 访问ADC设备 配置ADC设备 ADC实例 硬件设计 软件设计 ADC设备 前言 ADC(Analog-to-Digital Converter) 指模数转换器。是指将连续变化的模拟信号转换为离散的数字信号的器件。 对于ADC的详细介绍和在STM32中的裸机应用可参考以下…

C语言暑假刷题冲刺篇——day2

目录 一、选择题 二、编程题 🎈个人主页:库库的里昂 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏✨收录专栏:C语言每日一练 ✨其他专栏:代码小游戏C语言初阶🤝希望作者的文章能对你…

基于OFDM+64QAM系统的载波同步matlab仿真,输出误码率,星座图,鉴相器,锁相环频率响应以及NCO等

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 2.1 OFDM原理 2.2 64QAM调制 2.3 载波同步 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022a 3.部分核心程序 ............................................…

图数据库_Neo4j学习cypher语言_使用CQL_构建明星关系图谱_导入明星数据_导入明星关系数据_创建明星关系---Neo4j图数据库工作笔记0009

首先找到明星数据 可以看到有一个sheet1,是,记录了所有的关系的数据 然后比如我们搜索一个撒贝宁,可以看到撒贝宁的数据 然后这个是构建的CQL语句 首先我们先去启动服务 neo4j console 然后我们再来看一下以前导入的,可以看到导入很简单, 就是上面有CQL 看一下节点的属性

RTT(RT-Thread)IIC设备

目录 IIC设备 IIC介绍 电气连接 IIC总线时序 IIC协议 读协议 写协议 访问I2C总线设备 查找 I2C 总线设备 I2C数据读写(数据传输) 配置IIC步骤 IIC设备 IIC介绍 I2C(Inter Integrated Circuit)总线是 PHILIPS 公司开发…

【vue3】对axios进行封装,方便更改路由并且可以改成局域网ip访问(附代码)

对axios封装是在main.js里面进行封装,因为main.js是一个vue项目的入口 步骤: 在1处创建一个axios实例为http,baseURL是基础地址(根据自己的需求写),写了这个在vue界面调用后端接口时只用在post请求处写路由…

【Git】(三)回退版本

1、git reset命令 1.1 回退至上一个版本 git reset --hard HEAD^ 1.2 将本地的状态回退到和远程的一样 git reset --hard origin/master 注意:谨慎使用 –-hard 参数,它会删除回退点之前的所有信息。HEAD 说明:HEAD 表示当前版本HEAD^ 上…

深入探究Linux黑客渗透测试:方法、工具与防御

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 引言 随着信息技术的迅…