基于prefix tuning + Bert的标题党分类器

news2025/1/13 2:51:16

文章目录

  • 背景
  • 一、Prefix-Tuning介绍
  • 二、分类
  • 三、效果
  • 四、参阅

背景

近期, CSDN博客推荐流的标题党博客又多了起来, 先前的基于TextCNN版本的分类模型在语义理解上能力有限, 于是, 便使用的更大的模型来优化, 最终准确率达到了93.7%, 还不错吧.

一、Prefix-Tuning介绍

传统的fine-tuning是在大规模预训练语言模型(如Bert、GPT2等)上完成的, 针对不同的下游任务, 需要保存不同的模型参数, 代价比较高,

解决这个问题的一种自然方法是轻量微调(lightweight fine-tunning),它冻结了大部分预训练参数,并用小的可训练模块来增强模型,比如在预先训练的语言模型层之间插入额外的特定任务层。适配器微调(Adapter-tunning)在自然语言理解和生成基准测试上具有很好的性能,通过微调,仅添加约2-4%的任务特定参数,就可以获得类似的性能。

受到prompt的启发,提出一种prefix-tuning, 只需要保存prefix部分的参数即可.

相关论文: Prefix-Tuning: Optimizing Continuous Prompts for Generation

请添加图片描述

论文中, 作者使用Prefix-tuning做生成任务,它根据不同的模型结构定义了不同的Prompt拼接方式.

请添加图片描述

对于自回归模型,加入前缀后的模型输入表示:

z = [PREFIX; x; y]

对于编解码器结构的模型,加入前缀后的模型输入表示:

z = [PREFIX; x; PREFIX; y]

原理部分不过多赘述, 对于Prompt不太熟悉的同学, 一定要看看王嘉宁老师写的综述: Prompt-Tuning——深度解读一种新的微调范式

与Prefix-Tuning类似的方法还有P-tuning V2,不同之处在于Prefix-Tuning是面向文本生成领域的,P-tuning V2面向自然语言理解.

二、与Bert对比

Bert

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

因为我这里是二分类, 所以最后一层多了个Linear层

Prefix Tuning Bert

PeftModelForSequenceClassification(
  (base_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (classifier): ModulesToSaveWrapper(
      (original_module): Linear(in_features=768, out_features=2, bias=True)
      (modules_to_save): ModuleDict(
        (default): Linear(in_features=768, out_features=2, bias=True)
      )
    )
  )
  (prompt_encoder): ModuleDict(
    (default): PrefixEncoder(
      (embedding): Embedding(20, 18432)
    )
  )
  (word_embeddings): Embedding(21128, 768, padding_idx=0)
)

比Bert多了prompt_encoder, 查阅peft源码可以看到PrefixEncoder的实现:

class PrefixEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        encoder_hidden_size = config.encoder_hidden_size
        num_virtual_tokens = config.num_virtual_tokens
        if self.prefix_projection and not config.inference_mode:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
            self.transform = torch.nn.Sequential(
                torch.nn.Linear(token_dim, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
        else:
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.transform(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

其实就是多了个embedding层

二、分类

好在huggingface的peft库对几个经典的prompt tuning都有封装, 实现起来并不难:



import logging
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
    PromptTuningConfig,
    LoraConfig,
)
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from common.path.model.bert import get_chinese_roberta_wwm_ext

import os
import time
import torch
import math
from torch.utils.data import DataLoader
from datasets import load_dataset

import evaluate
from tqdm import tqdm
from server.tag.common.data_helper import TagClassifyDataHelper

logger = logging.getLogger(__name__)



class BlogTitleAttractorBertConfig(object):
    max_length = 32
    batch_size = 128
    p_type = "prefix-tuning"
    model_name_or_path = get_chinese_roberta_wwm_ext()
    lr = 3e-4
    num_epochs = 50
    num_labels = 2
    device = "cuda"   # cuda/cpu
    evaluate_every = 100
    print_every = 10


def get_tokenizer(model_config):
    if any(k in model_config.model_name_or_path for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"
    tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side=padding_side)
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer


def get_model(model_config, peft_config):
    model = AutoModelForSequenceClassification.from_pretrained(model_config.model_name_or_path, num_labels=model_config.num_labels)
    if model_config.p_type is not None:
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    return model


def get_peft_config(model_config):
    p_type = model_config.p_type
    if p_type == "prefix-tuning":
        peft_type = PeftType.PREFIX_TUNING
        peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
    elif p_type == "prompt-tuning":
        peft_type = PeftType.PROMPT_TUNING
        peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
    elif p_type == "p-tuning":
        peft_type = PeftType.P_TUNING
        peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=20, encoder_hidden_size=128)
    elif p_type == "lora":
        peft_type = PeftType.LORA
        peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
    else:
        print(f"p_type:{p_type} is not supported.")
        return None, None
    logger.info(f"训练: {p_type} bert 模型")
    return peft_type, peft_config


def get_lr_scheduler(model_config, model, data_size):
    optimizer = AdamW(params=model.parameters(), lr=model_config.lr)

    # Instantiate scheduler
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0.06 * (data_size * model_config.num_epochs),
        num_training_steps=(data_size * model_config.num_epochs),
    )
    return optimizer, lr_scheduler



class BlogTtileAttractorBertTrain(object):
    def __init__(self, config, options) -> None:
        self.config = config
        self.options = options
        self.model_path = "./data/models/blog_title_attractor/prefix-tuning"
        self.data_file = './data/datasets/blogs/title_attractor/train.csv'
        self.tag_path = "./data/models/blog_title_attractor/prefix-tuning/tag.txt"
        self.model_config = BlogTitleAttractorBertConfig()
        self.tag_dict, self.id2tag_dict = TagClassifyDataHelper.load_tag(self.tag_path)
        self.model_config.num_labels = len(self.tag_dict)

        _, peft_config = get_peft_config(self.model_config)
        self.tokenizer = get_tokenizer(self.model_config)
        self.model = get_model(self.model_config, peft_config)
        self.model.to(self.model_config.device)

    def replace_none(self, example):
        example["title"] = example["title"] if example["title"] is not None else ""
        example["label"] = self.tag_dict[example["label"].lower()]
        return example

    def load_data(self):
        # 加载数据集
        dataset = load_dataset("csv", data_files=self.data_file)
        dataset = dataset.filter(lambda x: x["title"] is not None)
        dataset = dataset["train"].train_test_split(0.2, seed=123)
        dataset = dataset.map(self.replace_none)
        return dataset
    
    def process_function(self, examples):
        tokenized_examples = self.tokenizer(examples["title"], truncation=True, max_length=self.model_config.max_length)
        return tokenized_examples
    
    def collate_fn(self, examples):
        return self.tokenizer.pad(examples, padding="longest", return_tensors="pt")

    def __call__(self):
        datasets = self.load_data()
        tokenized_datasets = datasets.map(self.process_function, batched=True, remove_columns=['title'])
        tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
        train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=self.collate_fn, batch_size=self.model_config.batch_size)
        eval_dataloader = DataLoader(
            tokenized_datasets["test"], shuffle=False, collate_fn=self.collate_fn, batch_size=self.model_config.batch_size
        )
        optimizer, lr_scheduler = get_lr_scheduler(self.model_config, self.model, len(train_dataloader))
        metric = evaluate.load("accuracy")
        start_time = time.time()
        best_loss = math.inf
        for epoch in range(self.model_config.num_epochs):
            self.model.train()
            for step, batch in enumerate(train_dataloader):
                batch.to(self.model_config.device)
                outputs = self.model(**batch)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                if step > 0 and step % self.model_config.print_every == 0:
                    now_time = time.time()
                    predictions = outputs.logits.argmax(dim=-1)
                    eval_metric = metric.compute(predictions=predictions, references=batch["labels"])
                    if loss < best_loss:
                        best_loss = loss
                        self.model.save_pretrained(self.model_path)

                        print("***TRAIN, steps:{%d}, loss:{%.3f}, accuracy:{%.3f}, cost_time:{%.3f}h" % (
                            step, loss, eval_metric["accuracy"], (now_time - start_time) / 3600))
                    else:
                        print("TRAIN, steps:{%d}, loss:{%.3f}, accuracy:{%.3f}, cost_time:{%.3f}h" % (
                            step, loss, eval_metric["accuracy"], (now_time - start_time) / 3600))
                if step > 0 and step % self.model_config.evaluate_every == 0:
                    self.model.eval()
                    total_loss = 0.
                    for step, batch in enumerate(tqdm(eval_dataloader)):
                        batch.to(self.model_config.device)
                        with torch.no_grad():
                            outputs = self.model(**batch)
                            loss = outputs.loss
                            total_loss += loss
                        predictions = outputs.logits.argmax(dim=-1)
                        predictions, references = predictions, batch["labels"]
                        metric.add_batch(
                            predictions=predictions,
                            references=references,
                        )

                    eval_metric = metric.compute()
                    print("epoch:%d, VAL, loss:%.3f, accuracy:%.3f" % (epoch, total_loss, eval_metric["accuracy"]))

三、效果

Prefix Tuning Bert

在标题党上的准确率为: 0.9368932038834952
在非标题党上的准确率为: 0.937015503875969

Bert

在标题党上的准确率为: 0.7006472491909385
在非标题党上的准确率为: 0.9728682170542635

直接微调的效果与Prefix Tuning 相比差距有点大, 理论上不应该有这么大差距才对, 应该和数据集有关系

四、参阅

  • Prompt-Tuning——深度解读一种新的微调范式
  • Continuous Optimization:从Prefix-tuning到更强大的P-Tuning V2
  • [代码学习]Huggingface的peft库学习-part 1- prefix tuning

温馨提示: 标题党文章在推荐流会被降权哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/634443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis之Redisson原理详解

文章目录 1 Redisson1.1 简介1.2 与其他客户端比较1.3 操作使用1.3.1 pom.xml1.3.2 配置1.3.3 启用分布式锁 1.4 大致操作原理1.5 RLock1.5.1 RLock如何加锁1.5.2 解锁消息1.5.3 锁续约1.5.4 流程概括 1.6 公平锁1.6.1 java中公平锁1.6.2 RedissonFairLock1.6.3 公平锁加锁步骤…

50 Projects 50 Days - Form Input Wave 学习记录

项目地址 Form Input Wave 展示效果 Form Input Wave 实现思路 简单的登陆界面结构&#xff0c;只是在输入框聚焦时标题提示文字会有一个字母逐渐向上跳动的动画效果&#xff0c;这需要针对每个字符单独设置变换的延时&#xff0c;可以考虑在JavaScript中处理这部分逻辑&am…

2017~2018学年《信息安全》考试试题(A1卷)

北京信息科技大学 2017 ~2018 学年第二学期《信息安全》考试试题 (A 卷) 课程所在学院:计算机学院 适用专业班级:计科 1504-6、重修 考试形式:闭卷 一、单选题(本题满分 20 分&#xff0c;共含 10 道小题&#xff0c;每小题 2 分) 网络安全是指网络系统的硬件、软件及( C )的…

【头歌-Python】Python第九章作业(初级)第5关

第5关&#xff1a;绘制程序设计语言饼图 任务描述 列表labels和sizes中的数据分别是目前主流程序设计语言及其热度数据&#xff08;百分比&#xff09;&#xff0c;请根据这些数据绘制饼图&#xff0c;并将Python程序设计语言所在区域突出0.1显示。 labels [C语言, Python…

Java ~ Reference ~ ReferenceQueue【总结】

前言 文章 相关系列&#xff1a;《Java ~ Reference【目录】》&#xff08;持续更新&#xff09;相关系列&#xff1a;《Java ~ Reference ~ ReferenceQueue【源码】》&#xff08;学习过程/多有漏误/仅作参考/不再更新&#xff09;相关系列&#xff1a;《Java ~ Reference ~ …

一篇就能学懂的散列表,让哈希表数据结构大放光彩

目录 1.散列表的基本概念 2.散列表的查找 3.散列函数的构造方法 1.直接定址法 2.除留余数法 4.散列表解决冲突的方法 1.开放定址法 2.链地址法 1.散列表的基本概念 基本思想&#xff1a;记录的存储位置与关键字之间存在的对应关系 对应关系——hash函数 Loc(i) H(k…

关于外包被开要怎么维护自己的权益

我一直以为外包被开都是没有任何赔偿的&#xff0c;之前网上对于外包的消息都是说&#xff0c;没有任何赔偿或者是怕麻烦然后就干脆放弃了的各种评论。。。但是最近我在问到一个朋友的时候&#xff0c;他很好的维护了自己的权益。最后获得了N1 保留证据 当被告知外包需要你离场…

牛客网语法篇刷题(C语言) — 运算

作者主页&#xff1a;paper jie的博客_CSDN博客-C语言,算法详解领域博主 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文录入于《C语言-语法篇》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白…

单链表OJ题:LeetCode--138.复制带随即指针的链表

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下LeetCode中第138道单链表OJ题&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; 数据结构与算法专栏&#xff1a;数据结构与算法 个 人…

python使用Faker库进行生成模拟mock数据(基本使用+五个小案例)

使用faker进行生成模拟(mock))数据 文章目录 使用faker进行生成模拟(mock))数据一、Faker库安装二、Faker库基本介绍三、案例1&#xff1a;Faker库生成核酸数据四、案例2&#xff1a;生成不重复的人名和地名五、案例3&#xff1a;生成有时间期限的低保数据六、案例4&#xff1a…

01-Vue 项目环境搭建和创建准备工作

一. 学习目标 掌握 Vue 项目创建的依赖环境掌握 Vue 项目创建过程 二. 学习内容 掌握搭建 Vue 项目准备环境掌握 Vue 项目创建过程了解 Vue 项目各子目录 三. 学习过程 1. 准备工作 &#xff08;1&#xff09;安装Node.js 打开node.js官网&#xff1a;Node.js &#xff0…

【状态估计】无迹卡尔曼滤波(UKF)应用于FitzHugh-Nagumo神经元动力学研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

详解ASP.NET Core 在 IIS 下的两种部署模式

KestrelServer最大的优势体现在它的跨平台的能力&#xff0c;如果ASP.NET CORE应用只需要部署在Windows环境下&#xff0c;IIS也是不错的选择。ASP.NET CORE应用针对IIS具有两种部署模式&#xff0c;它们都依赖于一个IIS针对ASP.NET CORE Core的扩展模块。 一、ASP.NET CORE C…

UML类图入门

UML类图入门 UML是一个通用的可视化建模描述语言&#xff0c;通过图形符号和文字来对系统进行建模。适用于各种软件的开发方法、生命周期的各个阶段。 类的UML图示 类使用包含类型、属性和操作&#xff08;方法&#xff09;且带有分割线的长方形来表示&#xff0c;如&#x…

人际关系的学习改进

表达的目的&#xff1a;让别人对你感兴趣 不要有苦劳而无功劳 爱的五种语言&#xff1a;表达爱的语言 人类存在的中心&#xff0c;是渴望和人亲近&#xff0c;被人所爱。婚姻即是被设计满足这种亲密关系和爱的需求的&#xff1b;把注意力集中在情绪健康所需的那片爱土上&…

【C++ 程序设计】第 4 章:运算符重载

目录 一、运算符重载的概念 &#xff08;1&#xff09;重载运算符的概念 ① 重载运算符的概念 ② 可重载的运算符 ③ 不可重载的运算符 ④ 运算符的优先级 &#xff08;2&#xff09;重载运算符为类的成员函数 &#xff08;3&#xff09;重载运算符为友元函数 &#…

【Linux】Docker部署镜像环境 (持续更新ing)

防火墙 1、查看防火墙状态 sudo systemctl status ufw 2、开启防火墙 sudo systemctl start ufw 3、关闭防火墙 sudo systemctl stop ufw 4、开机禁止开启防火墙 sudo systemctl disabled ufw 5、开启自启防火墙 sudo systemctl enabled ufw Elasticsearch 1、安装指定版本 比…

使用Pillow库轻松实现图像尺寸调整——>使每个图像具有相同的大小,方便模型处理和训练

在计算机视觉领域,对图像进行尺寸调整是一项非常常见的操作。在训练深度神经网络时,因为计算资源和内存限制的原因,我们通常需要将图像缩放到相同的尺寸。 在本文中,我们将介绍如何使用Python中的Pillow库对图像进行尺寸调整,并提供一个示例程序resize_images。 1. Pytho…

VulnHub靶场-Chronos

目录 0x01 声明&#xff1a; 0x02 简介&#xff1a; 0x03 环境准备&#xff1a; 0x04 信息收集&#xff1a; 1、主机发现 2、NMAP扫描 3、访问业务 4、目录探测 5、查看网页代码 0x05 渗透测试过程&#xff1a; 1、Burp Suite抓包 2、构造payload 测试是否可以外联…

CSS基础学习--5 background背景

一、介绍&#xff1a; CSS 背景属性用于定义HTML元素的背景。 CSS 属性定义背景效果: background-color 背景颜色background-image 背景图片background-repeatbackground-attachmentbackground-position 二、属性 2.1、background-color 属性定义了元素的背景颜色 <s…