transfomer中正余弦位置编码的源码实现

news2025/2/24 21:15:19

简介

Transformer模型抛弃了RNN、CNN作为序列学习的基本模型。循环神经网络本身就是一种顺序结构,天生就包含了词在序列中的位置信息。当抛弃循环神经网络结构,完全采用Attention取而代之,这些词序信息就会丢失,模型就没有办法知道每个词在句子中的相对和绝对的位置信息。因此,有必要把词序信号加到词向量上帮助模型学习这些信息,位置编码(Positional Encoding)就是用来解决这种问题的方法。
关于位置编码更多介绍参考bev感知专栏的博客

源码实现:

import torch
import matplotlib.pyplot as plt


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)


def posemb_sincos_1d(len, dim, temperature: int = 10000, dtype=torch.float32):
    x = torch.arange(len)
    assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb"
    omega = torch.arange(dim // 2) / (dim // 2 - 1)
    omega = 1.0 / (temperature ** omega)

    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos()), dim=1)  # 这里不用担心,不交叉无所谓,
    return pe.type(dtype)


if __name__ == '__main__':
    pos = posemb_sincos_1d(200, 256)
    # pos = posemb_sincos_2d(20,20,256)

    # 创建一个热力图
    plt.imshow(pos, cmap='hot', interpolation='nearest')
    # 添加颜色条
    plt.colorbar()
    # 显示图形
    plt.show()
    pass

可视化结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1391465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

子域名的介绍及收集

1、子域名作用编辑 收集子域名可以扩大测试范围,同一域名下的二级域名都属于目标范围。 2、 常用方式编辑 子域名中的常见资产类型一般包括办公系统,邮箱系统,论坛,商城,其他管理系统,网站管理后台也有可…

获得店铺的所有商品 API、店铺列表api

taobao.item_search_shop 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,item_get,item_search_shop等]cacheStrin…

transfomer的位置编码

什么是位置编码 在transformer的encoder和decoder的输入层中,使用了Positional Encoding,使得最终的输入满足: input_embeddingpositional_encoding 这里,input_embedding的shape为[n,b,embed_dim],positional_encoding和input_…

推荐一个页面引导库 driver.js

页面引导功能是 web 开发中常见的一个功能。通过页面引导功能,你可以让用户第一时间熟悉你的页面功能。今天给大家推荐一个页面引导库 driver.js。 简介 driver.js 是一款用原生 js 实现的页面引导库,上手非常简单,体积在 gzip 压缩下仅仅 5…

《手把手教你》系列技巧篇(十)-java+ selenium自动化测试-元素定位大法之By class name(详细教程)

1.简介 按宏哥计划,本文继续介绍WebDriver关于元素定位大法,这篇介绍By ClassName。看到ID,NAME这些方法的讲解,小伙伴们和童鞋们应该知道,要做好Web自动化测试,最好是需要了解一些前端的基本知识。有了前端…

DDOS攻击,一篇文章给你讲清!

1、互联网安全现状 随着网络世界的高速发展,各行业数字化转型也在如火如荼的进行。但由于TCP/IP网络底层的安全性缺陷,钓鱼网站、木马程序、DDoS攻击等层出不穷的恶意攻击和高危漏洞正随时入侵企业的网络,如何保障网络安全成为网络建设中的刚…

如何实现扫码填报信息,并且可以做统计和导出excel

日常工作中经常遇到需要收集信息的情况,如果能实现扫一下二维码,就可以直接填写信息,不用登录,不用开账号,填写完直接可以生成excel,那就非常好了。 我试用了很多平台,有的收费的,也…

VL53L5CX距离传感器

I2C接口的飞行时间多区测距传感器 意法半导体VL53L5CX是一款先进的飞行时间 (ToF) 多区域测距传感器 VL53L5CX 采用意法半导体最新一代的直接 ToF 技术,无论目标颜色和反射率如何,都可以进行绝对距离测量。它提供高达 400 cm的精确测距,并且…

transbigdata笔记:栅格参数优化

在transbigdata中,栅格参数有如下几个 params(lonStart,latStart,deltaLon,deltaLat,theta) 如何选择合适的栅格参数是很重要的事情,这会对最终的分析结果产生很大的影响。 怎么选择参数,和数据以及分析的目的息息相关,transbi…

C语言爬虫程序编写的爬取APP通用模板

互联网的飞快发展,尤其是手机终端业务的发展,让越来越多的事情都能通过手机来完成,电脑大部分的功能也都能通过手机实现,今天我就用C语言写一个手机APP类爬虫教程,方便后期拓展APP爬虫业务。而且这个模板是通用的适合各…

【PyTorch】在PyTorch中使用线性层和交叉熵损失函数进行数据分类

在PyTorch中使用线性层和交叉熵损失函数进行数据分类 前言: 在机器学习的众多任务中,分类问题无疑是最基础也是最重要的一环。本文将介绍如何在PyTorch框架下,使用线性层和交叉熵损失函数来解决分类问题。我们将以简单的Iris数据集作为起点…

Matlab交互式的局部放大图

在数据可视化中,很多时候需要对某一区间的数据进行局部放大,以获得对比度更高的可视化效果。下面利用 MATLAB 语言实现一个交互式的局部放大图绘制。 源码自行下载: 链接:https://pan.baidu.com/s/1yItVSinh6vU4ImlbZW6Deg?pwd9d…

使用 Python 创造你自己的计算机游戏(游戏编程快速上手)第四版:第十九章到第二十一章

十九、碰撞检测 原文:inventwithpython.com/invent4thed/chapter19.html 译者:飞龙 协议:CC BY-NC-SA 4.0 碰撞检测涉及确定屏幕上的两个物体何时相互接触(即发生碰撞)。碰撞检测对于游戏非常有用。例如,如…

iphone 5s的充电时序原理图纸,iPAD充电讲解

上一篇写了iphone 5的时序。那是电池供电的开机时序。iphone 5s也是差不多的过程,不说了。现在看iphone5s手机充电时候的时序。iphone5s充电比iphone5充电简单了很多。 首先是usb接口接到手机上,usb线连接到J7接口上。J7接口不只是接usb,还能…

ZooKeeper 实战(五) Curator实现分布式锁

文章目录 ZooKeeper 实战(五) Curator实现分布式锁1.简介1.1.分布式锁概念1.2.Curator 分布式锁的实现方式1.3.分布式锁接口 2.准备工作3.分布式可重入锁3.1.锁对象3.2.非重入式抢占锁测试代码输出日志 3.3.重入式抢占锁测试代码输出日志 4.分布式非可重入锁4.1.锁对象4.2.重入…

canvas绘制美队盾牌

查看专栏目录 canvas示例教程100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

项目管理十大知识领域之项目整体管理

1. 项目整体管理的定义和范畴 项目整体管理是指在整个项目生命周期中对项目进行全面规划、组织、协调、控制和监督的过程。这包括对项目目标、范围、时间、成本、质量、沟通、风险和采购等方面进行统一的管理和协调。项目整体管理的范畴涵盖了项目管理的方方面面,旨…

【特征工程】分类变量:MultiLabelBinarizer对多标签数据进行编码

MultiLabelBinarizer 说明介绍 1. MultiLabelBinarizer 是什么? MultiLabelBinarizer是scikit-learn库中的一个用于处理多标签数据的编码器。通常用于将多标签的分类任务中的标签转化为二进制形式,便于机器学习模型的处理。该编码器的主要目标是将每个…

leecode1011 | 在D天内送达包裹的能力

传送带上的包裹必须在 days 天内从一个港口运送到另一个港口。 传送带上的第 i 个包裹的重量为 weights[i]。每一天,我们都会按给出重量(weights)的顺序往传送带上装载包裹。我们装载的重量不会超过船的最大运载重量。 返回能在 days 天内将传…

压力测试+接口测试(工具jmeter)

jmeter是apache公司基于java开发的一款开源压力测试工具,体积小,功能全,使用方便,是一个比较轻量级的测试工具,使用起来非常简单。因 为jmeter是java开发的,所以运行的时候必须先要安装jdk才可以。jmeter是…