AIGC:【LLM(一)】——LoRA微调加速技术

news2024/11/22 23:50:40

文章目录

    • 一.微调方法
      • 1.1 Instruct微调
      • 1.2 LoRA微调
    • 二.LoRA原理
    • 三.LoRA使用

一.微调方法

Instruct微调和LoRA微调是两种不同的技术。

1.1 Instruct微调

Instruct微调是指在深度神经网络训练过程中调整模型参数的过程,以优化模型的性能。在微调过程中,使用一个预先训练好的模型作为基础模型,然后在新的数据集上对该模型进行微调。Instruct微调是一种通过更新预训练模型的所有参数来完成的微调方法,通过微调使其适用于多个下游应用。

1.2 LoRA微调

LoRA(Low-Rank Adaptation)微调冻结了预训练的模型权重,并将可训练的秩分解矩阵注入到 Transformer 架构的每一层,极大地减少了下游任务的可训练参数的数量。与Instruct微调相比,LoRA在每个Transformer块中注入可训练层,因为不需要为大多数模型权重计算梯度,大大减少了需要训练参数的数量并且降低了GPU内存的要求。研究发现,使用LoRA进行的微调质量与全模型微调相当,速度更快并且需要更少的计算。
基于LoRA的微调产生保存了新的权重,可以将生成的LoRA权重认为是一个原来预训练模型的补丁权重 。所以LoRA模型无法单独使用,需要搭配原模型,两者进行合并即可获得完整版权重。
在这里插入图片描述

二.LoRA原理

神经网络包含许多密集的层,这些层执行矩阵乘法。这些层中的权重矩阵通常具有全秩。当适应特定的任务时,预训练的语言模型往往具有较低的“instrisic dimension”,尽管随机投影到较小的子空间,但仍然可以有效地学习。
LoRA的实现原理在于,冻结预训练模型权重,并将可训练的秩分解矩阵注入到Transformer层的每个权重中,大大减少了下游任务的可训练参数数量。直白的来说,实际上是增加了右侧的“旁支”,也就是先用一个Linear层A,将数据从 d维降到r,再用第二个Linear层B,将数据从r变回d维。最后再将左右两部分的结果相加融合,得到输出的hidden_state。
如上图所示,左边是预训练模型的权重,输入输出维度都是d,在训练期间被冻结,不接受梯度更新。右边部分对A使用随机的高斯初始化,B在训练开始时为零,r是秩,会对△Wx做缩放 α/r。
在这里插入图片描述

Lora 性质 1:全面微调的推广
通过将 LoRA 秩 r 设置为预训练的权重矩阵的秩,大致恢复了完整微调的表达性。换句话说,当增加可训练参数的数量时,训练LoRA大致收敛于训练原始模型。

Lora 性质 2:没有额外的推断延迟
在生产中部署时,可以显式地计算和存储 W = W 0 + B A W = W_0 + BA W=W0+BA,并像往常一样执行推理,也就是将 LoRA 权重和原始模型权重合并,不增加任何的推断耗时 W 0 W_0 W0 B A BA BA 都是 R d × k R^{d×k} Rd×k。当需要切换到另一个下游任务时,可以通过减去 BA 然后添加不同的 B ′ A ′ B'A' BA 来恢复 W 0 W_0 W0,这是一个内存开销很小的快速操作。

LoRA的优势
1)一个预先训练好的模型可以被共享,并用于为不同的任务建立许多小的LoRA模块。可以冻结共享模型,并通过替换图中的矩阵A和B来有效地切换任务,大大降低了存储需求和任务切换的难度。
2)在使用自适应优化器时,LoRA使训练更加有效,并将硬件进入门槛降低了3倍,因为我们不需要计算梯度或维护大多数参数的优化器状态。相反,我们只优化注入的、小得多的低秩矩阵。
3)简单的线性设计允许在部署时将可训练矩阵与冻结权重合并,与完全微调的模型相比,在结构上没有引入推理延迟。
4)LoRA与许多先前的方法是正交的,可以与许多方法结合,如前缀调整。

三.LoRA使用

HuggingFace的PEFT(Parameter-Efficient Fine-Tuning)中提供了模型微调加速的方法,参数高效微调(PEFT)方法能够使预先训练好的语言模型(PLMs)有效地适应各种下游应用,而不需要对模型的所有参数进行微调。

对大规模的PLM进行微调往往成本过高,在这方面,PEFT方法只对少数(额外的)模型参数进行微调,基本思想在于仅微调少量 (额外) 模型参数,同时冻结预训练 LLM 的大部分参数,从而大大降低了计算和存储成本,这也克服了灾难性遗忘的问题,这是在 LLM 的全参数微调期间观察到的一种现象PEFT 方法也显示出在低数据状态下比微调更好,可以更好地泛化到域外场景。

例如,使用PEFT-lora进行加速微调的效果如下,从中我们可以看到该方案的优势:
在这里插入图片描述

## 1、引入组件并设置参数
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
import torch
from datasets import load_dataset
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm

## 2、搭建模型

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

## 3、加载数据
dataset = load_dataset("financial_phrasebank", "sentences_allagree")
dataset = dataset["train"].train_test_split(test_size=0.1)
dataset["validation"] = dataset["test"]
del dataset["test"]

classes = dataset["train"].features["label"].names
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

## 4、训练数据预处理
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=3, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


processed_datasets = dataset.map(
 
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

## 5、设定优化器和正则项
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

## 6、训练与评估

model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
 
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
 
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )
	    eval_epoch_loss = eval_loss / len(eval_dataloader)
	    eval_ppl = torch.exp(eval_epoch_loss)
	    train_epoch_loss = total_loss / len(train_dataloader)
	    train_ppl = torch.exp(train_epoch_loss)
	    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

## 7、模型保存
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

## 8、模型推理预测

from peft import PeftModel, PeftConfig
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
 
model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()

inputs = tokenizer(dataset["validation"][text_column][i], return_tensors="pt")
print(dataset["validation"][text_column][i])
print(inputs)
with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(outputs)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/489283.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flutter——最详细(TextField)使用教程

TextField简介 文本输入框,拥有复杂的属性。可指定控制器、文字样式、装饰线、行数限制、游标样式等。监听输入框变动事件。 使用场景: 搜索框,输入账号密码等 属性作用controller输入框监听器decoration输入框装饰属性textAlign内容对齐方式…

UE5.1.1 C++ 从0开始 (1.人物移动)

开个天坑,UE5.1.1的移动代码做了一个大更新,对于我这种万年蓝图然后正在转C的人来说可以说是个挑战也可以说是个更方便我去工作的一个点。同时斯坦福大学的那个教程的开头几个章节就不适用了,对于学习UE5.1.1的同学来说。所以我这里会尽量每天…

[230506] 2021年托福阅读真题第6篇|Water and Life on Mars|15:30~16:30|16:30~19:19

正确率:6/10 ​​​​​​​ Water and Life on Mars Paragraph 1: The question of life on Mars depends heavily on the characteristics of its air and water. Mars has a relatively thin and dry atmosphere, with a high percentage of carbon dioxide com…

想转行大数据,需要学习什么?

Python近段时间一直涨势迅猛,在各大编程排行榜中崭露头角,得益于它多功能性和简单易上手的特性,让它可以在很多不同的工作中发挥重大作用。 正因如此,目前几乎所有大中型互联网企业都在使用 Python 完成各种各样的工作&#xff0…

Spark大数据处理讲课笔记3.7 Spark任务调度

文章目录 零、本节学习目标一、有向无环图(一)DAG概念(二)实例讲解 二、Stage划分依据(一)两阶段案例(二)三阶段案例 三、RDD在Spark中的运行流程(一)RDD Obj…

buuctf7

目录 Crypto MD5 Url编码 看我回旋踢 web [极客大挑战 2019]BuyFlag​ [BJDCTF2020]Easy MD5 Crypto MD5 1.下载文件 2.md5在线解密 3.外包flag Url编码 使用url在线解码 看我回旋踢 下载,得到这串字符,搜一下synt编码 看到使用凯撒密码&#x…

2023-05-04:用go语言重写ffmpeg的scaling_video.c示例,用于实现视频缩放(Scaling)功能。

2023-05-04:用go语言重写ffmpeg的scaling_video.c示例,用于实现视频缩放(Scaling)功能。 答案2023-05-04: 这段代码实现了使用 libswscale 库进行视频缩放的功能。下面是程序的主要流程: 1.获取命令行参…

唐书计组第三章总线部分课后习题和解答

我自己的一些总结 总线周期分为哪四个阶段 申请分配阶段寻址阶段存数阶段结束阶段 总线分为哪四种通信方式 同步通信异步通信半同步通信分离式通信 总线有哪几种判优方式 链式查询 计数器定时查询 独立请求方式 计算数据传输率 3.14设总线的时钟频率为8MHz,一个总线周期…

(3)信号槽

目录 1.信号槽的概念 2.信号槽的连接 2.1自带信号 → 自带槽 2.2 自带信号 → 自定义槽 2.3 自定义信号 1.信号槽的概念 信号槽指的是信号函数与槽函数的连接,可以使用不同的对象通过信号槽连接在一起,从而实现对象之间的通信。 可以把信号槽的连接…

数字化经营3.0阶段,云徙科技如何定义“为增长而生”?

作者:Lucky 新时代风云变幻中,通过数字化转型,驱动业务增长、提升运营效率是企业升级的必由之路。如今,数字化经营也已经进入3.0阶段,企业对“人、货、场”三位一体的前端数字化的要求更高,行业也需要更有效…

Java设计模式-建造者模式

简介 建造者模式是一种创建型设计模式,用于将复杂对象的构建过程与其表示分离,使得同样的构建过程可以创建不同的表示。建造者模式通过将复杂对象的构建过程分解为多个简单的步骤来实现。 与其他创建型模式不同,建造者模式强调的是将构建过…

QML路径视图(The PathView)

路径视图(PathView)非常强大,但也非常复杂,这个视图由QtQuick提供。它创建了一个可以让子项沿着任意路径移动的视图。沿着相同的路径,使用缩放(scale),透明(opacity&…

nssctf (1)

[NISACTF 2022]popchains Happy New Year~ MAKE A WISH <?phpecho Happy New Year~ MAKE A WISH<br>;if(isset($_GET[wish])){ #通过get获取wish的值 并判断是不是空@unserialize($_GET[wish]); #反序列化wish } else{$a=new Road_is_Long; #实例化Road_is…

YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制

YOLOv5&#xff1a;添加SE、CBAM、CoordAtt、ECA注意力机制 前言前提条件相关介绍注意力机制SE添加SE注意力机制到YOLOv5 CBAM添加CBAM注意力机制到YOLOv5 CoordAtt添加CoordAtt注意力机制到YOLOv5 ECA添加ECA注意力机制到YOLOv5 参考 前言 记录在YOLOv5添加注意力机制&#xf…

原神3.2真端完整版架设教程

想必在座的各位都玩过这款游戏吧、开放世界的玩法、折磨人的剧情、做不完的任务、话多且烦人的派蒙、没眼看的伤害、贵到爆的抽卡、打不动的深渊、树脂刷空也刷不到想要的圣遗物、打不动的BOSS、这怎么受得了呀!反正我是受不了。废话不多说、教程开始。 准备工具: 一台16H 3…

【经典面试题】请使用C语言编程实现对IPV4地址的合法性判断

C语言编程实现对IPV4地址的合法性判断 有了解过我的朋友&#xff0c;可能有点印象&#xff0c;我在N年前的博客中&#xff0c;就写了这个主题&#xff0c;当时确实是工作中遇到了这个问题。本想着等工作搞完之后&#xff0c;就把这个问题的解决代码补上&#xff0c;结果一鸽&am…

MATLAB实现工业PCB电路板缺陷识别和检测

PCB&#xff08;PrintedCircuitBoard印刷电路板&#xff09;是电子产品中众多电子元器件的承载体&#xff0c;它为各电子元器件的秩序连接提供了可能&#xff0c;PCB已成为现代电子产品的核心部分。随着现代电子工业迅猛发展&#xff0c;电子技术不断革新&#xff0c;PCB密集度…

K8S常见异常事件与解决方案

集群相关 Coredns容器或local-dns容器重启 集群中的coredns组件发生重启(重新创建)&#xff0c;一般是由于coredns组件压力较大导致oom&#xff0c;请检查业务是否异常&#xff0c;是否存在应用容器无法解析域名的异常。 如果是local-dns重启&#xff0c;说明local-dns的性能…

fastai2 实现SSD

https://github.com/search?qfastaissd 有几个值得参考的代码&#xff0c;好好学习。 GitHub - Samjoel3101/SSD-Object-Detection: I am working on a SSD Object Detector using fastai and pytorch fastai2实现的SSD&#xff0c;终于找到了code。https://github.com/sidrav…

等保定级怎么做

Q25:现在还没做等保还来得及吗?有什么影响? 答:来得及。种一棵树,最好的时间是十年前,其次是现在。可先根据定级备案要求和流程,先向公安递交定级备案文件,测评与整改预算提上日程,在经费未落实前,可以先进行系统定级、差距分析、整改计划制订等工作。 根据《等保工…