Transformer 中 Positional Encoding 实现

news2024/11/19 15:16:30

参考博文:

https://www.cnblogs.com/nickchen121/p/16470736.html

解决问题 

位置编码的主要目的是确保模型能够理解序列中的元素之间的相对位置和顺序,从而更好地捕捉到语义信息。在Transformer模型中,位置编码通常与词嵌入(word embeddings)相加,以形成模型的输入表示。这有助于模型在处理序列数据时更好地理解元素的位置和顺序,从而提高其性能,特别是在自然语言处理任务中。

原理

这里就是拿经典款transformer举例了

这个i是维度,2i这块是告诉你是sin还是cos的,是0~dimension/2

详细过程:

sin(pos+k) = sin(pos)*cos(k)+cos(pos)*sin(k) #sin表示偶数维度

cos(pos+k) = cos(pos)cos(k) +sin(pos)sin(k) #cos表示奇数维度

!pos+k可是pos和k的线性组合!

例如

pos+K=5, 当我计算第五个单词的位置编码时:

pos=1, k=4; pos=2, k=3;

这样就可以得知几个位置之间的相对关系

代码实现

Transformer

一维绝对的位置编码

def create_1d_absolute_sincos_embeddings(n_pos_vec,dim):
    assert dim % 2 == 0, "wrong dimension" # dim must be even
    # 初始化position embedding
    position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) #numel()返回数组元素个数
    # omega是对i进行遍历
    omega = torch.arange(dim//2, dtype=torch.float) #//是整除
    omega /= dim/2.
    omega = 1./(10000**omega)

    out = n_pos_vec[:, None]@omega[None, :] # 先把n_pos_vec变成列向量,一个维度加上None相当于扩了一维;接下来是把omega拓成一个行向量, @是矩阵乘法
    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)

    # 接下来是偶数位用sin赋值,奇数位用cos去赋值
    position_embedding[:, 0::2] = emb_sin
    position_embedding[:, 1::2] = emb_cos

    return position_embedding

if __name__ == '__main__':
    n_pos = 4
    dim = 4
    n_pos_vec = torch.arange(n_pos, dtype=torch.float)
    print(n_pos_vec)
    pe = create_1d_absolute_sincos_embeddings(n_pos_vec,dim)
    print("pe", pe)

Vision Transformer

一维的,绝对的,可训练的

这里用的也是一维的位置编码因为论文里做了实验表明二维的位置编码对模型效果并没有提升

def create_1d_absolute_trainable_embeddings(n_pos_vec,dim):
    # 传入索引
    # n_pos_vec: torch.aramge(n_pos, dtype=torch.float)
    # 因为可学习所以用nn.embedding来实现
    position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
    # 初始化weight(parameter class)
    nn.init.constant_(position_embedding.weight, 0.)
    return position_embedding  # 一维的,绝对的,可学习的embedding

 Swin Transformer

二维的,相对的,基于位置偏差的

相对位置,可学习

def create_2d_relative_bias_trainable_embeddings(n_head,height,width,dim):
    # embeddings的行数就是bias的个数,列数就是num_heads
    # 横轴取值 width:5[0,1,2,3,4] bias ={-width+1, width-1 }{-4,4} 4-(-4)+1 = 9
    # 纵轴取值 height:5[0,1,2,3,4] bias ={-height+1, height-1} 1-(-1)+1 = 3
    position_embedding = nn.Embedding((2*width-1)*(2*height-1), n_head)
    # 初始化weight(parameter class)
    nn.init.constant_(position_embedding.weight, 0.)
    # 获取window中二维的,两两之间的位置偏差
    # step1:算出横轴和纵轴各自的位置偏差,用网格法把横轴的位置索引和纵轴的位置索引定义出来
    def get_2d_relative_position_index(height, width):
        m1, m2 = torch.meshgrid(torch.arange(height), torch.arange(width)) # m1行一样,m2列一样
        coords = torch.stack([m1, m2]) # 把m1和m2拼接起来,dim=-1表示最后一个维度 #2*height*width
        coords_flatten = torch.flatten(coords,1) # 把coords压缩成一维,dim=1表示第一个维度,得到2*【height*width】
        ralative_coords_bias = coords_flatten[:, :, :None]- coords_flatten[:, None, :]#得到网格里任意两点横轴纵轴坐标的差值,[2,height*width,height*width]
        # 把它们都变成正数
        ralative_coords_bias[0, :, :] += height-1 # 横轴坐标的差值,0代表高度维
        ralative_coords_bias[1, :, :] += width-1 # 纵轴坐标的差值 1代表宽度维
        # 把两个方向上的坐标转化成一个方向上的坐标,类似于把一个2dtensor赋值到1dtensor
        # A;2d,B:1d B[i*cols+j] = A[i,j]
        ralative_coords_bias[0,:,:] += ralative_coords_bias[1, :, :].max()+1 # 把横轴坐标的差值转化成一维坐标,即i*cols
        # 相对位置索引
        return ralative_coords_bias.sum(0) # [height*width,height*width] # 两个方向上的坐标相加,得到相对位置索引
    relative_position_bias = get_2d_relative_position_index(height, width) # [height*width,height*width]
    bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape(height*width,height*width,n_head) # [height*width,height*width,n_head]
    bias_embedding.permute(2,0,1).unsqueeze(0) # [1, n_head,height*width,height*width]
    return bias_embedding # 二维的,相对的,可学习的embedding

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1095053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端小知识之【浏览器内核】

目录 🌟前言🌟PC端浏览器内核🌟Trident内核🌟Gecko内核🌟WebKit内核(Chromium)🌟Blink内核 🌟移动端浏览器内核🌟应用🌟写在最后 🌟前言 通常所谓的浏览器内…

docker安装nessus

注册地址:https://zh-tw.tenable.com/products/nessus/nessus-essentials 临时邮箱:http://24mail.chacuo.net/ 帮助文档:https://docs.tenable.com/nessus/Content/DeployNessusDocker.htmdocker pull tenableofficial/nessusdocker run --name "my-nessus" -d -p 8…

【Go入门】编程语言比较:Golang VS Python

Golang:最佳人工智能语言,性能优于 Python 本节是学习go的引入,为了了解Python与go编程语言间比较。后续会完成相关课程,并分享笔记。 如今,世界各地有数百万用户使用 Golang 作为机器学习和人工智能的编程语言。 最好…

算法通过村第十四关-堆|白银笔记|经典问题

文章目录 前言在数组中寻找第K大的元素堆排序原理合并K个排序链表总结 前言 提示:想要从讨厌的地方飞出来,就得有藏起来的翅膀。 --三岛由纪夫《萨德侯爵夫人》 这里我们主要看一下经典的题目,这三个题目来说都是堆的热点问题。重点再理解处理…

Qt不能安装自己想要的版本,如Qt 5.15.2

使用在线安装工具安装Qt5.15.2时,发现没有Qt 5的相关版本,只有Qt 6的版本,这时选择右边的Archive,再点击筛选,这时就会出现之前的Qt版本。

vscode插件路径转移C盘之外盘

改变vscode系统路径 最近C盘路径不够了,网上的工具使用没那么精细,还不如自己手动看每个文件夹大小。在整理过长遇到vscode插件路径转移,方法如下: 桌面图标右键点击属性 改变–extensions-dir后面参数就可以了。

Web3 整理React项目 导入Web3 并获取区块链信息

上文 WEB3 创建React前端Dapp环境并整合solidity项目,融合项目结构便捷前端拿取合约 Abi 我们用react 创建了一个 dapp 项目 并将前后端代码做了个整合 那么 我们就来好好整理一下 我们的前端react的项目结构 我们在 src 目录下创建一个 components 用来存放我们的…

Python学习----Day07

函数 函数是组织好的,可重复使用的,用来实现单一,或相关联功能的代码段。函数能提高应用的模块性,和代码的重复利用率。你已经知道Python提供了许多内建函数,比如print()。但你也可以自己创建函数,这被叫做…

C++ 程序员入门需要多久,怎样才能学好?

文章目录 C学习方案有哪些推荐的在线教程或学习资源可以帮助我学习C编程?你能给我一些关于C内存管理的进阶学习资源吗? AI解答 C学习方案 C是一种功能强大且广泛应用的编程语言,作为一个初学者,学习C需要一定的时间和努力。学习…

【java学习—七】对象的实例化过程(33)

文章目录 1. 简单类对象的实例化过程2. 子类对象的实例化过程 1. 简单类对象的实例化过程 2. 子类对象的实例化过程

YOLO目标检测——打电话数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用:安全监控、智能驾驶、人机交互、智能城市数据集说明:YOLO目标检测数据集,真实场景的高质量图片数据,数据场景丰富。使用lableimg标注软件标注,标注框质量高,含voc(xml)、coco(json)和yolo(txt…

快速学习MyBatisPlus

文章目录 前言一、条件构造器和常用接口1.wapper介绍2.QueryWrapper(1)组装查询条件(2)组装排序查询(3)组装删除查询(4)条件优先级(5)组装select子句&#xf…

c++命名空间,缺省参数,引用

首先为了解决命名冲突,c提出了命名空间这一功能 比如using namespace std; 就是使用std(c官方库定义的命名空间)这个命名空间里面的命名。 using就可以直接指定本文件用那个命名空间。 也可以用::域作用限定符 如std::cin>> 并且会…

Linux网络编程系列之服务器编程——信号驱动模型

一、什么是信号驱动模型 在服务器中,信号驱动模型是一种事件处理模型,它能够异步地响应来自外部的事件。服务器可以注册一组回调函数,来处理来自客户端或其他进程的信号或事件,当信号或事件触发时,操作系统会通知服务器…

云耀服务器L实例部署Nextcloud企业云盘系统|华为云云耀云服务器L实例评测使用体验

文章目录 Nextcloud简介1.1 部署华为云云耀服务器L实例1.1.1 云耀服务器L实例购买1.1.2 云耀服务器L实例初始化配置1.1.3 远程登录云耀服务器L实例 2. 云耀服务器L实例中间件部署2.1 安装配置环境2.1.1 安装基本工具2.1.2 安装MariaDB2.1.3 安装Nginx2.1.4 安装PHP 3. 安装Next…

网络通信协议-HTTP、WebSocket、MQTT的比较与应用

在今天的数字化世界中,各种通信协议起着关键的作用,以确保信息的传递和交换。HTTP、WebSocket 和 MQTT 是三种常用的网络通信协议,它们各自适用于不同的应用场景。本文将比较这三种协议,并探讨它们的主要应用领域。 HTTP&#xff…

【VSCode】Windows环境下,VSCode 搭建 cmake 编译环境(通过配置文件配置)

除了之前的使用 VSCode 插件来编译工程外,我们也可以使用配置文件来编译cmake工程,主要依赖 launch.json 和 tasks.json 文件。 目录 一、下载编译器 1、下载 Windows GCC 2、选择编译器路径 二、配置 debug 环境 1、配置 lauch.json 文件 2、配置…

FPGA project : flash_write

本实验重点学习了: flash的页编程指令pp。 在写之前要先进行擦除(全擦除和页擦除); 本实验:先传写指令,然后进入写锁存周期,然后传页编程指令,3个地址; 然后传数据&a…

分布式事务协调中间件---seata快速入门

分布式事务 Seata,之前叫做Fescar,是一个开源的分布式事务解决方案,它主要致力于提供高效和简单的分布式事务服务。Seata主要用于解决微服务架构下的数据一致性问题。 Seata 的基本原理是基于两阶段提交 (2PC) 以及三阶段提交 (3PC)&#xff…

Linux 学习的六个过程

Linux 上手难,学习曲线陡峭,所以它的学习过程更像一个爬坡模式。这些坡看起来都很陡,但是一旦爬上一阶,就会一马平川。 1、抛弃旧的思维习惯,熟练使用 Linux 命令行 在 Linux 中,无论我们做什么事情&…