LLM:RoPE位置编码

news2024/9/24 11:27:05

论文:https://arxiv.org/pdf/2104.09864.pdf

代码:https://github.com/ZhuiyiTechnology/roformer

发表:2021

 绝对位置编码:其常规做法是将位置信息直接加入到输入中(在x中注入绝对位置信息)在计算 query, key 和 value 向量之前会计算一个位置编码向量 p_{i}先加到词嵌入 x_{i}上,然后再乘以对应的变换矩阵 W:       

而经典的位置编码PE的计算方式是采用上篇文章提到的Sinusoidal 函数。

  • 优势: 实现简单,可预先计算好,不用参与训练,速度快。

  • 劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。

相对位置编码:相对位置并没有完整建模每个输入的位置信息,而是在算Attention的时候考虑当前位置与被Attention的位置的相对距离,由于自然语言一般更依赖于相对位置,所以相对位置编码通常也有着优秀的表现。   (在k, v中注入相对位置信息)                                                          

  • 优势: 直接地体现了相对位置信号,效果更好。具有外推性,处理长文本能力更强。

RoPE旋转式位置编码(Rotary Position Embedding)

1:RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。

2:主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了。

旋转位置编码RoPE 是目前大模型中广泛使用的一种位置编码,包括但不限于Llama、Baichuan、ChatGLM、Qwen等。

RoPE推导

图解RoPE旋转位置编码及其特性

十分钟读懂旋转编码(RoPE)

Transformer升级之路:2、博采众长的旋转式位置编码 - 科学空间|Scientific Spaces

复杂的推导公式,可以这3个链接进行理解。着重看一下公式11和公式13:有助于理解后面的RoPE的代码实现。 

 结合论文的图例:

1:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量
2:然后对每个 token 位置都计算对应的旋转位置编码
3:接着对每个 token 位置的 query 和 key 向量的元素按照 两两一组 应用旋转变换
4:最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果

RoPE代码实现 

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim//2)
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]
    theta = torch.pow(10000, -2 * ids / output_dim)

    # (max_len, output_dim//2)
    embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))

    # (max_len, output_dim//2, 2)
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # (bs, head, max_len, output_dim//2, 2)
    embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复

    # (bs, head, max_len, output_dim)
    # reshape后就是:偶数sin, 奇数cos了
    embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
    embeddings = embeddings.to(device)
    return embeddings


def RoPE(q, k):
    # q,k: (bs, head, max_len, output_dim)
    batch_size = q.shape[0]
    nums_head = q.shape[1]
    max_len = q.shape[2]
    output_dim = q.shape[-1]

    # (bs, head, max_len, output_dim)
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)


    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了



    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k


def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    # q.shape: (bs, head, seq_len, dk)
    # k.shape: (bs, head, seq_len, dk)
    # v.shape: (bs, head, seq_len, dk)

    if use_RoPE:
        q, k = RoPE(q, k)

    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)  # mask掉为0的部分,设为无穷大

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores


if __name__ == '__main__':
    # (bs, head, seq_len, dk)
    q = torch.randn((8, 12, 10, 32))
    k = torch.randn((8, 12, 10, 32))
    v = torch.randn((8, 12, 10, 32))

    res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)


    # (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)
    print(res.shape, att_scores.shape)


参考

1:https://zhuanlan.zhihu.com/p/641274061 

2:十分钟读懂旋转编码(RoPE)-腾讯云开发者社区-腾讯云

3:LLM学习记录(五)--超简单的RoPE理解方式 - 知乎 

4:一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long_alibi位置编码-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1397641.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql原理--锁

1.解决并发事务带来问题的两种基本方式 上一章唠叨了事务并发执行时可能带来的各种问题,并发事务访问相同记录的情况大致可以划分为3种: (1). 读-读 情况:即并发事务相继读取相同的记录。 读取操作本身不会对记录有一毛钱影响,并不…

爬虫笔记(一):实战登录古诗文网站

需求:登录古诗文网站,账号+密码+图形验证码 第一:自己注册一个账号+密码哈 第二:图形验证码,需要一个打码平台(充钱,超能力power!)或…

“GPC爬虫池有用吗?

作为光算科技的独有技术,在深入研究谷歌爬虫推出的一种吸引谷歌爬虫的手段 要知道GPC爬虫池是否有用,就要知道谷歌爬虫这一概念,谷歌作为一个搜索引擎,里面有成百上千亿个网站,对于里面的网站内容,自然不可…

网络爬虫采集工具

在当今数字化的时代,获取海量数据对于企业、学术界和个人都至关重要。网络爬虫成为一种强大的工具,能够从互联网上抓取并提取所需的信息。本文将专心分享关于网络爬虫采集数据的全面指南,深入探讨其原理、应用场景以及使用过程中可能遇到的挑…

php array_diff 比较两个数组bug避坑 深入了解

今天实用array_diff出现的异常问题,预想的结果应该是返回 "integral_initiate">"0",实际没有 先看测试代码: $a ["user_name">"测","see_num">0,"integral_initiate&quo…

leetCode-42.接雨水

📑前言 本文主要是【算法】——算法模拟的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄每日一句&#xff…

Go语言基础快速上手

1、Go语言关键字 2、Go数据类型 3、特殊的操作 3.1、iota关键字 Go中没有明确意思上的enum(枚举)定义,不过可以借用iota标识符实现一组自增常亮值来实现枚举类型。 const (a iota // 0b // 1c 100 // 100d // 100 (与上一…

H - Least Common Multiple H - 最小公倍数

题目 The least common multiple (LCM) of a set of positive integers is the smallest positive integer which is divisible by all the numbers in the set. For example, the LCM of 5, 7 and 15 is 105. 一组正整数的最小公倍数 (LCM) 是最小的正整…

rabbitmq的介绍、使用、案例

1.介绍 rabbitmq简单来说就是个消息中间件,可以让不同的应用程序之间进行异步的通信,通过消息传递来实现解耦和分布式处理。 消息队列:允许将消息发到队列,然后进行取出、处理等操作,使得生产者和消费者之间能够解耦&…

使用 Kali Linux Hydra 工具进行攻击测试和警报生成

一、Hydra 工具和 Kali Linux 简介 在网络安全领域中,渗透测试是评估系统密码强度的重要组成部分。Hydra 是一款由黑客组织“The Hackers Choice”开发的开源登录破解工具,支持50多种协议。本教程将探索如何将 Hydra 与 Kali Linux 结合使用&#xff0c…

快快销ShopMatrix 分销商城多端uniapp可编译5端 - 升级申请(可自定义申请表单)

在企业或组织中,升级申请通常涉及到员工职位、权限、设备或者其他资源的提升或更新。创建一个可自定义的升级申请表单可以帮助更高效地收集和处理这类申请信息。以下是一个基本的步骤: 确定表单字段: 申请人信息:姓名、部门、职位…

【Spring Boot 3】【Redis】基本数据类型操作

【Spring Boot 3】【Redis】基本数据类型操作 背景介绍开发环境开发步骤及源码工程目录结构 背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工…

priority_queue的使用与模拟实现(容器适配器+stack与queue的模拟实现源码)

priority_queue的使用与模拟实现 引言(容器适配器)priority_queue的介绍与使用priority_queue介绍接口使用默认成员函数 size与emptytoppush与pop priority_queue的模拟实现构造函数size与emptytoppush与pop向上调整建堆与向下调整建堆向上调整建堆向下调…

UE5 蓝图编辑美化学习

虚幻引擎中干净整洁蓝图的15个提示_哔哩哔哩_bilibili 1.双击线段成节点。 好用,爱用 2.用序列节点 好用,爱用 3.用枚举。 好用,能避免一些的拼写错误 4.对齐节点 两点一水平线 5.节点上下贴节点 (以前不懂,现在经常…

【AJAX框架】AJAX入门与axios的使用

文章目录 前言一、AJAX是干什么的?二、AJAX的安装2.1 CDN引入2.2 npm安装 三、基础使用3.1 CDN方式3.2 node方式 总结 前言 在现代Web开发中,异步JavaScript和XML(AJAX)已经成为不可或缺的技术之一。AJAX使得网页能够在不刷新整个…

SQL注入实战操作

一:SQl注入分类 按照注入的网页功能类型分类: 1、登入注入:表单,如登入表单,注册表单 2、cms注入:CMS逻辑:index.php首页展示内容,具有文章列表(链接具有文章id)、articles.php文 章详细页&a…

NX二次开发封装自己的函数及如何导入工程

目录 一、概述 二、函数封装 三、函数引用 四、案例——在NX中运行后输出“测试”两字 一、概述 随着对NX二次开发的学习,我们在各种项目里面会积累很多函数,对于一些经常用到的函数,我们可以考虑将其封装为类库,以后在开发其…

从请购到结算,轻松搞定!云迈ERP系统助力企业采购管理全流程!

​在企业的运营过程中,采购管理是至关重要的环节之一。为了确保采购流程的顺畅和高效,许多企业选择引入ERP(企业资源规划)系统来进行采购管理。那么erp系统中的采购管理都有哪些功能呢? 云迈erp系统中的采购管理模块&a…

虚拟机下载docker

一,Docker简介 百科说:Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化,容器是完全使用沙箱机制&#xff…

大数据导论(2)---大数据与云计算、物联网、人工智能

文章目录 1. 云计算1.1 云计算概念1.2 云计算的服务模式和类型1.3 云计算的数据中心与应用 2. 物联网2.1 物联网的概念和关键技术2.2 物联网的应用和产业2.3 大数据与云计算、物联网的关系 1. 云计算 1.1 云计算概念 1. 首先从商业角度给云计算下一个定义:通过网络…