LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

news2024/11/24 20:47:37

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

在这里插入图片描述

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
        super().__init__()
        # 相比RoPE增加scale参数
        self.scale = scale
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        # 构建max_seq_len_cached大小的张量t
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        # 张量t归一化,RoPE没有这一步
        t /= self.scale
        # einsum计算频率矩阵
        # 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 在-1维做一次拷贝、拼接
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        # 注册为模型的缓冲区cos_cached和sin_cached
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            t /= self.scale
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        # 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
        super().__init__()
        # 与线性插值法相比,实现更简单,alpha仅用来改变base
        base = base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
        super().__init__()
        # 是否开启NTK(Neural Tangent Kernel)
        self.ntk = ntk
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # emb:[max_seq_len_cached, dim]
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            if self.ntk:
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
                # 计算新的inv_freq
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                self.register_buffer("inv_freq", inv_freq)
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            if not self.ntk:
                # 缩放
                t *= self.max_position_embeddings / seq_len
            # 得到新的频率矩阵freqs
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            # freqs与自身拼接得到新的emb
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            # 注册为模型的缓冲区cos_cached和sin_cached
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)

        # 长度裁剪
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/860443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用appuploader怎么安装测试​

使用appuploader怎么安装测试​ 一.安装测试​ 首先我们来看安装测试这个模块,注意按照上面提示内容操作。 点击首页的测试设备管理 二.选择IPA​ 进入“安装测试”页面,选择一个(必须是开发类型描述文件编译,且描述文件包含设…

f1tenth仿真设置

文章目录 一、安装依赖二、进入工作空间克隆三、编译四、运行 一、安装依赖 tf2_geometry_msgs ackermann_msgs joy map_server sudo apt-get install ros-noetic-tf2-geometry-msgs ros-noetic-ackermann-msgs ros-melodic-joy ros-noetic-map-server 二、进入工作空间克隆…

详解Linux文本三剑客

目录 一、grep 1.什么是grep? 2.如何使用? 3.正则 二、sed 1.认识sed? 2.如何使用? 三、awk(重点) 1.awk变量 1.1内置变量 1.2自定义变量 2.awk数组 四、经典实战案例 案例一:筛选IPv4地址 案例二&am…

为MySQL新增一张performance_schema表 | StoneDB 技术分享会 #4

StoneDB开源地址 https://github.com/stoneatom/stonedb 设计:小艾 审核:丁奇、李浩 编辑:宇亭 作者:王若添 中国科学技术大学-软件工程-在读硕士、StoneDB 内核研发实习生 performance_schema 简介 MySQL 启动后会自动创建四…

基于SpringBoot+LayUI的宿舍管理系统 001

项目简介 源码来源于网络,项目文档仅用于参考,请自行二次完善哦。 系统以MySQL 8.0.23为数据库,在Spring Boot SpringMVC MyBatis Layui框架下基于B/S架构设计开发而成。 系统中的用户分为三类,分别为学生、宿管、后勤。这三…

【MySQL常见面试题】

索引的基本原理 索引⽤来快速地寻找那些具有特定值的记录。如果没有索引,⼀般来说执⾏查询时遍历整张表。 索引的原理:就是把⽆序的数据变成有序的查询 把创建了索引的列的内容进⾏排序 对排序结果⽣成倒排表 在倒排表内容上拼上数据地址链 在查询的…

计算机网络:网络通信相关概念入门

目录 一、网络发展背景二、理解网络通信三、理解IP地址1.简述IP地址2.IP地址的版本3.提高地址利用率的技术 四、理解端口1.简述端口2.使用端口的原因 五、理解网络通信协议 一、网络发展背景 网络发展背景: 最初的计算机是单机,那么单机是这样传输数据的…

谁是5G应用的狮子座?

8月5日,广和通公布2023年上半年财报,2023上半年总营业收入38.65亿元,同比增长59.87%。在全球经济存在诸多不确定性的背景下,广和通精准预判市场风向,制定稳健发展策略,业绩仍保持逆势增长。其中&#xff0c…

详谈基于布局分析的表格识别方法

基于布局分析的OCR(Optical Character Recognition)是一种基于页面布局信息的文本识别方法。传统的OCR系统通常依赖于表格线或者特定的格式来进行文本区域检测和字符识别,但对于一些表格线不全或线不清晰,甚至没表格线&#xff0c…

协程(一)单机--》并发--》协程

目录 一 协程的概述1.1 并行与并发1.2 线程1.3 新的思路1.4 Goroutine 二 第一个入门程序 一 协程的概述 我查看了网上的一些协程的资料,发现每个人对协程的概念都不一样,但是我认可的一种说法是:协程就是一种轻量级的线程框架(K…

去趋势化一个心电图信号、信号功率谱、低通IIR滤波器并平滑信号、对滤波器引起的延迟进行补偿研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【C语言进阶篇】关于指针的八个经典笔试题(图文详解)

🎬 鸽芷咕:个人主页 🔥 个人专栏:《C语言初阶篇》 《C语言进阶篇》 ⛺️生活的理想,就是为了理想的生活! 文章目录 📋 前言💬 指针笔试题💭 笔试题 1:✅ 代码解析⁉️ 检验结果&…

二叉树(4)------收尾

1)最大二叉树 654. 最大二叉树 - 力扣(LeetCode) 题目解析: 1)首先我们找到了整个数组中最大的元素作为我们的根节点,然后再从左区间中找到最大的元素作为当前根节点的左子树,然后再从右区间里面找到最大的元素作为根节点的右子树…

【OpenVINOSharp】 基于C#和OpenVINO2023.0部署Yolov8全系列模型

基于C#和OpenVINO2023.0部署Yolov8全系列模型 1 项目简介1.1 OpenVINOTM 2 OpenVinoSharp2.1 OpenVINOTM 2023.0安装配置2.2 C 动态链接库2.3 C#构建Core推理类2.4 NuGet安装OpenVinoSharp 3 获取和转换Yolov8模型3.1 安装ultralytics3.2 导出yolov8模型3.3 安装OpenVINOTM Pyt…

杭电多校 Rikka with Square Numbers 费马平方和定理

👨‍🏫 Rikka with Square Numbers 🧀 参考题解 🍻 AC code import java.util.Scanner;public class Main {static boolean isSqu(int x){int t (int) Math.sqrt(x);return t * t x;}public static void main(String[] args…

Vue2:组件基础(上)

Vue2:组件基础(上) Date: July 29, 2023 Sum: 生命周期、Vue-cli、组件的使用、小黑记账清单、小兔鲜首页 生命周期: 生命周期介绍 思考: 什么时候可以发送初始化渲染请求?(越早越好&#x…

Aviator这么丝滑,怎么实现的呢?

大家好,我是老三,在上期 里我们介绍了轻量级规则引擎AviatorScript的基本用法和一些使用案例,这期我们来研究一下,这么丝滑的规则脚本是怎么实现的。 概览 我们先来回顾一个简单的例子: Testpublic void test(){//表…

一道RSA题(忘了名字)-云上贵州-网络安全攻防竞赛个人赛

题目: 很遗憾,这道题当时没做出来。 话不多说,直接开始,题目只给了一个式子,这里就命名为hint: 最开始我的想法是化为模q的形式,也就是: 然后就变成了: 接着就一直卡在这…

国产数据库排行

目录 一、理论 1.国产数据库排行 2.数据 一、理论 1.国产数据库排行 (1)墨天轮榜单 墨天轮国产数据库流行度排行于2019年6月推出,通过近50个维度的数据来考察近300个国产数据库的流行度排行,每月1日更新排行数据&#xff0c…

js 使用 Object.defineProperty() 对属性进行限制 06

小夏小夏,可爱到爆炸 🤣 💕💕💕 文章目录 一、对属性操作的控制二、属性描述符三、数据属性描述符四、存取属性描述符五、vue2 响应式原理六、defineProerties 同时定义多个属性七、对象方法补充 一、对属性操作的控制…