Langchain-Chatchat项目:4.1-P-Tuning v2实现过程

news2025/1/23 17:32:47

  常见参数高效微调方法(Parameter-Efficient Fine-Tuning,PEFT)有哪些呢?主要是Prompt系列和LoRA系列。本文主要介绍P-Tuning v2微调方法。如下所示:

  • Prompt系列比如,Prefix Tuning(2021.01-Stanford)、Prompt Tuning(2021.09-Google)、P-Tuning(2021.03-Tsinghua)、P-Tuning v2(2022.03-Tsinghua);
  • LoRA系列比如,LoRA(2021.11-Microsoft)、AdaLoRA(2023.03-Microsoft)、QLoRA(2023.05-Washington)。
  • 还有不知道如何分类的比如,BitFit、Adapter Tuning及其变体、MAM Adapter、UniPELT等。


一.P-Tuning v2工作原理
1.Hard/Soft Prompt-Tuning如何设计
  提示工程发展经过了从人工或半自动离散空间的hard prompt设计,到采用连续可微空间soft prompt设计的过程,这样的好处是可通过端到端优化学习不同任务对应的prompt参数。
2.P-Tuning工作原理和不足
  主要是将continuous prompt应用于预训练模型的输入层,预训练模型后面的每一层都没有合并continuous prompt。

3.P-Tuning v2如何解决P-Tuning不足
  P-Tuning v2把continuous prompt应用于预训练模型的每一层,而不仅仅是输入层。


二.P-Tuning v2实现过程
1.整体项目结构
  源码参考文献[4],源码结构如下所示:

参数解释如下所示:
(1)–model_name_or_path L:/20230713_HuggingFaceModel/20231004_BERT/bert-base-chinese:BERT模型路径
(2)–task_name qa:任务名字
(3)–dataset_name squad:数据集名字
(4)–do_train:训练过程
(5)–do_eval:验证过程
(6)–max_seq_length 128:最大序列长度
(7)–per_device_train_batch_size 2:每个设备训练批次大小
(8)–learning_rate 5e-3:学习率
(9)–num_train_epochs 10:训练epoch数量
(10)–pre_seq_len 128:前缀序列长度
(11)–output_dir checkpoints/SQuAD-bert:检查点输出目录
(12)–overwrite_output_dir:覆盖输出目录
(13)–hidden_dropout_prob 0.1:隐藏dropout概率
(14)–seed 11:种子
(15)–save_strategy no:保存策略
(16)–evaluation_strategy epoch:评估策略
(17)–prefix:P-Tuning v2方法
执行代码如下所示:

python3 run.py --model_name_or_path L:/20230713_HuggingFaceModel/20231004_BERT/bert-base-chinese --task_name qa --dataset_name squad --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 2 --learning_rate 5e-3 --num_train_epochs 10 --pre_seq_len 128 --output_dir checkpoints/SQuAD-bert --overwrite_output_dir --hidden_dropout_prob 0.1 --seed 11 --save_strategy no --evaluation_strategy epoch --prefix

2.代码执行流程
(1)P-tuning-v2/run.py

  • 根据task_name=="qa"选择tasks.qa.get_trainer
  • 根据get_trainer得到trainer,然后训练、评估和预测

(2)P-tuning-v2/tasks/qa/get_trainer.py

  • 得到config、tokenizer、model、squad数据集、QuestionAnsweringTrainer对象trainer
  • 重点关注model是如何得到的
# fix_bert表示不更新bert参数,model数据类型为BertPrefixForQuestionAnswering
model = get_model(model_args, TaskType.QUESTION_ANSWERING, config, fix_bert=True)
  • 重点关注QuestionAnsweringTrainer具体实现
trainer = QuestionAnsweringTrainer(  # 读取trainer
    model=model,  # 模型
    args=training_args,  # 训练参数
    train_dataset=dataset.train_dataset if training_args.do_train else None,  # 训练集
    eval_dataset=dataset.eval_dataset if training_args.do_eval else None,  # 验证集
    eval_examples=dataset.eval_examples if training_args.do_eval else None,  # 验证集
    tokenizer=tokenizer,  # tokenizer
    data_collator=dataset.data_collator,  # 用于将数据转换为batch
    post_process_function=dataset.post_processing_function,  # 用于将预测结果转换为最终结果
    compute_metrics=dataset.compute_metrics,  # 用于计算评价指标
)

(3)P-tuning-v2/model/utils.py
选择P-tuning-v2微调方法,返回BertPrefixForQuestionAnswering模型,如下所示:

def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False):
    if model_args.prefix:  # 训练方式1:P-Tuning V2(prefix=True)
        config.hidden_dropout_prob = model_args.hidden_dropout_prob  # 0.1
        config.pre_seq_len = model_args.pre_seq_len  # 128
        config.prefix_projection = model_args.prefix_projection  # False
        config.prefix_hidden_size = model_args.prefix_hidden_size  # 512
        # task_type是TaskType.QUESTION_ANSWERING,config.model_type是bert,model_class是BertPrefixForQuestionAnswering
        model_class = PREFIX_MODELS[config.model_type][task_type]
        # model_args.model_name_or_path是bert-base-chinese,config是BertConfig,revision是main
        model = model_class.from_pretrained(model_args.model_name_or_path, config=config, revision=model_args.model_revision,)

(4)P-tuning-v2/model/question_answering.py(重点)
主要是BertPrefixForQuestionAnswering(BertPreTrainedModel)模型结构,包括构造函数、前向传播和获取前缀信息。
(5)P-tuning-v2/model/prefix_encoder.py(重点)
BertPrefixForQuestionAnswering(BertPreTrainedModel)构造函数中涉及到前缀编码器PrefixEncoder(config)
(6)P-tuning-v2/training/trainer_qa.py
继承关系为QuestionAnsweringTrainer(ExponentialTrainer)->ExponentialTrainer(BaseTrainer)->BaseTrainer(Trainer)->Trainer,最核心训练方法如下所示:

3.P-tuning-v2/model/prefix_encoder.py实现
  该类作用主要是根据前缀prefix信息对其进行编码,假如不考虑batch-size,那么编码后的shape为(prefix-length, 2*layers*hidden)。假如prefix-length=128,layers=12,hidden=768,那么编码后的shape为(128,2*12*768)。

class PrefixEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection  # 是否使用MLP对prefix进行投影
        if self.prefix_projection:  # 使用两层MLP对prefix进行投影
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:  # 直接使用Embedding进行编码
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:  # 使用MLP对prefix进行投影  
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:  # 不使用MLP对prefix进行投影
            past_key_values = self.embedding(prefix)
        return past_key_values

  这里面可能会有疑问,为啥还要乘以2呢?因为past_key_values前半部分要和key_layer拼接,后半部分要和value_layer拼接,如下所示:

key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

  说明:代码路径为transformers/models/bert/modeling_bert.py->class BertSelfAttention(nn.Module)的forward()函数中

4.P-tuning-v2/model/question_answering.py
  简单理解,BertPrefixForQuestionAnswering就是在BERT上添加了PrefixEncoder,get_prompt功能主要是生成past_key_values,即前缀信息的编码表示,用于与主要文本序列一起输入BERT模型,以帮助模型更好地理解问题和提供答案。因为选择的SQuAD属于抽取式QA数据集,即根据question从context中找到answer的开始和结束位置即可。

class BertPrefixForQuestionAnswering(BertPreTrainedModel):
    def __init__(self, config):
        self.bert = BertModel(config, add_pooling_layer=False)  # bert模型
        self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)  # 线性层
        self.prefix_encoder = PrefixEncoder(config)  # 前缀编码器

    def get_prompt(self, batch_size):  # 根据前缀token生成前缀的编码,即key和value值
        past_key_values = self.prefix_encoder(prefix_tokens)
        past_key_values = past_key_values.view(
            bsz,                 # batch_size
            seqlen,              # pre_seq_len
            self.n_layer * 2,    # n_layer表示BERT模型的层数
            self.n_head,         # n_head表示注意力头的数量
            self.n_embd          # n_embd表示每个头的维度
        )
        return past_key_values

    def forward(self, ..., return_dict=None):
        past_key_values = self.get_prompt(batch_size=batch_size)  # 获取前缀信息
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        outputs = self.bert(
            ......
            past_key_values=past_key_values,
        )
        return QuestionAnsweringModelOutput(  # 返回模型输出,包括loss,开始位置的logits,结束位置的logits,hidden states和attentions
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

  重点是outputs = self.bert(past_key_values=past_key_values),将past_key_values传入BERT模型中,起作用的主要是transformers/models/bert/modeling_bert.py->class BertSelfAttention(nn.Module)的forward()函数中。接下来看下past_key_values数据结构,如下所示:

5.BertSelfAttention实现
  BERT网络结构参考附件1,past_key_values主要和BertSelfAttention部分中的key和value进行拼接,如下所示:

(self): BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

  具体past_key_values和key、value拼接实现参考代码,如下所示:

  经过BertSelfAttention部分后,输出outputs的shape和原始输入的shape是一样的,即都不包含前缀信息。

附件1:BERT网络结构
  打印出来BERT模型结构,如下所示:

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) #embeddings层做了LayerNorm
    (dropout): Dropout(p=0.1, inplace=False) #embeddings层做了Dropout
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer( #BertLayer包括BertAttention、BertIntermediate和BertOutput
        (attention): BertAttention( #BertAttention包括BertSelfAttention和BertSelfOutput
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

  BERT模型相关类结构在文件D:\Python310\Lib\site-packages\transformers\models\bert\modeling_bert.py中,如下所示:

附件2:SQuAD数据集
  SQuAD是斯坦福大学推出的机器阅读理解问答数据集,其中每个问题的答案来自于对应阅读段落的一段文本,即(问题,原文,答案)。一共有107,785问题,以及配套的536篇文章。除了SQuAD 1.1之外,还推出了难度更大的新版本SQuAD 2.0(《Know What You Don’t Know: Unanswerable Questions for SQuAD》_ACL2018)。
(1)训练集数据

(2)验证集数据

(3)加载SQuAD数据集

"""
执行脚本:python3 dataset_test.py --model_name_or_path L:/20230713_HuggingFaceModel/20231004_BERT/bert-base-chinese --task_name qa --dataset_name squad --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 2 --learning_rate 5e-3 --num_train_epochs 10 --pre_seq_len 128 --output_dir checkpoints/SQuAD-bert --overwrite_output_dir --hidden_dropout_prob 0.1 --seed 11 --save_strategy no --evaluation_strategy epoch --prefix
"""
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments

from arguments import get_args, ModelArguments, DataTrainingArguments, QuestionAnwseringArguments
from tasks.qa.dataset import SQuAD

if __name__ == '__main__':
    args = get_args()  # 从命令行获取参数

    model_args, data_args, training_args, qa_args = args  # model_args是模型相关参数,data_args是数据相关的参数,training_args是训练相关的参数
    tokenizer = AutoTokenizer.from_pretrained(  # 读取tokenizer
            model_args.model_name_or_path,  # 模型名称
            revision=model_args.model_revision,  # 模型版本
            use_fast=True,  # 是否使用fast tokenizer
        )
    dataset = SQuAD(tokenizer, data_args, training_args, qa_args)
    print(dataset)

  打个断点看下dataset数据结构如下所示:


  • input_ids:经过tokenizer分词后的subword对应的下标列表
  • attention_mask:在self-attention过程中,这一块mask用于标记subword所处句子和padding的区别,将padding部分填充为0
  • token_type_ids:标记subword当前所处句子(第一句/第二句/ padding)
  • position_ids:标记当前词所在句子的位置下标
  • head_mask:用于将某些层的某些注意力计算无效化
  • inputs_embeds:如果提供了,那就不需要input_ids,跨过embedding lookup过程直接作为Embedding进入Encoder计算
  • encoder_hidden_states:这一部分在BertModel配置为decoder时起作用,将执行cross-attention而不是self-attention
  • encoder_attention_mask:同上,在cross-attention中用于标记encoder端输入的padding
  • past_key_values:在P-Tuning V2中会用到,主要是把前缀编码和预训练模型每层的key、value进行拼接。
  • use_cache:将保存上一个参数并传回,加速decoding
  • output_attentions:是否返回中间每层的attention输出
  • output_hidden_states:是否返回中间每层的输出
  • return_dict:是否按键值对的形式返回输出,默认为真。

  觉得P-Tuning v2里面还有很多知识点没有讲解清楚,只能后续逐个讲解。仅仅一个P-Tuning v2仓库代码涉及的知识点非常之多,首要就是把Transformer和BERT标准网络结构非常熟悉,还有对各种任务及其数据集要熟悉,对BERT变体网络结构要熟悉,对于PyTorch和Transformer库的深度学习模型训练、验证和测试流程要熟悉,对于Prompt系列微调方法要熟悉。总之,对于各种魔改Transformer和BERT要了如指掌。

参考文献:
[1]P-Tuning论文地址:https://arxiv.org/pdf/2103.10385.pdf
[2]P-Tuning代码地址:https://github.com/THUDM/P-tuning
[3]P-Tuning v2论文地址:https://arxiv.org/pdf/2110.07602.pdf
[4]P-Tuning v2代码地址:https://github.com/THUDM/P-tuning-v2
[5]BertLayer及Self-Attention详解:https://zhuanlan.zhihu.com/p/552062991
[6]https://rajpurkar.github.io/SQuAD-explorer/
[7]https://huggingface.co/datasets/squad

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenGL_Learn04

我这边并不是教程&#xff0c;只是学习记录&#xff0c;方便后面回顾&#xff0c;代码均是100%可以运行成功的。 1. 渐变三角形 #include <glad/glad.h> #include <GLFW/glfw3.h>#include <iostream> #include <cmath>void framebuffer_size_callba…

科学计数法 [极客大挑战 2019]BuyFlag1

打开题目 注意中说&#xff0c;我们需要买flag&#xff0c;首先必须是cuit的学生&#xff0c;其次必须输对正确的密码 查看源代码得到 代码审计 首先&#xff0c;检查是否存在名为 password 的POST请求。 如果 password 存在&#xff0c;将其存储在变量 $password 中。 然后…

你没有见过的 git log 风格

背景 git大家都不陌生&#xff0c;git log 也是大家经常用的指令&#xff0c;今天分享三种 git log的美化格式&#xff0c;大家看看哪种更易读。 git log -15 --graph --decorate --oneline 带有 pretty 格式的git log 风格 log --color --graph --prettyformat:‘%Cred%h%C…

【音视频 | Ogg】RFC3533 :Ogg封装格式版本 0(The Ogg Encapsulation Format Version 0)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

Vue3入门指南:零基础小白也能轻松理解的学习笔记

文章目录 创建项目开发环境项目目录模板语法属性绑定条件渲染列表渲染事件处理内联事件处理器方法事件处理器&#xff08;常用&#xff09; 事件参数获取 event 事件事件传参 事件修饰符阻止默认事件阻止事件冒泡 数组变化侦测变更方法替换一个数组 计算属性class 绑定单对象绑…

汽车标定技术(一):XCP概述

目录 1.汽车标定概述 2.XCP协议由来及版本介绍 3.XCP技术通览 3.1 XCP上下机通信模型 3.2 XCP指令集 3.2.1 XCP帧结构定义 3.2.2 标准指令集 3.2.3 标定指令集 3.2.4 页切换指令集 3.2.5 数据采集指令集 3.2.6 刷写指令集 3.3 ECU描述文件(A2L)概述 3.3.1 标定上位…

无限上下文,多级内存管理!突破ChatGPT等大语言模型上下文限制

目前&#xff0c;ChatGPT、Llama 2、文心一言等主流大语言模型&#xff0c;因技术架构的问题上下文输入一直受到限制&#xff0c;即便是Claude 最多只支持10万token输入&#xff0c;这对于解读上百页报告、书籍、论文来说非常不方便。 为了解决这一难题&#xff0c;加州伯克利…

物联网AI MicroPython传感器学习 之 QMC5883指南针罗盘传感器

学物联网&#xff0c;来万物简单IoT物联网&#xff01;&#xff01; 一、产品简介 QMC5883是一款表面贴装的集成了信号处理电路的三轴磁性传感器&#xff0c;应用场景主要包括罗盘、导航、无人机、机器人和手持设备等一些高精度的场合。 引脚定义 VCC&#xff1a;3V3&#…

【Java 进阶篇】Java ServletContext详解:获取MIME类型

MIME&#xff08;Multipurpose Internet Mail Extensions&#xff09;类型是一种标识文件类型的文本标签&#xff0c;通常用于指示浏览器如何处理Web服务器返回的文件。在Java Web应用程序中&#xff0c;ServletContext对象提供了一种方便的方法来获取文件的MIME类型。本篇博客…

【实战Flask API项目指南】之五 RESTful API设计

实战Flask API项目指南之 RESTful API设计 本系列文章将带你深入探索实战Flask API项目指南&#xff0c;通过跟随小菜的学习之旅&#xff0c;你将逐步掌握 Flask 在实际项目中的应用。让我们一起踏上这个精彩的学习之旅吧&#xff01; 前言 当小菜踏入Flask后端开发的世界时…

Redis-命令操作Redis

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《Spring与Mybatis集成整合》《Vue.js使用》 ⛺️ 越努力 &#xff0c;越幸运。 1.Redis简介 1.1.什么是Redis Redis是一个开源&#xff08;BSD许可&#xff09;&#xff0c;内存存储的数据…

费用预算管理系统

费用预算管理系统 1. 模块概述 《费用管理》以企业费用管理为核心&#xff0c;围绕费用支出审批流程&#xff0c;从费用发生前的事前申请&#xff0c;报销单据审批、付款单据审批&#xff0c;再到出纳付款、会计记账等所有工作流程都在系统中全员、协同完成&#xff1b;并且能…

el-table中的el-input标签修改值,但界面未更新,解决方法

el-table中的el-input标签修改值&#xff0c;界面未更新 在el-table中的el-input里面写的change事件根本不触发&#xff0c;都不打印&#xff0c;试了网络上各种方法都没用 然后换成input事件&#xff0c;input事件会触发&#xff0c;但界面也未更新。我在触发事件的时候&…

微信小程序之开发工具介绍

一、微信小程序开发工具下载 微信小程序开发工具下载可以参考这篇博客《微信小程序开发者工具下载-CSDN博客》 二、开发工具组成部分 如下图所示&#xff0c;开发者工具主要由菜单栏、工具栏、模拟器、编辑器和调试器 5 个部分组成。。 1、菜单栏 菜单栏中主要包括项目、文…

听GPT 讲Rust源代码--library/std(13)

题图来自 Decoding Rust: Everything You Need to Know About the Programming Language[1] File: rust/library/std/src/os/horizon/raw.rs 在Rust源代码中&#xff0c;rust/library/std/src/os/horizon/raw.rs这个文件的作用是为Rust的标准库提供与Horizon操作系统相关的原始…

STM32HAL-完全解耦面向对象思维的架构-时间轮片法使用(timeslice)

目录 概述 一、开发环境 二、STM32CubeMx配置 三、编码 四、运行结果 五、代码解释 六、总结 概述 timeslice是一个时间片轮询框架&#xff0c;完全解耦的时间片轮询框架&#xff0c;非常适合裸机单片机引用。接下来将该框架移植到stm32单片机运行&#xff0c;单片机…

王道计算机网络

一、计算机网络概述 (一)计算机网络基本概念 计算机网络的定义、组成与功能 定义&#xff1a;以能够相互共享资源的方式互连起来的自治计算机系统的集合。 目的&#xff1a;资源共享&#xff0c; 组成单元&#xff1a;自治、互不影响的计算机 网络协议 从不同角度计算机网络…

【Python入门二】安装第三方库(包)

安装第三方库/包 1 使用pip安装2 使用PyCharm软件安装3 离线安装&#xff0c;使用whl文件安装参考 在Python中&#xff0c;有多种安装第三方库的方法&#xff0c;下面是一些常用的方法&#xff1a; 1 使用pip安装 pip是Python中最常用的包管理工具&#xff0c;也是最常用的在线…

PASCAL VOC 格式

文章目录 ImageSets 文件夹Main 文件夹:Segmentation 文件夹:Layout 文件夹:Action 文件夹: Annotations 文件夹主要标签&#xff1a;物体标签&#xff1a; SegmentationClass 文件夹SegmentationObject 文件夹 PASCAL VOC&#xff08;Visual Object Classes&#xff09;是一个…

计算流体动力学(CFD)软件

CFD&#xff0c;英语全称 (Computational Fluid Dynamics&#xff09;&#xff0c;即计算流体动力学。CFD 是近代流体力学&#xff0c;数值数学和计算机科学结合的产物&#xff0c;是一门具有强大生命力的交叉科学。它是将流体力学的控制方程中积分、微分项近似地表示为离散的代…