Transformer模型:Postion Embedding实现

news2024/11/16 4:29:18

前言

        这是对上一篇WordEmbedding的续篇PositionEmbedding。

视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

上一篇链接:Transformer模型:WordEmbedding实现-CSDN博客


正文

        先回顾一下原论文中对Position Embedding的计算公式:pos表示位置,i表示维度索引,d_model表示嵌入向量的维度,position分奇数列和偶数列。

        Position Embedding也是二维的,行数是训练的序列最大长度,列是d_model。首先定义position的最大长度,这里定为12,也就是训练中的长度最大值都是12。

max_position_len = 12

        这里先循环遍历得到pos,构造Pos序列,pos是从0到最大长度的遍历,决定行:

pos_mat = torch.arange(max_position_len)

        但是此时得到的是一维的,我们要将它转为二维矩阵的,也就是得到目标行数,使用.reshape()函数,这样就构造好了行矩阵

pos_mat = torch.arange(max_position_len).reshape((-1,1))

tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11]]) 

        接下来要构造列矩阵,构造 i 序列,首先是是2i/d_model部分,这里的8是因为我们设定的d_model=8,2是步长:

i_mat = torch.arange(0, 8, 2)/model_dim

        这时候再把分母的完整形式实现,幂次使用pow()函数:

i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)

tensor([   1.,   10.,  100., 1000.]) 

         此时就得到了列向量,这时候就有疑问了为什么列只有4列,我们的d_model不是8吗,应该有8列才对啊。这是因为区分了奇数列跟偶数列的计算,所以这里才要求步长为2生成的只有4列。

        先初始化一个max_position_len*model_dim的零矩阵(12*8),然后再分别使用sin和cos填充偶数列和奇数列:

pe_embedding_table = torch.zeros(max_position_len, model_dim)

pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)   # 从第0列到结束,步长为2,也就是填充偶数列
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)   # 从第1列到结束,步长为2,也就是填充奇数列

        得到的就是Position Embedding的权重矩阵了: 

        这下面采用的是使用nn.Embedding()的方法,得到的跟上面的结果还是一样的,只不过这里的pe_embedding是可以传入位置的,之后的调用就是这样得到的:

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)

         这里就要构造位置索引了:

src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)

        然后传入位置索引,就得到了src跟tgt的Position Embedding:

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

         这里我很疑惑的点是生成的结果src_pe_embedding跟tgt_pe_embedding内容是一样的,并且单个里面的一个内容也就是position embedding,刚入门听得我还是有点不太能理解。

src_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
tgt_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
src_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])
tgt_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])

 代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
 
# 句子数
batch_size = 2
 
# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10
 
# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12
 
# 模型的维度
model_dim = 8
 
# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)
 
#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)
 
# 构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)  
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)    
print(src_embedding)    
print(tgt_embedding)

# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))    
i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)

# 构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)    
pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)
print("pe_embedding_table is:\n",pe_embedding_table)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)

# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)
print("src_pos is:\n",src_pos)
print("tgt_pos is:\n",tgt_pos)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print("src_pe_embedding is:\n",src_pe_embedding)
print("tgt_pe_embedding is:\n",tgt_pe_embedding)

参考

Python的reshape的用法:reshape(1,-1)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923657.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何巧妙运用百川工作手机微信行为监控 防范员工离职带走客户

在竞争日益激烈的商业环境中,企业最宝贵的资产莫过于忠诚的客户群体与高效协作的团队。然而,当团队中不可避免地出现人员流动时,如何有效防止客户资源流失,成为众多企业管理者面临的严峻挑战。百川工作手机,作为一款专…

基于Redisson 实现 Redis 分布式锁

代码示例: GetMapping("/testJmeter")public void testJmeter() {synchronized (this){int stock Integer.parseInt(stringRedisTemplate.opsForValue().get("stock"))if (stock > 0) {int realStock stock - 1;stringRedisTemplate.opsFo…

【组件库】element-plus组件库

文章目录 0. 启动项目1. gc.sh 新增组件2. 本地验证(组件注册的方式)3. 官方文档修改3-1. 左侧菜单3-2 . 配置md文档3-3. 代码问题:文档修改----------------------------------------------4. 将naiveui的split 分割组件【 复制、迁移】到 element-ui-plus组件库4.1 naiveu…

Science Advances 仿生双模态触觉感知

研究背景 触觉感知在人类收集信息和接收周围环境反馈中扮演着至关重要的角色。随着人工智能的发展,具有类似人类感知能力的智能机器人受到越来越多的关注。现有的触觉传感器能够感知接触前的刺激和压力大小,但它们在区分物体类型、评估柔软度和量化杨氏…

go-高效处理应用程序数据

一、背景 大型的应用程序为了后期的排障、运营等,会将一些请求、日志、性能指标等数据保存到存储系统中。为了满足这些需求,我们需要进行数据采集,将数据高效的传输到存储系统 二、问题 采集服务仅仅针对某个需求开发,需要修改…

Docker容器的生命周期

引言 Docker 容器作为一种轻量级虚拟化技术,在现代应用开发和部署中扮演着重要角色。理解容器的生命周期对于有效地管理和运维容器化应用至关重要。本文将深入探讨 Docker 容器的生命周期,从创建到销毁的各个阶段,帮助读者更好地掌握容器管理…

分手后如何走出夜晚的抑郁,告别失眠困扰?

在这个快速变化的世界里,分手成为了许多人生活中不得不面对的现实。而每当夜幕降临,那种难以言表的孤独感和深深的抑郁往往让人倍感煎熬,甚至陷入失眠的漩涡。那么,分手后我们该如何应对这种情绪困扰,重新找回自己的宁…

防火墙NAT和智能选路实验详解(华为)

目录 实验概述实验拓扑实验要求要求一要求二要求三要求四要求五 实验概述 从我上面一个博客能够了解到NAT和防火墙选路原理 ——>防火墙nat和智能选路,这一章我通过实验来详解防火墙关于nat和智能选路从而能熟练使用和配置防火墙,这里使用的是华为US…

lvs集群、NAT模式和DR模式、keepalive

目录 lvs集群概念 集群的类型:三种类型 系统可靠性指标 lvs集群中的术语 lvs的工作方式 NAT模式 lvs的工具 算法 实验 数据流向 步骤 一 、调度器配置(test1 192.168.233.10) 二、RS配置(nginx1和nginx2)…

Android:如何绘制View

点击查看Android 如何绘制视图官网 一、简介 Android 框架会在 Activity 获得焦点时请求 Activity 绘制其布局。Android 框架会处理绘制流程,但该 Activity 必须提供其布局层次结构的根节点。 Android 框架会绘制布局的根节点,并测量和绘制布局树。它会…

【每日一练】python类和对象现实举例详细讲解

""" 本节课程目的: 1.掌握类描述现实世界实物思想 2.掌握类和对象的关系 3.理解什么事面向对象 """ #比如设计一个闹钟,在这里就新建一个类 class Clock:idNone #闹钟的序列号,也就是类的属性priceNone #闹…

Redis学习笔记(个人向)

Redis学习笔记(个人向) 1. 概述 是一个高性能的 key-value 数据库;其具有以下三个特点: Redis支持数据的持久化,可以将内存中的数据保存在磁盘中,重启的时候可以再次加载进行使用。Redis不仅仅支持简单的key-value类型的数据&…

Nginx+Keepalive调度的高可用

nginxkeepalive: 调度器的高可用 vip地址主备之间的切换,主在工作时,p地址只在主上,主停止工作,p飘移到备服务器。 在主备的优先级不变的情况下,主恢复工作,vp会飘回到主服务器。 1、配优先级 2、配置…

EventBus学习

视频:05_尚硅谷_EventBus_粘性事件案例_哔哩哔哩_bilibili 1.整体框架 2.demo下载地址:https://github.com/greenrobot/EventBus 3.实现非粘性时间流程: 3.1导入架包eventbus-3.0.0.jar和eventbus-3.0.0-sources.jar 3.2在接受数据页面注…

k8s(五)---名称空间

五、名称空间 名称空间是k8s划分不同工作空间的逻辑单位,是k8s资源逻辑隔离的机,。可以给不同的租户,不同的环境、不同的项目创建对应的命名空间。 1、查看名称空间 kubectl get ns kubectl get namespaces 此处展示了四个命名空间 2、管理名称空间 1…

更新商品前端接口编写

文章目录 新增页面书写写表单价格符号的显示然后状态的书写后端枚举书写时间书写使用组件 新增页面书写 书写直接复制页面 写表单的绑定信息 然后绑定表单 表单绑定还有表单数据的绑定 标签中ref的作用就是将 该组件注册到vue对象的ref属性中 那么在vue运行的时候,会加载所…

IOC、DI<4> Unity、AOP、MVCAOP、UnityAOP 区别

IOC():控制反转,把程序上层对下层的依赖,转移到第三方的容器来装配 是程序设计的目标,实现方式包含了依赖注入和依赖查找(.net里面只有依赖注入) DI:依赖注入&#xff0c…

【网络文明】关注网络安全

在这个数字化时代,互联网已成为我们生活中不可或缺的一部分,它极大地便利了我们的学习、工作、娱乐乃至日常生活。然而,随着网络空间的日益扩大,网络安全问题也日益凸显,成为了一个不可忽视的全球性挑战。认识到网络安…

Gitee简易使用流程(后期优化)

目录 1.修改用户名 2.文件管理 新建文件/文件夹流程如下: 上传文件流程如下: 以主页界面为起点 1.修改用户名 点解右上角的头像--> 点击“账号设置” 点击左边栏里的“个人资料“ 直接修改用户名即可 2.文件管理 选择一个有修改权限仓库&#…

【轻松拿捏】Java-final关键字(面试)

目录 1. 定义和基本用法 回答要点: 示例回答: 2. final 变量 回答要点: 示例回答: 3. final 方法 回答要点: 示例回答: 4. final 类 回答要点: 示例回答: 5. final 关键…