用通俗的方法讲解:大模型微调训练详细说明(附理论+实践代码)

news2025/1/11 12:37:12

本文内容如下

  • 介绍了大模型训练的微调方法,包括prompt tuning、prefix tuning、LoRA、p-tuning和AdaLoRA等。

  • 介绍了使用deepspeed和LoRA进行大模型训练的相关代码。

  • 给出了petals的介绍,它可以将模型划分为多个块,每个用户的机器负责其中一块,分摊了计算压力。

理解篇

prompt tuning

图片

固定预训练参数,为每一个任务额外添加一个或多个embedding,之后拼接query正常输入LLM,并只训练这些embedding。左图为单任务全参数微调,右图为prompt tuning。

图片

  • 标准的T5模型(橙色线)多任务微调实现了强大的性能,但需要为每个任务存储单独的模型副本。

  • prompt tuning也会随着参数量增大而效果变好,同时使得单个冻结模型可重复使用于所有任务。

  • 显著优于使用GPT-3进行fewshot prompt设计。

  • 当参数达到100亿规模与全参数微调方式效果无异。

代码样例:

from peft import PromptTuningConfig, get_peft_model
peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

prefix tuning

图片

prefix tuning依然是固定预训练参数,但除为每一个任务额外添加一个或多个embedding之外,利用多层感知编码prefix,注意多层感知机就是prefix的编码器,不再像prompt tuning继续输入LLM。

embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
transform = torch.nn.Sequential(
    torch.nn.Linear(token_dim, encoder_hidden_size),
    torch.nn.Tanh(),
    torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
)

在三个数据集中prefix和全参数微调的表现对比:

图片

代码样例:

peft_config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=20)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

LoRA

图片

LoRA冻结了预训练模型的参数,并在每一层decoder中加入dropout+Linear+Conv1d额外的参数

那么,LoRA是否能达到全参数微调的性能呢?

根据实验可知,全参数微调要比LoRA方式好的多,但在低资源的情况下也不失为一种选择

图片

细致到每个任务中的差距如下图:

图片

代码样例:

peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

p-tuning

图片

手动尝试最优的提示无异于大海捞针,于是便有了自动离散提示搜索的方法(作图),但提示是离散的,神经网络是连续的,所以寻找的最优提示可能是次优的。p-tuning依然是固定LLM参数,利用多层感知机和LSTM对prompt进行编码,编码之后与其他向量进行拼接之后正常输入LLM。注意,训练之后只保留prompt编码之后的向量即可,无需保留编码器。

self.lstm_head = torch.nn.LSTM(
                    input_size=self.input_size,
                    hidden_size=self.hidden_size,
                    num_layers=num_layers,
                    dropout=lstm_dropout,
                    bidirectional=True,
                    batch_first=True,
  )

self.mlp_head = torch.nn.Sequential(
    torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
    torch.nn.ReLU(),
    torch.nn.Linear(self.hidden_size * 2, self.output_size),
)
self.mlp_head(self.lstm_head(input_embeds)[0])

以上代码可清晰展示出prompt编码器的结构。

图片

如上图所示,GPT在P-tuning的加持下可达到甚至超过BERT在NLU领域的性能。下图是细致的对比:图片

MP: Manual prompt

FT: Fine-tuning

MP+FT: Manual prompt augmented fine-tuning

PT: P-tuning

代码样例:

peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

p-tuning v2

图片

p-tuning的问题是在小参数量模型上表现差(如上图所示),于是有了V2版本,类似于LoRA每层都嵌入了新的参数(称之为Deep FT),下图中开源看到p-tuning v2 集合了多种微调方法。p-tuning v2 在多种任务上下进行微调,之后对于不同的任务如token classification与sentence classification添加了随机初始化的任务头(AutoModelForTokenClassification、AutoModelForSequenceClassification),而非使用自然语言的方式,可以说V2是集大成者。

图片

KP: Knowledge Probe,知识探针,用于检测LLM的世界知识掌握能力:https://github.com/facebookresearch/LAMA

SeqTag: Sequence Tagging,如抽取式问答、命名实体识别

Re-param.:Reparameterization,对提示词做单独的编码器

No verb.: No verbalizer,不直接使用LLM head而接一个随机初始化的linear head

以下表格对比了[CLS] label linear head 和 verbalizer with LM head,[CLS] label linear head的方式药略好。

图片

v1到v2的可视化:蓝色部分为参数冻结,橙色部分为可训练部分

图片

下图中对比了FT、PT、PT-2三种方法,粗体为性能最好的,下划线为性能次好的。

图片

代码样例:

peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

AdaLoRA

预训练语言模型中的不同权重参数对下游任务的贡献是不同的。因此需要更加智能地分配参数预算,以便在微调过程中更加高效地更新那些对模型性能贡献较大的参数。

具体来说,通过奇异值分解将权重矩阵分解为增量矩阵,并根据新的重要性度量动态地调整每个增量矩阵中奇异值的大小。这样可以使得在微调过程中只更新那些对模型性能贡献较大或必要的参数,从而提高了模型性能和参数效率。

详细的算法如下:

图片

对比不同方法的性能:

图片

代码样例:

peft_config = AdaLoraConfig(peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],lora_dropout=0.01)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

代码篇

注:以下代码在pytorch 1.12.1版本下运行,其他包都是最新版本

deepspeed

官方的demo所需要的配置如下:
在这里插入图片描述

注意到官方给的样例单卡V100只能训练13亿规模的模型,如果换成67亿是否能跑起来呢?

按照官方文档搭建环境:

pip install deepspeed>=0.9.0

git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-Chat/
pip install -r requirements.txt

请注意如果你之前装了 deepspeed,请更新至0.9.0

试试全参数微调,这毫无疑问OOM

deepspeed --num_gpus 1 main.py \
      --data_path Dahoas/rm-static \
      --data_split 2,4,4 \
      --model_name_or_path facebook/opt-6.5b \
       --gradient_accumulation_steps 2 \
     --lora_dim 128 \
     --zero_stage 0 \
       --deepspeed \
      --output_dir $OUTPUT \
      &> $OUTPUT/training.log

答案是:我们需要卸载,这次便能愉快的run起来了

deepspeed main.py \
   --data_path Dahoas/rm-static \
   --data_split 2,4,4 \
   --model_name_or_path facebook/opt-6.7b \
   --per_device_train_batch_size 4 \
   --per_device_eval_batch_size 4 \
   --max_seq_len 512 \
   --learning_rate 9.65e-6 \
   --weight_decay 0.1 \
   --num_train_epochs 2  \
   --gradient_accumulation_steps 1 \
   --lr_scheduler_type cosine \
   --num_warmup_steps 0 \
   --seed 1234 \
   --lora_dim 128 \
   --gradient_checkpointing \
   --zero_stage 3 \
   --deepspeed \
   --output_dir $OUTPUT_PATH \
   &> $OUTPUT_PATH/training.log

可以加上LoRA

deepspeed --num_gpus 1 main.py \
   --data_path Dahoas/rm-static \
   --data_split 2,4,4 \
   --model_name_or_path facebook/opt-6.7b \
   --per_device_train_batch_size 8 \
   --per_device_eval_batch_size 8 \
   --max_seq_len 512 \
   --learning_rate 1e-3 \
   --weight_decay 0.1 \
   --num_train_epochs 2 \
   --gradient_accumulation_steps 16 \
   --lr_scheduler_type cosine \
   --num_warmup_steps 0 \
   --seed 1234 \
   --gradient_checkpointing \
   --zero_stage 0 \
   --lora_dim 128 \
   --lora_module_name decoder.layers. \
   --deepspeed \
   --output_dir $OUTPUT_PATH \
   &> $OUTPUT_PATH/training.log

peft

以下代码省略了数据处理

初始化

from datasets import load_dataset,load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer,default_data_collator
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model

MICRO_BATCH_SIZE = 1  
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3  
LEARNING_RATE = 3e-6  
CUTOFF_LEN = 256  
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

模型加载,并使用int8进行训练

model_path = "facebook/opt-6.7b"
output_dir = "model"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=True)
model = prepare_model_for_int8_training(model)  
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=None,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0  
data = load_from_disk("data")

训练与保存

trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    eval_dataset=data["validation"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=MICRO_BATCH_SIZE,
        per_device_eval_batch_size=MICRO_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=1000,
        num_train_epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        # bf16=True,  
        fp16=True,  
        logging_steps=1,
        output_dir=output_dir,
        save_total_limit=4,
    ),
    data_collator=default_data_collator,
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=False)
model.save_pretrained(output_dir)

直接这么启动当然会OOM,依然需要卸载

编写accelerate配置文件accelerate.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
dynamo_backend: 'yes'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
use_cpu: true

deepspeed配置文件:ds.json

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 500,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": 1e-8,
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 2e-05,
            "warmup_num_steps": 0
        }
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": false
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps":2,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": 4,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false
}

启动

accelerate launch --dynamo_backend=nvfuser  --config_file accelearte.yaml finetune.py

注:其他方法与Lora使用方法差距不大,不再赘述,在peft项目中均有代码样例。

顺便提一嘴:petals

图片

petals将模型划分为多个块,每个用户的机器负责其中一块,分摊了计算压力,类似于某磁力链接下载工具,利用hivemind库进行去中心化的训练与推理。当然你也可以创建自己局域网的群组,对自己独有的模型进行分块等自定义操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1276449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot校园周边美食探索及分享平台的设计与实现

摘要: 美食一直是与人们日常生活息息相关的产业。传统的电话订餐或者到店消费已经不能适应市场发展的需求。随着网络的迅速崛起,互联网日益成为提供信息的最佳俱渠道和逐步走向传统的流通领域,传统的美食业进而也面临着巨大的挑战&#xff0c…

Todesk 无法登录,无法联网

前言 我习惯用todesk远程ubuntu,但是突然发现掉线了,但是ssh还能连接 问题查找 1.ping 一下主机ip 2.ssh连接后,ping 一下百度,查看是否外网正常 3.输入一下命令 ps -ef | grep todesk #查看todesk 进程 sudo kill -9 ....…

红队攻防实战之某商城Getshell

此后如竟没有炬火,我便是唯一的光 信息收集 端口扫描 nmap -T4 -A -p 1-65535 可以看到目标系统开放22、80、888、3306、8800端口 敏感文件扫描 http:///admin/login.html 后台登陆地址泄露 漏洞挖掘 phpinfo信息泄露 phpinfo信息泄露,此站为Linu…

Linux环境下ARM开发

目录 前言ARM启动及开发基础1.Cortex-A架构2.启动方式3.汇编基础4.Makefile语法基础5.Makefile补充6.编译下载 结语 前言 主要介绍基于linux开发环境下,如何开发ARM A7 ARM启动及开发基础 1.Cortex-A架构 1)Cortex-A7运行模式 模式说明User(USR)用户模…

element中el-input限制只输入正整数或保留两位小数

文章目录 一、前言二、实现2.1、HTML2.2、只输入正整数2.3、只能输入数字或小数点 三、最后 一、前言 常见的el-input可能会限制用户只能输入正整数或保留两位小数,达到输入金额的需求,点击【跳转】访问el-input的官方文档 element-ui是有el-input-numb…

Git的介绍和下载安装

Git的介绍和下载安装 概述 Git是一个分布式版本控制工具, 通常用来管理项目中的源代码文件(Java类、xml文件、html页面等)进行管理,在软件开发过程中被广泛使用 Git可以记录文件修改的历史记录并形成备份从而实现代码回溯, 版本切换, 多人协作, 远程备份的功能Git具有廉价的…

VSC++=》 拆解整数对号入座重组

void 拆解整数对号入座重组(int& 数, bool 选 true) {int 对号[10]{}, j 选 ? 9 : 0, 反 0, 基 1;while (数)对号[数 % 10], 数 / 10;if (选)while (j > 0)if (对号[j])数 * 10, 数 j, (反 ? 基 * 10 : 0), 反 基 * j, --对号[j]; else --j;else while (j < …

UI自动化Selenium find_elements和find_element的区别

# 如果获取的element是list&#xff0c;那么需要用find_elements方法&#xff1b;此方法会返回list&#xff0c;然后使用len() 方法&#xff0c;计算对象的个数&#xff1b; # find_element方法返回的不是list对象&#xff0c;所以导致没办法计算对象个数 # 1.返回值类型不同…

go开发之个人微信号机器人开发

简要描述&#xff1a; 下载消息中的文件 请求URL&#xff1a; http://域名地址/getMsgFile 请求方式&#xff1a; POST 请求头Headers&#xff1a; Content-Type&#xff1a;application/jsonAuthorization&#xff1a;login接口返回 参数&#xff1a; 参数名必选类型…

okhttp导致的内存溢出(OOM)sun.security.ssl.SSLSocketImpl

使用分析工具&#xff1a;MAT(Memory Analyzer Tool)、JvisualVM占用内存&#xff1a;sun.security.ssl.SSLSocketImpl 一、 项目场景&#xff1a; 功能&#xff1a;一个定时任务(xxl-job)采用线程池的方式多线程请求第三方拉取数据&#xff0c;网络框架使用okhttp3。 问题&am…

改进YOLO5:结合CVPR2023最新 PConv |包含 YOLOv5 / YOLOv8 模型 YAML 文件

改进YOLO5:结合CVPR2023最新 PConv |包含 YOLOv5 / YOLOv8 模型 YAML 文件 一、论文总结PConv模块优势二、YOLOv51. yaml文件2. common代码文件三、YOLOv81. yaml2. modules文件添加3. Task文件4. 测试

MSSQL注入

目录 基本的UNION注入&#xff1a; 错误基于的注入&#xff1a; 时间基于的盲注入&#xff1a; 堆叠查询&#xff1a; 理解MSSQL注入是学习网络安全的一部分&#xff0c;前提是您在合法、授权的环境中进行&#xff0c;用于了解如何保护您的应用程序免受此类攻击。以下是有关…

蓝桥杯day03——二进制间距

1.题目 给定一个正整数 n&#xff0c;找到并返回 n 的二进制表示中两个 相邻 1 之间的 最长距离 。如果不存在两个相邻的 1&#xff0c;返回 0 。 如果只有 0 将两个 1 分隔开&#xff08;可能不存在 0 &#xff09;&#xff0c;则认为这两个 1 彼此 相邻 。两个 1 之间的距离…

安装获取mongodb

本文适合初学者或mongodb感兴趣的同学来准备学习测试环境&#xff0c;或本地临时开发环境。mongodb是一个对用户非常友好的数据库。这种友好&#xff0c;不仅仅体现在灵活的数据结构和操作工具&#xff0c;还有一把大大的羊毛可以薅&#xff0c; mongodb提供了一个云上3节点的复…

服务器和Linux ,安装R rstudio ,常用软件

服务器的基本概念&#xff1a; 如服务器的基本结构&#xff0c;节点&#xff0c;端口的概念等。 服务器的基本设置和管理&#xff1a; 如何配置新服务器&#xff0c; 如何管理服务器&#xff0c;如何分配账户并确保他们互不访问&#xff0c; 如何全局安装软件让所有人都可以…

基于python的FMCW雷达工作原理仿真

这篇文章将介绍如何使用python来实现FMCW工作原理的仿真&#xff0c;第1章内容将介绍距离检测原理&#xff0c;第2章内容会介绍速度检测原理。 第1章 第1部分: 距离检测原理 调制的连续波雷达通常也被叫做调频连续波&#xff08;FMCW&#xff09;雷达是一个使用频率调制来测量…

C#,数值计算——插值和外推,谢别德(Shep)插值方法的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// 谢别德插值方法 /// Object for Shepard interpolation using n points in dim dimensions. Call /// constructor once, then interp as many times as desired. /// &…

nodejs微信小程序+python+PHP金融产品销售系统的设计与实现-计算机毕业设计推荐

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

NSDT场景编辑器实现真数字孪生

在线工具推荐&#xff1a; 三维数字孪生场景工具 - GLTF/GLB在线编辑器 - Three.js AI自动纹理化开发 - YOLO 虚幻合成数据生成器 - 3D模型在线转换 - 3D模型预览图生成服务 1、什么是数字孪生&#xff1f; 数字孪生是资产或系统的实时虚拟模型&#xff0c;它使用来自连…

15、 深度学习之正向传播和反向传播

上一节介绍了训练和推理的概念,这一节接着训练和推理的概念讲一下,神经网络的正向传播和反向传播。 其实单看正向传播和反向传播这两个概念,很好理解。 正向传播(Forward Propagation)是指从输入层到输出层的数据流动过程,而反向传播(Backpropagation)是指数据从输出…