Transformer的代码实现 day03(Positional Encoding)

news2025/1/12 6:12:01

Positional Encoding的理论部分

  • 注意力机制是不含有位置信息,这也就表明:“我爱你”,“你爱我”这两者没有区别,而在现实世界中,这两者有区别。
  • 所以位置编码是在进行注意力计算之前,给输入加上一个位置信息,如下图:
    在这里插入图片描述
  • 位置编码的公式如下:
    • 注意,pos表示该单词在句子中的位置,i表示该单词的输入向量的第i维度
      在这里插入图片描述
  • 由此我们可以得出不同位置之间的位置编码关系:
    在这里插入图片描述

Positional Encoding代码

  • 由于位置编码的公式固定,所以对于相同位置的位置编码也固定,即“我爱你”中的我,和“你爱我”中的你的位置编码相同
  • 所以我们可以一次将所有要输入信息的位置编码都生成出来,之后需要哪个就传哪个
class PositionalEncoding(nn.Module):

    def __init__(self, dim, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
		# 确保每个单词的输入维度为偶数,这样sin和cos能配对
        if dim % 2 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dim (got dim={:d})".format(dim))

        """
        构建位置编码pe
        pe公式为:
        PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
        """
        pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10,dim为单词的输入维度
       
        # 将位置序号从一维变为只有一列的二维,方便与div_term进行运算,
        # 如将[0, 1, 2, 3, 4]变为:
        #[  
		#  [0],  
		#  [1],  
 		#  [2],  
		#  [3],  
		#  [4]  
		#]
        position = torch.arange(0, max_len).unsqueeze(1)
       
        # 这里使用a^b = e^(blna)公式,来简化运算
        # torch.arange(0, dim, 2, dtype=torch.float)表示从0到dim-1,步长为2的一维张量
        # 通过以下公式,我们可以得出全部2i的(pos/10000^2i/dim)方便接下来的pe计算
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                              -(math.log(10000.0) / dim)))
		# 得出的div_term为从0开始,到dim-1,长度为dim/2,步长为2的一维张量
		# 将position与div_term做张量乘法,得到的张量形状为(max_len,dim/2)
		# 将结果取sin赋给pe中偶数维度,取cos赋给pe中奇数维度
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        # 将pe的形状从(max_len,dim)变成(max_len,1,dim),在第二个维度上增加一个大小为1的新维度
        # 如从原始 pe 张量形状: (5, 4)  
		#[  
		# [a1, b1, c1, d1],  
		# [a2, b2, c2, d2],  
		# [a3, b3, c3, d3],  
		# [a4, b4, c4, d4],  
		# [a5, b5, c5, d5]  
		#]
		# 转换为:执行 unsqueeze(1) 后的 pe 张量形状: (5, 1, 4)  
		#[  
		# [[a1, b1, c1, d1]],  
		# [[a2, b2, c2, d2]],  
		# [[a3, b3, c3, d3]],  
		# [[a4, b4, c4, d4]],  
		# [[a5, b5, c5, d5]]  
		#]
        pe = pe.unsqueeze(1)
        # 将pe张量注册为模块的buffer。在PyTorch中,buffer是模型的一部分,但不包含可学习的参数(即不需要梯度)。
        # 这样做是因为位置编码在训练过程中是固定的,不需要更新。
        self.register_buffer('pe', pe)
        self.drop_out = nn.Dropout(p=dropout)
        self.dim = dim

    def forward(self, emb, step=None):
		# 做乘法是因为在 Transformer 模型中,位置编码被加到输入张量中,而输入张量通常是词嵌入的向量,其值通常在较小的范围内。
		# 但是,在将位置编码添加到输入张量之前,我们希望将其值扩大到一个较大的范围,以便位置编码对输入的影响更加显著。
		# 注意:emb为输入张量,形状为(seq_len, dim),seq_len 表示输入的句子的长度,dim为单词的输入维度
        emb = emb * math.sqrt(self.dim)
		# 根据step来选择加入pe的哪一部分
        if step is None:
        # 如果pe的形状为(max_len, dim),那么pe[:a]表示:取pe的第0行到第a-1行的全部元素,得到的新二维张量的形状为(a, dim)
        # 而pe[:, a]表示:取pe的第a-1列的全部元素,得到的新一维张量的形状为(max_len)
        # 而pe[:, :a]表示:取pe的第0列到第a-1列的全部元素,得到的新二维张量的形状为(max_len,a)
            emb = emb + self.pe[:emb.size(0)]
        else:
            emb = emb + self.pe[step]
        emb = self.drop_out(emb)
        return emb

参考文献

  1. 04 Transformer 中的位置编码的 Pytorch 实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1565865.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

真·面试题总结——JVM虚拟机

JVM虚拟机 JVM虚拟机规范与实现 JVM虚拟机规范 JVM虚拟机实现 JVM的常见实现 JVM虚拟机物理架构 JVM虚拟机的运转流程 JVM类加载过程 JVM类加载器及类加载器类型 JVM类加载器双亲委派机制 JVM运行时数据区的内存模型 JVM运行时数据区的内存模型:程序计数器…

使用OpenCV4.9的随机生成器和文本

返回:OpenCV系列文章目录(持续更新中......) 上一篇:OpenCV 4.9基本绘图 下一篇:OpenCV系列文章目录(持续更新中......) 目标 在本教程中,您将学习如何: 使用随机数生…

Java中的可变字符串

Java中的可变字符串 一、什么是可变字符串二、可变字符串的使用场景以及使用步骤1.新建一个可变字符串2.可变字符串的一系列方法 一、什么是可变字符串 可变字符串是Java.lang包下的 在我们学习到JDBC的时候需要将原有的sql语句根据不同的差异添加一段新的关键字或者单词&…

C语言_第一轮笔记_指针

8.1 密码开锁 地址和指针 一般以变量所在的内存单元的第一个字节的地址作为他的地址NULL的值为0,代表空指针 指针变量的定义 类型名 *指针变量名类型名指定指针变量所指向变量的类型指针声明符*在定义指针变量时被使用,说明被定义的那个变量是指针指针变…

护眼台灯十大排名品牌有哪些?2024护眼台灯十大排名品牌推荐

在当今的教育环境中,学生们面临着相当沉重的学业压力。放学后,许多孩子便投入到无休止的作业之中,常常夜深人静时还未完成。作为家长,孩子的视力健康自然成为了我们心中的一块大石。夜间学习时,灯光的质量至关重要。标…

批量转换图片神器,支持tiff图片转换成png格式,图片高效转换

在数字图像处理领域,格式转换一直是关键且必要的环节。尤其对于设计师、摄影师、网站开发者等专业人士来说,能够快速、高效地将图片从一种格式转换为另一种格式,是提升工作效率和保障项目质量的关键。今天,我们荣幸地向您推荐一款…

低压配电室数字孪生实现区域内的无人值守

众所周知,电力设备的精益管控、精益检修与精益维护对于电网智慧化转型的重要性。因此数字孪生公司深圳华锐视点利用精湛的数字孪生、虚拟仿真、3D建模和图形图像技术,集成数据采集、监控预警、计划维护、数据分析、决策支持等核心模块,为电力…

从零开始构建gRPC的Go服务

介绍 Protocol Buffers and gRPC是用于定义通过网络有效通信的微服务的流行技术。许多公司在Go中构建gRPC微服务,发布了他们开发的框架,本文将从gRPC入门开始,一步一步构建一个gRPC服务。 背景 之前在B站看过一个gRPC教学视频,…

Linux进程控制(改)

Linux进程控制 进程 内核数据结构(struct task_struct,struct mm_struct,页表) 代码和数据 在Linux中fork函数时非常重要的函数,它从已存在进程中创建一个新进程。新进程为子进程,而原进程为父进程 1.进程创建 ./程序fork&am…

GPS坐标转换为百度地图坐标并显示到百度地图上

百度地图有个坐标识取系统:https://api.map.baidu.com/lbsapi/getpoint/index.html,打开链接如下: 如上图,可以搜索某一个位置,然后会出现该位置的许多选择,选择一个就会显示出对应的百度地图的坐标&#x…

抖音小店正确的起店方法是什么?别再闭门造车了,快来学习!

大家好,我是电商糖果 随着抖音卖货的火爆的,开抖音小店的商家也越来越多。 很多没有电商经验的朋友就发现,想要起店非常难。 有的好一两个月了,都不出单。 糖果做抖音小店有四年时间了,也经营了多家小店。 这里就…

从零开始:如何进入IT行业

微信扫码体验我自己做的小程序(很有意思哦~~【坏笑】): 随着科技的飞速发展,IT行业已经成为了许多人梦寐以求的职业之一。不过,对于那些没有任何相关经验或技能的人来说,进入这个领域…

坦克大战_java源码_swing界面_带毕业论文

一. 演示视频 坦克大战_java源码_swing界面_带毕业论文 二. 实现步骤 完整项目获取 https://githubs.xyz/y22.html 部分截图 启动类是 TankClinet.java,内置碰撞检测算法,线程,安全集合,一切皆对象思想等,是java进阶…

filetype: python中判断图像格式库imghdr替代库

引言 imghdr库是python中的一个内置库,用来判断图像原本格式的。自己一直有在用,不过近来看到这个库在python 3.13中会被移除。 自己感觉一直被python版本赶着走。这不找了好久,才找到一个替代库–filetype Python各个版本将要移除和可替代…

IoT数采平台4:测试

IoT数采平台1:开篇IoT数采平台2:文档IoT数采平台3:功能IoT数采平台4:测试 Modbus RTU串口测试 OPC测试 HTTP测试 MQTT透传测试 MQTT网关测试及数据上报 TCP / UDP 监听,客户端连上后发送信息,客户端上报数据…

C语言杂谈

努力扩大自己,以靠近,以触及自身以外的世界 文章目录 什么是定义?什么是声明?什么是赋值?什么是初始化?什么是生命周期?什么是作用域?全局变量?局部变量?size…

MySQL数据库(数据库连接池)

文章目录 1.批处理应用1.基本介绍2.批处理演示1.创建测试表2.修改url3.编写java代码 3.批处理源码分析 2.数据库连接池1.传统连接弊端分析2.数据库连接池基本介绍1.概念介绍2.数据库连接池示意图3.数据库连接池种类 3.C3P0连接池1.环境配置1.导入jar包2.将整个lib添加到项目中3…

云存储属性级用户撤销可追溯的密文策略属性加密方案论文阅读

参考文献为2018年发表的Traceable ciphertext-policy attribute-based encryption scheme with attribute level user revocation for cloud storage 贡献 本篇路提出了一个可追踪、实现属性级用户撤销(删除用户的某一属性)、支持密钥更新和密文更新、外…

图片二维码如何制作生成?常规图片格式的二维码制作技巧

图片是展示信息很常用的一种方式,而现在查看图片很多人会通过二维码的形式来展现,这种方式优势在于更加的灵活,能够通过一个二维码展示大量的图片内容。那么图片二维码是如何制作生成的呢? 想要快速的将图片转二维码使用&#xf…

mysql-FIND_IN_SET包含查询

如图所示,需要查询字段ancestorid中包含14的所有数据,使用FIND_IN_SET即可实现,不需要使用模糊查找like 示例sql: SELECT * FROM mt_fire_template WHERE FIND_IN_SET(14,ancestorid) 结果