文章目录
- 0、prompt-tuning基本原理
- 1、实战
- 1.1、导包
- 1.2、加载数据
- 1.3、数据预处理
- 1.4、创建模型
- 1.5、Prompt Tuning*
- 1.5.1、配置文件
- 1.5.2、创建模型
- 1.6、配置训练参数
- 1.7、创建训练器
- 1.8、模型训练
- 1.9、推理:加载预训练好的模型
0、prompt-tuning基本原理
prompt-tuning的基本思想就是冻结主模型的全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示向量,即一个Embedding模块。其中,prompt又存在两种形式,一种是hard prompt,一种是soft prompt。
1、实战
1.1、导包
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
1.2、加载数据
ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")
1.3、数据预处理
tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):
MAX_LENGTH = 256
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
response = tokenizer(example["output"] + tokenizer.eos_token)
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = instruction["attention_mask"] + response["attention_mask"]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds
1.4、创建模型
model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh",low_cpu_mem_usage=True)
1.5、Prompt Tuning*
1.5.1、配置文件
#soft prompt
# config = PromptTuningConfig(
# task_type=TaskType.CAUSAL_LM,
# num_virtual_tokens=10,
# )
# config
#hard prompt
config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init = PromptTuningInit.TEXT,
prompt_tuning_init_text = '下面是一段机器人的对话:',
num_virtual_tokens=len(tokenizer('下面是一段机器人的对话:')['input_ids']),
tokenizer_name_or_path='../Model/bloom-389m-zh',
)
config
1.5.2、创建模型
model= get_peft_model(model,config)
model
打印模型训练参数
model.print_trainable_parameters()
1.6、配置训练参数
args = TrainingArguments(
output_dir="./chatbot",
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
logging_steps=10,
num_train_epochs=1
)
1.7、创建训练器
trainer = Trainer(
args=args,
model=model,
train_dataset=tokenized_ds,
data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True, )
)
1.8、模型训练
trainer.train()
1.9、推理:加载预训练好的模型
from peft import PeftModel
peft_model = PeftModel.from_pretrained(model=model,model_id='./chat_bot/checkpoint500/')
from transformers import pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, temperature=0.5)