【Python】科研代码学习:十二 PEFT(高效参数的训练,Adapter适配器)

news2024/9/28 3:18:49

【Python】科研代码学习:十二 PEFT

  • PEFT
    • 简单训练教程
    • 简单推理教程
    • Adapter 适配器
    • Merge Adapter
  • 架构关系

PEFT

  • 【HF官网-Doc-PEFT:API】
    首先日常问题,是什么,为什么,怎么用
    PEFT (Prameter-Efficient Fine-Tuning):参数高效的微调
    这里特指 HF 提供的 PEFT
    PEFT 让大的预训练模型可以很快适应到各种下游的任务中,并且没有进行全参微调,因为全参微调的时间、算力花费比较大。

简单训练教程

  • 两个很重要的模块:
    PeftConfig :提供 peft 的配置
    PeftModel:提供 peft 的模型
  • 最常见的是使用 LoRA (Low-Rank Adaptation ) 作为 PEFT 技术
    这里,PeftConfig 就使用了 LoraConfig
    然后给了一些必要的参数,比如任务类型,设定模式(训练还是推理),低秩矩阵的秩,和lora的俩参数:
from peft import LoraConfig, TaskType

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
  • 然后,加载一个预训练模型
    接着,使用 get_peft_model,把模型和 peft_config 传进去,变成 peftmodel
    我们发现,这里只用训练 0.19 % 0.19\% 0.19% 的参数
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_model

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-large")


model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
"output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282"
  • 然后直接提供 TrainingArgumentsTrainer 训练即可
training_args = TrainingArguments(
    output_dir="your-name/bigscience/mt0-large-lora",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()	
  • 保存部分,跟一般的模型一样。但它只存储那些额外训练的参数,因此保存后的文件很小。
model.save_pretrained("output_dir")

简单推理教程

  • 我们加载 peftmodel 的话,需要使用比如 AutoPeftModel
    同理,使用 .from_pretrained 方法加载
    其他步骤没啥区别
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
import torch

model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

model = model.to("cuda")
model.eval()
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=50)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

"Preheat the oven to 350 degrees and place the cookie dough in the center of the oven. In a large bowl, combine the flour, baking powder, baking soda, salt, and cinnamon. In a separate bowl, combine the egg yolks, sugar, and vanilla."

Adapter 适配器

  • Adapter-based 方法在冻结的注意力层和全连接层之后添加了额外的可训练参数
    这里简单介绍一下 PEFT 支持的几个 Adapter
  • LoRA (Low-Rank Adaptation):最受欢迎的一个PEFT方法
    主要是高秩到低秩的映射,然后再映射回高秩矩阵。
    一开始在NLP中,后来CV也有用
  • LoHa (Low-Rank Hadamard Product):使用了 Hadamard product 方法
    在CV中用,NLP中的嵌入层代码还没实现
  • LoKr (Low-Rankd Kronecker Product) :使用了 Kronecker Product 方法
    主要给 diffusion model 使用
    在这里插入图片描述
  • OFT (Orthogonal Finetuning):方法如下图
    一开始聚焦在微调阶段,预训练模型的生成能力
    在这里插入图片描述
  • Llama-Adapter:让 Llama 适配成接受指令模型 (instruction-following model)
    在这里插入图片描述
  • PEFT 库中,可以按照对应的模型和任务,选择想用的 Adapter
    不同的 Adapter 都有它自己的 SpecificPeftModelSpecificPeftConfig
    去查阅相关的参数即可。
    比较常用的有:
    IA3
    LoRA
    P-tuning
    Prefix tuning
    Prompt tuning
    在这里插入图片描述

Merge Adapter

  • 在实际过程中,由于基座模型和 adapter 适配器 分开加载,可能会遇到延迟问题
    这个时候,可以选择使用 merge_and_unload() 方法,把 adapter 权重与底座模型权重融合起来。这样的话,使用新的模型就和一开始单独的模型没有区别了。
  • 比如我使用的是 LoraAdapter,查阅该方法
    progressbar :是否显示进度条
    safe_merge:使用安全合并,检查适配器中是否有 Nan 权重
    adapter_names:要合并的适配器名字的列表
    在这里插入图片描述
  • 当然这些参数都可以用默认值。我们只要对 PeftModel 调用该方法即可返回合并后的 model 。
from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
model = PeftModel.from_pretrained(base_model, peft_model_id)
merged_model = model.merge_and_unload()

架构关系

  • 粗看上面关系有点乱,还是得看一下源码
    PeftModel 是从 torch.nn 继承过来的,按照不同的任务,使用不同的子类,比如 PeftModelForCausalLM
    LoRAModel 等,是从 BaseTuner 继承过来的,Tuner 也是继承自 torch.nn,但这个是按照使用不同的适配器分类的,并且它建议是使用 LoRAConfig,这个是 PeftConfig 的子类
  • PeftModel 更靠近 PretrainedModel,有 save_pretrained, from_pretrained 等方法。PeftModelForCausalLM 还有 generate 方法
    LoRAModel 更靠近 Adapter,有 merge_and_unload, delete_adapter 等方法
  • 它里面大部分的基类和使用到的网络几乎都是 torch.nn,因此大部分跟 PretrainedModel 可以接壤
  • 即根据我的查询,LoRAModel 等并不是 PeftModelForCausalLM / PeftModel 的子类(有待存疑)
    LoRAModel 来训练,PeftModel 来推理,是可以的。
    并且 LoRAModel 可以通过 merge_and_unload() 方法转成 torch.nn,也就相当于 PretrainedModel
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1517028.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在数据库中存储小数:FLOAT、DECIMAL还是BIGINT?

前言 这里还是用前面的例子: 在线机票订票系统的数据表设计。此时已经完成了大部分字段的设计,可能如下: CREATE TABLE flights ( flight_id INT AUTO_INCREMENT PRIMARY KEY, flight_number VARCHAR(10), departure_airport_code VARCHAR(3), arrival_air…

HAProxy——高性能负载均衡器

目录 一.常见的Web集群调度器 二.HAProxy基本介绍 1.HAProxy是什么? 2.HAProxy的特性 3.HAProxy常用的8种负载均衡调度算法 3.1 轮询:RR(Round Robin) 3.2 最小连接数:LC(Least Connections&#xff…

【TB作品】MSP430单片机,音乐播放器,四首音乐,八音盒,Proteus仿真

文章目录 题目要求仿真结果实验报告:基于MSP430单片机的八音盒设计实验目的实验设备实验原理总结 代码和仿真图 题目要求 八音盒 本设计利用MSP430单片机结合内部定时器及LED/LCD,设计一个八音盒,按下单键可以演奏预先设置的歌曲旋律。 基本要求: 使用LED/LCD显示器…

分销商城小程序开发可以为商家带来哪些好处

分销小程序的开发帮助商家更多地维系客户,市场竞争越来越激烈,各大商家争抢流量,拼命获客,小程序分销堪比商家的营销神器。 分销商城小程序是指商家通过小程序分销与分销商建立利润分享合作伙伴关系,允许分销商将参与小…

C语言例3-11:使用算术运算符的例子。

代码如下: int main(void) {int a12, b10;float c2.0, d0.5;double e6.5, f13.0;printf("-a %d\n",-a);printf("ab %d\n",ab);printf("a-b %d\n",a-b);printf("a*b %d\n",a*b);printf("a/b %d\n"…

第 7 场 小白入门赛

第5题 &#xff1a;兽之泪【算法赛】 AC_Code:C #include <iostream> #include <cstring> #include <algorithm> #include <vector> #include <queue> #include<stack> #include<cmath> #include <unordered_set> #include &…

【数据结构高阶】图

目录 一、图的基本概念 二、 图的存储结构 2.1 邻接矩阵 2.2.1 邻接矩阵存储模式的代码实现 2.2.2 邻接矩阵存储的优缺点 2.2 邻接表 2.2.1 无向图的邻接表 2.2.2 有向图的邻接表 2.2.3 邻接表存储模式的代码实现 2.2.4 邻接表存储的优缺点 三、图的遍历 3.1 图的…

[linux]信号处理:信号编码、基本API、自定义函数和集合操作的详解

一、信号的概述 1、定义 信号是 Linux 进程间通信的最古老的方式。信号是软件中断&#xff0c;它是在软件层次 上对中断机制的一种模拟&#xff0c;是一种异步&#xff08;不等待&#xff09;通信的方式 。信号可以导致一个正在运行的进程被 另一个正在运行的异步进程中断&a…

RHEL8部署baichuan2环境

前置 1、安装NVIDIA驱动 https://www.nvidia.cn/Download/index.aspx?langcn 阿里云 Alibaba Cloud Linux 3.2104 LTS 64位&#xff0c;需要选择RHEL8&#xff0c;如果没有RHEL8&#xff0c;则选最下面那个选择所有操作系统 点击搜索&#xff0c;下载这里有安装步骤&#x…

Datawhale【Sora原理与技术实战】| 学习笔记3

目录 一. 训练 Sora 模型二. 数据预处理三. 视频 VQVAE四. Diffusion Transformer 一. 训练 Sora 模型 Open-Sora 在下图中总结了 Sora 可能使用的训练流程&#xff1a; 链路: 二. 数据预处理 目前主流 LLM 框架缺乏针对 video 数据 统一便捷的管理和处理能力&#xff0c;…

天水麻辣烫:麻辣鲜香,地城风情尽在其中

天水麻辣烫&#xff0c;这道源自甘肃天水的地道美食&#xff0c;早已成为当地饮食文化中不可或缺的一部分。追溯其源头&#xff0c;它脱胎于上世纪80、90年代的麻辣粉&#xff0c;那时的麻辣粉&#xff0c;以土豆粉和土豆片为主&#xff0c;辅以香辣的油泼辣子&#xff0c;简单…

【C++ 】stack 和 queue

1. 标准库中的stack stack 的介绍&#xff1a; 1. stack是一种容器适配器&#xff0c;专门用在具有后进先出操作的上下文环境中&#xff0c;其删除只能从容器的一端进行 元素的插入与提取操作 2. stack是作为容器适配器被实现的&#xff0c;容器适配器即是对特定类封装作为其…

月结常见工单异常情况处理

1. 上月已经结算的工单&#xff0c;本月打开投料或者报工&#xff0c;或者增加产出 或者撤销报工修正报工 如果针对结算的订单&#xff0c;打开重新投料。 月末对工单重新结算&#xff0c;转出差异 KKS2单个处理&#xff08;KKS1集中处理&#xff09; 差异计算 KO88单个结算…

ThreadLocal基本原理

ThreadLocal基本原理 一、定义 ThreadLocal是java中所提供的线程本地存储机制&#xff0c;可以利用改机制将数据缓存在线程内部&#xff0c;该线程可以在任意时刻、任意方法中获取数据 二、底层原理 ThreadLocal底层是通过ThreadLocalMap来实现的&#xff0c;每个Thread对象中…

短剧APP系统开发:打造全新的掌上剧场体验

随着移动互联网的普及和人们娱乐方式的多样化&#xff0c;短剧已经成为现代人生活中不可或缺的一部分。为了满足用户对高质量、便捷观看短剧的需求&#xff0c;我们致力于开发一款功能全面、操作简便的短剧APP系统&#xff0c;为用户带来前所未有的掌上剧场体验。 一、系统开发…

AJAX 04 回调函数地狱和 Promise 链式调用、async 和 await、事件循环

AJAX 学习 AJAX 04 进阶01 同步代码和异步代码02 回调函数地狱和 Promise 链式调用(1) 回调函数地狱(2) Promise 链式调用(3) Promise 链式应用 03 async 和 await(1) async 和 await 使用(2) async函数和await捕获错误 04 事件循环-EventLoop(1) 事件循环(2) 事件循环练习(3) …

FREERTOS简介、移植和系统配置(基于STM32F103)

本文基础内容参考的是正点原子的FREERTOS课程。 这是基于HAL库的 正点原子手把手教你学FreeRTOS实时系统 这是基于标准库的 正点原子FreeRTOS手把手教学-基于STM32 基础知识&#xff0c;直接参考正点原子《FreeRTOS开发指南V1.1》基于标准库的&#xff0c;此处不再赘述。 本文…

SwiftUI的context Menu

SwiftUI的 context Menu 现在来演示一下如何使用 SwiftUI 的 Context Menu 。 代码&#xff1a; import SwiftUIstruct ContextMenuBootCamp: View {State var bgColor: Color .purplevar body: some View {VStack(alignment: .leading, spacing: 10.0) {Image(systemName: …

【LeetCode】升级打怪之路 Day 21:二叉树的最近公共祖先(LCA)问题

今日题目&#xff1a; 236. 二叉树的最近公共祖先1644. 二叉树的最近公共祖先 II235. 二叉搜索树的最近公共祖先 目录 LCA 问题LC 236. 二叉树的最近公共祖先 【classic】LC 1644. 二叉树的最近公共祖先 II 【稍有难度】LC 235. 二叉搜索树的最近公共祖先 ⭐⭐⭐ 今天做了几道有…

电源常用电路—驱动电路详解

数字电源控制核心对输入输出参数进行采集后&#xff0c;利用控制算法进行分析从而产生PWM控制信号&#xff0c;PWM信号将经过驱动电路的进行功率放大和隔离&#xff0c;随后接入功率开关器件从而完成电源的输出控制。本篇将主要针对电源的驱动电路进行讲解。 一、驱动电路概述…