LoRA微调

news2024/11/18 1:41:05

论文:LoRA: Low-Rank Adaptation of Large Language Models

实现:microsoft/LoRA: Code for loralib, an implementation of “LoRA: Low-Rank Adaptation of Large Language Models” (github.com)

摘要

自然语言处理的一个重要的开发范式包括:

  1. 对通用领域数据进行大规模的预训练;
  2. 对特定任务或领域的适应。

问题:当预训练的模型越来越大,全参数的微调(full fine-tuning)变得比较困难了。

解决方法:Low-Rank Adaptation,简称LoRA,其冻结了预训练的模型权重,并将可训练的秩分解矩阵注入Transformer架构的每一层,大大减少了下游任务的可训练参数的数量。

简介

[1804.08838] Measuring the Intrinsic Dimension of Objective Landscapes

上述文章表明,学习到的过度参数化模型权重实际上存在于一个较低的内在维度空间上。我们假设模型适应过程中权重的变化也具有较低的“内在秩”(也就是只在内在的低维空间中变化),从而提出了低秩适应(LoRA)方法。LoRA允许我们通过优化适应过程中密集层变化的秩分解矩阵,间接地训练神经网络中的一些密集层,同时冻结预训练权重:

低秩适应微调示意图

LoRA有几个关键的优势:

  • 一个预训练模型可以被共享,并用于为不同的任务构建许多小型的LoRA模块。我们可以通过替换低秩适应示意图中的矩阵A和矩阵B来冻结共享模型并有效地切换任务,从而显著地减少了存储需求和任务切换开销。
  • 当使用自适应优化器时,LoRA使训练更高效,并将硬件准入门槛降低了3倍,因为我们不需要为大多数的参数计算梯度或维护其优化器状态。相反,我们只需要优化注入的、小得多的低秩矩阵。
  • 我们简单的线性设计允许我们在部署时通过构造将可训练矩阵与冻结权重合并,与完全微调的模型相比,不会引入推理延迟
  • LoRA与许多先前的方法互不影响,并且可以与其中的许多方法结合起来,比如前缀调优(prefix-tuning)。

问题陈述

LoRA并不特定于某个具体的训练目标,这里以语言建模(language modeling)问题为用例进行问题描述。

给定一个以 Φ \Phi Φ为参数的预训练自回归语言模型 P Φ ( y ∣ x ) P_\Phi(y|x) PΦ(yx)。比如, P Φ ( y ∣ x ) P_\Phi(y|x) PΦ(yx)可以是一个像GPT一样的基于Transformer的通用多任务学习器。考虑将这个预训练模型适应于下游的条件文本生成任务,如摘要、机器阅读理解(MRC)和自然语言转SQL(NL2SQL)。每个下游任务都由一个上下文-目标对训练数据集表示: Z = { ( x i , y i ) } i = 1 , . . . , N \mathcal{Z}=\{(x_i,y_i)\}_{i=1,...,N} Z={(xi,yi)}i=1,...,N,其中 x i x_i xi y i y_i yi是token序列。例如,在NL2SQL中, x i x_i xi是一个自然语言查询, y i y_i yi是它对应的SQL命令;对于摘要, x i x_i xi是一篇文章的内容, y i y_i yi是它的摘要。

在全微调过程中,模型被初始化为预训练权重 Φ 0 \Phi_0 Φ0,并通过不断累积梯度最终更新为 Φ 0 + Δ Φ \Phi_0 + \Delta\Phi Φ0+ΔΦ,以最大化条件语言建模目标函数:

max ⁡ Φ ∑ ( x , y ) ∈ Z ∑ t = 1 ∣ y ∣ log ⁡ ( P Φ ( y t ∣ x , y < t ) ) \max _{\Phi} \sum_{(x, y) \in \mathcal{Z}} \sum_{t=1}^{|y|} \log \left(P_{\Phi}\left(y_{t} \mid x, y_{<t}\right)\right) Φmax(x,y)Zt=1ylog(PΦ(ytx,y<t))

全微调的一个主要缺点是,对于每个下游任务,都需要学习了一组不同的参数 Δ Φ \Delta\Phi ΔΦ,其维数 ∣ Δ Φ ∣ |\Delta\Phi| ∣ΔΦ∣等于 ∣ Φ 0 ∣ |\Phi_0| Φ0。因此,如果预训练模型很大(比如175B的GPT-3),那么存储和部署许多独立的微调模型实例各方面的开销和压力会比较大。

本文采用了一种更加参数高效(parameter-efficient)的方法,将任务特定的参数增量 Δ Φ = Δ Φ ( Θ ) \Delta\Phi=\Delta\Phi\left(\Theta\right) ΔΦ=ΔΦ(Θ)进一步用小得多的参数集 Θ \Theta Θ进行编码,其中 ∣ Θ ∣ ≪ ∣ Φ 0 ∣ |\Theta| \ll |\Phi_0| ∣Θ∣Φ0。所以,寻找 Δ Φ \Delta\Phi ΔΦ的任务变成了对 Θ \Theta Θ的优化:

max ⁡ Θ ∑ ( x , y ) ∈ Z ∑ t = 1 ∣ y ∣ log ⁡ ( P Φ 0 + Δ Φ ( Θ ) ( y t ∣ x , y < t ) ) \max _{\Theta} \sum_{(x, y) \in \mathcal{Z}} \sum_{t=1}^{|y|} \log \left(P_{\Phi_0 + \Delta\Phi(\Theta)}\left(y_{t} \mid x, y_{<t}\right)\right) Θmax(x,y)Zt=1ylog(PΦ0+ΔΦ(Θ)(ytx,y<t))

现有方法存在的问题

两种高效适应下游任务的策略:

  • 添加适配器层
  • 对输入层做某种形式的优化

存在的问题:

  • 适配器层引入了推理延迟
  • 直接优化提示是困难的

本文的方法

虽然本文中只关注Transformer语言模型中的某些权重作为用例,但该方法适用于深度学习模型中的任何密集层。

低秩参数化更新矩阵(LOW-RANK-PARAMETRIZED UPDATE MATRICES)

当适应一个特定的任务时,论文Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning表明了预训练语言模型具有较低的“内在维度”,尽管随机投影到更小的子空间,但仍然可以有效地学习。受此启发,做出假设:在适应下游任务的过程中,权重的更新也有一个较低的“内在秩”。对于预训练的权重矩阵 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k,使用低秩分解 W 0 + Δ W = W 0 + B A W_0+\Delta W = W_0+BA W0+ΔW=W0+BA表示后者来约束其更新,其中 B ∈ R d × r B \in \mathbb{R}^{d \times r} BRd×r A ∈ R r × k A \in \mathbb{R}^{r \times k} ARr×k和秩 r ≪ min ⁡ ( d , k ) r \ll \min(d,k) rmin(d,k)

在微调过程中, W 0 W_0 W0被冻结,不接收梯度更新,而 A A A B B B包含可训练的参数。 W 0 W_0 W0 Δ W = B A \Delta W = BA ΔW=BA都与相同的输入相乘,它们各自的输出向量按坐标求和。

对于 h = W 0 x h = W_0x h=W0x,修改后的正向传播为:

h = W 0 x + Δ W x = W 0 x + B A x h = W_0x + \Delta Wx = W_0x + BAx h=W0x+ΔWx=W0x+BAx

可训练参数的初始化:

  • A A A:随机高斯
  • B B B:0

所以 Δ W = B A \Delta W = BA ΔW=BA在训练开始时为零。

We then scale Δ W x \Delta Wx ΔWx by α r \frac{\alpha}{r} rα, where α \alpha α is a constant in r r r. When optimizing with Adam, tuning α \alpha α is roughly the same as tuning the learning rate if we scale the initialization appropriately. As a result, we simply set α \alpha α to the first r r r we try and do not tune it. This scaling helps to reduce the need to retune hyperparameters when we vary r r r.

  • A Generalization of Full Fine-tuning

    一种更一般的微调形式允许训练预训练参数的一个子集。LoRA更进一步,不需要权重矩阵的累积梯度更新在适应过程中具有全秩。这意味着,当将LoRA应用于所有权重矩阵并训练所有偏置项时,通过将LoRA的秩 r r r设置为预训练权重矩阵的秩,可以大致恢复完全微调的表达性。换句话说,当增加可训练参数的数量时,训练LoRA将大致收敛为训练原始模型,而基于适配器的方法和基于前缀的方法则分别收敛到一个MLP和一个不能接受长输入序列的模型。

  • No Additional Inference Latency

    当在生产环境中部署时,可以显式地计算和存储任务特定的权重 W = W 0 + B A W = W_0 + BA W=W0+BA,然后将该权重加载进模型并像往常一样执行推理。注意 W 0 W_0 W0 B A BA BA都在 R d × k \mathbb{R}^{d \times k} Rd×k中。当需要切换到另一个下游任务时,可以通过减去 B A BA BA来恢复 W 0 W_0 W0,然后加上一个不同的 B ′ A ′ B^{\prime}A^{\prime} BA,这是一个快速的内存开销很少的操作。重要的是,这确保了与通过构造进行微调的模型相比,这种方式在推理过程中没有引入任何额外的延迟。

在Transformer上应用LoRA

原则上,LoRA可以应用于神经网络中的权重矩阵的任何子集,以减少可训练参数的数量。在Transformer架构中,自注意模块中有四个权重矩阵( W q W_q Wq W k W_k Wk W v W_v Wv W o W_o Wo),在MLP模块中有两个。我们将 W q W_q Wq(或 W k W_k Wk W v W_v Wv)视为一个形状为 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel的单一矩阵,即使输出维通常被切分成注意力头。

在这里插入图片描述

实际的获益和局限性

获益
  • 最显著的好处来自于内存和存储使用量的减少。对于使用Adam训练的大型Transformer网络,如果 r ≪ d m o d e l r \ll d_{model} rdmodel,则VRAM使用量减少 2 / 3 2/3 2/3,因为不需要存储冻结参数的优化器状态。在GPT-3 175B上,训练期间的VRAM消耗从1.2TB减少到350GB。当 r = 4 r = 4 r=4和只对Query和Value投影矩阵进行调整时,检查点的大小减少了大约10000倍(从350GB减少到35MB)。这让我们可以使用少得多的gpu进行训练,并极大地避免I/O瓶颈。
  • 另一个好处是,通过只切换LoRA的权重,而不是所有的参数,可以在部署后以低得多的开销在不同任务间切换。
  • 我们还观察到,与完全微调相比,在GPT-3 175B上的训练速度提高了25%,因为不需要计算绝大多数参数的梯度。
局限性

例如,如果选择将 A A A B B B吸收到 W W W中以消除额外的推理延迟,那么在单次正向传递中批量处理具有不同 A A A B B B的不同任务的输入是很难的。尽管在延迟不是很重要的情况下,可以不合并权重并动态选择用于批处理中的样本的LoRA模块。

理解低秩更新

作者进行了一系列的实证研究来回答以下问题:

  1. 给定一个参数预算约束,在预训练的Transformer网络中应该适应权重矩阵的哪个子集以最大化下游性能?
  2. “最优”的适应矩阵 Δ W \Delta W ΔW真的是秩亏的吗?如果是这样,在实践中使用什么秩比较好?
  3. Δ W \Delta W ΔW W W W之间有什么关系? Δ W \Delta W ΔW W W W高度相关吗?与 W W W相比, Δ W \Delta W ΔW有多大?

我们应该将LORA应用到Transformer中的哪些权重矩阵?

给定有限的参数预算,应该使用LoRA调整哪些类型的权重才能在下游任务上获得最佳性能?这里只考虑自注意力模块中的权重矩阵。在GPT-3 175B上设置了18M的参数预算(如果以FP16存储,大约为35MB),对于所有96层,如果适应一种类型的注意力权重,则对应于 r = 8 r = 8 r=8;如果适应两种类型,则对应于 r = 4 r = 4 r=4。以下是实验结果:

在这里插入图片描述

可以看到,将所有参数放入 Δ W q \Delta W_q ΔWq Δ W k \Delta W_k ΔWk会导致性能显著降低,而同时调整 W q W_q Wq W v W_v Wv会产生最佳结果。这表明,即使是值为4的秩也能捕获 Δ W \Delta W ΔW中足够的信息,因此适应更多的权重矩阵比使用更大的秩适应单一类型的权重更好

对于LoRA最优的秩 r r r是什么

在这里插入图片描述

可以看出,使用一个非常小的 r r r就足以让LoRA表现得很好了,这表明更新矩阵 Δ W \Delta W ΔW可能有一个非常小的“内在秩”。但是不能指望一个小的 r r r适用于每个任务或数据集。假设下游任务使用的语言与预训练所使用的语言不同,则重新训练整个模型(类似于 r = d m o d e l r = d_{model} r=dmodel的LoRA)肯定会优于 r r r较小的LoRA。为了进一步支持这一发现,作者检查了使用不同的 r r r和不同随机种子学习到的子空间的重叠情况,得出结论:增加 r r r不覆盖一个更有意义的子空间,这表明一个低秩适应矩阵是足够的。

适应矩阵 Δ W \Delta W ΔW W W W相比如何?

通过计算 U T W V T U^{\mathsf{T}}WV^{\mathsf{T}} UTWVT W W W投影到 Δ W \Delta W ΔW r r r维子空间上,其中 U U U/ V V V Δ W \Delta W ΔW的左/右奇异向量矩阵,然后计算相应的Frobenius norm。作为比较,我们还将 U U U V V V替换为 W W W或一个随机矩阵的前 r r r个奇异向量后计算 ∥ U T W V T ∥ F \parallel U^{\mathsf{T}}WV^{\mathsf{T}}\parallel_F UTWVTF的值。结果如下:

在这里插入图片描述

从上表可以得出几个结论:

  1. 与随机矩阵相比, Δ W \Delta W ΔW W W W有更强的相关性,这表明 Δ W \Delta W ΔW放大了 W W W中已经存在的一些特征。
  2. Δ W \Delta W ΔW没有重复 W W W靠前的奇异向量方向,而是只放大了 W W W中没有强调的方向。
  3. 放大系数相当大:当 r = 4 r=4 r=4时,为 21.5 ≈ 6.91 / 0.32 21.5 \approx 6.91/0.32 21.56.91/0.32

这表明,低秩适应矩阵潜在地放大了特定下游任务的重要特征,这些特征是通用预训练模型学习到但并未注重的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1594780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

51单片机实验03-单片机定时/计数器实验

目录 一、实验目的 二、实验说明 1、51单片机有两个16位内部计数器/定时器&#xff08;C/T&#xff0c; Counter/Timer&#xff09;。 2、模式寄存器TMOD 1) M1M0工作模式控制位&#xff1b; 2) C/T定时器或计数器选择位&#xff1a; 3&#xff09;GATE定时器/计数器运行…

YOLOv1精读笔记

YOLO系列 摘要1. 将目标检测视为一个回归问题2. 定位准确率不如 SOTA&#xff0c;但背景错误率更低3. 泛化能力强 1.引言1.1 YOLO 速度很快1.2 全局推理 2. Unified Detection2.1 网络设计2.2 训练YOLOv1模型损失函数的选择和其潜在的问题YOLOv1模型如何改进其损失函数来更好地…

关于机器学习/深度学习的一些事-答知乎问(三)

可解释人工智能如何进行创新&#xff1f; &#xff08;1&#xff09;解释方法结合。现有的研究较少关注如何将不同的解释方法结合起来&#xff0c;未来可以考虑将不同的 解释方法结合在一起&#xff0c;如正反结合&#xff0c;事实解释侧重于 “为什么”&#xff0c;反事实解释…

回归预测 | Matlab基于RIME-SVR霜冰算法优化支持向量机的数据多输入单输出回归预测

回归预测 | Matlab基于RIME-SVR霜冰算法优化支持向量机的数据多输入单输出回归预测 目录 回归预测 | Matlab基于RIME-SVR霜冰算法优化支持向量机的数据多输入单输出回归预测预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab基于RIME-SVR霜冰算法优化支持向量机的数…

边缘计算【智能+安全检测】系列教程--使用OpenCV+GStreamer实现真正的硬解码,完全消除马赛克

通过现有博客的GST_URL = "rtspsrc location=rtsp://admin:abcd1234@192.168.1.64:554/h264/ch01/main/av_stream latency=150 ! rtph264depay ! avdec_h264 ! videorate ! videoconvert ! appsink sync=false" GStreamer的解码方式解码,大多情况应该存在上图马赛克…

项目实现:Boost搜索引擎

一.项目背景 当前已经有许多上市公司做了搜索引擎&#xff0c;比如说百度&#xff0c;搜狗&#xff0c;360等等&#xff0c;这些项目都是很大的项目&#xff0c;有很高的技术门槛&#xff0c;我们自己实现一个完整的搜索引擎是不可能的&#xff0c;但是我们可以写一个简单的搜…

Springboot+Vue项目-基于Java+MySQL的高校心理教育辅导系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

鸿蒙开发快速入门

基本概念 ArkTS 因为ArkTS是基于Type Script扩展而来&#xff0c;是Type Script的超集&#xff0c;所以也可以关注一下Type Script的语法来理解ArkTS的语法 ArkUI HarmonyOS提供了一套UI开发框架&#xff0c;即方舟开发框架&#xff08;ArkUI框架&#xff09;。方舟开发框架…

Kafka -基本概念

认识Kafka kafka是一个多分区、多副本、基于zookeeper协调的分布式消息系统。 扮演角色 消息系统存储系统&#xff1a;把消息持久化到磁盘&#xff0c;相比于其他基于内存存储的系统而言&#xff0c;有效降低了数据丢失的风险。流式处理平台 基本概念 kafka的体系结构&…

esp32-通过wifi使用timelib库同步时间(三)

库的安装 本文基于platformIO&#xff0c;安装较为简单如下图 实例代码 完整代码如下&#xff0c;如果时间获取超时请使用time1.aliyun.com获取时间。 /** Time_NTP.pde* Example showing time sync to NTP time source** This sketch uses the Ethernet library*/#include …

Ubuntu 20.04.06 PCL C++学习记录(二十一)【切记使用rm * -rf前先确认是否是对应文件夹】

[TOC]PCL中点云分割模块的学习 学习背景 参考书籍&#xff1a;《点云库PCL从入门到精通》以及官方代码PCL官方代码链接,&#xff0c;PCL版本为1.10.0&#xff0c;CMake版本为3.16&#xff0c;测试点云下载地址 学习内容 根据欧几里得距离和需要保持的用户可自定义条件对点进…

5GNR刷题

5G帧结构 5G NR帧结构的基本时间单位是( C ) A) subframe B) slot C) Tc D) symbol 5G无线帧长是多少ms&#xff08;B&#xff09; A) 5 B) 10 C) 20 D) 40 下面哪种子载波间隔是中国移动白皮书中规定必选(B ) A) 15KHz B) 30KHz C) 60KHz D) 120KHz 5G参数集包含哪…

ASP.NET基于Ajax+Lucene构建搜索引擎的设计和实现

摘 要 通过搜索引擎从互联网上获取有用信息已经成为人们生活的重要组成部分&#xff0c;Lucene是构建搜索引擎的其中一种方式。搜索引擎系统是在.Net平台上用C#开发的&#xff0c;数据库是MSSQL Server 2000。主要完成的功能有&#xff1a;用爬虫抓取网页&#xff1b;获取有效…

什么是JAVA面向对象

一&#xff0c;什么是面向对象&#xff1a; 我们以前的项目都是面向过程的&#xff0c;一个完整的项目所有的代码都写在一个类里 这就叫面向过程。 面向对象&#xff0c;是指在写大型项目时&#xff0c;多人分工合作&#xff0c;为了代码看上去简洁美观&#xff0c;会将不同的…

常见的垃圾回收算法

文章目录 1. 标记清除算法2. 复制算法3. 标记整理算法4. 分代垃圾回收算法 1. 标记清除算法 核心思想&#xff1a; 标记阶段&#xff0c;将所有存活的对象进行标记。Java中使用可达性分析算法&#xff0c;从GC Root开始通过引用链遍历出所有存活对象。清除阶段&#xff0c;从…

详解拷贝构造

拷贝构造的功能 写法&#xff1a; 拷贝构造函数的参数为什么是引用类型 系统自动生成的拷贝构造函数 拷贝构造的深拷贝与浅拷贝 概念 浅拷贝&#xff1a; 深拷贝 小结 拷贝构造的功能 拷贝构造函数可以把曾经实例化好的对象的数据拷贝给新创建的数据 &#xff0c;可见…

书生·浦语大模型-第五节课笔记/作业

笔记 作业 原7b模型问题耗时: 4.5s lmdeploy推理耗时: 0.43s 不知道是否因为没有正确的输出 lmdeploy kv-cache推理耗时&#xff1a;2.9s 推理时新增 past_key_values 参数&#xff0c;该参数就会以追加方式保存每一轮的K V值。kvcache变量内容为((k,v), (k,v), …, (k,v))…

Node.js 中的 RSA 加密、解密、签名与验证详解

引言 在现代的网络通信中&#xff0c;数据安全显得尤为重要。RSA加密算法因其非对称的特性&#xff0c;广泛应用于数据的加密、解密、签名和验证等安全领域。本文将详细介绍RSA算法的基本原理&#xff0c;并结合Node.js环境&#xff0c;展示如何使用内置的crypto模块和第三方库…

【python】python抓取古诗文内容保存(源码)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

下载了恶意软件怎么办,用这个软件可以解决 Mac电脑卸载软件 MacBook查杀病毒

随着苹果电脑在全球市场的普及&#xff0c;它们也日益成为恶意软件制作者的目标。这种趋势打破了许多人认为Mac系统不易受到病毒或恶意软件影响的传统观念。事实上&#xff0c;苹果电脑面临的恶意软件和安全威胁正在不断增多&#xff0c;这要求用户采取更加积极的措施来保护自己…