自注意力和位置编码(比较卷积神经网络、循环神经网络和自注意力)

news2025/1/10 2:38:51
  • 在自注意力中,查询、键和值都来自同一组输入。

  • 卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

  • 为了使用序列的顺序信息,可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息。

  • 参考:10.6. 自注意力和位置编码 — 动手学深度学习 2.0.0 documentation

在深度学习中,经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。 想象一下,有了注意力机制之后,我们将词元序列输入注意力池化中, 以便同一组词元同时充当查询、键和值。 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。 由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention) (Lin et al., 2017, Vaswani et al., 2017), 也被称为内部注意力(intra-attention) (Cheng et al., 2016, Parikh et al., 2016, Paulus et al., 2017)。 本节将使用自注意力进行序列编码,以及如何使用序列的顺序作为补充信息。

pip install mxnet==1.7.0.post1
pip install d2l==0.15.0
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

1.自注意力

根据 (10.2.4)中定义的注意力汇聚函数f。 下面的代码片段是基于多头注意力对一个张量完成自注意力的计算, 张量的形状为(批量大小,时间步的数目或词元序列的长度,d)。 输出与输入的张量形状相同。 

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape
(2, 4, 100)

 

2.比较卷积神经网络、循环神经网络和自注意力

接下来比较下面几个架构,目标都是将由n个词元组成的序列映射到另一个长度相等的序列,其中的每个输入词元或输出词元都由d维向量表示。具体来说,将比较的是卷积神经网络、循环神经网络和自注意力这几个架构的计算复杂性、顺序操作和最大路径长度。请注意,顺序操作会妨碍并行计算,而任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系 (Hochreiter et al., 2001)。

 总而言之,卷积神经网络和自注意力都拥有并行计算的优势, 而且自注意力的最大路径长度最短。 但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

3.位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的, 而自注意力则因为并行计算而放弃了顺序操作。 为了使用序列的顺序信息,通过在输入表示中添加 位置编码(positional encoding)来注入绝对的或相对的位置信息。 位置编码可以通过学习得到也可以直接固定得到。 接下来描述的是基于正弦函数和余弦函数的固定位置编码 (Vaswani et al., 2017)。

乍一看,这种基于三角函数的设计看起来很奇怪。 在解释这个设计之前,让我们先在下面的PositionalEncoding类中实现它。 

#@save
class PositionalEncoding(nn.Block):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = np.zeros((1, max_len, num_hiddens))
        X = np.arange(max_len).reshape(-1, 1) / np.power(
            10000, np.arange(0, num_hiddens, 2) / num_hiddens)
        self.P[:, :, 0::2] = np.sin(X)
        self.P[:, :, 1::2] = np.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
        return self.dropout(X)

在位置嵌入矩阵P中, 行代表词元在序列中的位置,列代表位置编码的不同维度。 从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。 第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])

 

3.1绝对位置信息

为了明白沿着编码维度单调降低的频率与绝对位置信息的关系, 让我们打印出0,1,…,7的二进制表示形式。 正如所看到的,每个数字、每两个数字和每四个数字上的比特值 在第一个最低位、第二个最低位和第三个最低位上分别交替。

for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111

 

在二进制表示中,较高比特位的交替频率低于较低比特位, 与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。 由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.asnumpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
P = np.expand_dims(np.expand_dims(P[0, :, :], 0), 0)
show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')

 没搞出来

 

3.2相对位置信息

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/195546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nostr with NIP-05 Verification Guide

What is a NIPNIPs (Nostr Implementation Possibilities) document what MUST, what SHOULD and what MAY be implemented by Nostr-compatible relay and client software. See a complete list of nips here.NIP-05 提案是针对用户 Nostr 帐户的验证方法,可以将其…

The update is not applicable to your computer

在安装windows CVE-2022-44698补丁的时候出现了报错,"Windows Update Standalone Installer The update is not applicable to your computer" 1.找到漏洞对应的官方文章 CVE-2022-44698 - Security Update Guide - Microsoft - Windows SmartScreen S…

学术科研无从下手?27 条机器学习避坑指南,让你的论文发表少走弯路

内容一览:如果你刚接触机器学习不久,并且未来希望在该领域开展学术研究,那么这份为你量身打造的「避坑指南」可千万不要错过了。 关键词:机器学习 科研规范 学术研究 机器学习学术小白,如何优雅避坑坑、让自己的论文顺…

力扣sql简单篇练习(九)

力扣sql简单篇练习(九) 1 合作过至少三次的演员和导演 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 SELECT actor_id,director_id FROM ActorDirector GROUP BY actor_id,director_id HAVING count(timestamp)>31.3 运行截图 2 患某种疾病的患…

结构体内存对齐;内存优化

结构体需要根据数据类型进行内存对齐。 所有数据类型占据空间的大小,一定是它基础类型的倍数。 首先按照最大的数据类型格式来作为最小分割单元。 最大整数倍 struct stu {char gender;unsigned int age; }student1;int main() {printf("sizeof this struct i…

SpringCloud Config分布式配置中心

目录 一、概述 二、Config服务端配置与测试 配置读取规则 三、Config客户端配置与测试 bootstrasp.yml 四、Config客户端之动态刷新 一、概述 官网:Spring Cloud Config 微服务意味着要将单体应用中的业务拆分成一个个子服务,每个服务的粒度相对…

【Java】IDEA调试线上服务

目录IEDA打开线上服务对应的代码,Edit Configuration创建与远程服务的连接复制黄匡生成的参数,添加到服务器启动命令中服务器的服务重新启动,并注意调试端口5005是否被防火墙拦截本地IDEA启动,控制台打印如图即成功代码上加断点&a…

Linux做选择题时的要点

1.线程独有:栈,寄存器,信号屏蔽字,errno...等信息,因此各个线程各自有各自的栈区,但是堆区共用。 2.用户态线程的切换在用户态实现,不需要内核支持。 3.每个线程在进程虚拟地址空间中会分配拥…

部署在Docker中的iServer进行服务迁移

目录前言一:备份与恢复1.备份2.恢复二:迁移配置文件作者:kxj 前言 Linux 容器虚拟技术(LXC,Linux Container)是一种轻量级的虚拟化手段,它利用内核虚拟化技术提供轻量级的虚拟化,来…

【八大数据排序法】希尔排序法的图形理解和案例实现 | C++

第十七章 希尔排序法 目录 第十七章 希尔排序法 ●前言 ●认识排序 ●一、希尔排序法是什么? 1.简要介绍 2.图形理解 3.算法分析 ●二、案例实现 1.案例一 ●总结 前言 排序算法是我们在程序设计中经常见到和使用的一种算法,它主要是将一…

C++引用(特性+使用场景+常引用)

文章目录1. 概念2. 关于别名的理解3. 引用的特性1.引用必须在定义时初始化2.一个变量可以有多个别名3.引用一旦引用一个实体,再不能引用其他实体4.使用场景1. 引用做参数2. 引用做返回值1. 传值返回是否为n直接返回临时变量作为返回值2. 传引用返回编译器傻瓜式判断…

11、循环语句

目录 一、while语句 二、do…while语句 三、for语句 一、while语句 使用while语句可以执行循环结构,其一般形式如下: while语句的执行流程图: while语句首先检验一个条件,也就是括号中的表达式。当条件为真时,就执…

跟同事杠上了,用雪花算法生成的id做主键对MySQL性能有影响?

公司最近开发了一个新项目,设计表时由于有些字段需要对外展示,所以使用了雪花算法生成的id做主键。 不过有位同事对此提出了异议,认为雪花算法生成的id不是顺序递增的,会对MySQL的性能造成影响。 经过交流,发现持有这…

【Linux 系统运维基础】Linux目录 以及重要配置文件

Linux目录 以及重要配置文件 文本讲述了Linux中目录含义 以及我们工作中常用到的路径 1. 目录含义 2. 常用路径地址 2.1 网卡配置文件 /etc/sysconfig/network-scripts但是网卡的名称是有区别的,使用不同服务器生产商的名称是不同的。如下图: 网卡配置…

Window10下FFMPEG的安装与使用

文章目录一.FFMPEG介绍FFMPEG组成二.Windows10下FFMPEG安装三.FFMPEG的使用1.关键指令一.FFMPEG介绍 FFmpeg是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。采用LGPL或GPL许可证。它提供了录制、转换以及流化音视频的完整解决方案。它…

字节前端面试题目2

1.为什么通常在发送数据埋点请求的时候使用的是 1x1 像素的透明 gif 图片? 1. 没有跨域问题,一般这种上报数据,代码要写通用的;(排除 ajax) 2. 不会阻塞页面加载,影响用户的体验,只…

基于SSM的图书购物商城设计与实现

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

【C++】三大特性之继承

目录 一、继承的概念及定义 1.继承的概念 2. 继承定义 2.1定义格式 2.2继承关系和访问限定符 2.3继承基类成员访问方式的变化 二、基类和派生类对象赋值转换 三、继承中的作用域 四、派生类的默认成员函数 五、友元与继承 六、继承与静态成员 七、复杂的菱形继承及菱…

机器学习中的数学原理——线性不可分

这个专栏主要是用来分享一下我在 机器学习中的 学习笔记及一些感悟,也希望对你的学习有帮助哦!感兴趣的小伙伴欢迎 私信或者评论区留言!这一篇就更新一下《 白话机器学习中的数学——线性不可分》! 目录 一、什么是线性不可分 二…

用125行C语言编写一个简单的16位虚拟机

改博文用图文代码的方式详细描述了实现的具体过程,包含每一条指令的含义。 虚拟机 在计算领域,VM(虚拟机)是一个术语,指的是模拟/虚拟化计算机系统/架构的系统。 从广义上讲,有两类虚拟机: 系统…