昇思大模型平台打卡体验活动:项目4基于MindSpore实现Roberta模型Prompt Tuning

news2024/12/27 10:55:11

基于MindNLP的Roberta模型Prompt Tuning

本文档介绍了如何基于MindNLP进行Roberta模型的Prompt Tuning,主要用于GLUE基准数据集的微调。本文提供了完整的代码示例以及详细的步骤说明,便于理解和复现实验。

环境配置

在运行此代码前,请确保MindNLP库已经安装。本文档基于大模型平台运行,因此需要进行适当的环境配置,确保代码可以在相应的平台上运行。

模型与数据集加载

在本案例中,我们使用 roberta-large 模型并基于GLUE基准数据集进行Prompt Tuning。GLUE (General Language Understanding Evaluation) 是自然语言处理中的标准评估基准,包括多个子任务,如句子相似性匹配、自然语言推理等。Prompt Tuning是一种新的微调技术,通过插入虚拟的“提示”Token在模型的输入中,以微调较少的参数达到较好的性能。

import mindspore
from tqdm import tqdm
from mindnlp import evaluate
from mindnlp.dataset import load_dataset
from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from mindnlp.core.optim import AdamW
from mindnlp.transformers.optimization import get_linear_schedule_with_warmup
from mindnlp.peft import (
    get_peft_model,
    PeftType,
    PromptTuningConfig,
)

1. 定义训练参数

首先,定义模型名称、数据集任务名称、Prompt Tuning类型、训练轮数等基本参数。

batch_size = 32
model_name_or_path = "roberta-large"
task = "mrpc"
peft_type = PeftType.PROMPT_TUNING
num_epochs = 20

2. 配置Prompt Tuning

在Prompt Tuning的配置中,选择任务类型为"SEQ_CLS"(序列分类任务),并定义虚拟Token的数量。虚拟Token即为插入模型输入中的“提示”Token,通过这些Token的微调,使得模型能够更好地完成下游任务。

peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
lr = 1e-3

3. 加载Tokenizer

根据模型类型选择padding的侧边,如果模型为GPT、OPT或BLOOM类模型,则从序列左侧填充(padding),否则从序列右侧填充。

if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
    padding_side = "left"
else:
    padding_side = "right"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

4. 加载数据集

通过MindNLP加载GLUE数据集,并打印样本以便确认数据格式。在此示例中,我们使用GLUE的MRPC(Microsoft Research Paraphrase Corpus)任务,该任务用于句子匹配,即判断两个句子是否表达相同的意思。

datasets = load_dataset("glue", task)
print(next(datasets['train'].create_dict_iterator()))

5. 数据预处理

为了适配MindNLP的数据处理流程,我们定义了一个映射函数 MapFunc,用于将句子转换为 input_idsattention_mask,并对数据进行padding处理。

from mindnlp.dataset import BaseMapFunction

class MapFunc(BaseMapFunction):
    def __call__(self, sentence1, sentence2, label, idx):
        outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)
        return outputs['input_ids'], outputs['attention_mask'], label

def get_dataset(dataset, tokenizer):
    input_colums=['sentence1', 'sentence2', 'label', 'idx']
    output_columns=['input_ids', 'attention_mask', 'labels']
    dataset = dataset.map(MapFunc(input_colums, output_columns),
                          input_colums, output_columns)
    dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})
    return dataset

train_dataset = get_dataset(datasets['train'], tokenizer)
eval_dataset = get_dataset(datasets['validation'], tokenizer)

6. 设置评估指标

我们使用 evaluate 模块加载评估指标(accuracy 和 F1-score)来评估模型的性能。

metric = evaluate.load("./glue.py", task)

7. 加载模型并配置Prompt Tuning

加载 roberta-large 模型,并根据配置进行Prompt Tuning。可以看到,微调的参数量仅为总参数量的0.3%左右,节省了大量计算资源。

model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

模型微调(Prompt Tuning)

在Prompt Tuning中,训练过程中仅微调部分参数(主要是虚拟Token相关的参数),相比于传统微调而言,大大减少了需要调整的参数量,使得模型能够高效适应下游任务。

1. 优化器与学习率调整

使用 AdamW 优化器,并设置线性学习率调整策略。

optimizer = AdamW(params=model.parameters(), lr=lr)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0.06 * (len(train_dataset) * num_epochs),
    num_training_steps=(len(train_dataset) * num_epochs),
)

2. 训练逻辑定义

训练步骤如下:

  1. 构建正向计算函数 forward_fn
  2. 定义梯度计算函数 grad_fn
  3. 定义每一步的训练逻辑 train_step
  4. 遍历数据集进行训练和评估,在每个 epoch 结束时,计算评估指标。
def forward_fn(**batch):
    outputs = model(**batch)
    loss = outputs.loss
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, tuple(model.parameters()))

def train_step(**batch):
    loss, grads = grad_fn(**batch)
    optimizer.step(grads)
    return loss

for epoch in range(num_epochs):
    model.set_train()
    train_total_size = train_dataset.get_dataset_size()
    for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):
        loss = train_step(**batch)
        lr_scheduler.step()

    model.set_train(False)
    eval_total_size = eval_dataset.get_dataset_size()
    for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):
        outputs = model(**batch)
        predictions = outputs.logits.argmax(axis=-1)
        predictions, references = predictions, batch["labels"]
        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    eval_metric = metric.compute()
    print(f"epoch {epoch}:", eval_metric)

在每个 epoch 后,程序输出当前模型的评估指标(accuracy 和 F1-score)。从结果中可以看到,模型的准确率和 F1-score 会随着训练的进展逐渐提升。
7797b4532920b53cb41371e07cfa81c6.png
7797b4532920b53cb41371e07cfa81c6.png

总结

本案例通过Prompt Tuning技术,在Roberta模型上进行了微调以适应GLUE数据集任务。通过控制微调参数量,Prompt Tuning展示了较强的高效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2238254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

后悔没早点知道,Coze 插件 + Cursor 原来可以这样赚钱

最近智能体定制化赛道异常火爆。 打开闲鱼搜索"Coze 定制",密密麻麻的服务报价直接刷屏,即使表明看起来几十块的商家,一细聊,都是几百到上千不等的报价。 有趣的是,这些智能体定制化服务背后,最核心的不只是工作流设计,还有一个被很多人忽视的重要角色 —— …

基于STM32的节能型路灯控制系统设计

引言 本项目基于STM32微控制器设计了一个智能节能型路灯控制系统,通过集成多个传感器模块和控制设备,实现对路灯的自动调节。该系统能够根据周围环境光照强度、车辆和行人活动等情况,自动控制路灯的开关及亮度调节,从而有效减少能…

Qml 模型-视图-代理(贰)之 动态视图学习

目录 动态视图 动态视图用法 ⽅向(Orientation) 键盘导航和⾼亮 页眉与页脚 网格视图 动态视图 动态视图用法 Repeater 元素适合有限的静态数据, QtQuick 提供了 ListView 和 GridView, 这两个都是基于 Flickable(可滑动) 区域的元素…

新标准大学英语综合教程1课后习题答案PDF第三版

《新标准大学英语(第三版)综合教程1 》是“新标准大学英语(第三版)”系列教材之一。本书共包含6个单元,从难度和话题上贴近大一上学生的认知和语言水平,包括与学生个人生活领域和社会文化等相关内容&#x…

Python闭包|你应该知道的常见用例(下)

引言 在 Python 编程语言中,闭包通常指的是一个嵌套函数,即在一个函数内部定义的另一个函数。这个嵌套的函数能够访问并保留其外部函数作用域中的变量。这种结构就构成了一个闭包。 闭包在函数式编程语言中非常普遍。在 Python 中,闭包特别有…

Rocky、Almalinux、CentOS、Ubuntu和Debian系统初始化脚本v9版

Rocky、Almalinux、CentOS、Ubuntu和Debian系统初始化脚本 Shell脚本源码地址: Gitee:https://gitee.com/raymond9/shell Github:https://github.com/raymond999999/shell脚本可以去上面的Gitee或Github代码仓库拉取。 支持的功能和系统&am…

AUTOSAR OS模块详解(一) 概述

AUTOSAR OS模块详解(一) 概述 本文主要介绍AUTOSAR架构下的OS概述。 文章目录 AUTOSAR OS模块详解(一) 概述1 前言1.1 操作系统1.2 嵌入式操作系统1.3 AUTOSAR操作系统 2 AUTOSAR OS2.1 AUTOSAR OS组成2.2 AUTOSAR OS类别2.3 任务管理2.4 调度表2.5 资源管理2.6 多核特性2.7 …

5位机械工程师如何共享一台工作站的算力?

在现代化的工程领域中,算力已成为推动创新与技术进步的关键因素之一。对于机械工程师而言,强大的计算资源意味着能够更快地进行复杂设计、模拟分析以及优化工作,从而明显提升工作效率与项目质量。然而,资源总是有限的,…

Scala 中 set 的实战应用 :图书管理系统

1. 创建书籍集合 首先,我们创建一个可变的书籍集合,用于存储图书馆中的书籍信息。在Scala中,mutable.Set可以用来创建一个可变的集合。 val books mutable.Set("朝花惜拾", "活着") 2. 添加书籍 我们可以使用操作符…

DevCheck Pro手机硬件检测工具v5.33

前言 DevCheck Pro是一款手机硬件和操作系统信息检测查看工具,该软件的功能非常强大,为用户提供了系统、硬件、应用程序、相机、网络、电池等一系列信息查看功能 安装环境 [名称]:DevCheckPro [版本]:5.33 [大小]&a…

cv::intersectConvexConvex返回其中一个输入点集,两个点集不相交

问题:cv::intersectConvexConvex返回其中一个输入点集,但两个点集并不相交 版本:opencv 3.1.0 git上也有人反馈了intersectConvexConvex sometimes returning one of the input polygons in case of empty intersection #10044 是凸包嵌套判…

【刷题12】ctfshow刷题

来源:ctfshow easyPytHon_P 考点:代码审计,源代码查看 打开后查看源码,发现一个源码地址,打开看看 可以知道在此目录下有个flag.txt文件,再观察源码 from flask import request cmd: str request.form.get…

spark的学习-03

RDD的创建的两种方式: 方式一:并行化一个已存在的集合 方法:parallelize 并行的意思 将一个集合转换为RDD 方式二:读取外部共享存储系统 方法:textFile、wholeTextFile、newAPIHadoopRDD等 读取外部存储系统的数…

axios平替!用浏览器自带的fetch处理AJAX(兼容表单/JSON/文件上传)

fetch 是啥? fetch 函数是 JavaScript 中用于发送网络请求的内置 API,可以替代传统的 XMLHttpRequest。它可以发送 HTTP 请求(如 GET、POST 等),并返回一个 Promise,从而简化异步操作 基本用法 /* 下面是…

Linux(CentOS)安装 Nginx

CentOS版本:CentOS 7 Nginx版本:1.24.0 有两种安装方式 一、通过 yum 安装 需要 root 权限,普通用户使用 sudo 进行命令操作 参考:https://nginx.org/en/linux_packages.html#RHEL 1、安装依赖 sudo yum install yum-utils 2…

[原创]手把手教学之前端0基础到就业——day11( Javascript )

文章目录 day11(Javascript)01Javascript①Javascript是什么②JavaScript组成③ Javascript的书写位置1. 行内式 (不推荐)2 . 内部位置使用 ( 内嵌式 )3. 外部位置使用 ( 外链式 ) 02变量1. 什么是变量2. 定义变量及赋值3. 注意事项4. 命名规范 03输入和输出1) 输出形式12) 输出…

【C++笔记】C++三大特性之继承

【C笔记】C三大特性之继承 🔥个人主页:大白的编程日记 🔥专栏:C笔记 文章目录 【C笔记】C三大特性之继承前言一.继承的概念及定义1.1 继承的概念1.2继承的定义1.3继承基类成员访问方式的变化1.4继承类模板 二.基类和派生类间的转…

Colorful/七彩虹iGame G-ONE Plus 12代处理器 Win11原厂OEM系统 带COLORFUL一键还原

安装完毕自带原厂驱动和预装软件以及一键恢复功能,自动重建COLORFUL RECOVERY功能,恢复到新机开箱状态。 【格式】:iso 【系统类型】:Windows11 原厂系统下载网址:http://www.bioxt.cn 注意:安装系统会…

【LeetCode】分发糖果 解题报告

135. 分发糖果 - 题目链接 n个孩子站成一排。给你一个整数数组ratings表示每个孩子的评分。 你需要按照以下要求,给这些孩子分发糖果: 每个孩子至少分配到1个糖果。相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩子分发糖果,…

ArcGIS从Excel表格文件导入XY数据并定义坐标系与投影的方法

本文介绍在ArcMap软件中,从Excel表格文件中批量导入坐标点数据,将其保存为.shp矢量格式,并定义坐标系、转为投影坐标系的方法。 已知我们有一个Excel表格文件(可以是.xls、.xlsx、.csv等多种不同的表格文件格式)&#…