【LLM】Prompt tuning大模型微调实战

news2024/9/24 3:25:44

文章目录

  • 一、Propmt tuning
    • 1. peft库中的tuning
    • 2. prompt tuning怎么搞
  • 二、Prompt tuning代码实战
    • 1. tuning训练
    • 2. 模型推理比较
    • 3. 其他tuning技术
  • Reference

一、Propmt tuning

1. peft库中的tuning

  • 之前提到过可以借助peft库(Parameter-Efficient Fine-Tuning)进行微调,支持如下tuning:
    • Adapter Tuning(固定原预训练模型的参数 只对新增的adapter进行微调)
    • Prefix Tuning(在输入token前构造一段任务相关的virtual tokens作为prefix,训练时只更新Prefix不分的参数,而Transformer的其他不分参数固定,和构造prompt类似,只是prompt是人为构造的即无法在模型训练时更新参数,而Prefix可以学习<隐式>的prompt)
    • Prompt Tuning(Prefix Tuning的简化版,只在输入层加入prompt tokens,并不需要加入MLP)
    • P-tuning(将prompt转为可学习的embedding层,v2则加入了prompts tokens作为输入)
    • LoRA(Low-Rank Adaption,为了解决adapter增加模型深度而增加模型推理时间、上面几种tuning中prompt较难训练,减少模型的可用序列长度)
      • 该方法可以在推理时直接用训练好的AB两个矩阵和原预训练模型的参数相加,相加结果替换原预训练模型参数。
      • 相当于用LoRA模拟full-tunetune过程

2. prompt tuning怎么搞

  • 给出好的prompt可以让LLM生成更好的答案,反过来想通过LLM帮我们找到好的prompt就是prompt tuning的思路,训练让模型看到新的例子生成prompt,并把该段prompt作为前缀拼接到我们自己的prompt上,送入LLM得到结果
    • prompt tuning训练的前缀是向量,所以解释性略差
  • 和few show比较:LLM的上下文context长度是有限的(prompt中给出有限的例子,业务复杂时难让模型学习足够多知识),prompt tuning就没有这个限制,只需在训练LLM时给他看足够多例子,之后提问带上一个短的prompt前缀(一般8~20个token)即可
  • 和fine tuning比较:prompt tuning是完全冻结LLM模型参数,只需训练一个几个token的prompt前缀;但是fine tuning精调一个模型很耗资源
  • 为每一个任务额外添加一个或多个embedding,之后拼接query正常输入LLM,并只训练这些embedding。如下图,左图为单任务全参数微调,右图为prompt tuning。
    • prompt tuning将fine tune任务转为mlm任务。自动学习模板:离散的主要包括 Prompt Mining, Prompt Paraphrasing, Gradient-based Search, Prompt Generation 和 Prompt Scoring;连续的则主要包括Prefix Tuning, Tuning Initialized with Discrete Prompts 和 Hard-Soft Prompt Hybrid Tuning。
    • 正常微调举例:[cls]今天天上都出太阳了,阳光明媚。[SEP]
      prompt输入举例:[cls]今天天气是[MASK]。[SEP] 今天天上都出太阳了,阳光明媚。[SEP]

在这里插入图片描述

二、Prompt tuning代码实战

1. tuning训练

  • 数据:twitter_complaints
  • 模型:bigscience/bloomz-560m模型
  • PromptTuningConfig设置Prompt tuning配置,下面num_virtual_tokens设置prompt前缀的token数,因为token初始化用任务相关文字效果更好,所以下面用Classify if the tweet is a complaint or not:初始化,
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author : andy
@Date   : 2023/7/10 20:37
@Contact: 864934027@qq.com 
@File   : prompt_tuning.py 
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "mps"
# device = "cuda"
model_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "bigscience/bloomz-560m"
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    tokenizer_name_or_path=tokenizer_name_or_path,
)

dataset_name = "twitter_complaints"
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
learning_rate = 3e-2
num_epochs = 20
batch_size = 8
output_dir = './output'

# 1. load a subset of the RAFT dataset at https://huggingface.co/datasets/ought/raft
dataset = load_dataset("ought/raft", dataset_name)

# get lable's possible values
label_values = [name.replace("_", "") for name in dataset["train"].features["Label"].names]
# append label value to the dataset to make it more readable
dataset = dataset.map(
    lambda x: {label_column: [label_values[label] for label in x["Label"]]},
    batched=True,
    num_proc=1
)
# have a look at the data structure
dataset["train"][0]

在这里插入图片描述

# 2. dataset
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

def preprocess_fn(examples):
    tweets = examples[text_column]
    # pad labels with a pad token at the end
    labels = [str(x) + tokenizer.pad_token for x in examples[label_column]]
    # concatenate the tweet with it label
    inputs = [f"{text_column} : {tweet}\nLabel :{label}"
              for tweet, label in zip(tweets, labels)]
    # tokenize input
    model_inputs = tokenizer(inputs,
                           padding='max_length',
                           max_length=max_length,
                           truncation=True,)
    # tokenize label, as -100 not a valid token id, do the padding manually here
    labels_input_ids = []
    for i in range(len(labels)):
        ids = tokenizer(labels[i])["input_ids"]
        padding = [-100] * (max_length - len(ids))
        labels_input_ids.append(padding + ids)
        model_inputs["labels"] = labels_input_ids
        # make model inputs tensor
        model_inputs["input_ids"] = [torch.tensor(ids) for ids in model_inputs["input_ids"]]
        model_inputs["attention_mask"] = [torch.tensor(ids) for ids in model_inputs["attention_mask"]]
        model_inputs["labels"] = [torch.tensor(ids) for ids in model_inputs["labels"]]

    return model_inputs

# have a look at the preprocessing result
# print(preprocess_fn(dataset["train"][:2]))

processed_datasets = dataset.map(
    preprocess_fn,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names, #remove unprocessed column for training
    load_from_cache_file=False,
    desc="Running tokenizer on datasset"
)

test_size = round(len(processed_datasets["train"]) * 0.2)
train_val = processed_datasets["train"].train_test_split(
    test_size=test_size, shuffle=True, seed=42)
train_data = train_val["train"]
val_data = train_val["test"]


# 3. model
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())
trainable params: 8192 || all params: 559222784 || trainable%: 0.0014648902430985358

从上面打印结果看出,模型的参数有5.6亿左右,但是需要训练的参数只占0.001%,只有8192个。

# 4. trainer
from transformers import Trainer, TrainingArguments
trainer = Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=default_data_collator,
    args=TrainingArguments(
      output_dir='./output',
      per_device_train_batch_size=batch_size,
      num_train_epochs=num_epochs,
      learning_rate=learning_rate,
      load_best_model_at_end=True,
      logging_strategy='steps',
      logging_steps=10,
      evaluation_strategy='steps',
      eval_steps=10,
      save_strategy='steps',
      save_steps=10,
    )
  )
trainer.train()

在这里插入图片描述

2. 模型推理比较

# 5. inference
def  inference():
    def generate(inputs, infer_model):
        with torch.no_grad():
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = infer_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=20,
                eos_token_id=3
            )
            print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

    # (1) base model_inference
    base_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    base_model.to(device)
    inputs = tokenizer(
        f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
        return_tensors="pt",  # Return PyTorch torch.Tensor objects.
    )
    generate(inputs, base_model)
    print("----------------------------------------")
    shot1 = f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"}\nLabel :complaint\n'
    shot2 = f'{text_column} : {"@HMRCcustomers No this is my first job"}\nLabel :no complaint\n'
    input = f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :'
    inputs_few_shot = tokenizer(
        shot1 + shot2 + input,
        return_tensors="pt",
    )
    generate(inputs_few_shot, base_model)

    # (2) prompt-tuned model_inference
    from peft import PeftModel, PeftConfig
    path = "/content/drive/MyDrive/prompt_tuning"
    config = PeftConfig.from_pretrained(path)
    pretrained_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    prompt_tuned_model = PeftModel.from_pretrained(pretrained_model, path)
    prompt_tuned_model.to(device)
    inputs = tokenizer(
        f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
        return_tensors="pt",  # Return PyTorch torch.Tensor objects.
    )
    generate(inputs, prompt_tuned_model)

inference()
  • 上面base model推理结果:
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label : @denny the grocery<?php
/**
 * Copyright © 2016 Google Inc.

----------------------------------------
Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this?
Label :complaint
Tweet text : @HMRCcustomers No this is my first job
Label :no complaint
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label :complaint<?php
/**
 * Copyright © Magento, Inc. All rights reserved.
  • prompt-tuned model推理结果:
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label :complaint

3. 其他tuning技术

在这里插入图片描述

  • prefix tuning和prompt tuning都不需要改LLM模型参数本身,但prefix tuning不进在用户输入该层找到一个前缀,还要在LLM的每层都找到一个前缀并加上,训练成本明显更高
  • p-tuning则不仅可在用户输入的开头加附加信息,也可以在中间或结尾附加信息
  • lora tuning如下图,上一篇博客也讲过

在这里插入图片描述

Reference

[1] https://github.com/jxhe/unify-parameter-efficient-tuning
[2] Continuous Optimization:从Prefix-tuning到更强大的P-Tuning V2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/739270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于MeanShift的图像滤波方法

前言 在视觉领域中&#xff0c;图像滤波是一种常用的技术&#xff0c;用于去除图像中的噪声和平滑图像。其中&#xff0c;MeanShift滤波是一种基于颜色和空间信息的非参数化滤波算法。 MeanShift滤波原理 MeanShift滤波是一种基于密度估计的非参数化滤波技术&#xff0c;它对每…

【Go】Go 语言教程--GO语言数组(十一)

往期回顾&#xff1a; Go 语言教程–介绍&#xff08;一&#xff09;Go 语言教程–语言结构&#xff08;二&#xff09;Go 语言教程–语言结构&#xff08;三&#xff09;Go 语言教程–数据类型&#xff08;四&#xff09;Go 语言教程–语言变量&#xff08;五&#xff09;Go …

希尔排序及其时间复杂度(图文详解)

&#x1f63e; 博客主页: 爱吃bug的猿 &#x1f680;博客专栏: 数据结构,C语言初阶进阶全流程讲解 &#x1f63d;&#x1f63d;&#x1f63d;如果喜欢博主的文章&#xff0c;可以给博主点波赞和关注加速博主更新 文章目录 前言1. 代码思路代码实现法1代码实现法2&#xff08;不…

【KingbaseES】数据库如何查询数据库,模式及表大小

新建数据kingbase及kingbase模式 CREATE DATABASE kingbase OWNER kingbase; CREATE SCHEMA kingbase AUTHORIZATION "kingbase";在数据库kingbase的kingbase模式下新建两张测试表test_size,test_size1并插入数据 CREATE TABLE "kingbase"."test_sz…

课时9:PKI证书基础知识

快速链接: . 👉👉👉 个人博客笔记导读目录(全部) 👈👈👈 付费专栏-付费课程 【购买须知】:Secureboot从入门到精通-[目录] 👈👈👈目录 1、CA证书的一般用法2、x509证书3、openssl证书命令行4、X509v3证书格式

第十五章 原理篇:YOLOv8

找工作也太难了吧根本找不到工作我哭死。 参考教程&#xff1a; https://mmyolo.readthedocs.io/en/latest/recommended_topics/algorithm_descriptions/yolov8_description.html https://zhuanlan.zhihu.com/p/599761847【这个写的挺不错】 https://zhuanlan.zhihu.com/p/63…

Linux进程监控及控制【命令ps的使用】

ps命令详解 Linux中的ps命令是Process Status的缩写。ps命令用来列出系统中当前运行的那些进程。ps命令列出的是当前那些进程的快照&#xff0c;就是执行ps命令的那个时刻的那些进程&#xff0c;如果想要动态的显示进程信息&#xff0c;就可以使用top命令。 要对进程进行监测…

EcoVadis 2023年最新评分细则

【EcoVadis 2023年最新评分细则】 Ecovadis 的四大主题 EcoVadis 企业社会责任评级方法的目标是通过其方针政策、实施执行和绩效反馈来衡量一家公司的企业社会责任管理系统的质量。 EcoVadis企业社会责任&#xff08;CSR&#xff09;评估方法基于七项基本原则&#xff08;如图&…

Blazor前后端框架Known-V1.2.3

V1.2.3 Known是基于C#和Blazor开发的前后端分离快速开发框架&#xff0c;开箱即用&#xff0c;跨平台&#xff0c;一处代码&#xff0c;多处运行。 Gitee&#xff1a; https://gitee.com/known/KnownGithub&#xff1a;https://github.com/known/Known 概述 基于C#和Blazor…

1.SpringBoot编写第一个接口(保姆级)

1.下载SpringBoot框架 下载地址&#xff1a;https://start.spring.io/ 选择对应的springboot 版本&#xff0c;工具&#xff0c;依赖等。 2.用Idea打开项目 下载完后&#xff0c;解压文件后&#xff0c;用Idea打开,进行项目的JDK和Maven的相关配置。 将项目的JDK配置成自己…

数据结构与算法——最小生成树问题(什么是最小生成树、Prim算法、Kruskal算法)

目录 什么是最小生成树 贪心算法 Prim算法 思路 代码&#xff08;C语言&#xff09; Kruskal算法 思路 代码 什么是最小生成树 1.是一颗树 无回路个顶点一定有条边 2.是生成树 包含全部顶点条边都在图里 3.边的权重和最小 向生成树中任加一条边都一定构成回路 最…

前端vue入门(纯代码)25_路由vueRouter

如果耐不住寂寞&#xff0c;你就看不到繁华。 【23.Vue Router--路由】 [可以去官网看看Vue Router文档](入门 | Vue Router (vuejs.org)) 用 Vue Vue Router 创建单页应用非常简单&#xff1a;通过 Vue.js&#xff0c;我们已经用组件组成了我们的应用。当加入 Vue Router …

Linux基础服务9——lamt架构

文章目录 一、基本了解二、安装tomcat三、 常用文件3.1 bin目录文件3.2 conf目录文件3.2.1 server.xml文件3.2.2 tomcat-users.xml 文件 四、分离部署lamt架构4.1 安装httpd4.2 安装mysql4.3 部署tomcat4.4 配置apache 一、基本了解 前提背景&#xff1a; Tomcat是Apache 软件基…

Git开发项目完整流程使用(图文超详解)

前世今生 自2002年开始&#xff0c;林纳斯托瓦兹&#xff08;Linus Torvalds&#xff09;决定使用BitKeeper作为Linux内核主要的版本控制系统用以维护代码。因为BitKeeper为专有软件&#xff0c;这个决定在社群中长期遭受质疑。在Linux社群中&#xff0c;特别是理查德斯托曼与…

springboot电子招投标系统

开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven…

MySQL基础篇第5章(排序与分页)

文章目录 1、排序1.1 排序规则1.2 单列排序1.3 多列排序 2、分页2.1 背景2.2 实现规则2.3 拓展 1、排序 1.1 排序规则 使用 ORDER BY 子句排序 ASC&#xff08;ascend&#xff09;: 升序DESC&#xff08;descend&#xff09;:降序 ORDER BY 子句在SELECT语句的结尾。 1.2 …

lumpy_sv的安装

目前利用lumpy-SV进行鉴定SV是比较常用的软件&#xff0c;但是其依赖python2.7版本的环境 ## 创建python 2.7版本conda环境 conda create -n python27 python2.7 ## 进入py27环境 conda activate python27 ## 下载lumpy软件并安装 git clone --recursive https://github.com/a…

深度学习——搭建神经网络的两种方式

方法一&#xff1a;定义神经网络类&#xff0c;然后实例化 代码&#xff1a; import torch import torch.nn.functional as F"""方法一&#xff1a;定义神经网络类&#xff0c;然后再实例化 """ # 神经网络类 class Net(torch.nn.Module):def …

CRYPTO-36D-justShow

0x00 前言 CTF 加解密合集&#xff1a;CTF 加解密合集 0x01 题目 hlcgoyfsjqknyifnpd:bcdefghijklmnopqrstuvwxyza0x02 Write Up 这道题也是完全击中了我的软肋&#xff0c;用到了新的加密方式&#xff0c;并且感觉到自己对这种类型的题目并不是很敏感。 首先看到hlcgoyfs…

基于matlab自动检测图像中的圆形目标并可视化检测到的圆(附源码)

一、前言 此示例说明如何自动检测图像中的圆或圆形目标并可视化检测到的圆。 二、实现步骤 步骤 1&#xff1a;加载图像 读取并显示包含各种颜色的圆形塑料片的图像。除了有大量要检测的圆之外&#xff0c;从圆检测的角度来看&#xff0c;此图像还有一些有趣的特点&#xf…