2023年的深度学习入门指南(12) - PEFT与LoRA

news2024/11/23 22:58:48

2023年的深度学习入门指南(12) - PEFT与LoRA

大家都知道,大模型的训练需要海量的算力。其实,即使是只对大模型做微调训练,也是需要大量的计算资源的。

有没有用更少的计算资源来进行微调的方法呢?研究者研发出了几种被Hugging Face统称为参数高效微调PEFT(Parameter-Efficient Fine-Tuning)的技术。

这其中常用的几个大家应该已经耳熟能详了,比如广泛应用的LoRA技术(Low Rank Adapters,低秩适配),Prefix Tuning技术,Prompt Tuning技术等等。

我们先学习如何使用,然后我们再学习其背后的原理。

用Huggingface PEFT库进行低秩适配

首先我们先安装相关的库,主要有量化用的bitsandbytes库,低秩适配器loralib库,以及加速库accelerate。
另外,PEFT库和transformers库都用最新的版本。

pip install -q bitsandbytes datasets accelerate loralib
pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

我们来尝试训练一个7B左左的模型,我们选用opt-6.7b模型,它以float16的精度存储,大小大约为13GB!如果我们使用bitsandbytes库以8位加载它们,我们需要大约7GB的显存。

但是,这只是加载用的,在实际训练的时候,16G显存都照样不够用。最终的消耗大约在20G左右。

加载大模型仍然使用我们前面学过的AutoModelForCausalLM.from_pretrained()函数,只是我们需要加上load_in_8bit=True参数来调用bitsandbytes库进行8位量化。

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-6.7b",
    load_in_8bit=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b")

下面PEFT就正式出场了,我们先针对所有非int8的模块进行预处理以提升精度:

from peft import prepare_model_for_int8_training

model = prepare_model_for_int8_training(model)

我们再配置下LoRA的参数,参数的具体含义我们后面结合原理再讲。

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

我们选用名人名言数据集作为训练数据:

from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

然后就可以开始训练了:

trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=200,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

最后,我们做一个推理测试下效果:

batch = tokenizer("Two things are infinite: ", return_tensors="pt")

with torch.cuda.amp.autocast():
    output_tokens = model.generate(**batch, max_new_tokens=50)

print("\n\n", tokenizer.decode(output_tokens[0], skip_special_tokens=True))

输出的结果如下:

 Two things are infinite:  the universe and human stupidity; and I'm not sure about the universe.  -Albert Einstein
I'm not sure about the universe either.

基本上,我们除了配置了一个LoRA参数之外什么也没干。

LoRA的原理

LoRA的思想是将原始的权重矩阵分解为两个低秩矩阵的乘积,这样就可以大大减少参数量。其本质思想还是将复杂的问题拆解为简单的问题的组合。
LoRA通过注入优化后的秩分解矩阵,将预训练模型参数冻结,减少了下游任务的可训练参数数量,使得训练更加高效。并且在使用适应性优化器时,降低了硬件进入门槛。
因为我们不需要计算大多数参数的梯度或维护优化器状态,而是仅优化注入的、远小于原参数量的秩分解矩阵。

光定量地这么讲,大家没有观感,我们以上面训练的例子来看看LoRA的效果。

我们写一个函数来计算模型中的可训练参数数量:

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

运行一下:

print_trainable_parameters(model)

输出结果如下:

trainable params: 8388608 || all params: 6666862592 || trainable%: 0.12582542214183376

我们看到,原始的模型参数有66亿多个,但是我们只训练了838多万个,只占了0.125%。

所以这也就是为什么我们经常看到有6b,7b,还有13b参数的大模型了。因为这个量级的模型,刚好可以在一张40G或者80G的A100显卡上训练。甚至在24G的3090上也能训练。

下面我们来解释一下低秩更新的原理。

如图所示,输入为x,x是d维的向量,输出是h。

我们将参数分为冻结的权重 W 0 W_0 W0和可以训练的参数 Δ W \Delta W ΔW。然后我们把 Δ W \Delta W ΔW分解成A和B两个可训练参数的矩阵,其中A矩阵取随机值,而B矩阵全取0.

h = W 0 x + Δ W x = W 0 x + B A x h=W_0 x+\Delta W x=W_0 x+B A x h=W0x+ΔWx=W0x+BAx

其中, W 0 W_0 W0是一个d乘以r维的矩阵, W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k

为了让B乘以A的结果为输入是d维而输出为k维,B矩阵我们取d行r列,而A矩阵为r行k列,这样一相乘就是d行k列:
B ∈ R d × r , A ∈ R r × k B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k} BRd×r,ARr×k

为了让低秩后的效果更好,r要取一个远小于d和k的值。

为了减少更换r给训练带来的影响,我们再引入一个缩放参数 α \alpha α。我们给 Δ W x \Delta W x ΔWx乘以 α r \frac{\alpha}{r} rα。当使用Adam优化时,如果我们适当地缩放初始化,调整α就大致相当于调整学习率。因此,我们简单地将α设置为我们尝试的第一个r,并不对其进行调整。这种缩放有助于减少在改变r时重新调整超参数的需要。

我们来参照一下前面配置的LoRA config:

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

我们可以看到,r选择的是16,而alpha为32。说明最开始是用32作为r来进行尝试的。后面我们再调参数的时候,就改r而不调整alpha了。

那么,我们为什么只选择了q和v两个参数进行LoRA呢?

我们来看论文中的数据:

取q和k两组参数的效果,还不如只取v一个的效果好。而把q,k,v,o全都训练了,也没有明显的优势。所以就取相对最有效率的q,v两组。

当然,这也不是金科玉律,大家可以在实践中去探索更好的LoRA策略。

小结

LoRA的一个例子就是alpaca-lora项目,其网址为:https://github.com/tloen/alpaca-lora

alpaca-lora是一个使用LoRA技术对Alpaca模型进行轻量化的项目。Alpaca模型是一个基于LLaMA 7B模型的聊天机器人,使用了Instruct数据集进行微调。alpaca-lora的优点是可以在低成本和低资源的情况下,获得与Alpaca模型相当的效果,并且可以在MacBook、Google Colab、Raspberry Pi等设备上运行。alpaca-lora使用了Hugging Face的PEFT和bitsandbytes来加速微调过程,并提供了一个脚本来下载和推理基础模型和LoRA模型。

现在,PEFT和LoRA对我们来说,已经不再陌生了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/500007.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

fastCGI使用

1.http解释 在使用fastCGI之前需要先了解什么是http,以及静态请求和动态请求。 1.什么是http HTTP是超文本传输协议,它定义了客户端和服务器端之间文本传输的规范。HTTP通常运行在TCP之上,使用80端口。HTTP是一种简单的请求-响应协议&#x…

GUN C编译器拓展语法学习笔记(二)属性声明

一、属性声明 1、存储段:section 1.1 GNU C编译器扩展关键字:__attribute__ GNU C增加了一个__attribute__关键字用来声明一个函数、变量或类型的特殊属性。主要用途就是指导编译器在编译程序时进行特定方面的优化或代码检查。例如,我们可以…

C语言三子棋小游戏

哈喽,大家好,今天我们要利用之前所学习的C语言知识来写一个三子棋小游戏。 目录 1.游戏 2.函数部分 2.1.菜单 2.2.初始化棋盘 2.3.打印棋盘 2.4.玩家下棋 2.5.电脑下棋 2.6.判断输赢 2.7.判断棋盘是否已满 3.完整代码展示 1.游戏 今天我们写的…

未知时间信息下雷达运动目标的计算高效重聚焦与估计方法

论文背景 在雷达成像中,回波信号在接收到之前可能已经被多次反射或散射,这样会导致回波信号的时间和频率发生变化。其中,距离向维度上的变化称为距离单元迁移(range cell migration,RCM),频率向…

Spring笔记

文章目录 1、什么是Spring?2、如何创建Spring3、Spring简单的读和取操作1.直接在spring-config.xml里面放置对象2.通过配置扫描路径和添加注解的方式添加Bean对象3.为什么需要五个类注解4.从spring中简单读取 Bean对象5.Resource和Autowired的异同 1、什么是Spring&…

Transformer结构细节

一、结构 Transformer 从大的看由 编码器输入、编码器、解码器、解码器输入和解码器输出构成。 编码器中包含了词嵌入信息编码、位置编码、多头注意力、Add&Norm层以及一个全连接层; 解码器中比编码器多了掩码的多头注意力层。 二、模块 2.1 Input Embeddi…

canvas学习之华丽小球滚动电子时钟

教程来自 4-3 华丽的小球滚动效果 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>华丽小球滚动时钟…

【AVL树的模拟实现】

1 AVL树的概念 二叉搜索树虽可以缩短查找的效率&#xff0c;但如果数据有序或接近有序二叉搜索树将退化为单支树&#xff0c;查找元素相当于在顺序表中搜索元素&#xff0c;效率低下。因此&#xff0c;两位俄罗斯的数学家G.M.Adelson-Velskii和E.M.Landis在1962年发明了一种解决…

人工智能基础部分14-蒙特卡洛方法在人工智能中的应用及其Python实现

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能基础部分14-蒙特卡洛方法在人工智能中的应用及其Python实现&#xff0c;在人工智能领域&#xff0c;蒙特卡洛方法&#xff08;Monte Carlo Method, MCM&#xff09;被广泛应用于各种问题的求解。本文首先将…

wvp-GB28181-pro录像功能开发环境搭建、配置、使用

开发环境、调试环境搭建 开发wvp平台搭建 离线安装脚本&#xff1a;https://gitcode.net/zenglg/ubuntu_wvp_online_install.git 下载离线安装脚本&#xff0c;完成wvp平台的部署 开发环境要求 操作系统&#xff1a;包管理工具是apt ky10桌面版uos桌面版deepin桌面版ubuntu桌面…

ArmDot.NET Crack

ArmDot.NET Crack ArmDot是一个.NET加密工具&#xff0c;用于保护使用.NET编写的程序。 企业需要保护他们的知识产权&#xff0c;包括他们的算法、产品和使用的资源的源代码。 然而&#xff0c;.NET编译器会生成一个通用的可访问代码。代码中嵌入的资源很容易访问&#xff0c;并…

RocketMQ不同的类型消息

目录 普通消息 可靠同步发送 可靠异步发送 单向发送 三种发送方式的对比 顺序消息 事物消息 两个概念 事务消息发送步骤 事务消息回查步骤 消息消费要注意的细节 RocketMQ支持两种消息模式: 普通消息 RocketMQ提供三种方式来发送普通消息&#xff1a;可靠同步发送、…

剑指Offer题集(力扣)

文章目录 剑指Offer题集&#xff08;[力扣题单](https://leetcode.cn/problemset/all/?listIdlcof&page1)&#xff09;[剑指 Offer 03. 数组中重复的数字](https://leetcode.cn/problems/shu-zu-zhong-zhong-fu-de-shu-zi-lcof/)[剑指 Offer 04. 二维数组中的查找](https:…

SSM框架练习一(登录后关联数据表的业务模型)

需要实现的整体功能&#xff1a; 登录反馈信息列表展示查询反馈信息发表反馈 1.数据库设计 创建数据库 创建表结构及其约束 添加测试数据 工具&#xff1a;PHP、Navicat create table tab_user(id int primary key auto_increment,uname varchar(30) not null,pwd varc…

Weblogic XMLDecoder 反序列化漏洞(CVE-2017-10271复现)

文章目录 前言影响版本环境搭建漏洞复现深度利用 前言 CVE-2017-10271漏洞产生的原因大致是Weblogic的WLS Security组件对外提供webservice服务&#xff0c;其中使用了XMLDecoder来解析用户传入的XML数据&#xff0c;在解析的过程中出现反序列化漏洞&#xff0c;导致可执行任意…

从搬砖工到架构师,Java全栈学习路线总结

&#x1f307;文章目录 前言一、前置知识二、 Web前端基础示例&#xff1a;1.文本域2.密码字段 三、后端基础一. Java基础二. 数据库技术三. Web开发技术四. 框架技术五. 服务器部署 四、其他技术五、全栈开发六、综合实践七、学习教程一、前端开发二、后端开发三、数据库开发四…

springboot+jsp乡村中小学校园网站建设

随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;乡村小学校园网当然也不能排除在外&#xff0c;从校园概况、学校风采、招生信息的统计和分析&#xff0c;在过程中会产生大量的…

Maven依赖原则及如何解决Maven依赖冲突

前言 在大数据应用中&#xff0c;现在发现依赖关系非常复杂&#xff0c;在上线之前很长测试&#xff0c;前一段时间在部署udf 出现了导致生产Hiveserver2 宕机问题&#xff0c;出现严重事故。现在就咨询研究一下。Maven虽然已经诞生多年&#xff0c;但仍然是当前最流行的Java系…

Arrays:点燃你的数组操作技巧的隐秘武器。

前言 数组在 Java 中是一种常用的数据结构&#xff0c;用于存储和操作大量数据。但是在处理数组中的数据&#xff0c;可能会变得复杂和繁琐。Arrays 是我们在处理数组时的一把利器。它提供了丰富的方法和功能&#xff0c;使得数组操作变得更加简单、高效和可靠。无论是排序、搜…

【c语言】字符串类型转换 | itoa函数的使用

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 给大家跳段街舞感谢支持&#xff01;ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ ኈ ቼ …