Mindspore 公开课 - GPT

news2024/11/16 1:39:50

GPT Task

在模型 finetune 中,需要根据不同的下游任务来处理输入,主要的下游任务可分为以下四类:

  • 分类(Classification):给定一个输入文本,将其分为若干类别中的一类,如情感分类、新闻分类等;
  • 蕴含(Entailment):给定两个输入文本,判断它们之间是否存在蕴含关系(即一个文本是否可以从另一个文本中推断出来);
  • 相似度(Similarity):给定两个输入文本,计算它们之间的相似度得分;
  • 多项选择题(Multiple choice):给定一个问题和若干个答案选项,选择最佳的答案。

在这里插入图片描述
我们使用IMDb数据集,通过finetune GPT进行情感分类任务。

IMDb数据集是一个常用的情感分类数据集,其中包含50,000条影评文本,其中25,000条用作训练数据,另外25,000条用作测试数据。每个样本都有一个二元标签,表示影评的情感是正面还是负面。

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp import load_dataset
from mindnlp.transforms import PadTransform, GPTTokenizer

from mindnlp.engine import Trainer, Evaluator
from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp.metrics import Accuracy

数据预处理

# 加载数据集
imdb_train, imdb_test = load_dataset('imdb', shuffle=False)

通过load_dataset加载IMDb数据集后,我们需要对数据进行如下处理:

  • 将文本内容进行分词,并映射为对应的数字索引;
  • 统一序列长度:超过进行截断,不足通过占位符进行补全;
  • 按照分类任务的输入要求,在句首和句末分别添加Start与Extract占位符(此处用 <bos>与 <eos>表示);
  • 批处理。
import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):
    """数据集预处理"""
    def pad_sample(text):
        if len(text) + 2 >= max_seq_len:
            return np.concatenate(
                [np.array([tokenizer.bos_token_id]), text[: max_seq_len-2], np.array([tokenizer.eos_token_id])]
            )
        else:
            pad_len = max_seq_len - len(text) - 2
            return np.concatenate( 
                [np.array([tokenizer.bos_token_id]), text,
                 np.array([tokenizer.eos_token_id]),
                 np.array([tokenizer.pad_token_id] * pad_len)]
            )

    column_names = ["text", "label"]
    rename_columns = ["input_ids", "label"]

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    dataset = dataset.map(operations=[tokenizer, pad_sample], input_columns="text")
    dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)
    dataset = dataset.batch(batch_size)

    return dataset

加载 GPT tokenizer,并添加上述使用到的 <bos>, <eos>, <pad>占位符。

gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

由于IMDb数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。

imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

模型训练

同BERT课件中的情感分类任务实现,这里我们依旧使用了混合精度。另外需要注意的一点是,由于在前序数据处理中,我们添加了3个特殊占位符,所以在token embedding中需要调整词典的大小(vocab_size + 3)。

from mindnlp.models import GPTForSequenceClassification
from mindnlp._legacy.amp import auto_mixed_precision

model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)
model = auto_mixed_precision(model, 'O1')

loss = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='sentiment_model', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=3, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb], jit=True)

trainer.run(tgt_columns="label")

模型评估

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="label")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1387047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手写一个starter来理解SpringBoot的自动装配

自动装配以及简单的解析源码 自动装配是指SpringBoot在启动的时候会自动的将系统中所需要的依赖注入进Spring容器中 我们可以点开SpringBootApplication这个注解来一探究竟 点开这个注解可以发现这些 我们点开SpringBootConfiguration这个注解 可以发现实际上SpringBootApp…

【Python学习】Python学习20- 面向对象(1)

目录 【Python学习】Python学习20- 面向对象&#xff08;1&#xff09; 前言面向对象技术简介类的创建实例&#xff1a;创建实例对象访问属性 Python内置类属性完整代码输出 参考 文章所属专区 Python学习 前言 本章节主要说明Python的面向对象的处理。Python从设计之初就已经…

2024-01-05 C语言定义的函数名里面插入宏定义,对函数名进行封装,可以通过宏定义批量修改整个文件的函数名里面的内容

一、C语言定义的函数名里面插入宏定义&#xff0c;对函数名进行封装&#xff0c;可以通过宏定义批量修改整个文件的函数名里面的内容。使用下面的代码对函数进行封装&#xff0c;这样移植的时候可以根据包名和类名进行批量修改&#xff0c;不用一个函数一个函数的修改。。 #de…

2023年全球软件开发大会(QCon北京站2023)9月:核心内容与学习收获(附大会核心PPT下载)

随着科技的飞速发展&#xff0c;全球软件开发大会&#xff08;QCon&#xff09;作为行业领先的技术盛会&#xff0c;为世界各地的专业人士提供了交流与学习的平台。本次大会汇集了全球的软件开发者、架构师、项目经理等&#xff0c;共同探讨软件开发的最新趋势、技术与实践。本…

Linux下MySQL用户管理、权限、密码

一、原理 MySQL的用户管理实质上是对用户表的管理&#xff0c;系统中的数据库mysql存在一张用户表&#xff08;user&#xff09;&#xff0c;所有的用户都在该表内&#xff0c;对用户的管里也就是对该表进行增删查改的操作。 show databases; 如图中的mysql数据库&#xff0c;…

一致性协议浅析

Paxos 简介 Paxos 发明者是大名鼎鼎的 Lesile Lamport。Lamport 虚拟了一个叫做 Paxos 的希腊城邦&#xff0c;城邦按照议会民主制的政治模式制定法律。在 Lesile Lamport 的论文中&#xff0c;提出了 Basic Paxos、Multi Paxos、Fast Paxos 三种模型。 Basic Paxos 角色介绍…

网络安全等级保护测评规划与设计

笔者单位网络结构日益复杂&#xff0c;应用不断增多&#xff0c;使信息系统面临更多的风险。同时&#xff0c;网络攻防技术发展迅速&#xff0c;攻击的技术门槛随着自动化攻击工具的应用也在不断降低&#xff0c;勒索病毒等未知威胁也开始泛滥。基于此&#xff0c;笔者单位拟进…

c语言嵌套循环

c语言嵌套循环 c语言嵌套循环 c语言嵌套循环一、c语言嵌套循环类型二、嵌套循环案例九九惩罚口诀 一、c语言嵌套循环类型 for(初始值&#xff1b;表达式&#xff1b;表达式) {for&#xff08;初始值&#xff1b;表达式&#xff1b;表达式&#xff09;{代码} }int main() {for (…

报错消息号M3318:数字***没有定义对于物料类型原材料***

消息号M3318&#xff1a;数字***没有定义对于物料类型原材料*** 报错说明自定的物料编码&#xff0c;不在编码范围内&#xff0c;外部给号不在后台配置的范围内。MMNR可以查看后台定义的范围。 消息号&#xff1a;M3565对于物料类型&#xff0c;外部物料数一定不能只包含数量 …

获取本地IP网卡信息

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、获取本地IP&#xff0c;以及全部网卡信息总结 前言 一、获取本地IP&#xff0c;以及全部网卡信息 const os require(node:os) function getIPAdress(){/…

十三、Three场景物体增加发光特效

物体发光效果非常炫酷,本期来讲three场景内物体自带发光效果怎么来实现。本次使用的是threejs138版本,在vue3+vite+ant的项目中使用。 下面来看看实现的效果。绿色罐体有了明显的发光效果。 实现步骤 增加composer.js import { UnrealBloomPass } from three/examples/jsm/po…

C++ | 四、指针、链表

指针 指针用来储存地址定义方式&#xff0c;int *ptr;&#xff0c;使用*来表示所定义的变量是指针取地址符&#xff0c;ptr &a;&#xff0c;通过&来取得一个普通变量的地址&#xff0c;并储存到指针中取值&#xff08;解引用&#xff09;&#xff0c;想要取得一个指针…

软件测试|Beautiful Soup库详细使用指南

简介 Beautiful Soup是一款强大的Python库&#xff0c;广泛用于解析HTML和XML文档&#xff0c;从中提取数据并进行处理。它的灵活性和易用性使得数据抽取变得简单&#xff0c;本文将详细介绍Beautiful Soup库的基本用法和示例。 安装Beautiful Soup 首先&#xff0c;需要确保…

一文详解JAVA的字节流,BufferedReader和BufferedWriter

目录 一、什么是Java的字节流 二、BufferedReader介绍 三、BufferedWriter介绍 一、什么是Java的字节流 Java的字节流是一种用于处理二进制数据的输入输出流。在Java中&#xff0c;字节流以字节为单位进行读取和写入操作。字节流分为输入字节流和输出字节流。 输入字节流&…

springCloude中Eureka模拟搭建集群

开三个不同端口号的服务&#xff0c; 而且还得模拟出三个不同的ip&#xff0c;由于时本机&#xff0c;所以只能去做三个本地域名&#xff0c;不要乱来&#xff0c;弄不好会出事的! eureka8886.com eureka8887.com eureka8888.com 这个是eureka的集群模块。 提供模块&#xff0…

Redis图形界面闪退/错误2系统找不到指定文件/windows无法启动Redis/不是内部或外部命令,也不是可运行的程序

Redis图形界面闪退/错误2系统找不到指定文件/windows无法启动Redis/不是内部或外部命令&#xff0c;也不是可运行的程序 我遇到了以上的问题。 其实&#xff0c;最重要的原因是我打开不了another redis desktop mannager&#xff0c;就是我安装了之后&#xff0c;无法打开它…

FFmpeg解决视频播放加载卡顿问题(FFmpeg+M3U8分片)

FFmpeg解决视频播放加载卡顿问题(FFmpegM3U8分片) 在这静谧的时光里&#xff0c;我们能够更清晰地审视自己&#xff0c;思考未来的方向。每一步的坚实&#xff0c;都是对勇气的拥抱&#xff0c;每一个夜晚的努力&#xff0c;都是对未来的信仰。不要害怕独行&#xff0c;因为正是…

机器学习算法实战案例:时间序列数据最全的预处理方法总结

文章目录 1 缺失值处理1.1 统计缺失值1.2 删除缺失值1.3 指定值填充1.4 均值/中位数/众数填充1.5 前后项填充 2 异常值处理2.1 3σ原则分析2.2 箱型图分析 3 重复值处理3.1 重复值计数3.2 drop_duplicates重复值处理 3 数据归一化/标准化3.1 0-1标准化3.2 Z-score标准化 技术交…

使用Openssl生成Https免费证书以及Nginx配置

1 证书和私钥的生成 1.创建服务器证书密钥文件 server.key&#xff1a; openssl genrsa -des3 -out server.key 2048 输入密码&#xff0c;确认密码&#xff0c;自己随便定义&#xff0c;但是要记住&#xff0c;后面会用到。 2.创建服务器证书的申请文件 server.csr openssl r…

外包干了4年,废了···

有一说一&#xff0c;外包没有给很高的薪资&#xff0c;是真不能干呀&#xff01; 先说一下自己的情况&#xff0c;大专生&#xff0c;19年通过校招进入湖南某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0…