【大模型】基于LoRA微调Gemma大模型(1)

news2024/9/21 18:40:09

文章目录

  • 一、LoRA工作原理
    • 1.1 基本原理
    • 1.2 实现步骤
  • 二、LoRA 实现
    • 2.1 PEFT库:高效参数微调
      • LoraConfig类:配置参数
    • 2.2 TRL库
      • SFTTrainer 类
  • 三、代码实现
    • 3.1 核心代码
    • 3.2 完整代码
  • 参考资料

大模型微调技术有很多,如P-TuningLoRA 等,我们在之前的博客中也介绍过,可以参考:大模型高效参数微调技术(Prompt-Tuning、Prefix Tuning、P-Tuning、LoRA…)

在本篇文章中,我们就 LoRA (Low-Rank Adaptation) 即低秩适应的微调方法工作原理及代码实践进行介绍。

完整的微调步骤可以参考我们的博客:【大模型】基于LoRA微调Gemma大模型(2)

一、LoRA工作原理

1.1 基本原理

LoRA 是 Low-Rank Adaptation 或 Low-Rank Adaptors的首字母缩写词,它提供了一种高效且轻量级的方法,用于微调预先训练好的的大语言模型。

LoRA的核心思想是用一种低秩的方式来调整这些参数矩阵。LoRA通过保持预训练矩阵(即原始模型的参数)冻结(即处于固定状态),并且只在原始矩阵中添加一个小的增量,其参数量比原始矩阵少很多。

例如,考虑矩阵 W,它可以是全连接层的参数,也可以是来Transformer中计算自注意力机制的矩阵之一:

显然,如果 W o r i g W_{orig} Worig 的维数为 n×m,而假如我们只是初始化一个具有相同维数的新的增量矩阵进行微调,虽然我们也实现类似的功能,但是我们的参数量将会加倍。 LoRA使用的Trick就是通过训练低维矩阵 B 和 A ,通过矩阵乘法来构造 ΔW ,来使 ΔW 的参数量低于原始矩阵。

这里我们不妨定义秩 r,它明显小于基本矩阵维度 r≪n 和 r≪m。则矩阵 B 为 n×r,矩阵 A 为 r×m。将它们相乘会得到一个维度为 nxm的W 矩阵,但构建的参数量减小了很多。

LoRA原理见下图:具体来说就是固定原始模型权重,然后定义两个低秩矩阵作为新增weight参与运算,并将两条链路的结果求和后作为本层的输出,而在微调时,只梯度下降新增的两个低秩矩阵。

此外,我们希望我们的增量ΔW在训练开始时为零,这样微调就会从原始模型一样开始。因此,B通常初始化为全零,而 A初始化为随机值(通常呈正态分布)。

1.2 实现步骤

(1)选择目标层

首先,在预训练神经网络模型中选择要应用LoRA的目标层。这些层通常是与特定任务相关的,如自注意力机制中的查询Q和键K矩阵。

值得注意的是,原则上,我们可以将LoRA应用于神经网络中权矩阵的任何子集,以减少可训练参数的数量。在Transformer体系结构中,自关注模块(Wq、Wk、Wv、Wo)中有四个权重矩阵,MLP模块中有两个权重矩阵。我们将Wq(或Wk,Wv)作为维度的单个矩阵,尽管输出维度通常被切分为注意力头。

(2)初始化映射矩阵和逆映射矩阵

为目标层创建两个较小的矩阵A和B,然后进行变换。

参数变换过程:将目标层的原始参数矩阵W通过映射矩阵A和逆映射矩阵B进行变换,计算公式为: W ′ = W + A ∗ B W' = W + A * B W=W+AB,这里W’是变换后的参数矩阵。

其中,矩阵的大小由LoRA的秩(rank)和alpha值确定。
在这里插入图片描述

(3)微调模型
使用新的参数矩阵替换目标层的原始参数矩阵,然后在特定任务的训练数据上对模型进行微调。

(4)梯度更新
在微调过程中,计算损失函数关于映射矩阵A和逆映射矩阵B的梯度,并使用优化算法(如Adam、SGD等)对A和B进行更新。

注意:在更新过程中,原始参数矩阵W保持不变。其实也就是训练的时候固定原始PLM的参数,只训练降维矩阵A与升维矩阵B (W is frozen and does not receive gradient updates, while A and B contain trainableparameters )

(5)重复更新
在训练的每个批次中,重复步骤3-5,直到达到预定的训练轮次(epoch)或满足收敛条件。

且当需要切换到另一个下游任务时,可以通过减去B A然后添加不同的B’ A’来恢复W,这是一个内存开销很小的快速操作。

When we need to switch to another downstream task, we can recover W0 by subtracting BA andthen adding a different B0A0, a quick operation with very little memory overhead.

总之,LoRA的详细步骤包括:选择目标层、初始化映射矩阵和逆映射矩阵、进行参数变换和模型微调。在微调过程中,模型会通过更新映射矩阵U和逆映射矩阵V来学习特定任务的知识,从而提高模型在该任务上的性能。

二、LoRA 实现

这里主要介绍几个与 LoRA 实现相关的类库。

2.1 PEFT库:高效参数微调

Huggingface公司推出的 PEFT (Parameter-Efficient Fine-Tuning,即高效参数微调之意) 库封装了LoRA这个方法,PEFT库可以使预训练语言模型高效适应各种下游任务,而无需微调模型的所有参数,即仅微调少量(额外)模型参数,从而大大降低了计算和存储成本。

peft:全称为Parameter-Efficient Fine-Tuning,PEFT。peft是一种专门为高效调参而设计的深度学习库,其使用了类似于只是蒸馏的技术,通过在预训练模型上添加少量数据来进行微调,从而实现将预训练模型的知识迁移到新的微调模型中。
Github地址:https://github.com/huggingface/peft

LoraConfig类:配置参数

from peft import LoraConfig

LoraConfig是Hugging Face transformers库中用于配置LoRA(Low-Rank Adaptation)的类。LoraConfig允许用户设置以下关键参数来定制LoRA训练:

  • r: 低秩矩阵的秩,即添加的矩阵的第二维度,控制了LoRA的参数量。
  • alpha: 权重因子,用于在训练后将LoRA适应的权重与原始权重相结合时的缩放。
  • lora_dropout: LoRA层中的dropout率,用于正则化。
  • target_modules: 指定模型中的哪些模块(层)将应用LoRA适应。这允许用户集中资源在对任务最相关的部分进行微调。
  • bias: 是否在偏置项上应用LoRA,通常设置为’none’或’all’。
  • task_type: 指定任务类型,如’CAUSAL_LM’,以确保LoRA适应正确应用到模型的相应部分。

2.2 TRL库

trl 库:全称为Transformer Reinforcement Learning,TRL是使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。
Github地址:TRL - Transformer Reinforcement Learning

SFTTrainer 类

from trl import SFTTrainer

SFTTrainertransformers.Trainer的子类,增加了处理PeftConfig的逻辑,可轻松在自定义数据集上微调语言模型或适配器。

三、代码实现

3.1 核心代码

(1)训练阶段

  • LoraConfig:定义LoRA微调参数
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    # lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    # lora_dropout=0.05,
    task_type="CAUSAL_LM",  # 因果语言模型
)
  • SFTTrainer:基于Lora进行微调
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,  # 最大迭代次数
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="./outputs/gemma-new",  # 微调后模型的输出路径
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

# 开始训练
trainer.train()

(2)推理阶段

训练完成后,我们需要将 LoRA 模型基础模型 进行合并,来进行推理。核心代码如下:

base_model_path = "./model/gemma-2b"   
peft_model_path = "./outputs/gemma-new/checkpoint-500"

base_model = AutoModelForCausalLM.from_pretrained(base_model_path, return_dict=True,  device_map=device, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# print(model)

# 加载LoRA模型(基础模型+微调模型)
merged_model = PeftModel.from_pretrained(base_model, peft_model_path)
# print(model)

3.2 完整代码

这里,我们以微调gemma-2b 模型为例,完整的微调步骤可以参考博客:【大模型】基于LoRA微调Gemma大模型(2)

主要包含 train.pyinfer.py 两个文件,具体代码如下:

  • train.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer

device = "cuda:0"

# 定义量化参数
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # 启用4位加载
    bnb_4bit_quant_type="nf4",  # 指定用于量化的数据类型。支持两种量化数据类型: fp4 (四位浮点)和 nf4 (常规四位浮点)
    bnb_4bit_compute_dtype=torch.bfloat16  # 用于线性层计算的数据类型
)

model_path = "./model/gemma-2b"   # chatglm2-6b, gemma-2b
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map=device)  # quantization_config=bnb_config


# 测试原始模型的输出
text = "Quote: Imagination is more"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=30)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


# 加载微调数据集
# data = load_dataset(data_path)   # 加载远程数据集
data_path = "./data/english_quotes/quotes.jsonl"  # 本地数据文件路径
data = load_dataset('json', data_files=data_path)   # 加载本地数据文件
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
print(data)


# 定义格式化函数
def formatting_func(example):
    raise RuntimeError("if you can read this, formatting_func was called")
    text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}<eos>"
    return [text]

print(formatting_func(data["train"]))


# 定义LoRA微调参数
lora_config = LoraConfig(
    r=8,
    # lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    # lora_dropout=0.05,
    task_type="CAUSAL_LM",  # 因果语言模型
)


# 基于Lora进行微调
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,  # 最大迭代次数
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="./outputs/gemma-new",  # 微调后模型的输出路径
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

trainer.train()
# trainer.save_model(trainer.args.output_dir)
  • infer.py
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda:1"

base_model_path = "./model/gemma-2b"   # chatglm2-6b, gemma-2b
peft_model_path = "./outputs/gemma-new/checkpoint-500"


base_model = AutoModelForCausalLM.from_pretrained(base_model_path, return_dict=True,  device_map=device, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# print(model)

# 加载LoRA模型(基础模型+微调模型)
merged_model = PeftModel.from_pretrained(base_model, peft_model_path)
# print(model)

# 测试1
text = "Quote: Imagination is more"
inputs = tokenizer(text, return_tensors="pt").to(device)

参考资料

  • google/gemma-7b官方示例:https://huggingface.co/google/gemma-7b/blob/main/examples/notebook_sft_peft.ipynb

  • 使用 Hugging Face 微调 Gemma 模型

  • 【AI大模型】Transformers大模型库(八):大模型微调之LoraConfig

  • 【机器学习】QLoRA:基于PEFT亲手量化微调Qwen2大模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1951119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3计算属性终极实战:可媲美Element Plus Tree组件研发之节点勾选

前面完成了JuanTree组件的节点编辑和保存功能后&#xff0c;我们把精力放到节点勾选功能实现上来。**注意&#xff0c;对于组件的开发者来说&#xff0c;要充分考虑用户的使用场景&#xff0c;组件提供的多个特性同时启用时必须要工作良好。**就拿Tree组件来说&#xff0c;用户…

数据库(MySQL)-视图、存储过程、触发器

一、视图 视图的定义、作用 视图是从一个或者几个基本表&#xff08;或视图&#xff09;导出的表。它与基本表不同&#xff0c;是一个虚表。但是视图只能用来查看表&#xff0c;不能做增删改查。 视图的作用&#xff1a;①简化查询 ②重写格式化数据 ③频繁访问数据库 ④过…

如何学习Doris:糙快猛的大数据之路(从入门到专家)

引言:大数据世界的新玩家 还记得我第一次听说"Doris"这个名字时的情景吗?那是在一个炎热的夏日午后,我正在办公室里为接下来的大数据项目发愁。作为一个刚刚跨行到大数据领域的新手,我感觉自己就像是被丢进了深海的小鱼—周围全是陌生的概念和技术。 就在这时,我的…

江苏科技大学24计算机考研数据速览,有专硕复试线大幅下降67分!

江苏科技大学&#xff08;Jiangsu University of Science and Technology&#xff09;&#xff0c;坐落在江苏省镇江市&#xff0c;是江苏省重点建设高校&#xff0c;江苏省人民政府与中国船舶集团有限公司共建高校&#xff0c;国家国防科技工业局与江苏省人民政府共建高校 &am…

pyqt designer使用spliter

1、在designer界面需要使用spliter需要父界面不使用布局&#xff0c;减需要分割两个模块选中&#xff0c;再点击spliter分割 2、在分割后&#xff0c;再对父界面进行布局设置 3、对于两边需要不等比列放置的&#xff0c;需要套一层 group box在最外层进行分割

Linux系统:date命令

1、命令详解&#xff1a; date 命令可以用来显示或设定系统的日期与时间。 2、官方参数&#xff1a; -d, --dateSTRING 通过字符串显示时间格式&#xff0c;字符串不能是now。-f, --fileDATEFILE 类似 --date 在 DATEFILE 的每一行生效-I[FMT], --iso-8601[FMT…

Redis的使用场景、持久化方式和集群模式

1. Redis的使用场景 热点数据的缓存 热点数据&#xff1a;频繁读取的数据 限时任务的操作。比如短信验证码 完成session共享的问题。因为前后端分离 完成分布式锁 商品的销售量 2. Redis的持久化方式 2.1 什么是持久化 把内存中的数据存储到磁盘的过程。同时也可以把磁盘中…

Python中的Numpy库使用方法

numpy Ndarry和创建数组的方式 NumPy数组&#xff08;ndarray&#xff09;是NumPy库的核心数据结构&#xff0c;它是一系列同类型数据的集合&#xff0c;以 0 下标为开始进行集合中元素的索引。 ndarray本质上是一个存放同类型元素的多维数组&#xff0c;其中的每个元素在内存…

TransformerEngine

文章目录 一、关于 TransformerEngine &#xff1f;亮点 二、使用示例PyTorchJAXFlax 三、安装先决条件Dockerpip从源码使用 FlashAttention-2 编译 四、突破性的变化v1.7: Padding mask definition for PyTorch 五、FP8 收敛六、集成七、其它贡献论文视频最新消息 一、关于 Tr…

美团大众点评字符验证码

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 前言(…

为什么优秀员工往往最先离职?

在企业管理中有很多误区&#xff0c;令企业流失优秀员工和人才&#xff0c;根据优思学院过往的经验&#xff0c;大致可以分为以下几个情况。 1. 忽视帕累托法则&#xff08;80/20法则&#xff09; 帕累托法则&#xff08;80/20法则&#xff09;是六西格玛管理的基本原则&…

好的STEM编程语言有哪些?

STEM是科学&#xff08;Science&#xff09;&#xff0c;技术&#xff08;Technology&#xff09;&#xff0c;工程&#xff08;Engineering&#xff09;&#xff0c;数学&#xff08;Mathematics&#xff09;四门学科英文首字母的缩写&#xff0c;STEM教育简单来说就是在通过在…

django_创建菜单(实现整个项目的框架,调包)

文章目录 前言代码仓库地址在线演示网址启动网站的时候出现错误渲染路径的一些说明文件结构网页显示一条错误路由顺序js打包出现问题的代码函数没有起作用关于进度开发细节显示不了图片梳理一下函数调用的流程修改一些宽度参数classjs 里面的一些细节让三个按钮可以点击设置按钮…

前端JS特效第56集:基于canvas的粒子文字动画特效

基于canvas的粒子文字动画特效&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compat…

GPT-4O 的实时语音对话功能在处理多语言客户时有哪些优势?

最强AI视频生成&#xff1a;小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量 我瞄了一眼OpenAI春季发布会&#xff0c;这个发布会只有26分钟&#xff0c;你可以说它是一部科幻短片&#xff0c;也可以说它过于“夸夸其谈”&#xff01;关于…

5个工具帮助你轻松将PDF转换成WORD

有时候编辑PDF文件确实不如编辑word文档方便&#xff0c;很多人便会选择先转换再编辑。但是如果还有人不知道要怎么将PDF文件转换成word文档的话&#xff0c;可以看一下这5款工具&#xff0c;各种类型的都有&#xff0c;总有一款可以帮助到你。 &#xff11;、福昕PDF转换软件 …

socket实现全双工通信,多个客户端接入服务器端

socket实现全双工通信 客户端&#xff1a; #define IP "192.168.127.80" //服务器IP地址 #define PORT 7266 // 服务器端口号int main(int argc, const char *argv[]) {//1.创建套接字&#xff1a;用于接收客户端链接请求int sockf…

MSQP Mysql数据库权限提升工具,UDF自动检测+快速反向SHELL

项目地址:https://github.com/MartinxMax/MSQP MSQP 这是一个关于Mysql的权限提升工具 安装依赖 $ python3 -m pip install mysql-connector-python 使用方法 $ python3 msqp.py -h 权限提升:建立反向Shell 在建立反向连接前,该工具会自动检测是否具有提权条件&#xff0…

4-4 数值稳定性 + 模型初始化和激活函数

数值稳定性 这里的 t t t表示层&#xff0c;假设 h t − 1 h^{t-1} ht−1是第 t − 1 t-1 t−1层隐藏层的输出&#xff0c;经过一个 f t f_{t} ft​得到第 t t t层隐藏层的输出 h t h^{t} ht。 y y y表示 x x x进来&#xff0c;第一层一直到第 d d d层&#xff0c;最后到一个损…

2024最新网络安全自学路线,内容涵盖3-5年技能提升

01 什么是网络安全 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域&#xff0c;都有攻与防两面…