14. PEFT:在大模型中快速应用 LoRA

news2024/9/22 18:57:03

如果你对LoRA还没有一个直观的概念,可以回看这篇文章:《3. 认识 LoRA:从线性层到注意力机制》。

我们将在这里进一步探讨如何快速地在大型预训练模型中应用 LoRA,并解答可能存在的问题,包括:

  • peftlora 之间有什么关系?
  • get_peft_model 怎么使用?
  • 如何知道应用 LoRA 后模型的参数变化量?

代码文件下载

文章目录

  • PEFT 和 LoRA 的关系
  • 在大模型中应用 LoRA
    • 安装必要的库
    • 加载预训练模型
    • 应用 LoRA
    • 查看当前模型架构
    • 查看增加的参数量
      • 理论计算
      • 使用 PEFT 查看参数
      • 自定义函数查看参数
    • 准备数据并进行微调
    • 保存和加载 LoRA 微调的模型

PEFT 和 LoRA 的关系

PEFT(Parameter-Efficient Fine-Tuning)是 Hugging Face 提供的专门用于参数高效微调的工具库。LoRA(Low-Rank Adaptation)是 PEFT 支持的多种微调方法之一,旨在通过减少可训练参数来提高微调大模型的效率。除此之外,PEFT 还支持其他几种常见的微调方法,包括:

  • Prefix-Tuning:冻结原模型参数,为每一层添加可学习的前缀向量,只学习前缀参数。
  • Adapter-Tuning:冻结原模型参数,在模型的层与层之间插入小型的 adapter 模块,仅对 adapter 模块进行训练。

在大模型中应用 LoRA

下面,我们以实际的例子来展示如何在大模型中快速应用 LoRA。

安装必要的库

首先,确保你已经安装了 transformerspeft 库。

pip install transformers peft

加载预训练模型

我们以 Hugging Face 的 transformers 库为例,加载一个预训练的 GPT-2 模型,其参数大小为 110M。

from transformers import AutoTokenizer, AutoModelForCausalLM

# 加载预训练的 GPT-2 模型和分词器
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

print(model)

打印 model,方便和应用 LoRA 后进行对比。

image-20240921111524351

应用 LoRA

使用 peft 库,我们可以轻松地将 LoRA 集成到模型中:

from peft import get_peft_model, LoraConfig, TaskType

# 配置 LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # 任务类型:因果语言模型
    inference_mode=False,          # 推理模式关闭,以进行训练
    r=8,                           # 低秩值 r
    lora_alpha=32,                 # LoRA 的缩放因子
    lora_dropout=0.1,              # Dropout 概率
)

# 将 LoRA 应用到模型中
model = get_peft_model(model, lora_config)

查看当前模型架构

print(model)

可以看到 LoRA 已经成功应用。

image-20240921111458757

查看增加的参数量

应用 LoRA 后,或许你希望了解模型参数量的变化。以下是理论计算和查看方式:

理论计算

对于每个应用了 LoRA 的层,增加的参数量为:

增加的参数量 = r × ( 输入维度 + 输出维度 ) \text{增加的参数量} = r \times (\text{输入维度} + \text{输出维度}) 增加的参数量=r×(输入维度+输出维度)

  • r:LoRA 的低秩值。
  • 输入维度:层的输入特征数。
  • 输出维度:层的输出特征数。

使用 PEFT 查看参数

peft 提供了查看模型参数的便捷方法:

# 查看 LoRA 模块
model.print_trainable_parameters()

输出:

trainable params: 294,912 || all params: 124,734,720 || trainable%: 0.23643136409814364

自定义函数查看参数

实际上直接计算所有可训练参数就行。

def print_trainable_parameters(model):
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        all_params += num_params
        if param.requires_grad:
            trainable_params += num_params
    print(f"可训练参数量: {trainable_params}")
    print(f"总参数量: {all_params}")
    print(f"可训练参数占比: {100 * trainable_params / all_params:.2f}%")
    
print_trainable_parameters(model)

输出:

可训练参数量: 294912
总参数量: 124734720
可训练参数占比: 0.24%

准备数据并进行微调

假设你已经有了训练数据集 train_dataset,下面是一个简单的样例代码。

from transformers import Trainer, TrainingArguments

# 定义训练参数
training_args = TrainingArguments(
    output_dir='./results',         # 模型保存和日志输出的目录路径
    num_train_epochs=3,             # 训练的总轮数(epochs)
    per_device_train_batch_size=16, # 每个设备(如GPU或CPU)上的训练批次大小,16表示每次输入模型的数据数量
    learning_rate=5e-5,             # 学习率
    logging_steps=10,               # 每隔多少步(steps)进行一次日志记录
    save_steps=100,                 # 每隔多少步保存模型
)

# 创建 Trainer
trainer = Trainer(
    model=model,                    # 训练的模型对象,需要事先加载好
    args=training_args,             # 上面定义的训练参数配置
    train_dataset=train_dataset,    # 需要对应替换成已经处理过的dataset
)

# 开始训练
trainer.train()

保存和加载 LoRA 微调的模型

训练完成后,你可以保存或者加载 LoRA 微调的参数,下面是个简单的示例。

# 保存 LoRA 参数
model.save_pretrained('./lora_model')

在推理时,加载原始的预训练模型和 LoRA 参数。

# 加载原始模型
base_model = AutoModelForCausalLM.from_pretrained("gpt2")

# 加载 LoRA 参数
from peft import PeftModel

model = PeftModel.from_pretrained(base_model, './lora_model')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2155686.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NSSCTF刷题篇1

js类型 [SWPUCTF 2022 新生赛]js_sign 这是一道js信息泄露的题目直接查看源码,有一个main.js文件点击之后,有一串数字和一段base64编码,解开base64编码得到这个编码为敲击码 解码在线网站:Tap Code - 许愿星 (wishingstarmoye.…

828华为云征文|华为云Flexus云服务器X实例之openEuler系统下部署k8s管理面板KubePi

828华为云征文|华为云Flexus云服务器X实例之openEuler系统下部署k8s管理面板kubepi 前言一、Flexus云服务器X实例介绍1.1 Flexus云服务器X实例简介1.2 Flexus云服务器X实例特点1.3 Flexus云服务器X实例使用场景 二、 KubePi介绍2.1 KubePi简介2.2 KubePi主要特点&am…

序列化方式二——JSON之Gson

Gson 1、什么是Gson? Gson是Google提供的一个用于Java编程语言的JSON(JavaScript Object Notation)序列化和反序列化库。它允许开发者在Java对象和JSON数据之间进行高效的映射和转换。 官网地址:https://github.com/google/gson 官网文档…

小程序隐私合规自查指南

一 背景:小程序作为一种轻量级应用,广泛应用于各大互联网平台。工信部通报2022年第5批侵害用户权益名单中首次出现8款违规小程序。各监管单位对“小程序”违规收集个人信息监控手段和监控力度不断加强。 工信部APP违法违规通报 上海市委网信办查处违规小…

Python_控制循环语句

if语句单分支结构的语法形式如下&#xff1a; 【操作】输入一个数字&#xff0c;小于10&#xff0c;则打印这个数字(if_test01.py)&#xff1a; num input("输入一个数字&#xff1a;") if int(num)<10: print("小于10的数&#xff1a;"num)条件表达式…

BOE(京东方)携多场景物联网创新应用亮相2024服贸会 “屏之物联”赋能数字经济

9 月 12 日&#xff0c;以“全球服务 互惠共享”为主题的2024中国国际服务贸易交易会&#xff08;以下简称“服贸会”&#xff09;在北京拉开帷幕。作为领先的物联网创新企业&#xff0c;BOE&#xff08;京东方&#xff09;携智慧办公、智慧商显、智能车载、智慧教育、智能工厂…

设计模式例题

答案&#xff1a;D C A D 知识点&#xff1a; 观察者模式的意图&#xff1a;定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变时&#xff0c;所有依赖它的对象都得到通知并被自动更新&#xff0c;和自媒体很相似&#xff0c;自媒体更新内容&#xff0c…

C++--C++11(下)

目录 7.5 完美转发 8 新的类功能 9 可变参数模板 10 lambda表达式 11 包装器 7.5 完美转发 模板中的 && 万能引用 void Fun(int &x){ cout << "左值引用" << endl; } void Fun(const int &x){ cout << "const 左值引用…

秒变 Vim 高手:必学的编辑技巧与隐藏功能大揭秘

文章目录 前言一、vi与vim二、Vim的三种模式1. 普通模式2. 插入模式3. 命令模式 三、Vim中的查找与替换1. 查找2. 替换 四、给Vim设置行号1. 临时显示行号2. 永久显示行号 总结 前言 在Linux系统中&#xff0c;文本编辑器是开发者和系统管理员日常工作中的重要工具之一。其中&…

DeepSeek 2.5本地部署的实战教程

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于大模型算法的研究与应用。曾担任百度千帆大模型比赛、BPAA算法大赛评委,编写微软OpenAI考试认证指导手册。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。授权多项发明专利。对机器学…

想学习下Python和深度学习,Python需要学习到什么程度呢?

想要学习Python和深度学习&#xff0c;Python的学习程度需要达到能够熟练运用这门语言进行编程&#xff0c;并能够理解和实现深度学习模型的基本构建和训练过程。以下是一些推荐的书籍&#xff0c;可以帮助你系统地学习Python和深度学习&#xff1a; Python学习推荐书籍 《Py…

kubectl 执行一条命令之后发生了什么?

kubectl 是与 Kubernetes 集群交互的命令行工具&#xff0c;用户通过它可以对集群资源进行操作和管理。你有没有想过&#xff0c;当我们执行一条 kubectl 命令之后&#xff0c;背后都发生了什么&#xff1f; 详细过程 kubectl -> kube-api-server 根据通信类型&#xff0…

【大模型】AutoDL部署AI绘图大模型Stable Diffusion使用详解

目录 一、前言 二、AI绘图大模型概述 2.1 AI绘图大模型介绍 2.2 AI绘图大模型特点 2.3 AI绘图大模型优势 三、主流的AI绘图大模型介绍 3.1 Midjourney 3.1.1 Midjourney介绍 3.1.2 Midjourney功能特点 3.1.3 Midjourney使用场景 3.2 Stable Diffusion 3.2.1 Stable …

zynq的PS端mac与RTL8211F的连接要点

目录 1 VCCO_MIO12 PS_MIO_VREF3 PS的引脚4 RXDLY TXDLY5 ZYNQ的MAC可以调整延时吗 1 VCCO_MIO1 接1.8V 2 PS_MIO_VREF 接0.9V&#xff0c;可通过电阻分压 可通过电阻分压 3 PS的引脚 4 RXDLY TXDLY RXDLY RXD[0] TXDLY RXD[1] 与XC7Z020的PS端MAC连接&#xff0c;必须…

python画正方形、平行四边形、六边形、五角星、风车(四个半圆)

画正方形、平行四边形、六边形、五角星、风车&#xff08;四个半圆&#xff09; import turtle def square(side_length):"""正方形"""for _ in range(4):turtle.forward(side_length)turtle.right(90)def parallelogram(base, height):"&q…

C++——模拟实现string

1.再谈string string为什么要被设计成模板&#xff1f;日常使用string好像都是char*&#xff0c;char*不够使用吗&#xff0c;为什么要设计成模板呢&#xff1f; 1.1 关于编码 //计算机的存储如何区分呢&#xff1f;int main() {//比如在C语言中&#xff0c;有整型//如果是有…

Linux网络——HTTP协议详解(2)

文章目录 HTTP方法GET方法POST方法 状态码与报头状态码报头 会话 HTTP方法 HTTP方法有这些 但是怎么说呢&#xff0c;这些方法只有GET和POST方法是99%情况用到的 剩下的几乎不太用&#xff0c;如果有兴趣可以找《图解HTTP》&#xff0c;是处于了解的范畴 大家肯定一看就明白…

Qt Creator项目模板介绍

在Qt Creator中创建项目时&#xff0c;用户可以从多个模板类别中进行选择&#xff0c;以满足不同的开发需求。 Application(Qt) 在Application(Qt)类别下&#xff0c;Qt Creator提供了多种用于创建不同类型Qt应用程序的模板。这些模板主要包括&#xff1a; Qt Widgets Applic…

专业解析:移动硬盘“要求格式化”背后的真相与数据救援策略

引言&#xff1a;格式化预警下的数据危机 在日常的数字生活中&#xff0c;移动硬盘作为数据存储与传输的重要工具&#xff0c;其稳定性与安全性直接关系到用户资料的安全。然而&#xff0c;不少用户遭遇过这样一个令人头疼的问题——移动硬盘在接入电脑后&#xff0c;系统突然…

Tcping:一款实用的端口存活检测工具

简介 tcping 是一个基于TCP协议的网络诊断工具,通过发送 TCP SYN/ACK包来检测目标主机的端口状态。 官网:tcping.exe - ping over a tcp connection 优点: (1)监听服务器端口状态:tcping 可以检测指定端口的状态,默认是80端口,也可以指定其他端口。 (2)显示ping返…