Attention Is All You Need详解

news2024/11/17 13:51:07

一.背景。

在此模型之前,序列到序列的任务(如机器翻译、文本摘要等)通常采用循环神经网络(RNN)或卷积神经网络(CNN)。然而,RNN 在处理长距离依赖时存在一定的局限性(举个例子:处理第Kt个词时,需要用到K1到Kt-1的词的输出作为输入),训练时也比较耗时。而 CNN 在处理序列数据时难以捕捉到全局的依赖关系。然而这篇文章介绍的模型Transformer完全基于注意力机制,与CNN,RNN,LSTM模型对比更加简单并且高效。

二.模型架构。

Transformer 模型采用了编码器-解码器架构。先上一个论文里面的架构图,再逐步介绍其中的各个部分。
在这里插入图片描述

1.Embedding

在这里插入图片描述

Embedding是什么:

为了对字符进行计算,我们首先需要将字符(或单词)转换成一种数值表示形式。独热编码(One-Hot Encoding)是一种常用的方法之一,例如词汇表 {'猫': 0, '狗': 1, '苹果': 2} (此处的索引012一般是根据某个词典获得,即某词典0号索引处为单词‘猫’),‘ 猫’ 的独热编码就是 [1, 0, 0] ,‘狗’ 的独热编码就是 [0, 1, 0],但是这样的缺点就是向量维度高且稀疏,计算效率低。如果词汇表有 10,000 个单词,那么每个独热向量的维度就是 10,000,并且是稀疏的,即大部分元素都是 0,只有一个位置是 1。
所以在深度学习特别是自然语言处理(NLP)中,我们通常会采用更加高效的嵌入表示(Embedding)。通过Embedding,每个字符(或单词)被表示为一个低维的密集向量。这些向量是通过训练得到的,可以捕捉字符(或单词)之间的语义关系。例如通过嵌入层,单词“猫”可能被表示为一个 5 维(维数由我们定义)向量 [0.0376, -0.2343, 0.1655, -0.0053, 0.1353] 。可见其特点是维度低且密集,计算效率高,并且能够捕捉语义信息。

Embedding的例子:

在 PyTorch 中,有一个函数torch.nn.Embedding(num_embedings, embedding_dim),其中 num_embedding 表示词表总的长度,embedding_dim 表示单词嵌入的维度,此函数会创建一个嵌入矩阵(通常是随机的),其形状为 (num_embedding , embedding_dim),给定输入张量(通常是单词索引),其形状为(batch_size, sequence_length),该层会将每个索引映射到对应的嵌入向量,返回一个形状为 (batch_size, sequence_length, embedding_dim) 的张量。
以下是例子代码:

import torch
import torch.nn as nn

# 定义词汇表大小和嵌入向量维度
vocab_size = 10 #词汇表大小
embedding_dim = 5 #向量维度

# 创建嵌入层
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

# 输入张量(单词索引)
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)  # 示例输入,形状为 (batch_size, sequence_length)

# 通过嵌入层得到嵌入向量
embedding_output = embedding_layer(input_tensor)

print(f"Input Tensor Shape: {input_tensor.shape}")
print(f"Embedding Output Shape: {embedding_output.shape}")
print(f"Embedding Output:\n{embedding_output}")
'''
输出:
Input Tensor Shape: torch.Size([2, 3])
Embedding Output Shape: torch.Size([2, 3, 5])
Embedding Output:
tensor([[[ 0.0069,  0.0465, -0.0205,  0.0080, -0.0114],
         [-0.0244,  0.0404,  0.0452, -0.0027, -0.0307],
         [ 0.0024, -0.0043,  0.0340,  0.0370, -0.0400]],

        [[ 0.0057, -0.0015, -0.0154, -0.0306, -0.0375],
         [ 0.0317, -0.0275,  0.0160,  0.0283,  0.0040],
         [-0.0331, -0.0061,  0.0452,  0.0484, -0.0350]]], grad_fn=<EmbeddingBackward0>)
'''

Transformer中的Embedding

论文原文:(在嵌入层中,我们将这些权重乘以√dmodel)
在这里插入图片描述

在Embedding中使用 math.sqrt(self.d_model)(即 d \sqrt[]{d} d ) 进行缩放,是在实际实现中的一种实践,可以保持数值稳定性,确保在随后的计算中,尤其是在与模型其他部分进行交互时,不会出现数值过大或过小的问题。

  • 数学解释:
    在这里插入图片描述
    以下是复现代码:(这里暂定input跟output的Embedding是一样的)
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        # vocab_size:词表长度    d_model:嵌入维度
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # x:输入张量
        # 乘以 根号dk 保持数据稳定性 
        return self.embedding(x) * math.sqrt(self.d_model)

2.Positional Encoding

在这里插入图片描述

为什么需要Positional Encoding

论文原文:(由于我们的模型不包含递归和卷积,为了使模型利用序列的顺序,我们必须注入一些关于序列中标记的相对或绝对位置的信息)
在这里插入图片描述
也就是attention没有时序信息,需要我们自己加入。(RNN的做法是上一个时刻的输出作为此时刻的输入以此引入时序信息)

Positional Encoding的实现

论文原文:
在这里插入图片描述

其中,pos 即 position,意为 token 在句中的位置,i为向量的某一维度。借助此公式再结合三角函数的性质
在这里插入图片描述
可以得到:
在这里插入图片描述
可以看出,对于 pos+k 位置的位置向量某一维 2i 或 2i+1 而言,可以表示为,pos 位置与k位置的位置向量的2i与 2i+1维的线性组合,这样的线性组合意味着位置向量中蕴含了相对位置信息。具体可以参考视频讲解。
以下是复现代码:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        '''初始化函数,三个参数分别是:
            d_model:词嵌入维度; dropout:置0比率(位置编码与输入嵌入相加后一起作为模型的输入。模型在学习过程中会学习如何利用这些位置信息。如果位置编码没有经过 dropout 的正则化处理,模型可能会过度依赖这些位置信息,从而对训练数据记忆过深,导致在处理未见数据时表现不佳。)
            max_len:每个句子的最大长度。
        '''
        super(PositionalEncoding, self).__init__()

        #实例化dropout层,并传入参数
        self.dropout = nn.Dropout(p=dropout)
        #初始化一个位置编码矩阵,全为0,大小是max_len * d_model
        pe = torch.zeros(max_len, d_model)
        #初始化一个绝对位置矩阵,词的绝对位置即索引位置
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        #接下来就是把位置信息加入到位置编码矩阵中去,也就是把max_len * 1的position绝对位置矩阵变换成max_len * d_model形状,然后覆盖初始矩阵
        #也就是max_len * 1 的矩阵去乘以一个 1 * d_modl 的变换矩阵div_term,然后再进行覆盖,这里因为位置编码可以分成奇数和偶数两部分,故可以将变换矩阵更改为 1 * (d_model / 2)的形状
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        #按照公式给位置编码进行赋值
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        #这样子就得到了位置编码矩阵pe,但是要和embedding的输出相加就必须拓展一个维度
        pe = pe.unsqueeze(0)

        #因为无论我们输入的是什么,这个位置编码都不会改变,也就是所有的输入是公用一个位置编码的,所以这边使用self.register_buffer,其是一个用于将张量注册为模型的一部分的方法。它的主要用途是注册一些不作为模型参数的持久状态,例如在训练和推理过程中不需要更新的固定数据。
        self.register_buffer('pe', pe)

    def forward(self,x):
        #因为一个句子有长有短,所以可以位置编码只截取到句子的实际长度即可。
        x = x + self.pe[:, :x.size(1)]
        #最后使用dropout防止过拟合,并返回结果。
        return self.dropout(x)
实际例子

通过一个超参数比较小的例子输出并展示还是比较容易理解每一步骤的做法的。

import torch
import torch.nn as nn
import math

max_len = 10
d_model = 6

# 初始化位置编码矩阵
pe = torch.zeros(max_len, d_model)
print(pe)
'''
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
'''


# 初始化绝对位置矩阵
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
print(position)
'''
tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]])
'''

# 计算变换矩阵 div_term
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
print(div_term)
'''
Div term matrix:
tensor([1.0000, 0.0464, 0.0022])
'''
# 计算位置编码矩阵
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
print(pe)
'''
tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])
'''

#最后再添加一个维度
pe = pe.unsqueeze(0)
print(pe)
'''
tensor([[[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
         [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
         [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
         [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
         [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
         [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
         [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
         [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
         [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
         [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]]])
'''

3.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839178.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

20240619在飞凌OK3588-C的LINUX系统启动的时候拉高3个GPIO口141-111-120【方法一】

20240619在飞凌OK3588-C的LINUX系统启动的时候拉高3个GPIO口141-111-120【方法一】 2024/6/19 16:12 缘起&#xff1a;在凌OK3588-C的LINUX R4系统启动的时候&#xff0c;需要拉高GPIO4_B5、GPIO3_B7和GPIO3_D0。 修改rcS&#xff0c;在系统启动的时候&#xff0c;即可拉高。 通…

通信系统的最佳线性均衡器(1)---维纳滤波线性均衡

本篇文章是博主在通信等领域学习时&#xff0c;用于个人学习、研究或者欣赏使用&#xff0c;并基于博主对通信等领域的一些理解而记录的学习摘录和笔记&#xff0c;若有不当和侵权之处&#xff0c;指出后将会立即改正&#xff0c;还望谅解。文章分类在通信领域笔记&#xff1a;…

全新剧场app的独特功能

全新剧场App通过引入一系列独特功能&#xff0c;旨在提升用户体验、增加用户粘性并拓宽市场范围。以下是对这些功能的详细分析&#xff1a; 1、虚拟剧场导览&#xff1a; 功能概述&#xff1a;利用增强现实技术&#xff0c;为用户提供虚拟剧场导览体验。用户可以在App中启动这…

一文读懂Java线程状态转换

Java线程有哪些状态?状态如何转换? 线程可以拥有自己的操作栈、程序计数器、局部变量表等资源,它与同一进程内的其他线程共享该进程的所有资源。Java的线程有自己的生命周期,在 Java 中线程的生命周期中一共有 6 种状态。 NewRunnableBlockedWaitingTimed WaitingTerminat…

报表工具数据源的取数处理方式大对比

根据报表的需求&#xff0c;很多报表中的指标数据需要进行预处理&#xff0c;以满足快速抽取和展示的需要。对于帆软报表类似的产品&#xff0c;一般通过建立视图、合并数据表&#xff0c;形成直接应用于模板设计的数据集&#xff0c;报表直接和数据集进行交互、关联。当用户发…

AI时代中的模型安全保护,如何通过加密和许可管理保障AI模型的安全

在进入AI时代中&#xff0c;网络安全威胁和数字版权管理变得愈发复杂&#xff0c;保护AI数据模型变得至关重要。这些模型已成为企业核心竞争力的关键&#xff0c;尤其在医疗设备和工业自动化等高敏感领域。确保数据模型的安全性和完整性不仅是保护知识产权的必要措施&#xff0…

【Linux 基础】目录结构

Linux 的目录结构&#xff08;也称为文件系统结构&#xff09;是组织文件和目录的一种逻辑方式。每个文件和目录在文件系统中都有一个唯一的位置或路径。 Linux文件系统是整个操作系统的基础架构&#xff0c;对于系统的稳定运行、数据安全以及用户操作便捷性至关重要&#xff0…

全球AI视频技术竞赛加速:Runway即将推出更优更快的第三代AI视频模型|TodayAI

Runway即将在未来几天推出其更优更快的第三代AI视频模型&#xff0c;这是新一代模型中最小的一个。据公司透露&#xff0c;这款名为Gen-3的模型将带来“在真实度、一致性和动态效果上的重大提升”&#xff0c;同时在速度上也有显著的加快。 去年六月&#xff0c;Runway首次推出…

Redis 集群 - 数据分片算法

前言 广义的集群&#xff1a;只要是多个机器构成了一个分布式系统&#xff0c;都可以被称为集群。 狭义的集群&#xff1a;redis 的集群模式&#xff0c;这个集群模式下&#xff0c;主要是解决存储空间不足的问题。 Redis 集群 redis 采用主从结构&#xff0c;可以提高系统的可…

ABAP 搜索帮助F4IF_INT_TABLE_VALUE_REQUEST

F4IF_INT_TABLE_VALUE_REQUEST 一般用于在选择屏幕提供搜索帮助 可以看到设置的是物料与物料描述的对应关系&#xff0c;而且对话类型是立即显示值&#xff0c;所以才能够实现如上的效果 有两种搜索帮助,这里选择基本索引帮助即可 然后填上对应的文本表和字段即可 然后在选…

【非常实验】Android模拟x86_64系统——安装Alpine虚拟机

安卓是一款功能强大的操作系统,为什么不试试它的极限呢? 百无聊赖中,我发现了各种 Android 修补项目。这激起了我对 DevOps 的好奇心,促使我探索在该平台上运行容器。这种好奇心又把我带入了另一个兔子洞:在 Android 上运行虚拟机。这其中经历了许多曲折,也许以后有必要…

cs144 LAB1 基于滑动窗口的碎片字节流重组器

一.StreamReassembler.capacity 的意义 StreamReassembler._capacity 的含义&#xff1a; ByteStream 的空间上限是 capacityStreamReassembler 用于暂存未重组字符串片段的缓冲区空间 StreamReassembler.buffer 上限也是 capacity蓝色部分代表了已经被上层应用读取的已重组数…

计算机专业毕设-springboot论坛系统

1 项目介绍 基于SSM的论坛网站&#xff1a;后端 SpringBoot、Mybatis&#xff0c;前端thymeleaf&#xff0c;具体功能如下&#xff1a; 基本功能&#xff1a;登录注册、修改个人信息、修改密码、修改头像查看帖子列表&#xff1a;按热度排序、按更新时间排序、查看周榜月榜查…

棱镜七彩荣获CNNVD两项大奖,专业能力与贡献再获认可!

6月18日&#xff0c;国家信息安全漏洞库&#xff08;CNNVD&#xff09;2023年度工作总结暨优秀表彰大会在中国信息安全测评中心成功举办。棱镜七彩凭借在漏洞方面的突出贡献和出色表现&#xff0c;被授予“2023年度优秀技术支撑单位”与“2023年度最佳新秀奖”。 优秀技术支撑单…

Gobject tutorial 七

The GObject base class GObject是一个fundamental classed instantiatable type,它的功能如下&#xff1a; 内存管理构建/销毁实例set/get属性方法信号 /*** GObjectClass:* g_type_class: the parent class* constructor: the constructor function is called by g_object…

最新技术:跨境电商源码,应对多国市场需求,让您轻松开展全球业务!

随着全球化进程的不断推进&#xff0c;跨境电商已成为企业拓展国际市场的重要途径。为了满足不同国家和地区消费者不断增长的需求&#xff0c;跨境电商源码应运而生&#xff0c;为企业提供了便捷高效的全球化业务发展方案。 一、全球化运营的关键 跨境电商源码的核心功能在于…

极具吸引力的小程序 UI 风格

极具吸引力的小程序 UI 风格

小白速成AI大模型就看这份资源包

前言 在数字化浪潮席卷全球的今天&#xff0c;人工智能&#xff08;AI&#xff09;技术已成为推动社会进步的重要引擎。尤其是AI大模型&#xff0c;以其强大的数据处理能力和广泛的应用前景&#xff0c;吸引了无数人的目光。然而&#xff0c;对于初学者“小白”来说&#xff0…

ProtoBuf序列化协议简介

首先&#xff0c;常见的序列化方法主要有以下几种&#xff1a; TLV编码及其变体(tag, length, value)&#xff1a; 比如ProtoBuf。文本流编码&#xff1a;XML/JSON固定结构编码&#xff1a;基本原理是&#xff0c;协议约定了传输字段类型和字段含义&#xff0c;和TLV类似&…

MyBatis框架基础

文章目录 1 MyBatis概述2 MyBatis入门2.1 相关依赖2.2 properties配置文件2.3 预编译SQL 3 基本操作3.1 新增操作3.2 删除操作3.3 更新操作3.4 查询操作 4 动态SQL4.1 XML映射文件4.2 if/set/where标签4.3 foreach标签4.4 sql/include标签 5 参考资料 1 MyBatis概述 MyBatis是…