昇思25天学习打卡营第13天|LLM-基于MindSpore实现的GPT对话情绪识别

news2024/9/19 17:15:46

打卡

目录

打卡

预装环境

流程简述

部分执行结果演示

词向量加载过程

模型结构

模型训练过程

模型预测过程

代码


预装环境

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install mindnlp
pip install jieba 
pip install spacy
pip install ftfy 

环境变量设置:HF_ENDPOINT=https://hf-mirror.com

流程简述

任务:用IMDB开源标注数据集,微调开源的预训练模型GPT,实现对话情绪识别。

1、数据集准备:IMDB数据集,从 https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz 下载数据集并按照7:3切分为训练和验证集。

2、加载TOKEN:用 mindnlp.transformers.GPTTokenizer 加载 tokenizer,并为其添加3个特殊的TOKEN("bos_token"、"eos_token"、"pad_token")

3、预处理训练、验证、测试数据集,包括将文本数据进行tokenizer,并根据设备类型对数据进行批处理和填充,其中训练集打散。

4、预训练模型微调设置:

  1. 用 mindnlp.transformers.GPTForSequenceClassification 加载预训练的 'openai-gpt' 模型,用于序列分类,配置指定模型的输出标签数量为2(通常是二分类任务)。
  2. 基于第二个步骤的 tokenzier,为预训练模型配置填充(padding)token ID。
  3.  为预训练模型配置调整token嵌入层的尺寸(+3,因为第二个步骤手动添加了3个特殊的TOKEN)。
  4. 定义模型优化器为 nn.Adam ,用于在训练过程中更新模型的参数,学习率设置为2e-5。
  5. 定义了一个准确率指标 ( metric=mindnlp._legacy.metrics.Accuracy() ),用于评估模型的性能。
  6. 定义2个回调函数,一个用于保存每个epoch的模型检查点,另一个用于保存最佳模型。

5、开始训练:创建训练器 (mindnlp._legacy.engine.Trainer)并训练,该训练器可以接收模型、训练数据集、评估数据集、评估指标、训练轮数、优化器、回调函数列表以及是否启用JIT编译的选项。

6、创建评估器并评估模型:创建评估器(mindnlp._legacy.engine.Evaluator),用于在测试数据集dataset_test上评估模型的性能。评估器使用了之前定义的预训练模型和评估指标metric

部分执行结果演示

词向量加载过程

看到词表大小为 40478,模型维度长512,右侧截断,一共有4种特殊的token.

模型结构

模型训练过程

loss降低到了0.2599,精度达到了 0.9421 。一般水平。 

模型预测过程

代码

import os
import numpy as np
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindnlp.dataset import load_dataset
from mindnlp.transformers import GPTTokenizer
from mindspore import nn
from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam


def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    """
    dataset: 待处理的数据集。
    tokenizer: 用于将文本转换为token的tokenizer对象.
    max_seq_len: 文本序列的最大长度,默认为512。
    batch_size: 批处理的大小,默认为4。
    shuffle: 是否对数据集进行随机打乱,默认为False。
    """
    ## 判断当前设备目标是否为Ascend(华为的昇腾处理器)。如果是,则is_ascend为True。
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        # 定义了一个内部函数tokenize,用于将文本转换为tokens。
        # 根据is_ascend的值来决定是否启用填充策略padding。函数返回token的input_ids和attention_mask。
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        ## shuffle参数为True,则对数据集进行打乱 
        dataset = dataset.shuffle(batch_size)

    # map dataset
    ## 用map操作对数据集中的每个文本进行tokenization处理,将文本列text映射为input_ids和attention_mask。
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    ## 将标签列label的数据类型转换为MindSpore的int32类型,并重命名为labels。
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    #
    # batch dataset
    ## 根据设备类型将数据集分批处理。如果是在Ascend设备上,直接使用batch操作;否则,使用padded_batch操作来确保每个批次中的序列长度一致,不足部分使用pad token填充。
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset


imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
print("imdb_train data_size: ", imdb_train.get_dataset_size())
print("imdb_test data_size: ", imdb_test.get_dataset_size())

# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
print("openai-gpt GPTTokenizer: ", gpt_tokenizer)

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

## 训练集切分、处理
# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

### 预训练模型加载
# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)

### 开始训练
trainer.run(tgt_columns="labels")

### 开始评估
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933707.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最优控制问题中的折扣因子

本文探讨了在线性二次型调节器&#xff08;LQR&#xff09;中引入折扣因子的重要性和方法。通过引入折扣因子&#xff0c;性能指标在无穷时间上的积分得以收敛&#xff0c;同时反映了现实问题中未来成本重要性递减的现象&#xff08;强化学习重要概念&#xff09;。详细推导了带…

AI智能名片在Web 3.0技术栈中的应用与前景研究

摘要&#xff1a;在Web 3.0的浪潮中&#xff0c;AI智能名片作为一种创新的数字工具&#xff0c;正逐步渗透到商业交流的各个层面。本文深入探讨了AI智能名片在Web 3.0技术栈中的具体应用&#xff0c;详细分析了其背后的技术支撑、应用场景、优势以及面临的挑战。通过案例分析、…

Postgresql主键自增的方法

Postgresql主键自增的方法 一.方法&#xff08;一&#xff09; 使用 serial PRIMARY KEY 插入数据 二.方法&#xff08;二&#xff09; &#x1f388;边走、边悟&#x1f388;迟早会好 一.方法&#xff08;一&#xff09; 使用 serial PRIMARY KEY 建表语句如下&#xf…

E - Count Arithmetic Subsequences (abc362)

题意&#xff1a;给定一个数组A&#xff0c;求长度为k的A的子序列中等差数列的个数。模为998244353。如果两个子序列取自不同的位置&#xff0c;即使它们作为序列是相等的&#xff0c;也是有区别的。 分析&#xff1a;设dp[i][k][d]为以i为结尾公差为d的长度为k的个数。k的范围…

Beyond Compare 4

工具推荐: —该版由zd423基于官方简体中文版便携式制作&#xff0c;完全绿色便携—原生绿色便携化&#xff08;无资源管理器扩展模块、数据保存至根目录&#xff09;—集成专业版永久授权密钥&#xff0c;彻底去主界面首页下方网络资源横幅—完全禁止自动检测升级&#xff0c;…

Java反射和动态代理用法(附10道练习题)

目录 一、什么是反射二、反射的核心接口和类三、测试代码 Bean 类和目录结构Person 类代码目录结构 四、反射的用法1. 获取 Class 对象2. 获取构造方法 Constructor 并使用3. 获取成员变量 Field 并使用4. 获取成员方法 Method 并使用 五、动态代理与反射1. 动态代理三要素&…

算法学习笔记:贪心算法

贪心算法&#xff08;又称贪婪算法&#xff09;是指&#xff0c;在对问题求解时&#xff0c;总是做出当前看来是最好的选择&#xff0c;就能得到问题的答案。 虽然贪心算法不是对所有问题都能得到整体最优解&#xff0c;但对范围相当广的许多问题它能产生整体最优解。在一些情况…

电机线电流与转差率曲线[进行中...]

1.电机T型等效电路模型 1.1 Python代码 - 考虑转差率为负 import numpy as np import matplotlib.pyplot as plt # 设置已知参数值 rm 11.421 lm 553.9e-3 r2 7.553 l2 42.90e-3 freq_in 50# 设置频率值范围和步长 s np.linspace(-0.05, 0.05, 1000) im 380/(rm(lml2)…

PyQt弹出式抽屉窗口

代码 from enum import Enum import sys from PyQt5.Qt import *class PopupOrientation(Enum):LEFT 0TOP 1RIGHT 2BOTTOM 3class PopupMaskDialogBase(QDialog):"""带有蒙版的对话框基类"""def __init__(self, parentNone):super().__init…

FedProto:跨异构客户端的联邦原型学习(论文阅读)

题目&#xff1a;FedProto: Federated Prototype Learning across Heterogeneous Clients 网址&#xff1a;http://arxiv.org/abs/2105.00243 摘要 在联邦学习(FL)中&#xff0c;当客户端知识在梯度空间中聚集时&#xff0c;客户端间的异构性通常会影响优化的收敛和泛化性能。…

携手AI人才 共赢智算未来丨誉天人工智能AI大模型首期班火热报名中

在数字化浪潮汹涌澎湃的今天&#xff0c;人工智能已成为推动社会进步与产业升级的关键力量。 回顾人工智能历史&#xff0c;自1956年诞生以来&#xff0c;历经三次发展热潮和两次低谷期。五十年代符号主义和逻辑推理的出现标志着人工智能的诞生&#xff0c;引发第一次发展浪潮&…

自动驾驶AVM环视算法–全景和标定全功能算法实现和exe测试demo

参考&#xff1a;全景和标定全功能算法实现和exe测试demo-金书世界 1、测试环境 opencv310vs2022 2、使用的编程语言 c和c 3、测试的demo的获取 更新&#xff1a;测试的exe程序&#xff0c;无需解压码就可以体验算法测试效果 百度网盘&#xff1a; 链接&#xff1a;http…

开发一个自己的chrom插件

开发一个自己的chrom插件 一、创建一个文件夹 二、配置文件manifest.json 创建名字为&#xff1a;manifest.json的配置文件&#xff0c;模板如下&#xff1a; {"manifest_version": 3,"name": "Hello World Extension","version": …

Go 语言 UUID 库 google/uuid 源码解析:UUID version7 的实现

google/uuid 库地址 建议阅读内容 在阅读此篇文章之前&#xff0c;建议先了解 UUIDv1 的构成、UUIDv4 的 API 以及掌握位运算。 了解 UUIDv1 的构成可以参考Go 语言 UUID 库 google/uuid 源码解析&#xff1a;UUID version1 的实现 或 RFC 9562。 了解 UUIDv4 的 API 可以看…

【数据结构】非线性表----树详解

树是一种非线性结构&#xff0c;它是由**n&#xff08;n>0&#xff09;**个有限结点组成一个具有层次关系的集合。具有层次关系则说明它的结构不再是线性表那样一对一&#xff0c;而是一对多的关系&#xff1b;随着层数的增加&#xff0c;每一层的元素个数也在不断变化&…

算法——双指针(day3)

611.有效三角形的个数 611. 有效三角形的个数 - 力扣&#xff08;LeetCode&#xff09; 题目解析&#xff1a; 三角形的判定很简单&#xff0c;任意两边之和大于第三边即可。按照正常情况&#xff0c;我们得判断3次才可以确认是否构成三角形。 因为c在本来就是最大的情况下与…

安全测试必学神器 --BurpSuite 安装及使用实操

BurpSuite是一款功能强大的集成化安全测试工具&#xff0c;专门用于攻击和测试Web应用程序的安全性。适合安全测试、渗透测试和开发人员使用。本篇文章基于BurpSuite安装及常用实操做详解&#xff0c;如果你是一名安全测试初学者&#xff0c;会大有收获&#xff01; 一、BurpS…

C++ Qt 登录界面 Login

效果: 核心代码: #include "simpleapp.h" #include "ui_simpleapp.h" #include <QMessageBox>SimpleApp::SimpleApp(QWidget *parent): QMainWindow(parent), ui(new Ui::SimpleApp) {ui->setupUi(this); }SimpleApp::~SimpleApp() {delete ui; …

ROS、pix4、gazebo、qgc仿真ubuntu20.04

一、ubuntu、ros安装教程比较多&#xff0c;此文章不做详细讲解。该文章基于ubuntu20.04系统。 pix4参考地址&#xff1a;https://docs.px4.io/main/zh/index.html 二、安装pix4 1. git clone https://github.com/PX4/PX4-Autopilot.git --recursive 2. bash ./PX4-Autopilot…

MQTT服务端EMQX开源版安装和客户端MQTTX介绍

一、EMQX是什么 EMQX 是一款开源的大规模分布式 MQTT 消息服务器&#xff0c;功能丰富&#xff0c;专为物联网和实时通信应用而设计。EMQX 5.0 单集群支持 MQTT 并发连接数高达 1 亿条&#xff0c;单服务器的传输与处理吞吐量可达每秒百万级 MQTT 消息&#xff0c;同时保证毫秒…