LORA概述: 大语言模型的低阶适应

news2024/12/24 3:27:39

LORA概述: 大语言模型的低阶适应

  • LORA: 大语言模型的低阶适应
    • 前言
    • 摘要
    • 论文十问
    • 实验
      • RoBERTa
      • DeBERTa
      • GPT-2
      • GPT-3
    • 结论
    • 代码调用

LORA: 大语言模型的低阶适应

前言

LoRA的核心思想在于优化预训练语言模型的微调过程,通过有效地处理权重矩阵的变化(即梯度更新的累积),使其具有“低秩”结构。简而言之,这意味着可以通过低秩分解有效地表示变化矩阵。

具体来说,对于预训练权重矩阵W₀,其更新量可以表示为∆W = BA,其中B和A都是低秩矩阵(例如,秩为r,r明显小于矩阵维度d)。在训练期间,W₀被冻结,而B和A中的参数是可训练的。这明显减少了适应的可训练参数数量。

如下图所示,在原始预训练语言模型旁边添加一个旁路,执行降维再升维的操作,以模拟内在秩。在训练过程中,固定预训练语言模型的参数,只训练降维矩阵A和升维矩阵B。模型的输入输出维度保持不变,输出时将BA与预训练语言模型的参数叠加。矩阵A使用随机高斯分布进行初始化,而矩阵B则使用零矩阵进行初始化,以确保在训练开始时,该旁路矩阵仍然是零矩阵。

在这里插入图片描述

摘要

自然语言处理的一个重要范式包括在通用域数据上进行大规模预训练,以及针对特定任务或域进行适配。随着我们预训练更大的模型,全面微调,即重新训练所有模型参数,变得更加不可行。

以GPT-3 175B为例,单独部署经过微调的独立实例模型,每个实例拥有1750亿个参数,是极其昂贵的。我们提出了低秩适应(LoRA)方法,其中冻结预训练模型权重,并在transformer体系结构的每个层中插入可训练的低秩分解矩阵,从而大大减少下游任务的可训练参数数量。

与使用Adam微调GPT-3 175B相比,LoRA可以将可训练参数数量减少10000倍,GPU内存需求减少3倍。尽管只有更少的可训练参数和更高的训练吞吐量,但LoRA在RoBERTa、DeBERTa、GPT-2和GPT-3上的性能优于或等同于微调。

我们还对语言模型适配中的秩缺失进行了实证研究,这解释了LoRA的功效。我们发布了一个软件包,可以方便地将LoRA与PyTorch模型集成,并为RoBERTa、DeBERTa和GPT-2提供了我们的实现和模型检查点。

论文十问

  1. 论文试图解决什么问题?

这篇论文试图解决大规模预训练语言模型(如GPT-3)微调(fine-tuning)所带来的巨大的存储、部署和任务切换成本的问题。

  1. 这是否是一个新的问题?

这不是一个全新的问题,但随着 transformer 语言模型规模的不断增长(如 GPT-3 175B 参数),这个问题的严重性在增加。论文中也提到了许多相关的已有工作。

  1. 这篇文章要验证一个什么科学假设?

这篇文章的主要科学假设是微调过程中模型参数的变化矩阵具有低秩结构(rank-deficient)。基于这个假设,作者提出了低秩适应(LoRA)方法来有效地适应下游任务。

  1. 有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员?

相关的工作包括适配器模块、prompt tuning等参数高效适应方法。

  1. 论文中提到的解决方案之关键是什么?

LoRA的关键是只训练插入到每个transformer层中的低秩分解矩阵,而保持预训练权重固定。这大大降低了适应的可训练参数数量。

  1. 论文中的实验是如何设计的?

在 RoBERTa、DeBERTa、GPT-2 和 GPT-3 等模型上进行了大量实验比较。实验设计针对性强,测试了性能和参数数量的权衡。

  1. 用于定量评估的数据集是什么?代码有没有开源?

使用的数据集包括 GLUE、WikiSQL、SAMSum 等。实验代码和模型检查点开源。

  1. 论文中的实验及结果有没有很好地支持需要验证的科学假设?

是的,丰富的实验验证了 LoRA 在性能、存储效率、训练速度等方面都优于或匹敌全微调基线,支持了低秩适应的有效性。

  1. 这篇论文到底有什么贡献?

主要贡献是提出 LoRA 方法,大幅降低大模型微调的成本,并给出可复现的实验验证。

  1. 下一步呢?有什么工作可以继续深入?

下一步可以考虑与其他高效适应方法(如prompt tuning)的结合,解释微调过程中模型内部表示的变化,进一步提高 LoRA 的泛化性等。

实验

评估了 LoRA 在 RoBERTa (Liu et al., 2019)、DeBERTa (He et al., 2021) 和 GPT-2 (Radford etal., b) 上的下游任务性能,然后再扩展到 GPT- 3 175B(布朗等人,2020 年)

RoBERTa

RoBERTa是Facebook AI于2019年提出的语言表示模型。相比BERT有更优化的预训练步骤,性能更好,参数规模类似,分Base和Large两个版本。实验中分别使用了1.25亿参数和3.55亿参数的RoBERTa模型。

DeBERTa

微软于2020年提出的改进型BERT模型。采用多任务预训练、增强型注意力机制等技术。实验中使用了极大规模的DeBERTa XXL,包含了1500亿参数。
在这里插入图片描述

GPT-2

OpenAI于2019年提出的基于Transformer的语言生成模型GPT-2。模型架构采用了解码器,支持自回归文本生成。实验分别基于中等规模(3.54亿参数)和大规模(7.74亿参数)的GPT-2进行。
在这里插入图片描述

GPT-3

OpenAI于2020年发布的巨大语言模型, Transformer规模达到了1750亿参数,是当时最大的神经语言模型。论文使用了这个极具挑战性的大模型进行扩展实验。这几种模型的选择,可以让作者全面验证LoRA适配方法在不同规模的Transformer类模型中的有效性。覆盖了目前最典型和最前沿的语言表示与生成模型。
在这里插入图片描述

结论

实际好处:

  1. 内存和存储使用减少: 在使用Adam训练的大型Transformer中,通过使用LoRA,显著减少了VRAM(显存)和存储的使用量。例如,在GPT-3 175B上,将训练期间的VRAM消耗从1.2TB减少到350GB。
  2. 检查点大小减小: 在一定条件下,检查点大小减少了大约10,000倍,从350GB减少到35MB。这降低了GPU训练的硬件需求,并避免了I/O瓶颈。
  3. 任务切换成本降低: LoRA允许在任务之间进行切换,通过仅交换LoRA权重而不是所有参数,降低了部署的成本。这使得可以在机器上动态换入和换出预训练权重,创建自定义模型。
  4. 加速训练: 在GPT-3 175B的训练中,相较于完全微调,观察到25%的加速,因为不需要计算绝大多数参数的梯度。

局限性:

  1. 前向传递复杂性: 吸收不同任务的A和B到W中,以消除额外推理延迟,在单个前向传递中批量输入并不简单。需要考虑不同任务的权重合并和动态选择LoRA模块的复杂性。
  2. 推理延迟问题: 尽管可以动态选择LoRA模块以处理不同任务的推理延迟,但在一些场景中,合并权重可能引入不可避免的问题。

代码调用

使用 🤗 PEFT 训练您的模型

下面的示例是使用 LoRA 进行微调的情况。

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

model = model.to(device)
model.eval()
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1277687.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker中部署并启动RabbitMQ

目的 由于最近频繁更换云服务器,导致环境啥的都需要重新配置,关于RabbitMQ,我在看其他博主的文章时,总是不能第一时间找到想要的配置方法(也不是没有,只是花的时间太久),于是打算自己…

接口响应时长几十秒问题排查以及解决

目录 背景 解决方案 总结 背景 线上系统运行几年后,被项目上提bug,有些接口响应很慢,加载页面要几十秒 解决方案 1、步骤一,加索引 性能优化成本高,需要开发周期,临时方案先分析慢sql,通过增…

C语言——深入理解指针(4)

目录 1.回调函数 2. qsort 函数的使用 2.1 排序整型数据 2.2 排序结构体数据 3. qsort 函数的模拟实现 1.回调函数 回调函数就是通过一个函数指针调用的函数。 你把函数的地址作为参数传递给另一个函数,当这个指针被用来调用其所指向的函数时,被调…

Linux:查看端口占用的进程

命令 netstat -tunlp可以从图中看到,端口被那个进程占用,对应进程的pid是多少。

软件测试工程师如何面试?

首先作为HR的角度: 一般我们面试的时候都会问应聘者一些问题,但是问什么?怎么问?每个HR都会有不同的做法。 有的HR问的比较广泛,有的HR比较注重专业度,还有的HR喜欢问一些开放性的问题,没有标…

版本控制系统Git学习笔记-Git分支操作

文章目录 概述一、Git分支简介1.1 基本概念1.2 创建分支1.3 分支切换1.4 删除分支 二、新建和合并分支2.1 工作流程示意图2.2 新建分支2.3 合并分支2.4 分支示例2.4.1 当前除了主分支,再次创建了两个分支2.4.2 先合并test1分支2.4.3 合并testbranch分支 2.5 解决合并…

智慧工地一体化解决方案(里程碑管理)源码

智慧工地为管理人员提供及时、高效、优质的远程管理服务,提升安全管理水平,确保施工安全提高施工质量。实现对人、机、料、法、环的全方位实时监控,变被动“监督”为主动“监控”。 一、建设背景 施工现场有数量多、分布广,总部统…

C++11 左值 右值

什么是左值?什么是左值引用? 左值是一个表示数据的表达式(如变量名或解引用的指针),我们可以获取它的地址可以对它赋 值,左值可以出现赋值符号的左边,右值不能出现在赋值符号左边。定义时const修饰符后的左 值&am…

亚马逊云科技re:Invent大会,助力安全构建规模化生成式AI应用

2023亚马逊云科技re:Invent全球大会进入第三天,亚马逊云科技数据和人工智能副总裁Swami Sivasubramanian博士在周三的主题演讲中,为大家带来了关于亚马逊云科技生成式AI的最新能力、面向生成式AI时代的数据战略以及借助生成式AI应用提高生产效率的精彩分…

变量和引用

变量和引用 2.1.深入认识变量 2.1.1.什么是变量 变量是在程序中保存用户数据的一段内存存储空间,变量名是内存空间的首地址 变量三要素:名称、类型、值 2.1.2.变量的名称 组成: 字母、数字、下划线组成,不能以数字开头 变量名称的长…

Android 获取应用签名

Android 获取应用签名 本文主要讲下在android中如何获取应用签名. 也方便平时用来区分一个应用是不是原包应用. 1: 通过PackageManager获取签名信息 首先,通过packageManager获取到指定应用的PackageInfo. 这里需要传入的flag是PackageManager.GET_SIGNATURES /*** {link P…

scrapyd及gerapy的使用及docker-compse部署

一、scrapyd的介绍 scrapyd是一个用于部署和运行scrapy爬虫的程序,它允许你通过JSON API(也即是web api)来部署爬虫项目和控制爬虫运行,scrapyd是一个守护进程,监听爬虫的运行和请求,然后启动进程来执行它们 scrapyd的安装 scr…

蓝桥第一期模拟总结

文章目录 1.动态的 Tab 栏2.地球漫游3.迷惑的this4.燃烧卡路里5.魔法失灵了6.年龄统计 所有题目链接 1.动态的 Tab 栏 本题要实现一个tab栏固定效果,看见题目就想到css中的 position: fixed; 尝试了很久都没能实现效果,后来又想到了粘性定位 position: …

网络编程之套接字

端口 && IP 在学习套接字编程之前,我们必须了解一下前缀知识。首先是IP和端口的作用。 在这之前,我们要明白一件事。那就是把数据从一台主机发送到另一台主机,是目的吗???当然不是!&a…

TOP-K问题和向上调整算法和向下调整算法的时间复杂度问题的分析

TOP-K问题 TOP-K问题:即求数据结合中前K个最大的元素或者最小的元素,一般情况下数据量都比较大 比如:专业前10名、世界500强、富豪榜、游戏中前100的活跃玩家等 对于Top-K问题,能想到的最简单直接的方式就是排序,但是…

聊聊测试for Jeffky

什么是测试 测试是一个系统性的过程,它涉及到在已开发的软件中执行程序、应用工具和技术来评估其质量、功能和性能。这个过程的目的是发现并纠正程序中的错误,提高软件的可靠性和稳定性,以满足用户的需求。 测试的分类 什么是自动化测试 自动…

Android Termux 安装Kali Linux 或 kali Nethunter史诗级详细教程

Android Termux 安装Kali Linux 或 kali Nethunter史诗级详细教程 一、Termux配置1、下载安装2、配置存储和换源3、基本工具安装 二、Kali Linux安装1、下载安装脚本2、更换apt源3、图形化安装 三、Kali Nethunter安装1、下载安装脚本2、更换apt源3、图形化连接 四、报错汇总1、…

五、关闭三台虚拟机的防火墙和Selinux

目录 1、关闭每台虚拟机的防火墙 2、关闭每台虚拟机的Selinux 2.1 什么是SELinux

《第一行代码:Android》第三版4.2常用控件的使用方法(1)

概述 本文主要讲解常用控件的使用&#xff0c;包括&#xff1a;TextView、Button、EditText、ImageView、ProgressBar、AlertDialog。 布局文件 布局文件是activity_main.xml,内容如下&#xff1a; <?xml version"1.0" encoding"utf-8"?> <…

Spring Cloud笔记 —— 什么是Spring Cloud?

引言&#xff1a; 在写这篇博客之前&#xff0c;其实吧&#xff0c;博主很久之前有过一段时间的Spring Cloud的案例项目开发经验&#xff0c;就是一个案例项目开发而已&#xff0c;也说不上有多高大上&#xff0c;那个时候&#xff0c;我其实也是从众而已罢了&#xff0c;毕竟现…