《昇思25天学习打卡营第24天|基于MindSpore通过GPT实现情感分类》

news2024/9/20 12:16:18

在这里插入图片描述

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()
import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset
from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)
trainer.run(tgt_columns="labels")
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1952064.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

yolov5-7环境搭建训练自己的模型

1.下载代码 git clone https://github.com/ultralytics/yolov5 # clone可以切到5-7版本&#xff0c;也可以去github选标签下载 2.配置好conda环境&#xff0c;网上教程比较多&#xff0c;不做讲解&#xff0c;python3.8即可。 3.在环境里安装pyrtorch 按自己的需求选取&am…

每日一练,java05

目录 题目知识点&#xff1a;1.12.13.1 题目 选自牛客网 1.下列表述错误的是&#xff1f;&#xff08;&#xff09; A.int是基本类型&#xff0c;直接存数值&#xff0c;Integer是对象&#xff0c;用一个引用指向这个对象。 B.在子类构造方法中使用super()显示调用父类的构造…

第T6周:使用TensorFlow实现好莱坞明星识别

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 文章目录 一、前期工作1.设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09;2. 导入数据3. 查看数据 二、数据预处理1、加载数据2、数据可视化3、再…

【论文速读】| LLMCloudHunter:利用大语言模型(LLMs)从基于云的网络威胁情报(CTI)中自动提取检测规则

本次分享论文&#xff1a;LLMCloudHunter: Harnessing LLMs for Automated Extraction of Detection Rules from Cloud-Based CTI 基本信息 原文作者&#xff1a;Yuval Schwartz, Lavi Benshimol, Dudu Mimran, Yuval Elovici, Asaf Shabtai 作者单位&#xff1a;Ben-Gurion…

mfc100u.dll 文件缺失?两种方法快速修复丢失mfc100u.dll 文件难题

您的电脑是否遭遇了 mfc100u.dll 文件缺失的问题&#xff1f;这种情况通常由多种原因引起。在本文中&#xff0c;我们将介绍两种修复 mfc100u.dll 文件丢失问题的策略——一种是手动方法&#xff0c;另一种是自动修复的使用。我们将探讨如何有效地解决 mfc100u.dll 文件缺失的几…

Linux下git入门操作

0.创建仓库 可以按这个配置来&#xff0c;.gitignore中存放了上传时忽略的文件类型后缀。 1.clone仓库 在gitee上创建好仓库&#xff0c;点击克隆/下载&#xff0c; 复制地址fyehong/Linux_notes 。 在所需的文件夹中放置仓库。比如我在文件夹lesson9下存储仓库。就在less…

Python爬虫技术 第18节 数据存储

Python 爬虫技术常用于从网页上抓取数据&#xff0c;并将这些数据存储起来以供进一步分析或使用。数据的存储方式多种多样&#xff0c;常见的包括文件存储和数据库存储。下面我将通过一个简单的示例来介绍如何使用 Python 爬取数据&#xff0c;并将其存储为 CSV 和 JSON 文件格…

【数据结构】二叉树链式结构——感受递归的暴力美学

前言&#xff1a; 在上篇文章【数据结构】二叉树——顺序结构——堆及其实现中&#xff0c;实现了二叉树的顺序结构&#xff0c;使用堆来实现了二叉树这样一个数据结构&#xff1b;现在就来实现而二叉树的链式结构。 一、链式结构 链式结构&#xff0c;使用链表来表示一颗二叉树…

【机器学习】解开反向传播算法的奥秘

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 解开反向传播算法的奥秘反向传播算法的概述反向传播算法的数学推导1. 前向传播2…

3.k8s:服务发布:service,ingress;配置管理:configMap,secret,热更新;持久化存储:volumes,nfs,pv,pvc

目录​​​​​​​ 一、服务发布 1.service &#xff08;1&#xff09;service和pod之间的关系 &#xff08;2&#xff09; service内部服务创建访问 &#xff08;3&#xff09;service访问外部服务 &#xff08;4&#xff09;基于域名访问外部 &#xff08;5&#xff…

Prometheus各类监控及监控指标和告警规则

目录 linux docker监控 linux 系统进程监控 linux 系统os监控 windows 系统os监控 配置文件&告警规则 Prometheus配置文件 node_alert.rules docker_container.rules mysql_alert.rules vmware.rules Alertmanager告警规则 consoul注册服务 Dashboard JSON…

并发编程--volatile

1.什么是volatile volatile是 轻 量 级 的 synchronized&#xff0c;它在多 处 理器开 发 中保 证 了共享 变 量的 “ 可 见 性 ” 。可 见 性的意思是当一个 线 程 修改一个共享变 量 时 &#xff0c;另外一个 线 程能 读 到 这 个修改的 值 。如果 volatile 变 量修 饰 符使用…

车载录像机:移动安全领域的科技新星

随着科技的飞速发展&#xff0c;人类社会的各个领域都在不断经历技术革新。其中&#xff0c;车载录像机作为安防行业与汽车技术结合的产物&#xff0c;日益受到人们的关注。它不仅体现了人类科技发展的成果&#xff0c;更在安防领域发挥了重要作用。本文将详细介绍车载录像机的…

Spring Boot集成canal快速入门demo

1.什么是canal&#xff1f; canal 是阿里开源的一款 MySQL 数据库增量日志解析工具&#xff0c;提供增量数据订阅和消费。 工作原理 MySQL主备复制原理 MySQL master 将数据变更写入二进制日志&#xff08;binary log&#xff09;, 日志中的记录叫做二进制日志事件&#xff…

【QT】UDP

目录 核心API 示例&#xff1a;回显服务器 服务器端编写&#xff1a; 第一步&#xff1a;创建出socket对象 第二步&#xff1a; 连接信号槽 第三步&#xff1a;绑定端口号 第四步&#xff1a;编写信号槽所绑定方法 第五步&#xff1a;编写第四步中处理请求的方法 客户端…

Simulink代码生成: 基本模块的使用

文章目录 1 引言2 模块使用实例2.1 In/Out模块2.2 Constant模块2.3 Scope/Display模块2.4 Ground/Terminator模块 3 总结 1 引言 本文中博主介绍Simulink中最简单最基础的模块&#xff0c;包括In/Out模块&#xff08;输入输出&#xff09;&#xff0c;Constant模块&#xff08…

Postman测试工具详细解读

目录 一、Postman的基本概念二、Postman的主要功能1. 请求构建2. 响应查看3. 断言与自动化测试4. 环境与变量5. 集合与文档化6. 与团队实时协作 三、Postman在API测试中的重要性1. 提高测试效率2. 保障API的稳定性3. 促进团队协作4. 生成文档与交流工具 四、Postman的使用技巧1…

CAS算法

CAS算法 1. CAS简介 CAS叫做CompareAndSwap&#xff0c;比较并交换&#xff0c;主要是通过处理器的指令来保证操作的原子性。 CAS基本概念 内存位置 (V)&#xff1a;需要进行CAS操作的内存地址。预期原值 (A)&#xff1a;期望该内存位置上的旧值。新值 (B)&#xff1a;如果旧…

VSCode python autopep8 格式化 长度设置

ctrl, 打开设置 > 搜索autopep8 > 找到Autopep8:Args > 添加项--max-line-length150

Java泛型的介绍和基本使用

什么是泛型 ​ 泛型就是将类型参数化&#xff0c;比如定义了一个栈&#xff0c;你必须在定义之前声明这个栈中存放的数据的类型&#xff0c;是int也好是double或者其他的引用数据类型也好&#xff0c;定义好了之后这个栈就无法用来存放其他类型的数据。如果这时候我们想要使用这…