Transformer step by step--Positional Embedding 和 Word Embedding

news2024/12/27 11:03:29

Transformer step by step往期文章:

Transformer step by step--层归一化和批量归一化 

要把Transformer中的Embedding说清楚,那就要说清楚Positional EmbeddingWord Embedding。至于为什么有这两个Embedding,我们不妨看一眼Transformer的结构图。

从上图可以看到,我们的输入需要在Input EmbeddingPositional Encoding的共同作用下才会分别输入给EncoderDecoder,所以我们就分别介绍一下怎么样进行Input EmbeddingPositional Encoding

同时为了帮助大家更好地理解这两种Embedding方式,我们这里生成一个自己的迷你数据集。

import tiktoken
import torch
import torch.nn as nn
encoding = tiktoken.get_encoding("cl100k_base") #导入openai的开源tokenizer库
context_length = 4 #选取4个token
batch = 4 # 批处理大小
example_text1 = "I am now writing an example to show the usage of word embedding."
example_text2 = "I am now writing an sentence to show the usage of another word embedding."
total_text = example_text1 + example_text2 #形成总数据集
tokenize_text = encoding.encode(total_text) #进行tokenize
tokenize_text = torch.tensor(tokenize_text, dtype=torch.long) #这里转换成tentor是因为后续我们要用pytorch的框架
idxs = torch.randint(low = 0,high = len(tokenize_text) - context_length,size = (batch,))#这里我们随机选batch个id
x_batch = torch.stack([tokenize_text[idx:idx + context_length] for idx in idxs]) #根据id和context_length抽取训练数据

一、Word Embedding

这里的Word Embedding就是论文中提到的Input Embedding。我们在之前的文章中已经介绍使用tokenizer将原始的文本信息转换为数字,便于输入进模型。 可是使用tokenizer有一个问题,这个问题就是我们虽然用不同的ID表示了不同的子词,但是这些ID所能蕴含的语义信息非常有限,比如dog和dogs这两个单词语义非常相近,但很有可能它们的ID相隔非常远,所以为了更好地体现不同单词之间的语义关系,我们将每个ID通过Word Embedding的方式变为一个向量。

这里的实现在Pytorch框架之下变得非常简单,只需要一行代码就可以搞定,但是我们这里还是详细对代码的参数进行一些讲解。

max_token_value = tokenize_text.max().item()
d_model = 16
input_embedding_lookup_table = nn.Embedding(max_token_value + 1, d_model)
x_batch_embedding = input_embedding_lookup_table(x_batch)
print(max_token_value)
#
40188

这里的第一、二行我们一起讲:

max_token_value = tokenize_text.max().item()是取出当前token中ID最大的那个数。

input_embedding_lookup_table = nn.Embedding(max_token_value + 1, d_model)是根据最大的ID去构建Embedding层。

首先第一个问题,nn.Embedding的两个参数是什么意思?

首先第一个参数我们这里使用的是 max_token_value + 1,这个是用来表示我们词汇库的最大长度。那这里可能又有一个新的问题,就是我们上面的句子,算上标点符号一共也就十几个词,为什么我们这里要用40188+1作为这个词汇库的最大长度呢?这是因为我们这里用的tokenizer是openai的第三方库,这个库在做tokenize的时候对应着大量的原始文本,而我们example_text中的文本在经过tokenize之后,最大的token ID对应的是这个库中的原始文本的40188。那么这里还有第二个问题,这里返回的是40188,我们为什么要加上1呢?因为tokenize之后的ID是从0开始算的,也就是说40188对应的词汇表的最大长度应该是40189。

第二个参数我们使用的是d_model,这个参数会好理解一些。我们之前说过,一个ID没有什么语义信息,但变成向量之后就可以通过余弦相似度计算两个向量之间的相关性。那么这里的一个问题就在于,我用多少维的向量去表示呢?d_model这个参数就是来解决这个事情,我们想让ID变成多少维的向量,就把d_model设置成多少。

那么第三行,就是我们讲原来形状为4 * 4的x_batch变为了4 * 4 * 16 的x_batch_embedding。后面多出来的16就是我们自己设置的嵌入的维度。

二、Positional Embedding

总体来说,word_embedding还是比较通俗易懂的,接下来我们根据论文当中的公式去写一下Positional Encoding,也就是Positional Embedding(同一个意思)。

这里我们解释一下这两行公式啥意思,Positional Embedding简单来说,就是给每个token分配一个位置信息,因为 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 本身无法判断不同token所在的位置。PE对应Positional Encoding的缩写,括号中的pos对应我们设置的context length的长度,2i对应嵌入维度中的偶数维度,2i+1对应嵌入维度中的奇数维度。

接下来我们就来实现一下相关的代码:

positional_encoding = torch.zeros(context_length, d_model) #首先初始化一个和token形状大小一样的positional encoding
positional = torch.arange(0, context_length).float().unsqueeze(1) #按照我们设置的context length去初始化position
_2i = torch.arange(0, d_model, 2) # 这里用生成d_model/2的步长,因为sin和cos两个加起来就变成了d_model
positional_encoding[:, 0::2] = torch.sin(torch.exp(positional/10000**(_2i/d_model))) #按照公式写一遍
positional_encoding[:, 1::2] = torch.cos(torch.exp(positional/10000**(_2i/d_model)))
positional_encoding = positional_encoding.squeeze(0).expand(batch, -1, -1) #最终根据batch的数量对维度进行扩充
input_x = x_batch_embedding + positional_encoding # 将word embedding和positional embedding相加得到模型的输入
print(positional_encoding.shape)
print(input_x.shape)
##
torch.Size([4, 4, 16]) 
torch.Size([4, 4, 16]) 

到这里,我们也就完成了两个embedding操作。

知乎原文链接:安全验证 - 知乎知乎,中文互联网高质量的问答社区和创作者聚集的原创内容平台,于 2011 年 1 月正式上线,以「让人们更好的分享知识、经验和见解,找到自己的解答」为品牌使命。知乎凭借认真、专业、友善的社区氛围、独特的产品机制以及结构化和易获得的优质内容,聚集了中文互联网科技、商业、影视、时尚、文化等领域最具创造力的人群,已成为综合性、全品类、在诸多领域具有关键影响力的知识分享社区和创作者聚集的原创内容平台,建立起了以社区驱动的内容变现商业模式。icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/691169616

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1624768.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3.1设计模式——Chain of Responsibility 责任链模式(行为型)

意图 使多个对象都有机会处理请求,从而避免请求的发送者和接受者之间的耦合关系,将这些对象练成一条链,并沿着这条链传递请求,直到有一个对象处理它为止。 实现 其中 Handle定义一个处理请求的接口:(可选…

Java中的ArrayList集合

特点: ArrayList中的一些方法: 1、add(Object element):向集合的末尾添加元素 add(int index,Object element):在列表的指定位置(从0开始)插入指定元素 2、size():返回列表的中的元素个数 3、get(int index):返回下标为index位置的…

鸿蒙应用ArkTS开发- 选择图片、文件和拍照功能实现

前言 在使用App的时候,我们经常会在一些社交软件中聊天时发一些图片或者文件之类的多媒体文件,那在鸿蒙原生应用中,我们怎么开发这样的功能呢? 本文会给大家对这个功能点进行讲解,我们采用的是拉起系统组件来进行图片…

【注解和反射】性能分析

继上一篇博客【注解和反射】通过反射动态创建对象、调用普通方法、操作属性-CSDN博客 目录 九、性能分析 代码分析 完整代码 分析结果 九、性能分析 代码分析 (1)这部分代码是一个简单的基准测试,它测量了在Java中使用普通方式调用一个…

解决Jmeter 4.x 请求到elasticsearch 中文乱码的问题

文章目录 前言解决Jmeter 4.x 请求到elasticsearch 中文乱码的问题 前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差,实在白嫖的话&#…

Android kotlin 协程异步async与await介绍与使用

一、介绍 在kotlin语言中,协程是一个处理耗时的操作,但是很多人都知道同步和异步,但是不知道该如何正确的使用,如果处理不好,看似异步,其实在runBloacking模块中使用的结果是同步的。 针对如何同步和如何异…

从 MySQL 到 ClickHouse 实时数据同步 —— Debezium + Kafka 表引擎

目录 一、总体架构 二、安装配置 MySQL 主从复制 三、安装配置 ClickHouse 集群 四、安装 JDK 五、安装配置 Zookeeper 集群 六、安装配置 Kafaka 集群 七、安装配置 Debezium-Connector-MySQL 插件 1. 创建插件目录 2. 解压文件到插件目录 3. 配置 Kafka Connector …

快速搭建产品原型

今天要推荐的是一款能协助大家快速搭建产品原型的网站。对应产品经理来说,还真是个利器。 墨刀 https://modao.cc/brand 让想法快速呈现。 画APP原型图,很方便。 快速创建第一个原型文件。 在免费模版上进行修改。 感兴趣的同学去使用体验下吧~ 《Androi…

LabVIEW专栏七、队列

目录 一、队列范例二、命令簇三、队列应用1.1、并行循环队列1.2、命名队列和匿名队列1.2.1、命名队列1.2.2、匿名队列 1.3、长度为1的队列 队列是一种特殊的线性表,就是队列里的元素都是按照顺序进出。 队列的数据元素又称为队列元素。在队列中插入一个队列元素称为…

mysql reset slave reset master

mysql reset slave reset master 1、问题背景2、问题分析3、解决方法3.1、锁定主库,手动同步主库数据到从库,使得主从数据库数据一致3.1、从机执行stop slave、reset slave3.2、从机上再次指定主机的binlog文件名和偏移量3.3、从机执行 start slave3.4、…

蓝牙低能耗安全连接 – 数值比较

除了 LE Legacy 配对之外,LE Secure Connections 是另一种配对选项。 LE 安全连接是蓝牙 v4.2 中引入的增强安全功能。它使用符合联邦信息处理标准 (FIPS) 的算法(称为椭圆曲线 Diffie Hellman (ECDH))来生成密钥。对于 LE 安全连接&#xff…

MMSeg搭建模型的坑

Input type(torch.suda.FloatTensor) and weight type (torch.FloatTensor) should be same 自己搭建模型的时候,经常会遇到二者不匹配,以这种情况为例,是因为部分模型没有加载到CUDA上面造成的。 注意搭建模型的时候,所有层都应…

汽车企业安全上网解决方案

需求背景 成立于1866年的某老牌汽车服务独立运营商,目前已经是全球最大的独立汽车服务网络之一,拥有95年的历史,在全球150多个国家拥有17,000多个维修站,始终致力于为每一位车主提供高品质,可信赖的的专业汽车保养和维…

win10加入域环境

win10加入域环境 导航 文章目录 win10加入域环境导航一、关闭防火墙二、使客户端的电脑指向于域控服务器三、检验是否加入了域 一、关闭防火墙 在进行加入域服务之前,我们需要先关闭防火墙(为了不必要的麻烦) 按 winr调出运行窗口,输入 control打开控制面板 点击系统和安全点…

42. UE5 RPG 实现火球术伤害

上一篇,我们解决了火球术于物体碰撞的问题,现在火球术能够正确的和攻击目标产生碰撞。接下来,我们要实现火球术的伤害功能,在火球术击中目标后,给目标造成伤害。 实现伤害功能的思路是给技能一个GameplayEffect&#x…

JAVA毕业设计136—基于Java+Springboot+Vue的房屋租赁管理系统(源代码+数据库)

毕设所有选题: https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootVue的房屋租赁管理系统(源代码数据库)136 一、系统介绍 本项目前后端分离,分为管理员、用户、工作人员、房东四种角色 1、用户/房东: …

正态性检验

t检验、方差分析(ANOVA)等参数检验都有一个共同的前提条件:样本数据必须服从正态分布,即样本数据必须来源于一个正态分布的总体,若样本数据不服从正态分布,就不能用以上参数检验对数据进行分析,…

OpenCV鼠标绘制线段

鼠标绘制线段 // 鼠标回调函数 void draw_circle(int event, int x, int y, int flags, void* param) {cv::Mat* img (cv::Mat*)param;if (event cv::EVENT_LBUTTONDBLCLK){cv::circle(*img, cv::Point(x, y), 100, cv::Scalar(0, 0, 255), -1);} }// 鼠标回调函数 void dra…

.NET 个人博客-添加RSS订阅功能

个人博客-添加RSS订阅功能 前言 个人博客系列已经完成了 留言板文章归档推荐文章优化推荐文章排序 博客地址 然后博客开源的原作者也是百忙之中添加了一个名为RSS订阅的功能,那么我就来简述一下这个功能是干嘛的,然后照葫芦画瓢实现一下。 RSS简述…

SpringBoot+RabbitMQ实现MQTT协议通讯

一、简介 MQTT(消息队列遥测传输)是ISO 标准(ISO/IEC PRF 20922)下基于发布/订阅范式的消息协议。它工作在 TCP/IP协议族上,是为硬件性能低下的远程设备以及网络状况糟糕的情况下而设计的发布/订阅型消息协议,为此,它需要一个消息中间件 。此…