预训练大模型LLM的PEFT之—— Prefix Tuning

news2025/2/24 9:04:00

简介

Prefix Tuning是2021.01提出来的,在它之前,我们使用prompt主要是人工设计模板或者自动化搜索模板,也就是prompt范式的第一阶段,就是在输入上加上prompt文本,再对输出进行映射。这种离散模板对模型的鲁棒性很差。所以后续的研究都将离散生成prompt方式转成连续的方式。

Prefix Tuning是在模型输入前添加一个连续的且是任务特定的向量序列,该序列称之为prefix,然后在训练的时候固定PLM的所有参数,只更新优化特定任务的prefix。

实现思路

1、与Full Fine Tuning对比

从下图上看,对于FFT,我们在微调的时候,需要更新所有的PLM的参数,训练时需要大量的数据,而且耗费的资源和时间比较多。

Prefix Tuning针对特定的任务,只更新前缀那部分的参数,训练时只需要很少的数据,训练速度也非常快。

2、自回归模型与Encoder-Decoder模型的实现区别:

对于类似于GPT-2的自回归模型,最终的结果 z = [PREFIX; x; y],参考下图的上半部分

对于类似于BART的Encoder-Decoder模型,最终的结果z = [PREFIX; x; PREFIX0 ; y],参考下图的下半部分

实验效果

从实验的效果看,还是非常不错的

在HF的PEFT中的实现

1、下面是代码实现

2、我们需要关注的仅仅是PrefixTuningConfig类的num_virtual_tokens

from peft import PromptTuningConfig, PromptTuningInit, get_peft_model
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

peft_config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=20)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
#"trainable params: 983,040 || all params: 560,197,632 || trainable%: 0.1754809274167014"

lr = 3e-2
num_epochs = 50

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

device = "cuda"
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

结论:

1、prefix部分到底使用多少个虚拟token,直接影响模型微调的参数量级,以及处理长文本的能力。默认的prefix长度为10

2、不同的前缀长度有不一样的性能表现,在一定程度上长度越长,prefix的效果越明显,但也可能出现降低的问题。

3、实验表明,prefix长度对推理速度影响不大,因为prefix上的attention计算是并行的。

4、前缀调整通过向输入序列插入特定于任务的前缀来修改模型的更多层,因此需要对更多参数进行微调

5、与Fine-tuning对比,Fine-tuning需要根据不同的任务训练所有的参数,所以所有的任务都需要copy一份model,但是prefix-tuning只需要优化prefix,所以只需要存储一个大的Transformer和一个学习得来的特定任务的prefix。

6、此外,prefix-tuning是模块化的:训练一个上游LM,引导下游的LM,而下游的LM保持不变。因此一个单一的LM可以一次性支持多个任务。所以基于prefix的体系结构可以使我们能再单个batch中处理多个用户的任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1477931.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【YOLO v5 v7 v8 小目标改进】RFB:组合不同大小的卷积核和扩张卷积来模拟人类视觉感受野的多尺度特性

RFB:组合不同大小的卷积核和扩张卷积来模拟人类视觉感受野的多尺度特性 提出背景RFB 原理空间感受野结构RFB-Net 小目标涨点YOLO v5 魔改YOLO v7 魔改YOLO v8 魔改 提出背景 当前表现最好的目标检测器依赖于深层CNN骨干网络,如ResNet-101和Inception&am…

qt5与qt6的cmake区别

文章目录 使用cmake构建qt项目,坑很多。一是本身就麻烦,二是,确实坑,因为不同的qtcreator版本,选了不同的kits(套件) 生成的CMakeList.txt文件也不一样。 如果可以的话都选择Qt6的相关选项&…

【C++】认识类和对象

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 文章目录 一、什么是面向对象?二、类的引入三、类的定义四、类的访问限定符与…

一文速览深度伪造检测(Detection of Deepfakes):未来技术的守门人

一文速览深度伪造检测(Detection of Deepfakes):未来技术的守门人 前言一、Deepfakes技术原理卷积神经网络(CNN):细致的艺术学徒生成对抗网络(GAN):画家与评审的双重角色…

异常网络下TCP的可靠服务机制(慢启动、拥塞避免、快重传、快恢复)

目录 TCP超时重传拥塞控制概述慢启动和拥塞避免下面讲解发送端如何判断拥塞发生。 快速重传和快速恢复 本文描述TCP在异常网络下的处理方式 以保证其可靠的数据传输的服务 TCP超时重传 tcp服务能够重传其超时时间内没有收到确认的TCP报文段,tcp模块为每一个报文段都…

【机器学习】特征选择之包裹式特征选择法

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进…

matlab实现不同窗滤波器示例

1 汉明窗低通滤波器 : 在Matlab中使用汉明窗设计低通滤波器可以通过fir1函数实现。汉明窗通常用于设计滤波器,可以提供更突出的频率特性。 下面是一个示例代码,演示如何在Matlab中使用汉明窗设计低通滤波器: % 定义滤波器参数 fs …

数据结构——lesson4带头双向循环链表实现

前言✨✨ 💥个人主页:大耳朵土土垚-CSDN博客 💥 所属专栏:数据结构学习笔记​​​​​​ 💥双链表与单链表的区分:单链表介绍与实现 💥对于malloc函数有疑问的:动态内存函数介绍 感谢大家的观看…

J1—Vivado调试技巧VIO IP

1.简介 VIO(Virtual Input/Output)IP核是一种用于FPGA设计的IP核,它可以模拟输入/输出设备的功能,如键盘、鼠标、显示器等。VIO IP核可以在FPGA设计中用于调试和验证,帮助工程师快速定位问题并进行调试。如图所示&…

vue 解决:点击左侧相同菜单,右侧页面不重新加载的问题

1、问题描述: 其一、需求为: 无论是通过路由组件形成的平台管理系统,还是通过文件配置形成的平台管理系统,都存在通过切换左侧的导航栏而使右侧的页面切换的业务需求; 其二、问题描述为: A、步骤一&#…

全国产飞腾E2000Q +复旦微FPGA的轨道交通、电力解决方案

产品概述 ITX-XMF201是一款高性能边缘计算网关主板,采用飞腾E2000Q 4核处理器,国产化率达到95%国产化。 板载2电口,2路CAN,6路RS232接口,1路RS485接口,16路GPIO,可以满足银行、轨道交通、电力等…

springboot2入门到实战 - JWT

JWT是什么? JSON Web Token (JWT) is an open standard (RFC 7519) that defines a compact and self-contained way for securely transmitting information between parties as a JSON object。 This information can be verified and trusted because it is digi…

便签软件哪个好用?好用便签怎么设置提醒?

在当今信息爆炸的时代,便签软件成为了人们生活中不可或缺的工具之一。那么,便签软件哪个好用呢?下面为您推荐几款备受好评的便签软件。首先是知名度极高的好用便签,它拥有强大的笔记功能、提醒功能和多端同步,让您随时…

COMPOSER安装使用WIN下升级PHP-V

想用TP6使用phpspreadsheet但是说我PHP版本低,原来是PHP7.0 composer要求至少7.4 直接修改环境变量,把PHP目录切换到7.4 composer升级比较简单,在PHP目录下CMD然后官网的命令执行下即可 下面就可以在TP根目录下执行命令安装PHPSPREADSHEET…

Java进阶-集合(3)与泛型

这次介绍集合中的Iterator迭代器,以及泛型。简单来说,泛型对集合的元素类型进行了限制,使用泛型可以在编译时检查类型安全,提高代码的重用率。内容如下 一、Iterator迭代器 1、概念 Iterator迭代器是一个接口,作用…

MATLAB环境下脑电信号EEG的谱分析

脑电信号一直伴随着人类的生命,脑电波是脑神经细胞发生新陈代谢、离子交换时细胞群兴奋突触电位总和,脑电信号的节律性则和丘脑相关,含有丰富的大脑活动信息。通常我们所接触的脑电图都是头皮脑电图,在有些特殊场合还需要皮下部位…

TikTok网络相关问题详解来了,附原生住宅代理IP供应商推荐,

想要迈过TikTok新手门槛,首先必须要学习的就是网络问题。很多人开始做TikTok账号或者TikTok小店时,都会遇到一些先前没有遇到的词汇和概念,比如原生IP,独享IP,甚至专线,那么一个IP可以做几个账号呢&#xf…

0粉低成本带货!职人号矩阵正成为商家的香饽饽

近年来,中国消费市场变化不断,线上消费持续上涨,线上线下一体化成为零售行业的发展新趋势。 加上抖音等平台都在大力发展本地生活,众多连锁商家、本地商家、百货商场纷纷加快数字化转型步伐,掘金线上海量流量&#xff…

机器学习:原理、应用与未来展望

第一章 是什么 机器学习(Machine Learning)是一门跨学科的学科,它使用计算机模拟或实现人类学习行为,通过不断地获取新的知识和技能,重新组织已有的知识结构,从而提高自身的性能。机器学习涉及多个学科&am…

HGAME 2024 WEEK4 WP

文章目录 IOTez7621 MISCezKeyboardMaybezip**Mondrians 🔑 REchange webReverse and Escalation. 想念21和22年的平台和week4的 6557225了 IOT ez7621 拿到固件直接binwalk解,之后grep出hgame 在usr/lib/opkg/info/kmod-flag.control找到这个&#x…