昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

news2024/11/23 11:58:05

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1890090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【深海王国】小学生都能玩的语音模块?ASRPRO打造你的第一个智能语音助手(4)

Hi~ (o^^o)♪, 各位深海王国的同志们&#xff0c;早上下午晚上凌晨好呀~ 辛勤工作的你今天也辛苦啦(/≧ω) 今天大都督继续为大家带来系列——小学生都能玩的语音模块&#xff0c;帮你一周内快速学会语音模块的使用方式&#xff0c;打造一个可用于智能家居、物联网领域的语音助…

01 Docker 概述

目录 1.Docker简介 2.传统虚拟机 vs 容器 3.Docker运行速度快的原因 4.Docker基本组成三要素 5.Docker 平台架构 入门版 架构版 1.Docker简介 Docker是基于Go语言实现的云开源项目。 Docker的主要目标是&#xff1a;Build, Ship and Run Any App, Anywhere&#xff0c…

抖音常用的视频剪辑软件有哪些,变速视频如何制作?

抖音是一款当下流行的短视频软件。很多人都想在上面发表自己的作品&#xff0c;但是也还有人因为不会剪辑&#xff0c;找不到合适的视频制作软件&#xff0c;一直没能行动。今天就为大家解答抖音常用的制作视频软件有哪些&#xff0c;如何调整抖音制作视频的速度。 希望大家看完…

AzureDataFactory 实体间的关联如何处理(Lookup)

使用ADF从外部数据源(例如Sql Server)往D365推数时&#xff0c;实体间的Lookup一定是要做的&#xff0c;本篇以我项目中的设备为例&#xff0c;设备表中有产品的lookup字段 设备表结构如下 msdyn_customerasset 表名ID 设备表guidSerialNumber设备序列号ProductCode设备对应的…

Hadoop3:NameNode和DataNode多目录配置(扩充磁盘的技术支持)

一、NameNode多目录 1、说明 NameNode多目录&#xff0c;需要在刚搭建Hadoop集群的时候&#xff0c;就配置好 因为&#xff0c;配置这个&#xff0c;需要格式化NameNode 所以&#xff0c;如果一开始没配置NameNode多目录&#xff0c;后面&#xff0c;就不要配置了。 2、配置…

Linux环境下的字节对齐现象

在Linux环境下&#xff0c;字节对齐是指数据在内存中的存储方式。字节对齐是为了提高内存访问的效率和性能。 在Linux中&#xff0c;默认情况下&#xff0c;结构体和数组的成员会进行字节对齐。具体的对齐方式可以通过编译器选项来控制。 在使用C语言编写程序时&#xff0c;可…

技术市集 | 如何通过WSL 2在Windows上挂载Linux磁盘?

你是否常常苦恼&#xff0c;为了传输或者共享不同系统的文件需要频繁地在 Windows 和 Linux 系统之间切换&#xff0c;既耽误工作效率&#xff0c;也容易出错。 那么有没有一种办法&#xff0c;能够让你在Windows系统中像访问本地硬盘一样来操作Linux系统中的文件呢&#xff1…

jni原理和实现

一、jni原理 主要就是通过数据类型签名和反射来实现java与c/c方法进行交互的 数据类型签名对应表 javac/cbooleanZbyteBcharCshortSintIlongLfloatFdoubleDvoidVobjectL开头&#xff0c;然后以/分割包的完整类型&#xff0c;后面再加; 比如String的签名就是Ljava/long/Strin…

基于jeecgboot-vue3的Flowable流程-集成仿钉钉流程(一)一些样式的调整使用

因为这个项目license问题无法开源&#xff0c;更多技术支持与服务请加入我的知识星球。 1、比如下面的发起人双击后出现的界面不正常&#xff0c; 看它的样式主要是这个里面的margin-left应该太小了&#xff0c; [data-v-45b533d5] .el-tabs__content { margin-top: 50px;mar…

实用麦克风话筒音频放大器电路设计和电路图

设计目标 输入电压最大值输出电压最大值电源Vcc电源Vee频率响应偏差20Hz频率响应偏差20kHz100dB SPL(2Pa)1.228Vrms5V0V–0.5dB–0.1dB 设计说明 此电路使用跨阻抗放大器配置中的运算放大器将驻极体炭精盒麦克风的输出电流转换为输出电压。此电路的共模电压是固定的&#xf…

第15届蓝桥杯Python青少组选拔赛(STEMA)2023年8月真题-附答案

第15届蓝桥杯Python青少组选拔赛&#xff08;STEMA&#xff09;2023年8月真题 题目总数&#xff1a; 11 总分数&#xff1a; 400 一、单选题 第 1 题 单选题 以下不符合 Python 语言变量命名规则的是&#xff08; &#xff09;。 A. k B. 2_k C. _k D. ok 答案 B …

cesium方案论证实现功能

仓库地址&#xff1a;Harvey-Andrew 演示地址&#xff1a;哔哩哔哩-满分观察网友z 文章目录 1. 场景加载2. 3D 模型2.1. 坐标转换2.2. 放置模型2.3. 调整模型2.4. 提交方案 3. 查看方案3.1. 场景还原3.2. 删除 1. 场景加载 加载Cesium的Melbourne Photogrammetry的倾斜摄影作…

【Threejs进阶教程-着色器篇】1. Shader入门(ShadertoyShader和ThreejsShader入门)

ThreejsShader入门 关于本Shader教程认识ShaderShader和Threejs的关系WebGLShaderThreejsShaderShadertoyShader其他Shader 再次劝退数学不好的人从ShaderToy开始Shader的代码是强类型glsl的类型&#xff0c;变量&#xff0c;内置函数&#xff0c;关键字关于uv基于UV的颜色处理…

Linux——高级IO

目录 IO 五种IO模型 阻塞式IO 非阻塞式IO 信号驱动IO 多路转接 异步IO 阻塞IO VS 非阻塞IO IO 网络的知识我们已经介绍完了&#xff0c;网络通信的本质就是IO&#xff0c;一方要发送数据&#xff0c;还要接收数据&#xff0c;这就是一次IO&#xff0c;所以我们原来说过…

解决VSCode中导入PyTorch时报错的HTTP错误与Channel冲突

问题描述与解释 在Anaconda中成功安装PyTorch&#xff0c;并进行了验证&#xff1a; (base) C:\Users\Hui>conda activate pytorch(pytorch) C:\Users\\Hui>python Python 3.8.19 (default, Mar 20 2024, 19:55:45) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on …

安装windows服务,细节

1、选中服务代码&#xff0c;右键添加安装程序。 2、安装程序的权限一定改为local,否则安装时会提示null错误。 3、安装服务 InstallUtil D:\vs2022work\testFW\testFW\bin\Debug\testFW.exe p:InstallUtil 需要新建环境变量才能直接使用&#xff08;找到InstallUtil 工具所在…

沃尔核材:价值重估

当英伟达这个曾经的GPU行业龙头&#xff0c;伴随AI的发展成为AI芯片架构的供应商时&#xff0c;他就跳出了原本行业的竞争格局&#xff0c;曾经还能与之一战的超威半导体被远远甩在身后&#xff0c;成为宇宙第一公司。 这说的就是一家公司价值的重估。今天给大家聊的也是这样一…

【C++】相机标定源码笔记- RGB 相机与 ToF 深度传感器校准类

类的设计目标是为了实现 RGB 相机与 ToF 深度传感器之间的高精度校准&#xff0c;从而使两种类型的数据能够在同一个坐标框架内被整合使用。这在很多场景下都是非常有用的&#xff0c;比如在3D重建、增强现实、机器人导航等应用中&#xff0c;能够提供更丰富的场景信息。 -----…

学习笔记(linux高级编程)11

进程间通信 》信号通信 应用&#xff1a;异步通信。 中断&#xff0c;&#xff0c; 1~64&#xff1b;32应用编程。 如何响应&#xff1a; Term Default action is to terminate the process. Ign Default action is to ignore the signal. wait Core Default action is …

Eclipse运行main函数报 launch error

右键run as java application&#xff0c;运行main函数的时候报launch error 解决方式&#xff1a;文件右键run configurations 旧的是Project JRE&#xff0c;改成下图这个样子