大模型微调---Prompt-tuning微调

news2024/12/22 20:28:41

目录

    • 一、前言
    • 二、Prompt-tuning实战
      • 2.1、下载模型到本地
      • 2.2、加载模型与数据集
      • 2.3、处理数据
      • 2.4、Prompt-tuning微调
      • 2.5、训练参数配置
      • 2.6、开始训练
    • 三、模型评估
    • 四、完整训练代码

一、前言

Prompt-tuning通过修改输入文本的提示(Prompt)来引导模型生成符合特定任务或情境的输出,而无需对模型的全量参数进行微调。
在这里插入图片描述
Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定,其核心在于将下游任务转化为预训练任务

在这里插入图片描述
新增的 Prompt 内容可以分为 Hard PromptSoft Prompt 两类:

  • Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成;
  • Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务;

Soft Prompt 在 peft 中一般是随机初始化prompt的文本内容,而 Hard prompt 则一般需要设置具体的提示文本内容;

对于不同任务的Prompt的构建示例如下:
在这里插入图片描述

例如,假设我们有兴趣将英语句子翻译成德语。我们可以通过各种不同的方式询问模型,如下图所示。

1)“Translate the English sentence ‘{english_sentence}’ into German: {german_translation}”
2)“English: ‘{english sentence}’ | German: {german translation}”
3)“From English to German:‘{english_sentence}’-> {german_translation}”

上面说明的这个概念被称为硬提示调整

软提示调整(soft prompt tuning)将输入标记的嵌入与可训练张量连接起来,该张量可以通过反向传播进行优化,以提高目标任务的建模性能。

例如下方伪代码:

# 定义可训练的软提示参数
# 假设我们有 num_tokens 个软提示 token,每个 token 的维度为 embed_dim
soft_prompt = torch.nn.Parameter(
    torch.rand(num_tokens, embed_dim)  # 随机初始化软提示向量
)

# 定义一个函数,用于将软提示与原始输入拼接
def input_with_softprompt(x, soft_prompt):
    # 假设 x 的维度为 (batch_size, seq_len, embed_dim)
    # soft_prompt 的维度为 (num_tokens, embed_dim)
    # 将 soft_prompt 在序列维度上与 x 拼接
    # 拼接后的张量维度为 (batch_size, num_tokens + seq_len, embed_dim)
    x = concatenate([soft_prompt, x], dim=seq_len)
    return x

# 将包含软提示的输入传入模型
output = model(input_with_softprompt(x, soft_prompt))

  1. 软提示参数:

使用 torch.nn.Parameter 将随机初始化的向量注册为可训练参数。这意味着在训练过程中,soft_prompt 中的参数会随梯度更新而优化。

  1. 拼接输入:

函数 input_with_softprompt 接收原始输入 x(通常是嵌入后的 token 序列)和 soft_prompt 张量。通过 concatenate(伪代码中使用此函数代指张量拼接操作),将软提示向量沿着序列长度维度与输入拼接在一起。

  1. 传递给模型:

将包含软提示的输入张量传给模型,以引导模型在执行特定任务(如分类、生成、QA 等)时更好地利用这些可训练的软提示向量。

二、Prompt-tuning实战

预训练模型与分词模型——Qwen/Qwen2.5-0.5B-Instruct
数据集——lyuricky/alpaca_data_zh_51k

2.1、下载模型到本地

# 下载数据集
dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
ds = load_dataset("./data/alpaca_data", split="train")

# 下载分词模型
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
tokenizer.save_pretrained("./local_tokenizer_model")

#下载与训练模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径
    torch_dtype="auto",
    low_cpu_mem_usage=True,
    cache_dir="./local_model_cache"  # 指定本地缓存目录
)

2.2、加载模型与数据集

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")

# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")

# 记载模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",
    torch_dtype="auto",
    device_map="cuda:0")

2.3、处理数据

"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    if example["output"] is not None:
        response = tokenizer(example["output"] + tokenizer.eos_token)
    else:
        return
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)

2.4、Prompt-tuning微调

soft Prompt

# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化

Hard Prompt

# Hard Prompt
prompt = "下面是一段人与机器人的对话。"
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text=prompt,
                            num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),
                            tokenizer_name_or_path="../local_tokenizer_model")

加载peft配置

peft_model = get_peft_model(model, config)

print(peft_model.print_trainable_parameters())

在这里插入图片描述
可以看到要训练的模型相比较原来的全量模型要少很多

2.5、训练参数配置

# 配置模型参数
args = TrainingArguments(
    output_dir="chatbot",   # 训练模型的输出目录
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1,
)

2.6、开始训练

# 创建训练器
trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer_model, padding=True )
)
# 开始训练
trainer.train()

可以看到 ,损失有所下降

在这里插入图片描述

三、模型评估

# 模型推理
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model = AutoModelForCausalLM.from_pretrained("../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775", low_cpu_mem_usage=True)
peft_model = PeftModel.from_pretrained(model=model, model_id="./chatbot/checkpoint-643")
peft_model = peft_model.cuda()

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")
ipt = tokenizer_model("Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(
    peft_model.device)
print(tokenizer_model.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

print("-----------------")
#预训练的管道流
# 构建prompt
ipt = "Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: "
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer_model)
output = pipe(ipt, max_length=256, do_sample=True, truncation=True)
print(output)

训练了一轮,感觉效果不大,可以增加训练轮数试试
在这里插入图片描述

四、完整训练代码


from datasets import load_dataset
from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, \
    DataCollatorForSeq2Seq, Trainer

# 下载数据集
# dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
# ds = load_dataset("./data/alpaca_data", split="train")
# print(ds[0])

# 下载分词模型
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
# tokenizer.save_pretrained("./local_tokenizer_model")

#下载与训练模型
# model = AutoModelForCausalLM.from_pretrained(
#     pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径
#     torch_dtype="auto",
#     low_cpu_mem_usage=True,
#     cache_dir="./local_model_cache"  # 指定本地缓存目录
# )

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")

# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")

# 记载模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",
    torch_dtype="auto",
    device_map="cuda:0")


# 处理数据
"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    if example["output"] is not None:
        response = tokenizer(example["output"] + tokenizer.eos_token)
    else:
        return
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)




prompt = "下面是一段人与机器人的对话。"

# prompt-tuning
# Soft Prompt
# config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text=prompt,
                            num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),
                            tokenizer_name_or_path="../local_tokenizer_model")

peft_model = get_peft_model(model, config)

print(peft_model.print_trainable_parameters())

# 训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

# 创建训练器
trainer = Trainer(model=peft_model, args=args, train_dataset=tokenized_ds,
                  data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_model, padding=True))


# 开始训练
trainer.train()




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2263880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何使用 WebAssembly 扩展后端应用

1. WebAssembly 简介 随着互联网的发展,越来越多的应用借助 Javascript 转到了 Web 端,但人们也发现,随着移动互联网的兴起,需要把大量的应用迁移到手机端,随着手端的应用逻辑越来越复杂,Javascript 的解析…

python学习——洛谷P2010 [NOIP2016 普及组] 回文日期 三种方法

[NOIP2016 普及组] 回文日期 文章目录 [NOIP2016 普及组] 回文日期题目背景题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 样例 #2样例输入 #2样例输出 #2 提示方法一方法二方法三 题目背景 NOIP2016 普及组 T2 题目描述 在日常生活中,通过年、月、日这…

前端yarn工具打包时网络连接问题排查与解决

最近线上前端打包时提示 “There appears to be trouble with your network connection”,以此文档记录下排查过程。 前端打包方式 docker启动临时容器打包,命令如下 docker run --rm -w /app -v pwd:/app alpine-node-common:v16.20-pro sh -c "…

BenchmarkSQL使用教程

1. TPC-C介绍 Transaction Processing Performance Council (TPC) 事务处理性能委员会,是一家非盈利IT组织,他们的目的是定义数据库基准并且向产业界推广可验证的数据库性能测试。而TPC-C最后一个C代表的是压测模型的版本,在这之前还有TPC-A、…

Linux网络基础--传输层Tcp协议(上) (详细版)

目录 Tcp协议报头: 4位首部长度: 源端口号和目的端口号 32位序号和确认序号 标记位 超时重传机制: 两个问题 连接管理机制 三次握手,四次挥手 建立连接,为什么要有三次握手? 先科普一个概念&…

全志H618 Android12修改doucmentsui鼠标单击图片、文件夹选中区域

背景: 由于当前的文件管理器在我们的产品定义当中,某些界面有改动的需求,所以需要在Android12 rom中进行定制以符合当前产品定义。 需求: 在进入File文件管理器后,鼠标左击整个图片、整个文件夹可以选中该类型,进行操作,故代码分析以及客制化如下: 主要涉及的代码:…

Unity命令行传递自定义参数 命令行打包

命令行参数增加位置 -executeMethod 某脚本.某方法 参数1 参数2 参数3 ... 例如执行EditorTest.GetCommandLineArgs方法 增加两个命令行参数 Version=125 CDNVersion=100 -executeMethod EditorTest.GetCommandLineArgs Version=125 CDNVersion=100 Unity测试脚本 需要放在…

如何重新设置VSCode的密钥环密码?

故障现象: 忘记了Vscode的这个密码: Enter password to unlock An application wants access to the keyring “Default ke... Password: The unlock password was incorrect Cancel Unlock 解决办法: 1.任意terminal下,输入如下…

电子发票汇总改名,批量处理电子发票问题

今天给大家推荐一个财务方面工作的软件。可以帮你解决很多财务。发票方面的问题。 电子发票汇总改名 批量处理电子发票问题 这个软件安装之后。会在桌面上分成三个小软件,分别是修改单位信息、自定义命名规则和电子发票汇总改名。 你可以在这个软件里提取PDF或者of…

Linux——卷

Linux——卷 介绍 最近做的项目,涉及到对系统的一些维护,有些盘没有使用,需要创建逻辑盘并挂载到指定目录下。有些软件需要依赖空的逻辑盘(LVM)。 先简单介绍一下卷的一些概念,有分区、物理存储介质、物…

MySQL通用语法 -DDL、DML、DQL、DCL

SQL 全称 Structured Query Language,结构化查询语言。操作关系型数据库的编程语言,定义了 一套操作关系型数据库统一标准 。 SQL通用语法 MySQL语言的通用语法。 SQL语句可以单行或多行书写,以分号结尾。SQL语句可以使用空格/缩进来增强…

利用DnslogSqlinj工具DNSlog注入

工具下载链接 https://github.com/adooo/dnslogsqlinj 配置 将域名和API进行一个更改 之后再安装两个python2的库就可以使用dnslog进行自动化注入了 python2安装pip2 curl https://bootstrap.pypa.io/2.7/get-pip.py -o get-pip.py python2 get-pip.py库 pip2 install geven…

QT网络(一):主机信息查询

网络简介 在QT中进行网络通信可以使用QT提供的Qt Network模块,该模块提供了用于编写TCP/IP网络应用程序的各种类,如用于TCP通信的QTcpSocket和 QTcpServer,用于 UDP 通信的 QUdpSocket,还有用于网络承载管理的类,以及…

STM32-笔记5-按键点灯(中断方法)

1、复制03-流水灯项目,重命名06-按键点灯(中断法) 在\Drivers\BSP目录下创建一个文件夹exti,在该文件夹下,创建两个文件exti.c和exti.h文件,并且把这两个文件加载到项目中,打开项目工程文件 加载…

重塑医院挂号体验:SSM 与 Vue 搭建的预约系统设计与实现

4系统概要设计 4.1概述 本系统采用B/S结构(Browser/Server,浏览器/服务器结构)和基于Web服务两种模式,是一个适用于Internet环境下的模型结构。只要用户能连上Internet,便可以在任何时间、任何地点使用。系统工作原理图如图4-1所示: 图4-1系统工作原理…

mysql的事务控制和数据库的备份和恢复

事务控制语句 行锁和死锁 行锁 两个客户端同时对同一索引行进行操作 客户端1正常运行 客户端2想修改,被锁行 除非将事务提交才能继续运行 死锁 客户端1删除第5行 客户端2设置第1行为排他锁 客户端1删除行1被锁 客户端2更新行5被锁 如何避免死锁 mysql的备份和还…

C# OpenCV机器视觉:尺寸测量

转眼就是星期一了,又到了阿强该工作的时候了!阿强走进了他作为机器视觉工程师的办公室,准备迎接新一天的挑战。随着周末的结束,他心中暗想:“如果我能让机器像我一样聪明,那就太好了!” 正当他…

四川托普信息技术职业学院教案1

四川托普信息技术职业学院教案 【计科系】 周次 第 1周,第1次课 备 注 章节名称 第1章 XML语言简介 引言 1.1 HTML与标记语言 1.2 XML的来源 1.3 XML的制定目标 1.4 XML概述 1.5 有了HTML了,为什么还要发展XML 1.5.1 HTML的缺点 1.5.2 XML的特点 1.6 X…

网络安全防范

实践内容 学习总结 PDR,$$P^2$$DR安全模型。 防火墙(Firewall): 网络访问控制机制,布置在网际间通信的唯一通道上。 不足:无法防护内部威胁,无法阻止非网络传播形式的病毒,安全策略…

GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录

文章目录 论文背景Spectre-PHT(Transient Execution )Concurrency BugsSRC/SCUAF和实验条件 流程Creating an Unbounded UAF WindowCrafting Speculative Race ConditionsExploiting Speculative Race Conditions poc修复flush and reload 论文 https:/…