昇思25天学习打卡营第9天|基于 MindSpore 实现 BERT 对话情绪识别

news2024/11/16 15:45:18

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

环境配置

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip show mindspore

基于 MindSpore 实现 BERT 对话情绪识别

模型简介

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。

下面以一个文本情感分类任务为例子来说明BERT模型的整个应用过程。

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)

数据集

这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

数据加载和数据预处理

新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。

import numpy as np

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    def tokenize_and_pad(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_id
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
dataset_train.get_col_names()
print(next(dataset_train.create_tuple_iterator()))

下载并解压了一个已标注的数据集,该数据集包含了情绪分类的标签和相应的文本。

数据集被加载并转换为适合模型训练的格式,包括文本的Tokenization和Padding操作。

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision

# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
%%time
# start training
trainer.run(tgt_columns="labels")

针对BERT模型进行微调(Fine-tuning)。

使用`BertForSequenceClassification.from_pretrained`从预训练模型中加载BERT模型,在这里,使用了一个已经预训练好的中文BERT模型`'bert-base-chinese'`。

使用`auto_mixed_precision`函数将模型转换为自动混合精度(Auto Mixed Precision),可以在训练过程中提高计算效率。

使用Adam优化器,并设置了学习率为2e-5,用于模型的参数更新。

定义了评价指标`metric`,在模型训练过程中用于评估模型的性能。

通过CheckpointCallbackBestModelCallback设置回调函数,CheckpointCallback 用于在训练期间保存模型的检查点,BestModelCallback 用于保存表现最好的模型并支持自动加载。

创建Traine`对象trainer,将模型、训练数据集、验证数据集、优化器、评价指标和回调函数传递给它,训练的epochs=5。通过调用trainer.run()方法,对BERT模型进行训练。tgt_columns="labels"指定训练目标列,即模型试图预测的标签。

%%time魔术命令计时,可以查看训练的运行时间。

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

使用Evaluator类对模型进行验证,评估在验证集上的表现。

模型推理

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")
def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)
from mindspore import Tensor

for label, text in dataset_infer:
    predict(text, label)

定义函数predict,用于对新的文本进行情绪分类。

遍历推理数据集,使用模型对每个文本进行预测,并与真实标签进行比较。

自定义推理数据集

自己输入推理数据,展示模型的泛化能力。

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

输入任意文本,输出预测的情绪类别。

应用实践案例使用MindSpore实现一个基于BERT的情绪识别模型,有助于理解BERT模型的构建和Fine-tuning过程。实现过程中,用MindSpore的Trainer、Evaluator类对模型进行训练、验证,使用MindSpore的AMP(自动混合精度)加速训练过程。使用MindSpore的CheckpointCallback、BestModelCallback来保存模型。最后用predict推理,对文本进行情绪分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1868361.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

valgrind调试c/c++内存问题:非法地址访问_内存泄漏_越界访问

1.valgrind命令 调试内存问题: valgrind --leak-checkfull 更新详细的显示: valgrind --leak-checkfull --show-leak-kindsall valgrind提示信息汇总 内存泄漏 lost in loss record 丢失记录 , 内存泄漏实例[[#2.内存泄漏–不完全释放内存|实例链接]]段错误 Process termina…

如何解决Oracle中PL Developer过期

如果长时间不使用PL Deveploer,再次打开有可能会出现以下页面: 上方页面说明此软件已经过期,有两种方法可以解决上述问题,第一种: 操作注册表: WinR 输入指令“regedit”打开注册表,出现下方页…

Camera开发-相机输出常用数据格式

作者简介: 一个平凡而乐于分享的小比特,中南民族大学通信工程专业研究生在读,研究方向无线联邦学习 擅长领域:驱动开发,嵌入式软件开发,BSP开发 作者主页:一个平凡而乐于分享的小比特的个人主页…

STM32HAL库--定时器篇(速记版)

STM32F429 有14个定时器,其中包括 2 个基本定时器(TIM6 和 TIM7)、 10 个通用定时器(TIM2~TIM5,TIM9~TIM14)、 2 个高级控制定时器(TIM1 和 TIM8)。 由上表知道:除了 TIM…

具备生成自签名文档证书能力的印章管理软件_电子骑缝章软件

最新版的e-章宝具体生成自签名文档证书的能力,这种证书可用内部文档发布的签名,文档一旦用证书签名并发布,具有不可抵赖性,阅读者也能确认所发布的文档是否是发布者发布的(即中途有没有被他人恶意修改过)&a…

成熟ICT测试系统与LabVIEW定制开发的比较

ICT(In-Circuit Test)测试系统是电子制造行业中用于电路板(PCB)组件检测的重要工具。市场上有许多成熟的ICT测试系统,如Keysight、Teradyne、SPEA等公司提供的商用解决方案。此外,LabVIEW作为一种强大的图形…

如何在ArcGIS Pro中提取行政区划

我们在《2024版有审图号的SHP行政区划》一文中,为你分享过全国省市县级的行政区划。 现在再为你分享一下,如何在ArcGIS Pro中提取目标范围行政区划的方法,你还可在以文末查看领取该行政区划数据的方法。 直接选择 在菜单栏上点击一下选择下…

如何使用AIGC降重工具轻松提升论文原创性?

论文查重和降重是确保学术成果原创性及学术诚信的关键步骤,直接影响我们的学业成果和毕业资格。传统的论文查重方法主要包括使用查重软件和个人自查,而论文降重通常涉及改写、使用同义词替换、内容的扩展和深化,以及正确的引用和注释等方式来…

定时推送邮件如何与自动化工作流程相结合?

定时推送邮件如何设置?怎么优化推送邮件的发送频率? 在现代商业环境中,自动化工作流程和定时推送邮件是提高效率和优化运营的重要工具。AoKSend将探讨如何将这两者结合起来,以实现更高效的工作流程和更好的客户沟通。 定时推送邮…

Windows应急响应靶机 - Web3

一、靶机介绍 应急响应靶机训练-Web3 前景需要:小苕在省护值守中,在灵机一动情况下把设备停掉了,甲方问:为什么要停设备?小苕说:我第六感告诉我,这机器可能被黑了。 这是他的服务器&#xff…

车载系统类 UI 风格品质非凡

车载系统类 UI 风格品质非凡

Visio文件编辑查看工具:Visio Viewer for Mac 激活版

Visio Viewer 软件通过该软件,用户可以在没有Visio软件的情况下查看使用Visio创建的绘图和图表,方便用户对复杂信息的可视化、分析和交流。Visio Viewer 2007是一个功能强大的软件,它可以帮助IT和商务专业人员轻松地可视化、分析和交流复杂信…

堆箱子00

题目链接 堆箱子 题目描述 注意点 将箱子堆起来时,下面箱子的宽度、高度和深度必须大于上面的箱子 解答思路 初始想到深度优先遍历,最后超时了参照题解使用动态规划,先将盒子从小到大进行排序,dp[i]存储的是到第i个箱子时堆箱…

Redis Stream Redisson Stream

目录 一、Redis Stream1.1 场景1:多个客户端可以同时接收到消息1.1.1 XADD - 向stream添加Entry(发消息 )1.1.2 XREAD - 从stream中读取Entry(收消息)1.1.3 XRANGE - 从stream指定区间读取Entry(收消息&…

ubuntu24 安装 docker

更新 apt-get sudo apt-get update 安装软件包 sudo apt-get install apt-transport-https ca-certificates curl software-properties-common 添加Docker的官方GPG密钥 curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 添加 Docker 仓库 …

不锈钢氩弧焊丝ER316L

说明:TG316L 是超低碳的不锈钢焊丝。熔敷金属耐蚀、耐热、抗裂性能优良。防腐蚀性能良好。 用途:用于石油化工、化肥设备等。也可用于要求焊接后不进行热处理的高Cr钢的焊接。

Dynamic-Link库 (动态链接库)

一、动态链接库(DLL)的基本概念 二、动态链接库的优势 三、动态链接库的实现方法 四、动态链接库的版本冲突问题(DLL地狱) 五、动态链接库与静态链接库的区别 一、动态链接库(DLL)的基本概念 定义&…

数据结构(Java):ArrayList的应用

1、引言 上一篇博客,已经为大家讲解了集合类ArrayList。 这篇博客,就来帮助大家学会使用ArrayList。 2、题1: 删除字符(热身题) 题目:给出str1和str2两个字符串,删除str1中出现的所有的str2…

kafka进阶核心原理详解:案例解析(第11天)

系列文章目录 kafka高级(重点) kafka核心概念汇总 kafka的数据位移offset Kafka的基准/压力测试 Kafka的分片副本机制 kafka如何保证数据不丢失 kafka的消息存储及查询机制 生产者数据分发策略 消费者负载均衡机制 kafka的监控工具:kafka-eagle…

基于Java的多元化智能选课系统-计算机毕业设计源码040909

摘 要 多元化智能选课系统使用Java语言的Springboot框架,采用MVVM模式进行开发,数据方面主要采用的是微软的Mysql关系型数据库来作为数据存储媒介,配合前台技术完成系统的开发。 论文主要论述了如何使用JAVA语言开发一个多元化智能选课系统&a…