大模型常见的LoRA算法原理、实现和运用详解

news2024/9/24 17:17:06

文章目录

  • 1. 前言
  • 2. 算法介绍
    • 2.1 微调
    • 2.2 核心思想
  • 3. 几个关键问题
  • 4. 源码
  • 5. 实际应用
  • 6. 总结
  • 7. 参考


1. 前言

本文是常用算法的快速浏览入门(扫盲),结合论文+代码,从原理、实现以及实际应用上深入介绍。

LoRA(Low-Rank Adaptation)是一种用于在预训练模型基础上进行高效微调(Fine-Tuning)的算法,特别适用于大规模语言模型(LLMs)。LoRA 通过引入低秩矩阵的方式来适应和调整模型参数,从而在保持预训练模型原有能力的同时,显著减少微调的计算成本和存储需求。(以上如果看不明白,往后看完一定能懂!)

简言 LoRA:

(1)适用:预训练之后的微调阶段

(2)优势:减少计算成本和存储需求,同时不引入推理延迟(Inference Latency),方便在不同的微调任务灵活切换

论文原文:https://arxiv.org/abs/2106.09685

2. 算法介绍

2.1 微调

首先理解,为什么大模型需要微调。
(1)预训练的语言模型通常在大规模的通用语料库上进行训练,具备广泛的语言理解能力。微调的目的是让这些模型能够适应特定的任务,如情感分析、文本分类、机器翻译、问答系统等。
(2)某些应用可能涉及到专业领域的语言和术语,如医学、法律、金融等。通过在领域特定的数据集上进行微调,模型能够更好地理解和处理这些特定领域的语言和内容。

我们假设预训练得到的参数权重为 W 0 W_0 W0,在微调阶段,更新权重:

W = W 0 + Δ W W = W_0 + \Delta W W=W0+ΔW

其中 Δ W \Delta W ΔW 是微调训练阶段更新的权重大小。

如果直接训练,当预训练模型的 W 0 W_0 W0 矩阵的参数量太大时,导致微调 Δ W \Delta W ΔW 需要计算梯度并维护如此大量的优化器状态,对于动辄几百上千亿的参数量,会消耗大量存储空间。

2.2 核心思想

LoRA 将 Δ W \Delta W ΔW 进行矩阵变换(忽略缩放 α / r \alpha /r α/r等):
Δ W = B A \Delta W = BA ΔW=BA

其中 Δ W ∈ R d × k \Delta W \in \mathbb{R^{d \times k}} ΔWRd×k B ∈ R d × r B \in \mathbb{R^{d\times r}} BRd×r A ∈ R r × k A \in \mathbb{R^{r\times k}} ARr×k r ≪ m i n ( d , k ) r\ll min(d, k) rmin(d,k) r r r 取值一般 1,2,4,8等。

如论文中的下图:

其中 Δ W \Delta W ΔW

在微调训练中,仅需要计算右边的 A A A B B B 矩阵,而不需要变动左边的 W W W(预训练的权重)

这样的权重参数总量 Θ \Theta Θ
Θ = ( d + k ) × r ≪ d × k \Theta=(d+k) \times r \ll d\times k Θ=(d+k)×rd×k

需要维护的权重参数量远远降低了,这个思路是不是有点类似 SVD(奇异值分解)。

论文中表明在 GPT-3 175B 模型上,LoRA 将训练期间的 VRAM 消耗从1.2TB 减少到 350GB。一般对于 Transformer 可以减少 2/3 的内存占用。

训练过程:

3. 几个关键问题

以下也是论文中提到的在更新权重时的几个重要问题,下面结合本人的理解进行阐述!

3.1 LoRA如何应用在Transformer上

Transformer 的详细介绍参考:

《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》

《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

Transformer 有几个权重变换矩阵, W q W_q Wq W k W_k Wk W v W_v Wv,实际应用 LoRA 是作用在哪些矩阵效果最佳呢?

在 WikiSQL 和 MultiNLI 数据集进行验证:

根据实验结果,得出以下结论:

(1)只作用于 W q W_q Wq 或 只作用于 W k W_k Wk 效果最差;

(2)同时作用于 W q W_q Wq W v W_v Wv 效果最好;

(3)秩大、作用于单一矩阵,不如秩小、作用于多种矩阵。

3.2 LoRA的最佳秩r如何选择

以下是实验:

根据实验结果,得出以下结论:

(1)秩非常小都表现出很好的效果,表明更新矩阵 Δ W \Delta W ΔW 可能具有非常小的“内在秩”;

(2)LoRA 作用于越多的矩阵,表现效果越好;

(3)对于简单模型一味增大秩不具备太大的意义,r=4 和 r=64 效果差距不大,但是 r=64 反而消耗了更大的资源;

经验法则

在实践经验中,选择 r r r 的典型范围在 1 到 64 之间(不绝对)。这个范围内的 r r r 通常可以在性能和效率之间取得较好的平衡。

实验确定

  1. 初始尝试:可以从 r = 4 r = 4 r=4 r = 8 r = 8 r=8 开始。这些值通常可以提供良好的性能,同时保持低计算成本。

  2. 性能评估:在验证集上评估模型性能,记录各个 r r r 值对应的性能指标。

  3. 优化调整:根据评估结果,逐步调整 r r r,如从 r = 4 r = 4 r=4 增加到 r = 128 r = 128 r=128(或更大),观察性能提升情况。

3.3 W W W Δ W \Delta W ΔW有什么关系

论文通过计算降维后的 W W W Δ W \Delta W ΔW 的相似程度,得出以下结论:

(1)与随机矩阵相比, Δ W \Delta W ΔW W W W 的相关性更强,这表明 Δ W \Delta W ΔW 放大了 W W W 中已经存在的一些特征;

(2)低秩矩阵潜在地放大了特定下游任务的重要特征,这些特征在一般的预训练模型中被学习但没有被强调。

4. 源码

源码:https://github.com/microsoft/LoRA

以 Linear 为例,代码比较简单:

5. 实际应用

LoRA 集成到 peft 库中。使用 LoRA,需要增加一个配置,其他代码没有变化。

(1)配置 LoRA 参数:

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
# Define LoRA Config
lora_config = LoraConfig(
 r = LORA_R, # the dimension of the low-rank matrices
 lora_alpha = LORA_ALPHA, # scaling factor for the weight matrices
 lora_dropout = LORA_DROPOUT, # dropout probability of the LoRA layers
 bias="none",
 task_type="CAUSAL_LM",
 target_modules=["query_key_value"],
)

# Prepare int-8 model for training - utility function that prepares a PyTorch model for int8 quantization training. <https://huggingface.co/docs/peft/task_guides/int8-asr>
model = prepare_model_for_int8_training(model)
# initialize the model with the LoRA framework
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

(2)训练部分(不变):

from transformers import TrainingArguments, Trainer
import bitsandbytes

EPOCHS = 3
LEARNING_RATE = 1e-4 
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    logging_steps=100,
    num_train_epochs=EPOCHS,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_data_tokenized,
    eval_dataset=val_data_tokenized,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train()
# only saves the incremental   PEFT weights (adapter_model.bin) that were trained, meaning it is super efficient to store, transfer, and load.
trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)
# save the full model and the training arguments
trainer.save_model(MODEL_SAVE_FOLDER_NAME)
trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)

6. 总结

(1)低秩近似:LoRA 假设预训练模型的权重变化可以通过低秩矩阵来表示。这意味着相对于直接微调所有模型参数,通过低秩矩阵的调整,可以用更少的参数达到近似效果。

(2)参数效率:LoRA 只引入了少量的附加参数,这些参数是低秩矩阵的元素。由于这些矩阵的秩较低,所需的附加参数数量远小于模型的原始参数数量,显著减少了存储和计算开销。

(3)模块化设计:LoRA 通过在模型的特定层或模块中插入低秩矩阵,使其易于集成到各种预训练模型中,不需要对模型的原始结构进行大幅度修改。

LoRA 的设计和应用有效地解决了大规模模型微调中的高成本问题,使得在资源有限的环境中进行高效的模型适应成为可能。

7. 参考

[1] https://arxiv.org/abs/2106.09685


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1910303.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

升级iOS18有问题?学会这2招能解决90%iOS问题!

在 iOS 18beta发布后&#xff0c;有部分朋友升级后表示遇到了各种奇怪问题&#xff0c;比如升级卡在Apple Logo&#xff0c;黑屏&#xff0c;无限重启&#xff0c;卡在恢复模式&#xff0c;程序闪退&#xff0c;电池消耗过快&#xff0c;发烫等问题。 于是&#xff0c;小编决定…

HTML(27)——渐变

渐变是多个颜色逐渐变化的效果&#xff0c;一般用于设置盒子模型 线性渐变 属性&#xff1a;background-image : linear-gradient( 渐变方向 颜色1 终点位置, 颜色2 终点位置, ......&#xff09;&#xff1b; 取值: 渐变方向:可选 to 方位名词角度度数 终点位置:可选 百分…

生物素标记的柚皮苷探针;Biotin-Naringin

生物素标记的柚皮苷探针&#xff08;Biotin-Naringin&#xff09;是一种结合了生物素&#xff08;Biotin&#xff09;和柚皮苷&#xff08;Naringin&#xff09;特性的化合物&#xff0c;它在有机合成及药物化学技术领域具有重要意义。以下是对该探针的详细解析&#xff1a; 一…

研华工控机 UNO-2473G WIN7专业版系统下安装网卡驱动异常

基本配置&#xff1a;UNO-2473G、Windows 7 Pro 64bit 常规型嵌入式工控机&#xff0c;搭配Intel Atom™ E3845/Celeron J1900 处理器 第四代Intel Atom/Celeron J1900处理器&#xff0c;最高可达1.91/2.0 GHz&#xff0c;4GB DDR3L存储4/2 x GbE, 3 x USB 2.01 x USB 3.0或4…

VOS历史话单的非法呼叫话单解决方案,IPSS模块安装到VOS服务器,可大幅度提高安全性!

由于VOS的普及性&#xff0c;不得不承认VOS确实是非常优秀的软交换&#xff0c;但是很多客户在使用过程中都会遇到各种安全问题&#xff0c;比如话费被盗用了&#xff0c;历史话单一堆的非法呼叫话单&#xff0c;严重的影响到了话务安全&#xff0c;并不是那点话费的事了&#…

留学Assignment写作如何正确选择topic?

留学Assignment在写作之前首先需要选好topic&#xff0c;一个好的topic能让你的Assignment写起来更加顺畅&#xff0c;俗话说“文好题一半”&#xff0c;好的创作主题&#xff0c;不但对于写作帮助颇大&#xff0c;对于Assignment总体也是加分不少的。 添加图片注释&#xff0c…

python-课程满意度计算(赛氪OJ)

[题目描述] 某个班主任对学生们学习的的课程做了一个满意度调查&#xff0c;一共在班级内抽取了 N 个同学&#xff0c;对本学期的 M 种课程进行满意度调查。他想知道&#xff0c;有多少门课是被所有调查到的同学都喜欢的。输入格式&#xff1a; 第一行输入两个整数 N , M 。 接…

高,实在是高

go&#xff0c;去 //本义音通义通汉字“高”&#xff0c;指太阳升起、上升&#xff0c;即高上去 god | God&#xff0c;神&#xff0c;上帝 //本义音通义通“高的”&#xff0c;指太阳高高在上的&#xff0c;至高无上的 glad&#xff0c;高兴的 //本义音通义通“高了的”&#…

【CUDA|CUDNN】安装

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 显卡驱动安装参考之前的文章 cuda、cudnn 安装 1. cuda 安装 访问https://developer.nvidia.com/cuda-toolkit-archive 选择需要的版本&#xff1a;h…

【瑞吉外卖 | day03】公共字段自动填充+分类信息的增删改查

文章目录 1. 公共字段自动填充1.1 问题分析1.2 代码实现1.3 代码完善 2. 新增分类2.1 需求分析2.2 数据模型2.3 代码开发 3. 分类信息分页查询3.1 代码开发 4. 删除分类4.1 需求分析4.2 代码开发4.3 功能完善 5. 修改分类 1. 公共字段自动填充 1.1 问题分析 在后台系统的员工管…

【VUE基础】VUE3小技巧(持续更新)

一键快速生成自定义vue3模板代码 根据上图打开配置用户代码片段 搜索vue.jsond打开 "Print to console": {"prefix": "vue","body": ["<!-- $0 -->","<template>"," <div></div>&…

加油卡APP系统开发,优惠加油收益

目前&#xff0c;汽车已经成为了不可或缺的出行工具&#xff0c;汽车加油更是成为了家家户户要做的事。不过随着油价的波动&#xff0c;车主急需能够进行优惠加油的渠道&#xff0c;因此&#xff0c;加油卡APP成为了大众汽车加油新的选择方式&#xff0c;用户在下载APP后即可享…

220V降5V芯片输出电压电流封装选型WT

220V降5V芯片输出电压电流封装选型WT 220V降5V恒压推荐&#xff1a;非隔离芯片选型及其应用方案 在考虑220V转低压应用方案时&#xff0c;以下非隔离芯片型号及其封装形式提供了不同的电压电流输出能力&#xff1a; 1. WT5101A&#xff08;SOT23-3封装&#xff09;适用于将2…

客户案例|某大型证券公司数据库运维场景数据安全实践

证券行业涉及股票、债券、基金等金融产品的发行、交易和监管&#xff0c;业务具有数据规模大、数据价值高、数据应用场景复杂的显著特点&#xff0c;其中高速流转的业务系统中含有海量的客户个人信息、交易、行情、咨询等高敏感高价值信息。由于证券期货业务场景所具有的特殊性…

CC5利用链分析

分析版本 Commons Collections 3.2.1 JDK 8u65 环境配置参考JAVA安全初探(三):CC1链全分析 分析过程 CC6是在CC1 LazyMap利用链(引用)的基础上。 CC5和CC6相似都是CC1 LazyMap利用链(引用)的基础上&#xff0c;改变了到LazyMap的入口类。 CC6是用TiedMapEntry的hashCode方…

系统吃swap问题排查

目录 背景 问题 分析并解决 1.控制线程数 2.更换IO组件 3.Linux进程信息文件分析 总结加餐 参考文档 背景 隔壁业务组系统是简单的主从结构&#xff0c;写索引的服务(主)叫primary&#xff0c; 读索引并提供搜索功能的服务(从)叫replica。业务线同步数据并不是平滑的&…

DIF-Gaussian 代码讲解

这篇论文的标题是《Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction》&#xff0c;作者是Yiqun Lin, Hualiang Wang, Jixiang Chen和Xiaomeng Li&#xff0c;来自香港科技大学以及HKUST深圳-香港协同创新研究院。 这篇论文主要探讨了一种新的锥…

关于MySQL mvcc

innodb mvcc mvcc 多版本并发控制 在RR isolution 情况下 trx在启动的时候就拍了个快照。这个快照是基于整个数据库的。 其实这个快照并不是说拷贝整个数据库。并不是说要拷贝出这100个G的数据。 innodb里面每个trx有一个唯一的trxID 叫做trx id .在trx 开始的时候向innodb系…

录音的内容怎么做二维码?支持多种音频格式使用的制作技巧

怎么把录制的音频文件做成二维码呢&#xff1f;现在用二维码来存储内容是一种很常用的方式&#xff0c;让其他人扫描二维码来查看内容&#xff0c;从而提升内容传输的速度。比如现在很多人会将音频生成二维码&#xff0c;其他人可以通过扫码在手机上播放音频内容&#xff0c;那…

kafka的副本replica

指定topic的分区和副本 通过kafka命令行工具 kafka-topics.sh --create --topic myTopic --partitions 3 --replication-factor 1 --bootstrap-server localhost:9092 执行代码时指定分区个数