LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

news2024/11/26 14:42:57

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS (Paper reading)

Edward H, Microsoft, arXiv2021, Cited: 354, Code, Paper

1. 前言

自然语言处理的一个重要范式是在通用领域数据上进行大规模预训练,然后根据特定任务或领域进行适应性训练。随着我们对模型进行更大规模的预训练,完全微调(重新训练所有模型参数)变得越来越不可行。以GPT-3 175B为例,部署独立的微调模型实例,每个实例有175B个参数,成本过高。我们提出了低秩适应(Low-Rank Adaptation,LoRA)方法,该方法冻结预训练模型的权重,并将可训练的秩分解矩阵注入到Transformer架构的每个层中,从而大大减少了下游任务的可训练参数数量。与使用Adam进行微调的GPT-3 175B相比,LoRA可以将可训练参数数量减少10000倍,GPU内存需求减少3倍。LoRA在RoBERTa、DeBERTa、GPT-2和GPT-3的模型质量上表现出与微调相当或更好的水平,它的可训练参数更少,训练吞吐量更高,并且与适配器相比,不会引入额外的推理延迟。我们还对语言模型适应性中的秩缺失进行了实证研究,这有助于揭示LoRA的有效性。

2. 整体思想

有个大模型,直接微调很难,本文的想法是适应预训练模型的权重。具体的方法是,如下图哈。在大模型旁边训练一个编码器(?),去适应特定的任务,然后将输出值和原来大模型融合就可以了。

3. 介绍

许多人试图通过仅调整一些参数或为新任务学习外部模块来减轻这个问题。这样,我们只需要为每个任务存储和加载一小部分特定于任务的参数,以及预训练模型,从而在部署时大大提高操作效率。然而,现有技术往往通过增加模型深度或减少模型可用的序列长度来引入推理延迟。

推理延迟(Inference Latency)是指从输入数据传入模型,到模型输出结果可用的时间间隔。
在机器学习中,模型的推理过程通常包括将输入数据传递给模型,模型执行计算并生成输出结果。
推理延迟是指从输入数据传入模型开始,到最终输出结果可用的时间,包括数据传输、计算和处理的时间。

更重要的是,这些方法通常无法与微调基线相匹配,存在效率和模型质量之间的权衡。学习得到的超参数化模型实际上存在于一个低内在维度上。我们假设模型在适应过程中的权重变化也具有低的 “内在秩”,从而提出了我们的低秩自适应(LoRA)方法。LoRA允许我们通过优化适应过程中密集层权重变化的秩分解矩阵来间接训练神经网络中的一些密集层,同时保持预训练权重不变,如图1所示。以GPT-3 175B为例,我们证明即使完整秩(即d)高达12,288,一个非常低的秩(即图1中的r可以是一或二)就足够了,使得LoRA在存储和计算效率上都非常高。

在这里插入图片描述
LoRA具有几个关键优势:
• 一个预训练模型可以被共享并用于构建许多不同任务的小型LoRA模块。我们可以冻结共享模型,并通过替换图1中的矩阵A和B来高效地切换任务,从而显著降低存储需求和任务切换开销。
• LoRA通过使用自适应优化器,使训练更加高效,并将硬件准入门槛降低了最多3倍,因为我们不需要计算大多数参数的梯度或维护优化器状态。相反,我们只优化注入的、远小于原始模型的低秩矩阵。
• 我们简单的线性设计使我们能够在部署时将可训练矩阵与冻结的权重合并,与完全微调的模型相比,不会引入推理延迟。
• LoRA与许多先前的方法是正交的,可以与其中许多方法结合使用,如前缀微调(prefix-tuning)。

3.1 问题声明

虽然我们的提议与训练目标无关,但我们将重点放在语言建模上,作为我们的动机案例。以下是对语言建模问题的简要描述,特别是在给定任务特定提示的情况下,对条件概率的最大化。

假设我们有一个由参数 Φ Φ Φ参数化的预训练自回归语言模型 P Φ ( y ∣ x ) P_Φ(y|x) PΦ(yx)。例如, P Φ ( y ∣ x ) P_Φ(y|x) PΦ(yx)可以是基于Transformer架构的通用多任务学习器,如GPT。考虑将这个预训练模型适应到下游的条件文本生成任务,例如摘要生成、机器阅读理解(MRC)和自然语言到SQL(NL2SQL)。每个下游任务由一个上下文-目标对的训练数据集表示: Z = ( x i , y i ) , i = 1 , . . , N Z = {(x_i, y_i)}, i=1,..,N Z=(xi,yi),i=1,..,N,其中xi和yi都是标记序列。例如,在NL2SQL中, x i x_i xi是一个自然语言查询, y i y_i yi是其对应的SQL命令;在摘要生成中, x i x_i xi是一篇文章的内容, y i y_i yi是其摘要。

在完全微调过程中,模型初始化为预训练的权重 Φ 0 Φ_0 Φ0,并通过反复追踪梯度来最大化条件语言建模目标,更新为 Φ 0 + ∆ Φ Φ_0 + ∆Φ Φ0+∆Φ。完全微调的一个主要缺点是对于每个下游任务,我们学习了一组不同的参数 ∆ Φ ∆Φ ∆Φ,其维度 ∣ ∆ Φ ∣ |∆Φ| ∣∆Φ∣等于 ∣ Φ 0 ∣ |Φ_0| Φ0。因此,如果预训练模型很大(例如GPT-3, ∣ Φ 0 ∣ |Φ_0| Φ0约为1750亿),存储和部署许多独立的微调模型实例可能具有挑战性,甚至不可行。在本文中,我们采用一种更加参数高效的方法,其中任务特定参数增量 ∆ Φ = ∆ Φ ( Θ ) ∆Φ = ∆Φ(Θ) ∆Φ=∆Φ(Θ)通过一个大小远小于 Φ 0 Φ_0 Φ0的参数集合 Θ Θ Θ进行编码,即 ∣ Θ ∣ < < ∣ Φ 0 ∣ |Θ| << |Φ_0| ∣Θ∣<<Φ0。因此,寻找 ∆ Φ ∆Φ ∆Φ的任务变成了对 Θ Θ Θ进行优化。我们建议使用低秩表示来对 ∆ Φ ∆Φ ∆Φ进行编码,这既具有计算效率又具有内存效率。当预训练模型为GPT-3 175B时,可训练参数的数量 ∣ Θ ∣ |Θ| ∣Θ∣可以小到 Φ 0 Φ_0 Φ0的0.01%。

3.2 相关工作

我们要解决的问题绝不是新问题。自迁移学习的出现以来,已经有数十种方法致力于使模型自适应更加参数和计算效率高。以语言建模为例,存在两种主要的高效自适应策略:添加适配器层(Adapter)或优化输入层激活的某些形式。然而,这两种策略都有其局限性,尤其是在大规模和对延迟敏感的生产场景中。

Prompt Engineering and Fine-Tuning:虽然GPT-3 175B可以通过仅使用少量额外的训练示例来调整其行为,但结果在很大程度上取决于输入提示。这需要一种经验性的技巧,即组合和格式化提示,以最大化模型在所需任务上的性能,这被称为提示工程或提示改进。微调是将预先在通用领域上进行预训练的模型重新训练到特定任务上。它的变种包括仅学习一部分参数,但实践者通常重新训练所有参数以最大化下游性能。然而,GPT-3 175B的庞大规模使得以通常的方式进行微调变得困难,这是由于它产生的大型检查点和高硬件准入门槛,因为它具有与预训练相同的内存占用。

参数高效的自适应:许多人提出在神经网络的现有层之间插入适配器层。我们的方法使用类似的瓶颈结构对权重更新施加低秩约束。关键的功能差异在于我们学习的权重可以在推理过程中与主要权重合并,因此不会引入任何延迟,而适配器层则不具备这个特性。适配器的一个现代扩展是COMPACTER,它基本上使用克罗内克积和一些预定的权重共享方案对适配器层进行参数化。类似地,将LoRA与其他基于张量积的方法结合可能有助于提高其参数效率,这留待将来的研究。最近,许多人提出了在没有进行微调的情况下优化输入词嵌入,类似于连续且可微的提示工程的泛化。

在深度学习中的低秩结构:低秩结构在机器学习中非常常见。许多机器学习问题具有一定的内在低秩结构。此外,已知对于许多深度学习任务,特别是那些具有过度参数化的神经网络,经过训练后学习到的神经网络会具有低秩性质。一些先前的工作甚至在训练原始神经网络时明确地施加了低秩约束;然而,据我们所知,这些工作中没有一项考虑了对冻结模型进行低秩更新以适应下游任务。在理论文献中,已知当底层概念类具有一定低秩结构时,神经网络优于其他经典学习方法,包括相应的(有限宽度)神经切向核。低秩自适应对于对抗训练可能是有用的。总之,我们相信我们提出的低秩自适应更新受到了文献的良好启发。

4. 方法

4.1 术语和约定

我们经常引用Transformer架构,并使用其维度的常规术语。我们将Transformer层的输入和输出维度大小称为 d m o d e l d_{model} dmodel。我们使用 W q , W k , W v W_q,W_k,W_v WqWkWv W o W_o Wo来表示自注意模块中的查询/键/值/输出投影矩阵。 W W W W 0 W_0 W0指的是预训练的权重矩阵, ∆ W ∆W W指的是在自适应过程中累积的梯度更新。我们使用 r r r表示LoRA模块的秩。我们使用Adam进行模型优化,并使用Transformer MLP前馈维度 d f f n = 4 × d m o d e l d_{ffn} = 4×d_{model} dffn=4×dmodel

4.2 低秩参数化的更新矩阵

神经网络包含许多密集层,这些层执行矩阵乘法。这些层中的权重矩阵通常具有满秩。在适应特定任务时,预训练的语言模型具有低的“内在维度”,即使在随机投影到较小子空间时也能有效学习。受此启发,我们假设权重的更新在适应过程中也具有低的“内在秩”。对于预训练的权重矩阵 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k,我们通过低秩分解 W 0 + ∆ W = W 0 + B A W_0 + ∆W = W_0 + BA W0+W=W0+BA来约束其更新,其中 B ∈ R d × r , A ∈ R r × k B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k} BRd×r,ARr×k。在训练过程中, W 0 W_0 W0被冻结并且不接收梯度更新,而 A A A B B B包含可训练参数。请注意, W 0 W_0 W0 ∆ W = B A ∆W = BA W=BA都与相同的输入进行乘法运算,它们的输出向量在坐标上求和。对于 h = W 0 x h = W_0x h=W0x,我们修改后的前向传递结果为:
h = W 0 x + ∆ W x = W 0 x + B A x h=W_0x+∆Wx=W_0x+BAx h=W0x+Wx=W0x+BAx
我们对A使用随机高斯初始化,对B使用零初始化,因此在训练开始时 ∆ W = B A ∆W = BA W=BA为零。然后,我们通过 α / r α/r α/r ∆ W x ∆W x Wx进行缩放,其中 α α α r r r中的一个常数。在使用Adam进行优化时,调整 α α α与调整学习率的过程大致相同,只需适当地缩放初始化即可。因此,我们简单地将 α α α设置为我们尝试的第一个 r r r,并且不对其进行调整。

一种更一般的微调形式允许训练预训练参数的子集。LoRA更进一步,不要求在适应过程中对权重矩阵的累积梯度更新具有满秩。这意味着,当将LoRA应用于所有权重矩阵并训练所有偏差时,通过将LoRA的秩 r r r设置为预训练权重矩阵的秩,我们大致恢复了完全微调的表达能力。换句话说,随着可训练参数的数量增加,训练LoRA大致收敛于训练原始模型,而基于适配器的方法收敛于多层感知器(MLP),而基于前缀的方法收敛于无法处理长输入序列的模型

4.3 应用loRA于Transformer

原则上,我们可以将LoRA应用于神经网络中的任何权重矩阵子集,以减少可训练参数的数量。在Transformer架构中,自注意力模块中有四个权重矩阵( W q 、 W k 、 W v 、 W o W_q、W_k、W_v、W_o WqWkWvWo),MLP模块中有两个权重矩阵。我们将 W q W_q Wq(或 W k 、 W v W_k、W_v WkWv)视为一个维度为 d m o d e l × d m o d e l d_{model}×d_{model} dmodel×dmodel的单个矩阵,尽管输出维度通常被切分成多个注意力头。出于简单和参数效率的考虑,我们将研究限制在仅适应下游任务的注意力权重上,并冻结MLP模块(因此它们在下游任务中不会被训练)。

Transformer中的哪些权重矩阵应该使用Lora?在有限的参数预算下,为了在下游任务中获得最佳性能,我们应该使用LoRA来适应哪些类型的权重?我们只考虑自注意力模块中的权重矩阵。在GPT-3 175B上,我们将参数预算设置为18M(在FP16中存储时大约为35MB),这对应于适应一种类型的注意力权重时的r = 8,或者适应两种类型时的r = 4,对于所有96个层。结果见表5。
在这里插入图片描述
将适应MLP层、LayerNorm层和偏差的经验研究留给未来的工作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/598574.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hightopo 使用心得(2)- 2D 图纸 GraphView,节点 Node, 连线 Edge,与基本动画 ht.Default.startAnim()

概括来说&#xff0c;用 HT for Web 做可视化主要分为两部分&#xff0c;也就是 2D 和 3D。这两部分需要单独创建。在它们被创建完成后&#xff0c;我们再把它们集成到一起。 HT for Web 的 2D 部分主要是指 ht.graph.GraphView (简称 GraphView&#xff0c;也就是 2D 图纸)。…

匿名管道通信

目录 一、进程通信原理 二、什么是管道 三、创建一个匿名管道 四、fork共享管道的原理 五、管道的特点 六、4中场景 一、进程通信原理 我们知道进程间相互独立&#xff0c;具有独立性。那么我们要实现两个进程之间的通信就需要&#xff0c;让这两个进程看到同一个文件。然…

设计模式-访问者模式

访问者模式 问题背景解决方案&#xff1a;传统方案 访问者模式基本介绍原理UML类图 使用访问者模式解决问题UML类图示例代码运行结果 注意事项和细节 问题背景 我们来制作一台电脑&#xff0c;他的硬件有CPU和磁盘&#xff0c;CPU和磁盘类都有一个常量作为他们各自的数据&…

java企业级信息系统开发学习笔记10 利用MyBatis实现关联查询

文章目录 一、学习目标&#xff08;一&#xff09;针对三张表关联查询&#xff08;二&#xff09;按班级编号查询班级信息&#xff08;三&#xff09;查询全部班级信息 二、创建数据库&#xff08;一&#xff09;创建教师表&#xff08;二&#xff09;创建班级表&#xff08;三…

Linux系统搭建Java的运行环境

目录 JDKTomcatMySQL JDK 对于Linux安装JDK有很多方法~ 这里就掌握最简单的办法—基于yum来进行安装~ yum是“包管理器”&#xff0c;相当于应用商店~ 首先&#xff0c;先搜索一下&#xff0c;看看yum上关于jdk有没有&#xff0c;以及叫啥名字~ 通过 yum list命令&#xff0…

六一亲子嘉年华 | 来迅镭激光过一个五彩缤纷的儿童节!

童年是梦&#xff0c;如七彩的画卷&#xff1b; 童年是诗&#xff0c;如璀璨的星空&#xff1b; 童年是歌&#xff0c;如跳跃的音符&#xff01; 在“六一”儿童节到来之际 为给员工及子女创造一个难忘的亲子时光 迅镭激光开展了六一亲子嘉年华主题活动 让孩子们在迅镭大家庭的…

Minigpt4实战搭建

简介 Minigpt4虽然放出了网页版但是使用后发现网页体验的话&#xff0c;由于并发量比较大&#xff0c;很容易突然卡顿的现象&#xff0c;所以下面我主要讲解一下如何进行本地部署。 之前文章已经介绍过Minigpt4了这里就不重复赘述了&#xff0c;不了解的可以去看看https://bl…

使用python开发“魂斗罗”游戏

使用python开发“魂斗罗”游戏 开发完整的魂斗罗&#xff08;Contra&#xff09;游戏是一个庞大的任务&#xff0c;它涉及到图形渲染、物理碰撞、敌人AI、游戏关卡等多个方面。在这个简短的交互中&#xff0c;我将向你展示一个基本的魂斗罗风格的游戏框架&#xff0c;你可以在此…

结构化文档发布的故事和性能调优

前阵子一个TW朋友跟我抱怨他们的文档发布很慢。正常发布需要一个晚上才能完成发布。中间如果出点错&#xff0c;就得重新发布&#xff0c;那么中间是漫长的等待。 不像MS Word或者InDesign这样所见即所得的软件&#xff0c;结构化文档源文件是XML格式的&#xff0c;就像计算机…

C语言——数据在内存中的存储(下)

数据在内存中的存储&#xff08;下&#xff09; 1. 浮点数在内存中的存储 浮点数家族&#xff1a; float double long double 浮点数的表示范围&#xff1a; 这里要引用float.h头文件 【实例一】 //输出结果是什么&#xff1f; int main() {int n 9;float *pFloat (float…

【代码规范】Google开源项目风格指南

系列综述&#xff1a; &#x1f49e;目的&#xff1a;本系列是个人整理为了秋招面试的&#xff0c;整理期间苛求每个知识点&#xff0c;平衡理解简易度与深入程度。 &#x1f970;来源&#xff1a;材料主要源于Google开源项目风格指南进行的&#xff0c;每个知识点的修正和深入…

基于卡尔曼滤波实现线性目标跟踪

文章目录 前言卡尔曼滤波基本推导运算 实现目标检测卡尔曼预测器ID分配器&#xff08;跟踪器&#xff09; 完整代码代码总结 前言 一个需求&#xff0c;在一个稳定的场景当中&#xff0c;实现目标检测计数算法。 任务点&#xff1a; 实现目标检测完成对不同类别的物品进行计数…

Three.js--》实现3d字体模型展示

目录 项目搭建 初始化three.js基础代码 设置环境纹理 加载字体模型 今天简单实现一个three.js的小Demo&#xff0c;加强自己对three知识的掌握与学习&#xff0c;只有在项目中才能灵活将所学知识运用起来&#xff0c;话不多说直接开始。 项目搭建 本案例还是借助框架书写…

前后端交互模型http协议Ajax简介

0、前言&#xff1a;本文只是对“前后端交互模型&http协议&Ajax简介”当中的理论&#xff0c;作用&#xff0c;方法进行总结说明&#xff0c;用于回顾知识&#xff0c;做概括总结&#xff0c;没有具体实现代码。 1、前后端交互模型&#xff1a; 前端发送请求&#xff…

信号机制上(信号概念、发送、定时器、信号捕捉、SIGCHLD)

一、信号机制 概念&#xff1a;信号是在软件层次上对中断机制的一种模拟&#xff0c;是一种异步通信方式 所有信号的产生及处理全部都是由内核完成的 信号的产生&#xff1a; 1 按键产生 2 系统调用函数产生&#xff08;比如raise&#xff0c; kill&#xff09; 3 硬件异…

连接MQTT服务端

MQTT客户端之间要想实现通讯&#xff0c;必须要通过MQTT服务端。因此MQTT客户端无论是发布消息还是订阅消息&#xff0c;首先都要连接MQTT服务端。 MQTT客户端连接服务端一共有两步。 第一步&#xff08;CONNECT请求&#xff09; 首先MQTT客户端将会向服务端发送连接请求。该…

HBase 的关键流程解析

前言 本文隶属于专栏《大数据技术体系》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见大数据技术体系 正文 HBase 客户端会将查询过的 HRegion 的位置信息…

【Python爬虫】采集电商商品评价信息

目录 一、数据采集逻辑二、数据Schema三、数据爬取1.导入库2.对爬虫程序进行伪装3.抓取商品评论信息4.防止反爬&#xff0c;每爬取一页数据后&#xff0c;设置程序休眠环节 四、数据存储1. 存储到csv 2.存储到数据库 一、数据采集逻辑 在进行数据采集之前&#xff0c;明确哪些…

Linux下C语言文件描述符操作(dup / dup2 / sendfile / splice / tee)

Linux的哲学是一切皆文件&#xff0c;而操作文件是通过文件描述符来进行。本文梳理一下dup / dup2 / sendfile / splice/ tee函数对文件描述符的操作。 目录 1.dup 2.dup2 3.sendfile 4.splice 5.tee 1.dup #include <unistd.h> int dup(int fd); 复制一个现有的…

Java基础(maven)——maven新建项目 常用IO工具 Durid数据库工具 案例

目录 引出用Maven建项目0.Maven配置方式1.io流的工具IOUtils/FileUtils1&#xff09;可以读文件、按照行读、读网页等&#xff1b;2&#xff09;配合hasmap进行简体繁体转换 2.durid数据库连接工具1&#xff09;创建连接&#xff0c;durid进行连接管理2&#xff09;查询的方式q…