【论文笔记】LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models

news2025/1/7 16:04:30

🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


基本信息

标题: LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
作者: Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, Jiaya Jia
发表: ICLR 2024
arXiv: https://arxiv.org/abs/2309.12307

基本信息

摘要

我们提出了LongLoRA,一种高效的微调方法,它通过有限的计算成本扩展了预训练大型语言模型(LLM)的上下文大小。

通常,使用长上下文大小训练LLM在计算上非常昂贵,需要大量的训练时间和GPU资源。例如,在 8192 8192 8192 个上下文长度的训练中,自注意力层的计算成本是 2048 2048 2048 个上下文长度的 16 16 16 倍。

在本文中,我们从两个方面加速了LLM上下文扩展。

一方面,尽管在推理过程中需要密集的全局注意力,但通过稀疏局部注意力可以有效地进行模型微调。提出的移位稀疏注意力(S2-Attn)有效地实现了上下文扩展,与使用标准注意力微调具有相似的性能,同时实现了显著的计算节省。特别是,它可以在训练中仅用两行代码实现,而在推理中是可选的。

另一方面,我们重新审视了参数高效的上下文扩展微调机制。值得注意的是,我们发现LoRA在可训练嵌入和归一化的前提下,对于上下文扩展效果良好。LongLoRA将这种改进的LoRA与 S 2 S^2 S2-Attn相结合。

LongLoRA在Llama2模型(从7B/13B到70B)的各种任务上展示了强大的实证结果。LongLoRA将Llama2 7B的上下文从4k扩展到100k,或将Llama2 70B扩展到32k,在单个8×A100机器上完成。

LongLoRA在保持原始架构的同时扩展了模型的上下文,并且与大多数现有技术兼容,如Flash-Attention2。

此外,我们还使用LongLoRA和我们的长指令遵循LongAlpaca数据集进行了监督微调。

我们所有的代码、模型、数据集和演示代码都可在github.com/dvlab-research/LongLoRA上找到。

简介

LongLoRA closes the accuracy gap that between conventional LoRA and full fine-tuning, while still maintaining up to 1.8× lower memory cost than full fine-tuning. Furthermore, LongLoRA improves the training speed of LoRA by up to 1.8× with S2-Attn. Llama2-7B are fine-tuned to various context lengths with Flash-Attention2 (Dao, 2023) and DeepSpeed (Rasley et al., 2020) stage 2 and evaluated on the proof-pile (Azerbayev et al., 2022) test set in perplexity.

LongLoRA缩小了传统LoRA和全量微调之间的精度差距,同时保持了比全量微调低1.8倍的内存成本。此外,LongLoRA通过 S 2 S^2 S2-Attn 将LoRA的训练速度提高了高达1.8倍。Llama2-7B使用Flash-Attention2和DeepSpeed的第二阶段进行微调,并在 proof-pile 测试集上评估了困惑度。

Illustration of S2-Attn. It involves three steps. First, it splits features along the head dimension into two chunks. Second, tokens in one of the chunks are shifted by half of the group size. Third, we split tokens into groups and reshape them into batch dimensions. Attention only computes in each group in ours while the information flows between groups via shifting. Potential information leakage might be introduced by shifting, while this is easy to prevent via a small modification on the attention mask. We ablate this in the variant 2 in Section B.3 in the appendix.

S 2 S^2 S2-Attn的示意图涉及三个步骤。首先,它将特征沿头部维度分为两个部分。其次,其中一个部分中的token向右移动了组大小的一半。第三,我们将token分成组,并将它们重塑为批量维度。在我们的模型中,注意力仅在每组中计算,而信息通过移动在组之间流动。移动可能会引入潜在的信息泄露,但通过在注意力掩码上进行微小修改可以轻松防止。

LongLoRA

背景

Transformer

大型语言模型(LLMs)通常是基于 Transformer 构建的。例如,以 Llama2 为例,一个 LLM 模型由一个嵌入输入层和若干解码器层组成。每个解码器层包含一个自注意力模块。它通过带有权重矩阵 { W q , W k , W v } \{W_q, W_k, W_v\} {Wq,Wk,Wv} 的线性投影层将输入特征映射为一组查询、键和值 { q , k , v } \{q, k, v\} {q,k,v}。给定 { q , k , v } \{q, k, v\} {q,k,v},它计算输出 o o o

o = softmax ( q k T ) v o = \text{softmax}(qk^T)v o=softmax(qkT)v

输出随后通过一个权重矩阵 W o W_o Wo 的线性层进行投影,接着是多层感知机(MLP)层。在自注意力模块之前和之后,会应用层归一化。所有解码器层完成后,还会进行一次最终归一化。

对于较长的序列,自注意力在计算成本方面表现出困难,其计算复杂度与序列长度成平方关系。这大幅减慢了训练过程,并增加了 GPU 内存的使用成本。

Low-rank Adaptation

LoRA假设预训练模型中的权重更新在适配期间具有较低的内在秩(intrinsic rank)。对于一个预训练权重矩阵 W ∈ R d × k W \in \mathbb{R}^{d \times k} WRd×k,它通过低秩分解 W + Δ W = W + B A W + \Delta W = W + BA W+ΔW=W+BA 来更新,其中 B ∈ R d × r B \in \mathbb{R}^{d \times r} BRd×r A ∈ R r × k A \in \mathbb{R}^{r \times k} ARr×k。秩 r ≪ min ⁡ ( d , k ) r \ll \min(d, k) rmin(d,k)。在训练期间, W W W 被冻结(没有梯度更新),而 A A A B B B 是可训练的。这就是 LoRA 训练比完全微调更高效的原因。

在 Transformer 结构中,LoRA 仅适配注意力权重 { W q , W k , W v , W o } \{W_q, W_k, W_v, W_o\} {Wq,Wk,Wv,Wo},并冻结所有其他层,包括 MLP 和归一化层。这种方式简单且参数高效。然而,我们通过实验证明,仅在注意力权重中的低秩适配并不能很好地适用于长上下文扩展任务。

Shifted Sparse Attention

标准的自注意力计算成本为 O ( n 2 ) \mathcal{O}(n^2) O(n2),这使得长序列上的LLM具有高内存成本和低速。为了在训练期间避免这一问题,我们提出了移位稀疏注意力( S 2 S^2 S2-Attn),如图2所示。接下来,我们将进行一项初步研究,并逐步解释我们的设计。

Overview of LongLoRA. We introduce Shifted Sparse Attention (S2-Attn) during finetuning. The trained model retains original standard self-attention at inference time. In addition to training LoRA weights in linear layers, LongLoRA further makes embedding and normalization layers trainable. This extension is pivotal for context extension, and only introduces a minimal number of additional trainable parameters.

Pilot Study

在表1中,我们建立了一个标准基线,该基线经过完整注意力和微调训练和测试,在各种上下文长度下表现出一致的良好质量。第一次试验是使用短注意力进行训练,仅模式1如图2所示。正如我们所知,在长上下文中,高昂的成本主要来自自注意力模块。因此,在这次试验中,由于输入很长,我们在自注意力中将其分为几个组。例如,模型在训练和测试阶段都以8192个token作为输入,但在每个组中进行自注意力操作,组大小为2048,组数为4。这种模式效率很高,但在非常长的上下文中仍然不起作用,如表1所示。随着上下文长度的增加,困惑度变大。其背后的原因是没有不同组之间的信息交换。

Effectiveness of S2-Attn under different context lengths

为了引入组之间的通信,我们包括了一个移位模式,如图2所示。我们在半注意力头中将组分区移位半个组大小。以总体8192个上下文长度为例,在模式1中,第一组从第1个到第2048个token进行自注意力。在模式2中,组分区移位1024。第一个注意力组从第1025个开始到第3072个token结束,而前1024个和最后1024个token属于同一组。我们在每个半自注意力头中分别使用模式1和模式2。这种方式不会增加额外的计算成本,但能够实现不同组之间的信息流。我们在表1中展示了它接近标准注意力基线的结果。

Consistency to Full Attention

现有的高效注意力设计也可以提高长上下文LLM的效率。然而,大多数这些设计并不适合长上下文微调。因为这些从头开始训练的Transformer与预训练中使用的标准全注意力存在差距。在表6中,我们展示了 S 2 S^2 S2-Attn 不仅能够实现高效的微调,还支持全注意力测试。尽管其他注意力机制也可以用于长上下文微调,但模型必须使用微调期间使用的注意力进行测试。移位防止了模型对特定注意力模式的过度拟合。

Easy Implementation

S 2 S^2 S2-Attn易于实现。它仅涉及两个步骤:

  1. 在半注意力头中移位token;
  2. 将特征从token维度转置到批次维度。

两行代码就足够了。我们在算法1中提供了一个PyTorch风格的代码示例。

Algorithm 1: Pseudocode of S2-Attn in PyTorch-like style.

Improved LoRA for Long Context

LoRA是一种高效且流行的将LLMs适应其他数据集的方法。与全微调相比,它节省了大量的可训练参数和内存成本。然而,将LLMs从短上下文长度适应到长上下文长度并不容易。我们观察到LoRA与全微调之间存在明显的差距。如表2所示,随着目标上下文长度的增加,LoRA与全微调之间的差距逐渐增大。并且,具有更大秩的LoRA无法缩小这个差距。

Finetuning normalization and embedding layers is crucial for low-rank long-context adaptation

为了弥合这一差距,我们为训练打开了嵌入层和归一化层。如表2所示,它们占用的参数有限,但对长上下文适应有显著效果。特别是对于归一化层,参数在整个Llama2 7B中仅占0.004%。在实验中,我们将这种改进版的LoRA称为LoRA+。

实验

主实验

Perplexity evaluation on proof-pile (Rae et al., 2020) test split

Maximum context length that we can fine-tune for various model sizes on a single 8× A100 machine

Topic retrieval evaluation with LongChat (Li et al., 2023)

Accuracy comparison on passkey retrieval between Llama2 7B and our 7B model fine-tuned on 32768 context length

消融实验

Ablation on fine-tuning steps in both full fine-tuning and LoRA+

Comparisons among S2-Attn and alternative attention patterns during fine-tuning

总结

在这项工作中,我们提出了LongLoRA,它能够高效地扩展LLMs的上下文长度,使其显著更大。

与标准全微调相比,LongLoRA具有更低的GPU内存成本和训练时间,同时精度损失最小。

在架构层面,我们提出了 S 2 S^2 S2-Attn,用于在训练过程中近似标准自注意力模式。 S 2 S^2 S2-Attn 易于实现,仅需两行代码。

此外,通过 S 2 S^2 S2-Attn 训练的模型在推理过程中保留了原始的标准注意力架构,使得大多数现有基础设施和优化可以重用。

在训练层面,我们通过可训练归一化和嵌入弥合了LoRA与全微调之间的差距。

我们的方法可以将Llama2 7B扩展到100k上下文长度,将70B模型扩展到32k上下文长度,在单个8×A100机器上实现。

我们还提出了一个长指令遵循数据集LongAlpaca,并使用LongLoRA进行了监督微调。

我们相信LongLoRA是一种通用方法,可以与更多类型的LLMs和位置编码兼容。

我们计划在未来工作中调查这些问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LabVIEW四边形阻抗继电器

基于LabVIEW开发了四边形阻抗继电器,该系统主要应用于电力系统的距离保护中。四边形阻抗继电器在克服短路点过渡电阻的影响及躲避负荷阻抗方面展现出优良的特性。通过LabVIEW图形化编程环境实现的该系统,具备用户友好界面和简便的操作流程,有…

计算机网络--路由器问题

一、路由器问题 1.计算下一跳 计算机网络--根据IP地址和路由表计算下一跳-CSDN博客 2.更新路由表 计算机网络--路由表的更新-CSDN博客 3.根据题目要求给出路由表 4.路由器收到某个分组,解释这个分组是如何被转发的 5.转发分组之路由器的选择 二、举个例子 …

Python安装(新手详细版)

前言 第一次接触Python,可能是爬虫或者是信息AI开发的小朋友,都说Python 语言简单,那么多学一些总是有好处的,下面从一个完全不懂的Python 的小白来安装Python 等一系列工作的记录,并且遇到的问题也会写出&#xff0c…

2025 年前端新技术如何塑造未来开发生态?

开发领域:前端开发 | AI 应用 | Web3D | 元宇宙 技术栈:JavaScript、React、ThreeJs、WebGL、Go 经验经验:6 年 前端开发经验,专注于图形渲染和 AI 技术 开源项目:AI智简未来、晓智元宇宙、数字孪生引擎 大家好&#x…

1-markdown转网页样式页面 --[制作网页模板] 【测试代码下载】

markdown转网页 将Markdown转换为带有样式的网页页面通常涉及以下几个步骤:首先,需要使用Markdown解析器将Markdown文本转换为HTML;其次,应用CSS样式来美化HTML内容。此外,还可以加入JavaScript以增加交互性。下面我将…

数据逻辑(十)——逻辑函数的两种标准形式

目录 1 最小项和最大项 1.1 最小项 1.2 最大项 2 逻辑函数的最小项之和 3 逻辑函数的最大项之积 4 最小项之和以及最大项之积的联系和应用场景 4.1 最小项之和以及最大项目之积的联系 4.2 最小项之和以及最大项之积的应用场景 逻辑函数的两种标准形式分别是以最小项之和…

【Ubuntu使用技巧】Ubuntu22.04无人值守Crontab工具实战详解

一个愿意伫立在巨人肩膀上的农民...... Crontab是Linux和类Unix操作系统下的一个任务调度工具,用于周期性地执行指定的任务或命令。Crontab允许用户创建和管理计划任务,以便在特定的时间间隔或时间点自动运行命令或脚本。这些任务可以按照分钟、小时、日…

鸿蒙Flutter实战:15-Flutter引擎Impeller鸿蒙化、性能优化与未来

Flutter 技术原理 Flutter 是一个主流的跨平台应用开发框架,基于 Dart 语言开发 UI 界面,它将描述界面的 Dart 代码直接编译成机器码,并使用渲染引擎调用 GPU/CPU 渲染。 渲染引擎的优势 使用自己的渲染引擎,这也是 Flutter 与其…

UniApp | 从入门到精通:开启全平台开发的大门

UniApp | 从入门到精通:开启全平台开发的大门 一、前言二、Uniapp 基础入门2.1 什么是 Uniapp2.2 开发环境搭建三、Uniapp 核心语法与组件3.1 模板语法3.2 组件使用四、页面路由与导航4.1 路由配置4.2 导航方法五、数据请求与处理5.1 发起请求5.2 数据缓存六、样式与布局6.1 样…

法拉利F80发布 360万欧元限量799辆 25年Q4交付

今日,法拉利旗下全新超级跑车——F80正式发布,新车将作为法拉利GTO和法拉利LaFerrari(参数丨图片) Aterta的继任者,搭载V6混合动力系统,最大综合输出功率高达1632马力。售价360万欧元,全球限量生…

【pytorch练习】使用pytorch神经网络架构拟合余弦曲线

在本篇博客中,我们将通过一个简单的例子,讲解如何使用 PyTorch 实现一个神经网络模型来拟合余弦函数。本文将详细分析每个步骤,从数据准备到模型的训练与评估,帮助大家更好地理解如何使用 PyTorch 进行模型构建和训练。 一、背景 …

电脑steam api dll缺失了怎么办?

电脑故障解析与自救指南:Steam API DLL缺失问题的全面解析 在软件开发与电脑维护的广阔天地里,我们时常会遇到各种各样的系统报错与文件问题,其中“Steam API DLL缺失”便是让不少游戏爱好者和游戏开发者头疼的难题之一。作为一名深耕软件开…

Conda 安装 Jupyter Notebook

文章目录 1. 安装 Conda下载与安装步骤: 2. 创建虚拟环境3. 安装 Jupyter Notebook4. 启动 Jupyter Notebook5. 安装扩展功能(可选)6. 更新与维护7. 总结 Jupyter Notebook 是一款非常流行的交互式开发工具,尤其适合数据科学、机器…

组合的能力

在《德鲁克最后的忠告》一书中,有这样一段话: 企业将由各种积木组建而成:人员、产品、理念和建筑。积木的设计组合至少和其供给一样重要。……对于一切程序、应用软件以及附件来说,重要的是掌握将已有的软件模块组合的能力&…

去掉el-table中自带的边框线

1.问题:el-table中自带的边框线 2.解决后的效果: 3.分析:明明在el-table中没有添加border,但是会出现边框线. 可能的原因: 由 Element UI 的默认样式或者表格的某些内置样式引起的。比如,<el-table> 会通过 border-collapse 或 border-spacing 等属性影响边框的显示。 4…

大模型与EDA工具

EDA工具&#xff0c;目标是硬件设计&#xff0c;而硬件设计&#xff0c;您也可以看成是一个编程过程。 大模型可以辅助软件编程&#xff0c;相信很多人都体验过了。但大都是针对高级语言的软件编程&#xff0c;比如&#xff1a;C&#xff0c;Java&#xff0c;Python&#xff0c…

【HarmonyOS之旅】基于ArkTS开发(一) -> Ability开发一

目录 1 -> FA模型综述 1.1 -> 整体架构 1.2 -> 应用包结构 1.3 -> 生命周期 1.4 -> 进程线程模型 2 -> PageAbility开发 2.1 -> 概述 2.1.1 ->功能简介 2.1.2 -> PageAbility的生命周期 2.1.3 -> 启动模式 2.2 -> featureAbility接…

BART:用于自然语言生成、翻译和理解的去噪序列到序列预训练

摘要&#xff1a; 我们提出了BART&#xff0c;一种用于预训练序列到序列模型的去噪自编码器。BART通过以下方式训练&#xff1a;(1) 使用任意的噪声函数对文本进行破坏&#xff0c;(2) 学习一个模型来重建原始文本。它采用了一种标准的基于Transformer的神经机器翻译架构&#…

Promise编码小挑战

题目 我们将实现一个 createImage 函数&#xff0c;该函数返回一个 Promise&#xff0c;用于处理图片加载的异步操作。此外&#xff0c;还会实现暂停执行的 wait 函数。 Part 1: createImage 函数 该函数会&#xff1a; 创建一个新的图片元素。将图片的 src 设置为提供的路径…

Dubbo扩展点加载机制

加载机制中已经存在的一些关键注解&#xff0c;如SPI、©Adaptive> ©Activateo然后介绍整个加载机制中最核心的ExtensionLoader的工作流程及实现原理。最后介绍扩展中使用的类动态编译的实 现原理。 Java SPI Java 5 中的服务提供商https://docs.oracle.com/jav…