HuggingFace学习笔记--Prompt-Tuning高效微调

news2025/2/26 17:47:49

1--Prompt-Tuning介绍

        Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定;

        新增的 Prompt 内容可以分为 Hard Prompt 和 Soft Prompt 两类;

        Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成;

        Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务;

        Soft Prompt 在 peft 中一般是随机初始化prompt的文本内容,而 Hard prompt 则一般需要设置具体的提示文本内容;

2--实例代码

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit, PeftModel

# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")

# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

if __name__ == "__main__":
    # 加载数据集
    dataset = load_from_disk("./PEFT/data/alpaca_data_zh")
    
    # 处理数据
    tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)
    # print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
    # print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
    
    # 创建模型
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    
    # 设置 Prompt-Tuning
    # Soft Prompt
    # config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
    # Hard Prompt
    config = PromptTuningConfig(task_type = TaskType.CAUSAL_LM,
                                prompt_tuning_init = PromptTuningInit.TEXT,
                                prompt_tuning_init_text = "下面是一段人与机器人的对话。", # 设置hard_prompt的具体内容
                                num_virtual_tokens = len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                                tokenizer_name_or_path = "Langboat/bloom-1b4-zh")
    model = get_peft_model(model, config) # 生成Prompt-Tuning对应的model
    print(model.print_trainable_parameters())
    
    # 训练参数
    args = TrainingArguments(
        output_dir = "/tmp_1203",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        logging_steps = 10,
        num_train_epochs = 1
    )
    
    # trainer
    trainer = Trainer(
        model = model,
        args = args,
        train_dataset = tokenized_ds,
        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, padding = True)
    )
    
    # 训练模型
    trainer.train()
    
    # 模型推理
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    peft_model = PeftModel.from_pretrained(model = model, model_id = "/tmp_1203/checkpoint-500/")
    peft_model = peft_model.cuda()
    ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
    print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278617.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ELK配置记录

1. filebeat.yml配置 启动命令: ./filebeat -e -c filebeat.yml # 输入 filebeat.inputs: - type: logenabled: truepaths:- /soft/log/base.*#跨行日志正则,从有时间的开始,到下一个时间之前结束multiline.pattern: ^\[[0-9]{4}-[0-9]{2}…

Concurrent Security of Anonymous Credentials Light, Revisited

目录 摘要引言 摘要 我们重新审视了著名的匿名证书轻(ACL)方案(Baldimtsi和Lysyanskaya,CCS’13)的并发安全保证。该方案最初被证明在按顺序执行时是安全的,其并发安全性仍然是一个悬而未决的问题。Benham…

【PyTorch】线性回归

文章目录 1. 代码实现1.1 一元线性回归模型的训练 2. 代码解读2.1. tensorboardX2.1.1. tensorboardX的安装2.1.2. tensorboardX的使用 1. 代码实现 波士顿房价数据集下载 1.1 一元线性回归模型的训练 import numpy as np import torch import torch.nn as nn from torch.ut…

[计算机网络] 高手常用的几个抓包工具(上)

文章目录 高手常用的抓包工具一览什么是抓包工具优秀抓包工具WiresharkFiddlerTcpdumpCharles 高手常用的抓包工具一览 什么是抓包工具 抓包工具是一种可以捕获、分析和修改网络流量的软件。它可以帮助您进行网络调试、性能测试、安全审计等任务。 抓包工具可以实时地显示网…

使用trigger-forward跨流水线传递参数

参考文档:https://docs.gitlab.com/ee/ci/yaml/#triggerforward 今天给大家介绍一个gitlab CI/CD的关键字 - forward,该关键字是一个比较偏的功能,但同时也是一个很实用的功能,我们通过在gitlab的ci文件中使用forward关键字&#…

Android HCI日志分析案例1

案例1--蓝牙扫描设备过程分析 应用层发起搜索蓝牙设备,Android 官方提供的蓝牙扫描方式有三种,分别如下: BluetoothAdapter.startDiscovery(); //可以扫描经典蓝牙和BLE两种。BluetoothAdapter.startLeScan();//扫描低功耗蓝牙,…

深入理解同源限制:网络安全的守护者(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

超简单的node脚本,将xlsx文件转化为json

开发场景,在一个官网中,官网的设计非常简单,就是一个纯静态的页面,全网站仅一个地方调一下接口,发一下用户填写的信息到运营同学的邮箱,这些数据不会记录在数据库,我需要做一个这样的下拉框。 但…

python使用记录

1、VSCode添加多个python解释器 只需要将对应的python.exe的目录,添加到系统环境变量的Path中即可,VSCode会自动识别及添加 2、pip 使用 pip常用命令和一些坑 查看已安装库的版本号 pip show 库名称 通过git 仓库安装第三方库 pip install git仓库地…

SQL Server 2016(分离和附加数据库)

1、实验环境。 基于上一个实验《SQL Server(创建数据库)》 2、需求描述。 class数据库的数据文件和事务日志文件都位于C:\db_class目录下。现在需要把class数据库的数据文件和事务日志文件分开存放,数据文件class.mdf存放于原位置&#xff0…

DQN原理及PyTorch实现【强化学习】

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 欢迎来到我们的强化学习系列的第三部分。 在上两篇博客中,我们介绍了强化学习中…

【Linux】命令行参数

文章目录 前言一、C语言main函数的参数二、环境变量总结 前言 我们在Linux命令行输入命令的时候,一般都会跟上一些参数选项,比如l命令,ls -a -l。以前我总是觉得这是理所当然的,没深究其本质究竟是什么,今天才终于知道…

高压配电室智能运维

高压配电室智能运维是指通过运用先进的物联网、大数据、云计算等技术,对高压配电室进行智能化、远程化的运行维护,实现高压配电室的安全、高效、经济运行。以下是高压配电室智能运维的主要功能和优势: 实时监测:通过传感器和监测设…

分页助手入门以及小bug,报sql语法错误

导入坐标 5版本以上的分页助手 可以不用手动指定数据库语言&#xff0c;它会自动识别 <dependency> <groupId>com.github.pagehelper</groupId> <artifactId>pagehelper</artifactId> <version>5.3.2</version> </dependency&g…

vue中中的动画组件使用及如何在vue中使用animate.css

“< Transition >” 是一个内置组件&#xff0c;这意味着它在任意别的组件中都可以被使用&#xff0c;无需注册。它可以将进入和离开动画应用到通过默认插槽传递给它的元素或组件上。进入或离开可以由以下的条件之一触发&#xff1a; 由 v-if 所触发的切换由 v-show 所触…

OpenCV技术应用(6)— 暖色滤镜和冷色滤镜

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。本节课就手把手教大家如何将一幅图像转化成暖色滤镜和冷色滤镜&#xff0c;希望大家学习之后能够有所收获~&#xff01;&#x1f308; 目录 &#x1f680;1.技术介绍 &#x1f680;2.暖色滤镜 &#x1f680;3.冷色滤…

JVM运行时数据区域

文章目录 内存结构程序计数器&#xff08;寄存器&#xff09;虚拟机栈局部变量表两类异常状况 线程运行诊断 本地方法栈堆方法区运行时常量池串池&#xff08;StringTable&#xff09;字符串的拼接串池的位置StringTable垃圾回收StringTable性能调优 直接内存 内存结构 程序计…

关于标准库中的vector - (涉及迭代器失效,深浅拷贝,构造函数,内置类型构造函数,匿名对象)

目录 关于vector vector中的常见接口 vector常见接口的实现 迭代器失效 关于深浅拷贝 关于vector 关于vector的文档介绍 1. vector是表示可变大小数组的序列容器。 2. 就像数组一样&#xff0c;vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元…

【ArcGIS Pro微课1000例】0040:ArcGIS Pro创建北极点、南极点

文章目录 一、创建北极点图层二、创建北极点三、不同投影系下北极点的位置一、创建北极点图层 选择一个数据库,在上面右键→新建→要素类。 输入名称:北极点。 空间参考:WGS 1984 点击创建。 二、创建北极点 在编辑选项卡下,点击【创建】。 在创建要素窗口中,点击北极点…

docker容器中创建非root用户

简介 用 docker 也有一段时间了&#xff0c;一直在 docker 容器中使用 root 用户肆意操作。直到部署 stable diffusion webui 我才发现无法使用 root 用户运行它&#xff0c;于是才幡然醒悟&#xff1a;是时候搞个非 root 用户了。 我使用的 docker 镜像文件是 centos:centos…