LLaMA中ROPE位置编码实现源码解析

news2025/1/24 22:32:46

1、Attention中q,经下式,生成新的q。m为句长length,d为embedding_dim/head
θ i = 1 1000 0 2 i d \theta_i=\frac{1}{10000^\frac{2i}{d}} θi=10000d2i1
在这里插入图片描述

2、LLaMA中RoPE源码

import torch

def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):
    '''
    计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx
    :param dim: q,k,v的最后一维,一般为emb_dim/head_num
    :param end: 句长length
    :param constant: 这里指10000
    :return:
    复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t))
    '''
    # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta
    # 形式化为 [theta_0, theta_1, ..., theta_d/2]
    freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2]

    # 计算m
    t = torch.arange(end, device=freqs.device)  # [length]
    # 计算m*theta
    freqs = torch.outer(t, freqs).float()  # [length, d]
    # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1

    # 计算cos(m*theta)+j*sin(m*theta)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0),  cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_d/2-1)+j*sin(m*theta_d/2-1)]
    # 其中j为虚数单位, m=0,1,...,length-1
    return freqs_cis # [length, d]

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2)
    return freqs_cis.view(*shape) # [1, length, 1, d/2]

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,):
    # 先将xq维度变为[bs, length, head,  d/2, 2], 利用torch.view_as_complex转变为复数
    # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2]
    # 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2]
    # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下
    # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))
    # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)
    # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部
    # 最后经flatten函数将维度拉平,即[bs, length, head, d]
    # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d]
    # 即为新生成的q

    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__=='__main__':
    # (bs, length, head, d)
    q = torch.randn((2, 10, 12, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 10, 12, 32))
    v = torch.randn((2, 10, 12, 32))
    freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0)
    # print(freqs_cis.detach().numpy())
    
    q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)
    print()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/924403.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动化测试之Selenium

自动化测试Selenium介绍环境搭建如何操作浏览器定位元素css类选择器定位元素xpath定位元素css选择语法xpath选择语法 常用操作添加等待打印信息浏览器更多操作键盘事件鼠标事件特殊场景只选复选框iframe标签下拉框处理弹窗显示上传文件 关闭浏览器切换窗口截图 自动化测试 自动…

ubuntu20.04 直接安装vpp23.06 测试双 VPP Tunnel Ike2

环境信息&#xff1a;VMware Workstation 17 Pro ubuntu20.04 (清华源) ubuntu 源点进去选&#xff1a;ubuntu-22.04.3-desktop-amd64.iso 如果之前装过VPP&#xff0c;用以下命令确定是否卸载干净&#xff1a; dpkg -l | grep vpp dpkg -l | grep DPDK 卸载&#xff1a; …

ChatGPT + Flutter快速开发多端聊天机器人App

下载地址&#xff1a;ChatGPT Flutter快速开发多端聊天机器人App 下载地址&#xff1a;ChatGPT Flutter快速开发多端聊天机器人App

SpringCloud学习笔记(五)_Consul注册中心

本章使用的Consul版本是 1.7.2 项目架构图如下&#xff1a; 搭建服务提供者 1、新建一个maven项目&#xff08;test-springcloud-provider-payment8006&#xff09; 结构如下&#xff1a; 2、引入依赖&#xff0c;编辑pom文件 1 <!-- spring-cloud 整合 consul --> 2…

MyBatis分页与特殊字符处理

文章目录 一、分页1.1 分页插件PageHelper1.2 使用1.2.1 导入pom依赖1.2.2 Mybatis.cfg.xml配置拦截器1.2.3. 配置 Mapper.xml1.2.4 测试 二、特殊字符处理2.1 使用CDATA区段2.2 使用实体引用 一、分页 1.1 分页插件PageHelper PageHelper 是 Mybatis 的一个插件。官网 Page…

SOLIDWORKS提高装配效率的方法:阵列特征驱动阵列

相信SOLIDWORKS用户很喜欢SOLIDWORKS阵列命令&#xff0c;因为阵列可以提高设计效率&#xff0c;减少错误&#xff0c;修改也很方便&#xff0c;但是大家一般在零件里用阵列来阵列特征&#xff0c;在装配体里用阵列来阵列零件&#xff0c;那有没有办法用零件中的阵列特征来驱动…

开源与专有软件:比较与对比

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

173935-46-1/Fmoc-L-Ser(Ac3-L-Fucα)-OH的保存建议

Fmoc-L-Ser((Ac)3-α-Fuc)-OH是一种支链N-连接寡糖的碱性单糖片段&#xff0c;该片段在糖蛋白中发现。该片段由α-L-fucose和N-acetylated α-L-fucose通过乙酰基连接而成&#xff0c;并被包含在一个Fmoc-L-serine分子中。西安凯新生物科技有限公司的这种支链寡糖结构在生物体内…

linux安装JDK及hadoop运行环境搭建

1.linux中安装jdk &#xff08;1&#xff09;下载JDK至opt/install目录下&#xff0c;opt下创建目录soft&#xff0c;并解压至当前目录 tar xvf ./jdk-8u321-linux-x64.tar.gz -C /opt/soft/ &#xff08;2&#xff09;改名 &#xff08;3&#xff09;配置环境变量&#xf…

【Java架构-包管理工具】-Maven进阶(二)

本文摘要 Maven作为Java后端使用频率非常高的一款依赖管理工具&#xff0c;在此咱们由浅入深&#xff0c;分三篇文章&#xff08;Maven基础、Maven进阶、私服搭建&#xff09;来深入学习Maven&#xff0c;此篇为开篇主要介绍Maven进阶知识&#xff0c;包含坐标、依赖、仓库、生…

pgadmin4中的备份与恢复

一&#xff0c;postgresql 数据的备份与恢复 &#xff08;一&#xff09;数据库备份与恢复 1&#xff0c;备份 windows环境 1> dump 逻辑备份 1&#xff0c;用管理员身份打开power shell 2&#xff0c;切换到本机 postgresql 安装目录下的 bin 目录&#xff1a; PS C…

目标检测笔记(十一):如何在特定区域进行人脸检测实操

文章目录 背景代码 背景 由于我们在做项目的时候可能会涉及到某个指定区域进行目标检测或者人脸识别等任务&#xff0c;所以这篇博客是为了探究如何在传统目标检测的基础上来结合特定区域进行检测&#xff0c;以Opencv自带的包为例。 代码 import cv2face_cascade cv2.Casc…

云服务器 宝塔(每次更新)

su root 输入密码 使用 root 权限 /etc/init.d/bt default 获取宝塔登录 位置和账号密码。进入宝塔 删除数据库 删除php前端站点 删除PM2后端项目 前端更改完配置打包dist文件 后端更改完配置项目打包 数据库结构导出 导入数据库 配置 PM2 后端 安装依赖

在百度地图中添加自定义全屏控件

百度地图中添加全屏控件 前置知识&#xff1a; 进入整个页面的全屏模式 &#xff1a;document.documentElement.requestFullscreen() 进入特定元素的全屏模式 &#xff1a; document.getElementById("ID").requestFullscreen() 退出全屏&#xff1a;document.exitFu…

HTML番外篇(四)-HTML5新增元素-CSS常见函数-理解浏览器前缀-BFC

一、HTML5新增元素 1.HTML5语义化元素 在HMTL5之前&#xff0c;我们的网站分布层级通常包括哪些部分呢&#xff1f; header、nav、main、footer ◼ 但是这样做有一个弊端&#xff1a; 我们往往过多的使用div, 通过id或class来区分元素&#xff1b;对于浏览器来说这些元素不…

django自动创建model数据

目前使用的环境&#xff1a;django4.2.3&#xff0c;python3.10 django通过一些第三方库&#xff0c;可以轻易的自动生成一系列的后台数据。 首先先创建一个数据库&#xff1a; 然后&#xff0c;在setting.py中就可以指定我们新创建的数据库了。 DATABASES {default: {ENGI…

基于Pytorch的神经网络部分自定义设计

一、基础概念&#xff08;学习笔记&#xff09; &#xff08;1&#xff09;训练误差和泛化误差[1] 本质上&#xff0c;优化和深度学习的目标是根本不同的。前者主要关注的是最小化目标&#xff0c;后者则关注在给定有限数据量的情况下寻找合适的模型。训练误差和泛化误差通常不…

【前端】深入解析CSS:选择器、显示模式、背景属性和特征剖析

目录 一、前言二、CSS的复合选择器1、后代选择器①、语法②、注意事项 2、子选择器①、语法②、注意事项 3、并集选择器①、语法②、注意事项 4、链接伪类选择器①、语法②、注意事项 三、CSS元素显示模式转换1、转换为块元素display:block2、转换为行内元素display:inline3、转…

AIGC ChatGPT 制作地图可视化分析

地图可视化分析是一种将数据通过地图的形式进行展示的方法&#xff0c;可以让人们更加直观、快速、准确的理解和分析数据。以下是地图可视化分析的一些主要好处&#xff1a; 加强数据理解&#xff1a;地图可视化可以将抽象的数字转化为直观的图形&#xff0c;帮助我们更好地理解…

LLMs对单个任务进行微调Fine-tuning on a single task

虽然LLM因其在单一模型内执行多种不同语言任务的能力而变得出名&#xff0c;但您的应用程序可能只需要执行单一任务。在这种情况下&#xff0c;您可以微调一个预训练的模型&#xff0c;以仅提高您感兴趣的任务的性能。例如&#xff0c;使用该任务的示例数据集进行摘要。有趣的是…