《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码

news2025/1/10 11:52:33

在注意力机制中,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention),也被称为内部注意力(intra-attention)。本节将使用自注意力进行序列编码,以及使用序列的顺序作为补充信息。

import math
import torch
from torch import nn
from d2l import torch as d2l

10.6.1 自注意力

在这里插入图片描述

给定一个由词元组成的输入序列 x 1 , … , x n \boldsymbol{x}_1,\dots,\boldsymbol{x}_n x1,,xn,其中任意 x i ∈ R d ( 1 ≤ i ≤ n ) \boldsymbol{x}_i\in\R^d\quad(1\le i\le n) xiRd(1in) 。该序列的自注意力输出为一个长度相同的序列 y 1 , … , y n \boldsymbol{y}_1,\dots,\boldsymbol{y}_n y1,,yn,其中:

y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \boldsymbol{y}_i=f(\boldsymbol{x}_i,(\boldsymbol{x}_1,\boldsymbol{x}_1),\dots,(\boldsymbol{x}_n,\boldsymbol{x}_n))\in\R^d yi=f(xi,(x1,x1),,(xn,xn))Rd

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,  # 基于多头注意力对一个张量完成自注意力的计算
                                   num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # 张量的形状为(批量大小,时间步的数目或词元序列的长度,d)。
attention(X, X, X, valid_lens).shape  # 输出与输入的张量形状相同
torch.Size([2, 4, 100])

10.6.2 比较卷积神经网络、循环神经网络和自注意力

在这里插入图片描述

  • 卷积神经网络

    • 计算复杂度为 O ( k n d 2 ) O(knd^2) O(knd2)

      • k k k 为卷积核大小

      • n n n 为序列长度是

      • d d d 为输入和输出的通道数量

    • 并行度为 O ( n ) O(n) O(n)

    • 最大路径长度为 O ( n / k ) O(n/k) O(n/k)

  • 循环神经网络

    • 计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

      d × d d\times d d×d 权重矩阵和 d d d 维隐状态的乘法计算复杂度为 O ( d 2 ) O(d^2) O(d2),由于序列长度为 n n n,因此循环神经网络层的计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

    • 并行度为 O ( 1 ) O(1) O(1)

      O ( n ) O(n) O(n) 个顺序操作无法并行化。

    • 最大路径长度也是 O ( n ) O(n) O(n)

  • 自注意力

    • 计算复杂度为 O ( n 2 d ) O(n^2d) O(n2d)

      查询、键和值都是 n × d n\times d n×d 矩阵

    • 并行度为 O ( n ) O(n) O(n)

      每个词元都通过自注意力直接连接到任何其他词元。因此有 O ( 1 ) O(1) O(1) 个顺序操作可以并行计算

    • 最大路径长度也是 O ( 1 ) O(1) O(1)

总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

10.6.3 位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加 位置编码(positional encoding) 来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。

基于正弦函数和余弦函数的固定位置编码的矩阵第 i i i 行、第 2 j 2j 2j 列和 2 j + 1 2j+1 2j+1 列上的元素为:

p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) \begin{align} p_{i,2j}&=\sin{\left(\frac{i}{10000^{2j/d}}\right)}\\ p_{i,2j+1}&=\cos{\left(\frac{i}{10000^{2j/d}}\right)} \end{align} pi,2jpi,2j+1=sin(100002j/di)=cos(100002j/di)

#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

在位置嵌入矩阵 P \boldsymbol{P} P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。从下面的例子中可以看到位置嵌入矩阵的第 6 列和第 7 列的频率高于第 8 列和第 9 列。第 6 列和第 7 列之间的偏移量(第 8 列和第 9 列相同)是由于正弦函数和余弦函数的交替。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])


在这里插入图片描述

10.6.3.1 绝对位置信息

打印出 0 , 1 , … , 7 0,1,\dots,7 0,1,,7 的二进制表示形式即可明白沿着编码维度单调降低的频率与绝对位置信息的关系。

每个数字、每两个数字和每四个数字上的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替。

for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111

在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')


在这里插入图片描述

10.6.3.2 相对位置信息

除了捕获绝对位置信息之外,上述的位置编码还允许模型学习得到输入序列中相对位置信息。这是因为对于任何确定的位置偏移 δ \delta δ,位置 i + δ i+\delta i+δ 处的位置编码可以线性投影位置 i i i 处的位置编码来表示。

这种投影的数学解释是,令 ω j = 1 / 1000 0 2 j / d \omega_j=1/10000^{2j/d} ωj=1/100002j/d,对于任何确定的位置偏移 δ \delta δ,上个式子中的任何一对 ( p i , 2 j , p i , 2 j + 1 ) (p_{i,2j},p_{i,2j+1}) (pi,2j,pi,2j+1) 都可以线性投影到 ( p i + δ , 2 j , p i + δ , 2 j + 1 ) (p_{i+\delta,2j},p_{i+\delta,2j+1}) (pi+δ,2j,pi+δ,2j+1)

[ cos ⁡ ( δ ω j ) sin ⁡ ( δ ω j ) − sin ⁡ ( δ ω j ) cos ⁡ ( δ ω j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ cos ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + sin ⁡ ( δ ω j ) cos ⁡ ( i ω j ) − sin ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + cos ⁡ ( δ ω j ) cos ⁡ ( i ω j ) ] = [ sin ⁡ ( ( i + δ ) ω j ) cos ⁡ ( ( i + δ ) ω j ) ] = [ p i , 2 j p i , 2 j + 1 ] \begin{align} &\begin{bmatrix} \cos{(\delta\omega_j)} & \sin{(\delta\omega_j)}\\ -\sin{(\delta\omega_j)} & \cos{(\delta\omega_j)} \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =&\begin{bmatrix} \cos{(\delta\omega_j)}\sin{(i\omega_j)}+\sin{(\delta\omega_j)}\cos{(i\omega_j)}\\ -\sin{(\delta\omega_j)}\sin{(i\omega_j)}+\cos{(\delta\omega_j)}\cos{(i\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} \sin{((i+\delta)\omega_j)}\\ \cos{((i+\delta)\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix} \end{align} ===[cos(δωj)sin(δωj)sin(δωj)cos(δωj)][pi,2jpi,2j+1][cos(δωj)sin(iωj)+sin(δωj)cos(iωj)sin(δωj)sin(iωj)+cos(δωj)cos(iωj)][sin((i+δ)ωj)cos((i+δ)ωj)][pi,2jpi,2j+1]

2 × 2 2\times 2 2×2 投影矩阵不依赖于任何位置的索引 i i i

练习

(1)假设设计一个深度架构,通过堆叠基于位置编码的自注意力层来表示序列。可能会存在什么问题?


(2)请设计一种可学习的位置编码方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1137935.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

react-typescript-demo

1.使用 Context 来存储数据

rabbitMQ入门指南:管理页面全面指南及实战操作

文章目录 1. 引言2. RabbitMQ 管理页面的概览3. 探索 RabbitMQ 管理页面的主要功能3.1 连接3.2 通道3.3 交换机3.4 队列3.5 生产者3.6 消费者 4. RabbitMQ 的实战例子4.1 创建虚拟机4.2 连接和通道4.3 建立交换机4.3.1 交换机类型4.3.2 创建一个交换机 4.4 创建队列4.4.1 队列类…

海外广告投放保姆级教程,如何使用Quora广告开拓新流量市场?

虽然在Quora 上学习广告相对容易,但需要大量的试验和错误才能找出最有效的方法。一些广告技巧可以让您的工作更有效率。这篇文章将介绍如何有效进行quora广告投放与有价值的 Quora 广告要点,这将为您节省数万美元的广告支出和工作时间!往下看…

五大绩效指标,为企业新媒体矩阵管理注入新动力

⭐关注矩阵通服务号,探索企业新媒体矩阵搭建与营销策略 新媒体矩阵的建设运营是为了能实现企业确定的业务目标。 那如何才能让推动最后业务目标的实现? 对于企业来说,可以将业务目标拆解为短期绩效目标,通过适当调整短期指标来保证…

【Java 进阶篇】Java HTTP 请求消息详解

HTTP(Hypertext Transfer Protocol)是一种用于传输超文本的应用层协议,广泛用于构建互联网应用。在Java中,我们经常需要发送HTTP请求来与远程服务器进行通信。本文将详细介绍Java中HTTP请求消息的各个部分,包括请求行、…

spring-初识spring

初识spring Spring简介 初识spring一、Spring 特性二、IOC容器(反转控制)1、IOC容器1.1、IOC思想1.2、IOC容器在Spring中的实现 2、基于XML管理bean2.1入门案例2.2获取bean2.3依赖注入(DI)2.3.1set方法注入2.3.1构造器注入 2.4 为类类型属性赋值2.4.1 引用外部已声明bean2.4.2 …

一键同步,无处不在的书签体验:探索多电脑Chrome书签同步插件

说在前面 平时大家都是怎么管理自己的浏览器书签数据的呢?有没有过公司和家里的电脑浏览器书签不同步的情况?有没有过电脑突然坏了但书签数据没有导出,导致书签数据丢失了?解决这些问题的方法有很多,我选择自己写个chr…

数据结构与算法基础(青岛大学-王卓)(9)

终于迎来了最后一部分(排序)了,整个王卓老师的数据结构就算是一刷完成了,但是也才是数据结构的开始而已,以后继续与诸位共勉 😃 (PS.记得继续守护家人们的健康当然还有你自己的)。用三根美味的烤香肠开始吧。。。 文章目录 排序基…

高级深入--day41

用Pymongo保存数据 爬取豆瓣电影top250movie.douban.com/top250的电影数据,并保存在MongoDB中。 items.py class DoubanspiderItem(scrapy.Item):# 电影标题title scrapy.Field()# 电影评分score scrapy.Field()# 电影信息content scrapy.Field()# 简介info …

基于SSM的OA办公系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

nginx浏览器缓存和上流缓存expires指令_nginx配置HTTPS

1.nginx控制浏览器缓存是针对于静态资源[js,css,图片等] 1.1 expires指令 location /static {alias/home/imooc;#设置浏览器缓存10s过期expires 10s;#设置浏览器缓存时间晚上22:30分过期expires @22h30m;#设置浏览器缓存1小时候过期expires -1h;#设置浏览器不缓存expires …

【JavaEE初阶】 线程安全的集合类

文章目录 🍀前言🌲多线程环境使用 ArrayList🚩自己使用同步机制 (synchronized 或者 ReentrantLock)🚩Collections.synchronizedList(new ArrayList);🚩使用 CopyOnWriteArrayList 🎍多线程环境使用队列&am…

通过jdk自制https证书并配置到nginx并配置http2

生成证书 这里使用自己生成的免费证书。在${JAVA_HOME}/bin 下可以看到keytool.exe,在改目录打开cmd然后输入: keytool -genkey -v -alias lgq.com -keyalg RSA -keystore d:/zj/ssl/fastfly.com.keystore -validity 3650生成证书过程中:【你的名字】对…

美国海运专线冬季通航情况

随着全球经济一体化的推进,海运业务的重要性日益凸显。特别是在美国,作为世界上最大的经济体之一,其海运业务的发展对全球物流供应链有着重要影响。而在冬季,美国海运专线又会面临怎样的通航情况呢?下面就让我们一起来探讨一下。…

区块链技术与应用 【全国职业院校技能大赛国赛题目解析】第二套区块链系统部署与运维

第二套区块链系统部署与运维题目 环境 : ubuntu20 fisco : 2.8.0 docker: 20.10.21 webase-deploy : 1.5.5 mysql: 8.0.34 子任务1-2-1: 搭建区块链系统并验证(4分) 使用build_chain.sh 脚本文件进行搭建区块链 ,要求: 四节点,默认配置,单机,docker root@192-168-19…

【推荐】如何使用drawio的编辑器,看本文就够

学会使用drawio的编辑器 drawio是免费的开源的图表应用,你可以通过app.diagrams.net使用在线版本或者下载离线安装版在桌面端使用。 对于团队来讲, 把数据存储作为安全第一的图表应用, 你可以选择存储数据到不同的介质。例如Atlassian Conf…

deepxde更改backend

deepxde更改backend 1. 问题描述2. 解决方案 1. 问题描述 如果在使用deepxde库时,想使用除tensorflow.compat.v1之外的后端依赖,例如pytorch,会出现如下情况: 这时提示目前的后端依赖是tensorflow.compat.v1,需要更改…

RHCE---正则表达式

文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 一. 文本搜索工具 grep是linux中一种强大的文件搜索过滤工具,可以按照正 则表达式检索文件内容,并把匹配的结果显示到屏幕上 (匹配的内容会标红&#x…

[Python]Selenium-自动化测试

Selenium是一个web自动化测试的工具,在使用之前先在对应的项目添加工具包噢. 本文章主要简单的介绍了selenium对于自动化测试的使用 目录 添加浏览器驱动 get函数来到对应网站 驱动的定位 元素定位 id定位 class name定位 CSS定位 XPath定位 link text定位 定位一…

【如何去掉Unity点击运行时,Photon的警告弹窗】

如何去掉Unity点击运行时,Photon的警告弹窗 一、前言二、解决办法1、解决思路2、解决步骤 一、前言 我导入了Photon插件,但是我现在用不着,又不想将其删掉,Photon的配置也没有弄,导致现在一运行Unity,就会…