《动手学深度学习》-67自注意力

news2024/11/18 21:27:37

沐神版《动手学深度学习》学习笔记,记录学习过程,详细的内容请大家购买书籍查阅。
b站视频链接
开源教程链接

自注意力在这里插入图片描述

在深度学习中,经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。有了注意力机制之后,我们将词元序列输入注意力汇聚中,以便同一组次元同时充当查询、键和值。由于查询、键和值来自于同一组输入,因此被称为自注意力(self-attention)。

在这里插入图片描述

给定一个由词元组成的输入序列 x 1 , . . . , x n x1,...,xn x1,...,xn,该序列的自注意力输出为一个长度相同的序列 y 1 , . . . , y n y1,...,yn y1,...,yn

在这里插入图片描述

比较卷积神经网络、循环神经网络和自注意力

比较三种架构的计算复杂度顺序操作最大路径长度,三种架构都是将由 n n n个词元组成的序列映射到另一个长度相同的序列。

CNN的最长路径≈感受野,RNN无法并行。

在这里插入图片描述

在处理词元序列时,循环神经网络是逐个重复地处理词元的,而自注意力则是因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息。位置编码可以通过学习得到,也可以直接固定。

在这里插入图片描述

奇列和偶列(6和7)位移不一样(sin变cos),奇列和奇列(7和9)周期不一样。总之每个样本加的位置编码不同。加进去的好处是不改变模型,坏处是不知道模型是否能够学到。

在这里插入图片描述

在二进制表示中,较高位的交替频率低于较低位,类似位置编码的热图,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示更节省空间。

在这里插入图片描述

除了捕获绝对位置信息,上述位置编码还允许模型学习得到输入序列中的相对位置信息,这是因为对于任何确定的位置偏移 δ \delta δ,位置 i + δ i+\delta i+δ处的位置编码可以用线性投影位置i处的位置编码来表示。

总结
在这里插入图片描述
卷积神经网络和自注意力都具有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的平方,所以在很长的序列中计算会非常慢。

动手学

自注意力

import torch
from torch import nn
from d2l import torch as d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape
torch.Size([2, 4, 100])

位置编码

#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

在这里插入图片描述

P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/836856.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【暑期每日一练】 day14

目录 选择题 (1) 解析: (2) 解析: (3) 解析: (4) 解析: (5) 解析: 编程题 题一 …

全球十大知名看黄金即时行情的软件名单(综合榜单)

在当今的数字化时代,黄金投资已成为一种受欢迎的投资方式。为了获取即时的黄金行情信息,许多投资者开始使用黄金即时行情软件。然而,选择一款合适的软件并不是一件容易的事情。那么,如何选适合自己需求的软件呢?首先&a…

Hyper实现git bash在windows环境下多tab窗口显示

1.电脑上安装有git bash 下载链接:https://gitforwindows.org/ 安装Hyper 下载链接:官网 https://hyper.is/ 或者在百度云盘下载: https://pan.baidu.com/s/1BVjzlK0s4SgAbQgsiK1Eow 提取码:0r1f 设置 打开Hyper,依次点左上角-&g…

Arduino驱动MQ5模拟煤气气体传感器(气体传感器篇)

目录 1、传感器特性 2、硬件原理图 3、驱动程序 MQ5气体传感器,可以很灵敏的检测到空气中的液化气、天然气、煤气等气体,与Arduino结合使用,可以制作火灾液化气、天然气、煤气泄露报警等相关的作品。 1、传感器特性 MQ5用于消费和工业行业中气体泄漏检测设备,该传感器适…

【网络】DNS、ICMP、NAT

目录 一、DNS(Domain Name System) 1、DNS背景 2、域名简介 二、ICMP协议 1、ICMP功能 2、ICMP的报文格式 3、ping命令 4、traceroute命令 三、NAT技术 1、NAT技术背景 2、NAT IP转换过程 3、NAPT 4、NAT技术的缺陷 5、NAT和代理服务器 一、DNS(Domain Name Syste…

分布式应用:Zookeeper 集群与kafka 集群部署

目录 一、理论 1.Zookeeper 2.部署 Zookeeper 集群 3.消息队列 4.Kafka 5.部署 kafka 集群 6.FilebeatKafkaELK 二、实验 1.Zookeeper 集群部署​​​​​​​ 2.kafka集群部署 3.FilebeatKafkaELK 三、问题 1.解压文件异常 2.kafka集群建立失败 3.启动 filebeat报…

中国信通院发布《高质量数字化转型产品及服务全景图(2023)》

2023年7月27日,由中国信息通信研究院主办的2023数字生态发展大会暨中国信通院铸基计划年中会议在北京成功召开。 本次大会发布了中国信通院《高质量数字化转型产品及服务全景图(2023)》,中新赛克海睿思受邀出席本次大会并成功入选…

HarmonyOS应用开发者基础与高级认证题库——中级篇

系列文章目录 HarmonyOS应用开发者基础与高级认证题库——基础篇 HarmonyOS应用开发者基础与高级认证题库——中级篇 文章目录 系列文章目录前言一、判断二、单选三、多选 前言 今天刚换了台果子手机就收到了华子鸿蒙开发认证邀请(认证链接)&#xff0…

2020年江西省职业院校技能大赛高职组“信息安全管理与评估”赛项任务书

2020年江西省职业院校技能大赛高职组 “信息安全管理与评估”赛项任务书 赛项时间 9:00-12:00,共计3小时。 赛项信息 赛项内容 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第一阶段 平台搭建与安全设备配置防护 任务1 网络平台搭建 9:00-12:00 100 任务…

2023下半年软考初级程序员报名入口-报名流程-备考方法

软考初级程序员2023下半年考试时间: 2023年下半年软考初级程序员的考试时间为11月4日、5日。考试时间在全国各地一致,建议考生提前备考。共分两科,第一科基础知识考试具体时间为9:00到11:30;第二科应用技术考试具体时间为2:00到4…

【C++刷题集】-- day5

目录 选择题 单选 编程题 统计回文⭐ 【题目解析】 【解题思路 - 穷举】 【优化】 连续最大和⭐ 【题目解析】 【解题思路】 【空间优化】 选择题 单选 1、 在上下文和头文件均正常情况下,以下程序的输出结果是 ( ) int x 1; do {printf("%2d\n&q…

Python批量查字典和双语例句

最近,有网友反映,我的批量查字典工具换到其它的网站就不好用了。对此,我想说的是,互联网包罗万象,网站的各种设置也有所不同,并不是所有的在线字典都可以用Python爬取的。事实上,很多网站为了防…

Python入门自学进阶-Web框架——38、redis、rabbitmq、git

缓存数据库redis: NoSQL(Not only SQL)泛指非关系型的数据库。为了解决大规模数据集合多重数据类的挑战。 NoSQL数据库的四大分类: 键值(Key-Value)存储数据库列存储数据库文档型数据库图形(…

MySQL最终弹-并发(脏读,不可重复读,幻读及区别),JDBC的使用和安装,最全万字

一、💛并发基本概念 并发的基本意思: 什么是并发呢?简单的理解就是同一时间执行 服务器同一时刻,给多个客户端提供服务~~,这两个客户端都可以给服务器提交事务。 如果提交两个事务,改…

召唤神龙打造自己的ChatGPT

在之前的两篇文章中,我介绍了GPT 1和2的模型,并分别用Tensorflow和Pytorch来实现了模型的训练。具体可以见以下文章链接: 1. 基于Tensorflow来重现GPT v1模型_gzroy的博客-CSDN博客 2. 花费7元训练自己的GPT 2模型_gzroy的博客-CSDN博客 有…

C++STL——map/multimap容器详解

纵有疾风起,人生不言弃。本文篇幅较长,如有错误请不吝赐教,感谢支持。 💬文章目录 一.对组(pair)二.map/multimap基本概念三.map容器常用操作①map构造函数②map迭代器获取③map赋值操作④map大小操作⑤map…

突破视觉边界:深入探索AI图像识别的现状与挑战

图像识别作为人工智能领域的一个重要研究方向,取得了许多令人瞩目的成就。深入探索当前AI图像识别技术的现状以及所面临的挑战,讨论各种方法的优势和局限性。 目录 引言1.1 AI图像识别的背景和概述1.2 人工智能在图像识别中的应用和重要性 图像识别基础知…

RISC-V基础指令之逻辑指令 and、or、xor、not

RISC-V的逻辑指令是用于对两个寄存器或一个寄存器和一个立即数进行按位的逻辑运算,并将结果存放在另一个寄存器中的指令。按位的逻辑运算就是把两个操作数的每一位分别进行相应的逻辑运算,得到一个新的位。RISC-V的逻辑指令有以下几种: and&…

c++高性能多进程 cuda编程:GPU结构和通信速度+tiling的代码实现

根据c高性能多进程 cuda编程:GPU结构和通信速度tiling的分析,依靠pytorch的JIT进行了实现,所以在安装pytorch的环境中,直接执行test.py就能直接运行。 代码结构如下,地址 mm.h void function_mm(float *c,const float *a,cons…

一文辨析,性能分析top命令中进程NI和PR

分析 Linux 服务器性能,首先想到的命令肯定是 top, 通过它,我们可以看到当前服务器资源使用情况和进程运行资源占用情况。 如果你想学习自动化测试,我这边给你推荐一套视频,这个视频可以说是B站播放全网第一的自动化测试教程&…