几种预训练模型微调方法和peft包的使用介绍

news2024/11/18 20:47:09

文章目录

  • 微调方法
    • Lora(在旁边添加训练参数)
    • Adapter(在前面添加训练参数)
    • Prefix-tuning(在中间添加训练参数)
    • Prompt tuning
  • PEFT
    • PEFT 使用
      • PeftConfig
      • PeftModel
      • 保存和加载模型

微调方法

现流行的微调方法有:Lora,prompt,p-tunning v1,p-tunning v2,prefix,adapter等等,下面抱着学习的心态进行宏观层面的介绍
如有错误,欢迎指出

Lora(在旁边添加训练参数)

LoRA(Low-Rank Adaptation)是一种技术,通过低秩分解将权重更新表示为两个较小的矩阵(称为更新矩阵),从而加速大型模型的微调,并减少内存消耗。

为了使微调更加高效,LoRA的方法是通过低秩分解,使用两个较小的矩阵(称为更新矩阵)来表示权重更新。这些新矩阵可以通过训练适应新数据,同时保持整体变化的数量较少。原始的权重矩阵保持冻结,不再接收任何进一步的调整。为了产生最终结果,同时使用原始和适应后的权重进行合并。
在这里插入图片描述
假设权重的更新在适应过程中也具有较低的“内在秩”。对于一个预训练的权重矩阵W0 ∈ Rd×k,我们通过低秩分解W0 + ∆W = W0 + BA来表示其更新,其中B ∈ Rd×r,A ∈ Rr×k,且秩r ≤ min(d,k)。在训练过程中,W0被冻结,不接收梯度更新,而A和B包含可训练参数。需要注意的是,W0和∆W = BA都与相同的输入进行乘法运算,它们各自的输出向量在坐标上求和。前向传播公式如下:h = W0x + ∆Wx = W0x + BAx

在上图中我们对A使用随机高斯随机分布初始化,对B使用零初始化,因此在训练开始时∆W = BA为零。然后,通过αr对∆Wx进行缩放,其中α是r中的一个常数。在使用Adam优化时,适当地缩放初始化,调整α的过程与调整学习率大致相同。因此,只需将α设置为我们尝试的第一个r,并且不对其进行调整。这种缩放有助于在改变r时减少重新调整超参数的需求。

Adapter(在前面添加训练参数)

2019 年,Houlsby N 等人将 Adapter 引入 NLP 领域,作为全模型微调的一种替代方案。Adapter 主体架构下图所示。
在这里插入图片描述
在预训练模型每一层(或某些层)中添加 Adapter 模块(如上图左侧结构所示),微调时冻结预训练模型主体,由 Adapter 模块学习特定下游任务的知识。每个 Adapter 模块由两个前馈子层组成,第一个前馈子层将 Transformer 块的输出作为输入,将原始输入维度 d 投影到 m,通过控制 m 的大小来限制 Adapter 模块的参数量,通常情况下 m<<d。在输出阶段,通过第二个前馈子层还原输入维度,将 m 重新投影到 d,作为 Adapter 模块的输出(如上图右侧结构)。通过添加 Adapter 模块来产生一个易于扩展的下游模型,每当出现新的下游任务,通过添加 Adapter 模块来避免全模型微调与灾难性遗忘的问题。Adapter 方法不需要微调预训练模型的全部参数,通过引入少量针对特定任务的参数,来存储有关该任务的知识,降低对模型微调的算力要求。

Prefix-tuning(在中间添加训练参数)

前缀微调(prefix-tunning),用于生成任务的轻量微调。前缀微调将一个连续的特定于任务的向量序列添加到输入,称之为前缀,如下图中的红色块所示。与提示(prompt)不同的是,前缀完全由自由参数组成,与真正的 token 不对应。相比于传统的微调,前缀微调只优化了前缀。因此,我们只需要存储一个大型 Transformer 和已知任务特定前缀的副本,对每个额外任务产生非常小的开销。

Prompt tuning

提示通过包括描述任务的文本提示或甚至演示任务示例的文本提示来为特定的下游任务准备一个冻结的预训练模型。具体的,给每个任务定义 Prompt,拼接到数据上作为输入,同时 freeze 预训练模型进行训练,在没有加额外层的情况下,可以看到随着模型体积增大效果越来越好,最终追上了精调的效果。

提示方法可以分为两类:

  • 硬提示(Hard Prompts):手工制作的具有离散输入标记的文本提示;缺点是需要花费很多精力来创建一个好的提示。
  • 软提示(Soft Prompts):可与输入嵌入连接并进行优化以适应数据集的可学习张量;缺点是它们不太易读,因为您不是将这些“虚拟标记”与实际单词的嵌入进行匹配。

PEFT

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调),是一个用于在不微调所有模型参数的情况下,高效地将预训练语言模型(PLM)适应到各种下游应用的库。

PEFT方法仅微调少量(额外的)模型参数,显著降低了计算和存储成本,因为对大规模PLM进行完整微调的代价过高。最近的最先进的PEFT技术实现了与完整微调相当的性能。

代码:
https://github.com/huggingface/peft
文档:
https://huggingface.co/docs/peft/index

PEFT 使用

接下来将展示 PEFT 的主要特点,并帮助在消费设备上通常无法访问的情况下训练大型预训练模型。您将了解如何使用LoRA来训练1.2B参数的bigscience/mt0-large模型,以生成分类标签并进行推理。

PeftConfig

每个 PEFT 方法由一个PeftConfig类来定义,该类存储了用于构建PeftModel的所有重要参数。

由于您将使用LoRA,您需要加载并创建一个LoraConfig类。在LoraConfig中,指定以下参数:

task_type,在本例中为序列到序列语言建模
inference_mode,是否将模型用于推理
r,低秩矩阵的维度
lora_alpha,低秩矩阵的缩放因子
lora_dropout,LoRA层的dropout概率
from peft import LoraConfig, TaskType
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,请参阅LoraConfig参考。

PeftModel

使用 get_peft_model() 函数可以创建PeftModel。它需要一个基础模型 - 您可以从 Transformers 库加载 - 以及包含配置特定 PEFT 方法的PeftConfig。

首先加载您要微调的基础模型。

from transformers import AutoModelForSeq2SeqLM

model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解您模型中可训练参数的数量,可以使用print_trainable_parameters方法。在这种情况下,您只训练了模型参数的0.19%!

from peft import get_peft_model

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
输出示例: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

至此,我们已经完成了!现在您可以使用Transformers的Trainer、 Accelerate,或任何自定义的PyTorch训练循环来训练模型。

保存和加载模型

在模型训练完成后,您可以使用save_pretrained函数将模型保存到目录中。您还可以使用push_to_hub函数将模型保存到Hub(请确保首先登录您的Hugging Face帐户)。

model.save_pretrained("output_dir")

如果要推送到Hub

from huggingface_hub import notebook_login

notebook_login()
model.push_to_hub("my_awesome_peft_model")

这只保存了已经训练的增量PEFT权重,这意味着存储、传输和加载都非常高效。例如,这个在RAFT数据集的twitter_complaints子集上使用LoRA训练的bigscience/T0_3B模型只包含两个文件:adapter_config.json和adapter_model.bin,后者仅有19MB!

使用from_pretrained函数轻松加载模型进行推理:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

参考文章:
http://lihuaxi.xjx100.cn/news/1428773.html?action=onClick

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1082008.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySql】5- 实践篇(三)

文章目录 1. 日志和索引问题1. 日志相关问题1.1 两阶段提交 2. 业务设计相关问题 2. order by工作原理2.1 全字段排序2.2 rowid 排序2.3 全字段排序 VS rowid 排序 3. 正确显示随机消息3.1 内存临时表3.2 磁盘临时表3.3 随机排序方法 1. 日志和索引问题 1. 日志相关问题 1.1 …

NodeJs内置模块child_process

内置模块child_process子进程 写在前面 子进程是Nodejs的核心Api&#xff0c;如果你会shell命令&#xff0c;它会有非常大的帮助&#xff0c;或者你喜欢编写前端工程化工具之类&#xff0c;它也有很大的用处&#xff0c;以及处理CPU密集型应用。 创建子进程 Nodejs创建子进…

如何处理用户输入验证和表单提交?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

PCL点云处理之点云重建为Mesh模型并保存到PLY文件 ---方法二 (二百一十一)

PCL点云处理之点云重建为Mesh模型并保存到PLY文件 ---方法二 (二百一十一) 一、算法介绍二、算法实现1.代码2.效果一、算法介绍 离散点云重建为mesh网格模型,并保存到PLY文件中,用于其他软件打开查看,代码非常简短,复制粘贴即可迅速上手使用,具体参数根据自己的点云数据…

Unity可视化Shader工具ASE介绍——5、ASE快捷键和常用节点介绍

大家好&#xff0c;我是阿赵。   继续介绍Unity可视化Shader插件ASE。这次来说一些常用节点的快捷键&#xff0c;顺便介绍一些常用的节点。   用过UE引擎的朋友可能会发现&#xff0c;ASE的整体用法和UE的材质节点编辑器非常的像&#xff0c;甚至连很多节点的快捷键都和UE的…

【Vue面试题十六】、Vue.observable你有了解过吗?说说看

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;Vue.observable你有了解…

Python笔记;库,包,模块

在Python中库没有官方说法。 是其他地方沿用过来的。 姑且认为他是一个包或多个包的集合。 包里有子包和模块。 模块以.py格式存储。 下图是一个例子&#xff0c;对于Robot包&#xff1a; import math a math.sqrt(9) 等价于 from math import * a sqrt(9) from math im…

【数据库——MySQL(实战项目1)】(2)图书借阅系统——数据库测试、视图以及存储过程

目录 1. 简述2. 数据表 增、删、改 测试2.1 借阅人表2.2 图书表2.3 借阅信息表 3. 功能代码3.1 创建视图显示所有逾期未归还的借阅信息&#xff08;包括借阅人姓名&#xff0c;借阅人类别&#xff0c;书名&#xff0c;借出日期&#xff0c;应归还日期&#xff0c;逾期时长&…

二叉搜索树--新增节点-力扣 701 题

例题细节二叉搜索树的基础操作-CSDN博客也讲过了&#xff08;put&#xff09;&#xff0c;下面给出递归实现 public TreeNode insertIntoBST(TreeNode node, int val) {//找到空位了if(node null) {return new TreeNode(val);}if(val < node.val) {//一直找到有null的位置…

草柴返利APP如何查询领取天猫内部隐藏优惠券购物拿天猫返利?

草柴返利APP是一种简单、快捷的购物省钱工具&#xff0c;可以帮助你在天猫上查询并领取内部隐藏优惠券&#xff0c;确认收货后拿购物返利。草柴返利APP可以轻松查询到天猫优惠券&#xff0c;让你购物更加方便&#xff0c;享受更多的折扣优惠。 草柴返利APP如何查询领取天猫优惠…

2、使用阿里云镜像加速器提升Docker的资源下载速度

1、注册阿里云账号并登录 https://www.aliyun.com/ 2、进入个人控制台&#xff0c;找到“容器镜像服务” 3、在“容器镜像服务”中找到“镜像加速器” 4、在右侧列表中会显示你的加速器地址&#xff0c;复制地址 5、进入/etc/docker目录&#xff0c;编辑daemon.json&#xff0…

jumpserver如何录入web资产

需要部署远程应用发布机&#xff0c;此机器需新建一台Windows机器&#xff0c;不要加域 本次环境&#xff1a;Windows 2019 server标准版&#xff0c;8U16G 系统设置-远程应用 设置完成后提交。 此发布机上需预先安装openssh&#xff0c;否则jumpserver无法部署应用发布机 …

第二章 进程与线程 二十、死锁的处理策略(预防死锁、避免死锁、死锁的检测和解除)

目录 一、分类 二、预防死锁 1、破坏互斥条件 2、破坏不剥夺条件 3、破坏请求和保持条件 4、破坏循环等待条件 5、总结 三、避免死锁 1、什么是安全序列 2、安全状态和不安全状态 3、银行家算法 &#xff08;1&#xff09;核心思想 &#xff08;2&#xff09;例子 …

【ElasticSearch】使用 Java 客户端 RestClient 实现对文档的查询操作,以及对搜索结果的排序、分页、高亮处理

文章目录 前言&#xff1a;RestClient 查询文档的 RestAPI一、全文检索查询1.1 match_all 查询1.2 match 查询1.3 multi_match 查询 二、精确查询2.1 term 查询2.2 range 查询 三、复合查询&#xff1a;Boolean 查询与 function score 查询的综合案例四、对查询结果的处理4.1 将…

关于Qualifier你要知道的二三事

&#x1f35e; Qualifier注解的作用-定义Bean-指定Bean的名称 Qualifier注解可以区分具有相同类型的多个Bean&#xff0c;用于明确指定要注入的Bean的名称或限定符。通过为要注入的Bean添加 Qualifier注解&#xff0c;你可以告诉Spring应该使用哪个Bean&#xff0c;以解决Spri…

黑马JVM总结(三十一)

&#xff08;1&#xff09;类加载器-概述 启动类加载器-扩展类类加载器-应用程序类加载器 双亲委派模式&#xff1a; 类加载器&#xff0c;加载类的顺序是先依次请问父级有没有加载&#xff0c;没有加载自己才加载&#xff0c;扩展类加载器在getParent的时候为null 以为Boots…

zabbix监控实战1

1、zabbix监控平台部署 重新克隆纯净虚拟机 数据库初始化 修改密码为WHqwerty123 初始化完成 创建zabbix数据库 基础配置和服务启动 访问 2、zabbix添加监控节点 修改字体文件 在客户端 手动添加监控节点 自动添加监控节点 3、zabbix api 自动注册 停掉自动发现 删掉serve…

VMware 下模拟软 RAID 的创建及故障恢复

Author&#xff1a;rab 目录 前言一、创建 RAID1.1 环境1.2 什么是 RAID&#xff1f;1.3 软 RAID 和硬 RAID1.4 如何创建软 RAID&#xff1f; 二、故障模拟与数据恢复2.1 故障模拟2.2 故障恢复 思考&#xff1f; 前言 一块物理硬盘要投入生产使用&#xff0c;一般会经历一下这…

面试经典 150 题 2 —(滑动窗口)— 3. 无重复字符的最长子串

3. 无重复字符的最长子串 方法 class Solution { public:int lengthOfLongestSubstring(string s) {int result 0, length s.length();int start 0, end 0;while(end < length){// 发现有重复字符时&#xff0c;可以直接把左指针移动到第一个重复字符的下一个位置for(i…

Web前端-Vue2+Vue3基础入门到实战项目-Day3(生命周期, 案例-小黑记账清单, 工程化开发入门)

Web前端-Vue2Vue3基础入门到实战项目-Day3 生命周期生命周期 & 生命周期四个阶段生命周期钩子生命周期案例created应用mounted应用 案例 - 小黑记账清单工程化开发入门工程化开发和脚手架项目运行流程index.htmlmain.js 组件化组件注册局部注册全局注册 来源 生命周期 生命…