昇思25天学习打卡营第12天|应用实践之基于MindSpore通过GPT实现情感分类

news2024/9/23 21:25:51

基本介绍

        今天的应用实践是基于MindSpore通过GPT实现情感分类,这与之前的使用BERT模型实现情绪分类有异曲同工之妙,本次使用的模型是OpenAI开源的GPT,数据集是MindNLP内置的数据集imdb。我们将会使用该数据集对GPT进行训练,然后进行测试。由于数据集是内置的数据集,可以直接进行加载即可,若本地没有该数据集,则会先下载,再加载到内存,具体代码如下:

imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

数据集一般不会直接符合模型的输入,所以要对数据集进行预处理,主要预处理就是batch划分和Token化,处理完毕进行数据集划分即可。具体代码如下:

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset


from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])


dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

有了数据集,就需要模型,而模型是一个开源模型,MindNLP可以很方便加载该模型,加载了模型,配置训练相关参数,然后就可以训练模型了,具体代码如下:

# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)

trainer.run(tgt_columns="labels")

GPT的数据量比GPT2少很多,训练+验证大概用了1个小时即可

训练完毕,可使用测试集进行测试,看模型效果,测试结果如下:

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1915439.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鲁棒控制器设计方法:systune,hinfsyn,musyn,slTuner

systune和hinfsyn更侧重于基于数学模型的控制器设计&#xff0c;而musyn则特别考虑了系统的不确定性。slTuner则提供了在Simulink环境中进行控制器设计和调整的能力。 指定结构的控制器整定&#xff1a;systune, hinfstruct广义控制对象整定&#xff1a;musyn, mixed musyn, h…

LabVIEW自动测控与故障识别系统

使用LabVIEW 2019在Win10 64位系统上开发自动测控软件&#xff0c;通过与基恩士NR-X100数据采集仪通讯&#xff0c;实时采集和分析数据&#xff0c;自动识别判断产品是否合格&#xff0c;并增加数据记录和仿真功能。 具体解决方案&#xff1a; 1. 系统架构设计 硬件接口&#…

设计模式之工厂模式(简单工厂、工厂方法、抽象工厂)

写在前面&#xff1a;本文是个人在学习设计模式时的所思所想&#xff0c;汇总了其他博主及自己的感悟思考&#xff0c;可能存在出入&#xff0c;请大家理性食用~~ 工厂模式 在工厂模式中&#xff0c;父类决定实例的生成方式&#xff0c;但并不决定所要生成的具体的类&#xf…

带你了解“Java新特性——模块化”

Java平台从Java 8向Java 9及更高版本的进化&#xff0c;其中引入了一个重要的新特性——模块系统&#xff08;Project Jigsaw&#xff09;。模块系统的目的是解决大型应用的依赖管理问题&#xff0c;提升性能&#xff0c;简化JRE&#xff0c;增强兼容性和安全性&#xff0c;并提…

求整数数组的子集【C语言】

方法1&#xff1a;通过二进制位&#xff0c;因为n个整数数组的子集有2的n次方个&#xff0c;例如整数数组为{1,2,3},子集有2的3次方&#xff0c;8个&#xff1b; 期望的输出形式 其中需要了解关注的是 n&1判断最低位是否有数。如果一个子集为{2}&#xff0c;利用二进制位…

C++初阶:类与对象(一)

✨✨所属专栏&#xff1a;C✨✨ ✨✨作者主页&#xff1a;嶔某✨✨ 类的定义 定义格式 • class为定义类的关键字&#xff0c;后面跟类的名字&#xff0c;{}中为类的主体&#xff0c;注意类定义结束时后⾯分号不能省略。类体中内容称为类的成员&#xff1b;类中的变量称为类的…

2024最新PyCharm下载安装

&#xff08;1&#xff09;打开官网&#xff1a;https://www.jetbrains.com/ &#xff08;2&#xff09;点击pycharm &#xff08;3&#xff09;进入后点击下载按钮 &#xff08;4&#xff09;此时有两个选择&#xff1a;有专业版和社区版 PyCharm有专业版&#xff08;Prof…

zynq启动和程序固化流程

普通FPGA启动 FPGA的启动方式主要包含主动模式、被动模式和JTAG模式。 主动模式&#xff08;AS模式&#xff09; 当FPGA器件上电时&#xff0c;它作为控制器从配置器件EPCS中主动发出读取数据信号&#xff0c;并将EPCS的数据读入到自身中&#xff0c;实现对FPGA的编程。这种…

公众号运营秘籍:8 大策略让你的粉丝翻倍!

在当今信息爆炸的时代&#xff0c;微信公众号的运营者们面临着前所未有的挑战&#xff1a;如何在这个充满竞争的红海中脱颖而出&#xff0c;吸引并留住粉丝&#xff1f;事实上&#xff0c;微信公众号的红利期并未完全过去&#xff0c;关键在于我们如何策略性地运营&#xff0c;…

关于复现StableDiffusion相关项目时踩坑的记录

研究文生图也有了一段时间&#xff0c;复现的论文也算是不少&#xff0c;这篇博客主要记录我自己踩的坑。 目前实现文生图的项目主要分为两类&#xff1a; 一、基于Stable-diffusion原项目文件实现 原项目地址&#xff1a;https://github.com/Stability-AI/stablediffusion …

【自监督学习】DINO in ICCV 2021

一、引言 论文&#xff1a; DINO: Emerging Properties in Self-Supervised Vision Transformers 作者&#xff1a; Facebook AI Research 代码&#xff1a; DINO 特点&#xff1a; 对于一张图片&#xff0c;该方法首先进行全局和局部的裁剪与增强并分别送入教师和学生网络&am…

YOLOv10改进 | 图像去雾 | MB-TaylorFormer改善YOLOv10高分辨率和图像去雾检测(ICCV,全网独家首发)

一、本文介绍 本文给大家带来的改进机制是图像去雾MB-TaylorFormer&#xff0c;其发布于2023年的国际计算机视觉会议&#xff08;ICCV&#xff09;上&#xff0c;可以算是一遍比较权威的图像去雾网络&#xff0c; MB-TaylorFormer是一种为图像去雾设计的多分支高效Transformer…

WordPress PHP Everywhere <= 2.0.3 远程代码执行漏洞(CVE-2022-24663)

前言 CVE-2022-24663 是一个影响 WordPress 插件 PHP Everywhere 的远程代码执行&#xff08;RCE&#xff09;漏洞。PHP Everywhere 插件允许管理员在页面、文章、侧边栏或任何 Gutenberg 块中插入 PHP 代码&#xff0c;以显示基于评估的 PHP 表达式的动态内容。然而&#xff…

FreeCAD: 将STL格式文件转换为step格式文件的记录

首先我们需要下载开源的FreeCAD软件&#xff0c;官网链接如下&#xff1a; FreeCAD: Your own 3D parametric modeler 傻瓜式安装&#xff0c;跳过~ FreeCAD 是一款免费的开源CAD软件&#xff0c;支持多种文件格式转换&#xff0c;包括STL到STEP。 步骤&#xff1a; 打开Free…

PTrade常见问题系列7

获取可转债数据为空。 量化交易内&#xff0c;获取可转债标的行情&#xff0c;提示报错12319*.SZ不支持。 1、建议客户在研究内执行get_price&#xff0c;返回无数据&#xff1b; 2、怀疑asset.pk内不存在该可转债代码&#xff0c;再研究内执行import pandas as pd df pd.re…

前端使用pinia中存入的值

导入pinia,创建pinia实例 使用pinia中的值

Rust: 高性能序列化库Fury PK bincode

在序列化库中&#xff0c;传统的有Json,XML&#xff0c;性能好的有thrift&#xff0c;protobuf等。 对于二进制库来讲&#xff0c;据Fury官网的介绍&#xff0c;Fury性能要远远好于protobuf&#xff0c;且不象protobuf还需要定义IDL(即写.proto文件)&#xff0c;非常轻便&#…

数据库-ubuntu环境下安装配置mysql

文章目录 什么是数据库&#xff1f;一、ubuntu环境下安装mysql二、配置mysql配置文件1.先登上root账号2.配置文件的修改show engines \G; mysql和mysqld数据库的基础操作登录mysql创建数据库显示当前数据库使用数据库创建表插入students表数据打印students表数据select * from …

【ArcGIS 小技巧】为国空用地字段设置属性域,快速填充属性值并减少出错

属性域属性是描述字段类型可用值的规则。可用于约束表或要素类的任意特定属性中的允许值。——ArcGIS Pro 帮助文档 简单理解属性域&#xff1a;对于一个含义为性别的字段&#xff0c;我们一般会给的属性值有男、女两种。我们可以将这两种属性值制作成属性域并指定给该字段&…

05STM32EXIT外部中断中断系统

STM32EXIT外部中断&中断系统 中断系统中断触发条件&#xff1a;中断处理流程和用途&#xff1a; STM32中断NVIC嵌套中断向量控制器基本结构 中断系统 中断触发条件&#xff1a; 对外部中断来说&#xff0c;可以是引脚发生了电平跳变 对定时器来说&#xff0c;可以是定时的…