昇思25天学习打卡营第13天|BERT

news2025/1/23 17:43:57

一、简介:

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

二、环境准备:

在导入包之前,首先咱们要确定已经在实验的环境中安装了mindspore和mindnlp:mindspore的安装可以参考昇思25天学习打卡营第1天|快速入门-CSDN博客,mindnlp的安装则直接运行下面的命令即可:

pip install mindnlp==0.3.1

安装完成之后,导入我们下面训练需要的包:

import os
import time

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

三、数据集:

1、数据集下载:

这里使用的是来自于百度飞桨的一份已标注的、经过分词预处理的机器人聊天数据集。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
tar xvf emotion_detection.tar.gz

2、数据预处理:

新建 process_dataset 函数用于数据加载和数据预处理:

import numpy as np

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    def tokenize_and_pad(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理,也就是说数据预处理阶段需要采用固定的数据形状。

定义bert的分词器:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_id

将数据集划分为训练集、开发集和测试集:

class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)

dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)

dataset_train.get_col_names()
print(next(dataset_train.create_tuple_iterator()))

四、模型构建:

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision

# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])

%%time
# start training
trainer.run(tgt_columns="labels")

五、模型验证:

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

六、模型推理:

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)

from mindspore import Tensor

for label, text in dataset_infer:
    predict(text, label)

输入一句话测试一下(doge):

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1883745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

fastzdp_login的第一次构建

概述 为了方便能够快捷的实现fastapi实现登录相关功能代码开发,决定开发一个开源的fastapi组件库,想了很多个名字,在检查pypi的时候发现都被占用了,所以最终决定使用fastzdp_login这个名字。 fast代表的时fastapi。zdp代表的是张…

CesiumJS【Basic】- #041 绘制纹理线(Entity方式)- 需要自定义着色器

文章目录 绘制纹理线(Entity方式)- 需要自定义着色器1 目标2 代码2.1 main.ts3 资源文件绘制纹理线(Entity方式)- 需要自定义着色器 1 目标 使用Entity方式绘制纹理线 2 代码 2.1 main.ts import * as Cesium from cesium;const viewer = new Cesium.Viewer

工程技术类SCI,低分快刊首选期刊,无版面费!

1、期刊概况 【期刊简介】IF:1.0-2.0,JCR2区,中科院4区; 【检索情况】SCIE在检 【版面类型】正刊,仅少量版面; 【出刊频率】年刊 2、征稿范围 本刊主要是发表有关能源转型和可再生能源需求相关的研究文…

windows USB 设备驱动开发-Host端和Device端

Windows 中的 USB 宿主端驱动程序 下图显示了适用于 Windows 的 USB 驱动程序堆栈的体系结构框图。 此图显示了适用于 USB 2.0 和 USB 3.0 的单独 USB 驱动程序堆栈。 当设备连接到 xHCI 控制器时,Windows 加载 USB 3.0 驱动程序堆栈。 Windows 为连接到 EHCI、OHC…

MySQL进阶:存储过程和函数

存储过程和函数 1. 简介2. 创建存储过程使用MySQL工作台创建存储过程 3. 删除存储过程4. 参数带默认值的参数参数验证输出参数 5. 变量6. 函数7. 其他约定 1. 简介 存储过程三大作用: 储存和管理SQL代码(置于数据库中,与应用层分离&#xf…

Redis集群-主从复制、哨兵

●主从复制:主从复制是高可用Redis的基础,哨兵和集群都是在主从复制基础上实现高可用的。 主从复制主要实现了数据的多机备份,以及对于读操作的负载均衡和简单的故障恢复。缺陷:故障 恢复无法自动化;写操作无法负载均…

小白学习手册:轻松理解MQ消息队列

目录 # 开篇 RabbitMQ介绍 通讯概念 1. 初始MQ及类型 2. MQ的架构 2.1 RabbitMQ的结构和概念 2.2 RabbitMQ消息流示意图 3. MQ下载使用 3.1 Docker下载MQ参考 3.2 进入RabbitMQ # 开篇 MessagesQueue 是一个抽象概念,用于描述消息队列系统的一般特性和功能…

Ubuntu(通用)—网络加固—ufw+防DNS污染+ARP绑定

1. ufw sudo ufw default deny incoming sudo ufw deny in from any to any # sudo ufw allow from any to any port 5353 protocol udp sudo ufw enable # 启动开机自启 # sudo ufw reload 更改后的操作2. 防ARP欺骗 华为云教程 arp -d删除dns记录arp -a显示arp表 ipconfi…

windows USB设备驱动开发-双角色驱动

在USB的通讯协议中,规定发起连接的一方为主机(Host),接受连接的一方为设备,这可以用U盘插入电脑举个例子,当U盘插入电脑后,电脑这边主动发起查询和枚举,U盘被动响应查询和数据存取。 USB 双角色驱动程序堆…

《RepViT Revisiting Mobile CNN From ViT Perspective》

期刊:CVPR 年份:2024 代码:http://https: //github.com/THU-MIG/RepViT 摘要 最近,与轻量级卷积神经网络(CNN)相比,轻量级视觉Transformer(ViTs)在资源受限的移动设备上表现出了更高的性能和更低的延迟。研究人员已…

大聪明教你学Java | 深入浅出聊 RocketMQ

前言 🍊作者简介: 不肯过江东丶,一个来自二线城市的程序员,致力于用“猥琐”办法解决繁琐问题,让复杂的问题变得通俗易懂。 🍊支持作者: 点赞👍、关注💖、留言&#x1f4…

技术派Spring事件监听机制及原理

Spring事件监听机制是Spring框架中的一种重要技术,允许组件之间进行松耦合通信。通过使用事件监听机制,应用程序的各个组件可以在其他组件不直接引用的情况下,相互发送和接受消息。 需求 在技术派中有这样一个需求,当发布文章或…

旋转变压器软件解码simulink仿真

1.介绍 旋转变压器是一种精密的位置、速度检测装置,尤其适用于高温、严寒、潮湿、高速、振动等环境恶劣、旋转编码器无法正常工作的场合。旋转变压器在使用时并不能直接提供角度或位置信息,需要特殊的激励信号和解调、计算措施,才能将旋转变压…

【RT摩拳擦掌】基于RT106L/S语音识别的百度云控制系统

【RT摩拳擦掌】基于RT106L/S语音识别的百度云控制系统 一 文档简介二 平台构建2.1 使用平台2.2 百度智能云2.2.1 物联网核心套件2.2.2 在线语音合成 2.3 playback语音数据准备与烧录2.4 开机语音准备与添加2.5 唤醒词识别词命令准备与添加 三 代码准备3.1 sln-local/2-iot 代码…

2025第13届常州国际工业装备博览会招商全面启动

常州智造 装备中国|2025第13届常州国际工业装备博览会招商全面启动 2025第13届常州国际工业装备博览会将于2025年4月11-13日在常州西太湖国际博览中心盛大举行!目前,各项筹备工作正稳步推进。 60000平米的超大规模、800多家国内外工业装备制造名企将云集…

Unity Shader 软粒子

Unity Shader 软粒子 前言项目Shader连连看项目渲染管线设置 鸣谢 前言 当场景有点单调的时候,就需要一些粒子点缀,此时软粒子就可以发挥作用了。 使用软粒子与未使用软粒子对比图 项目 Shader连连看 这里插播一点,可以用Vertex Color与…

KUKA机器人不同运行方式

KUKA机器人有以下四种运行方式: 1、手动慢速运行(T1) 2、手动快速运行(T2) 3、自动运行(AUT) 4、外部自动运行(AUT EXT) 将示教器上的钥匙向右旋转,就会…

【数据结构之B树的讲解】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

专题五:Spring源码之初始化容器上下文

上一篇我们通过如下一段基础代码作为切入点,最终找到核心的处理是refresh方法,从今天开始正式进入refresh方法的解读。 public class Main {public static void main(String[] args) {ApplicationContext context new ClassPathXmlApplicationContext(…

Study--Oracle-05-Oracler体系结构

一、oracle 体系概览 Oracle数据库的体系结构通常包括以下主要组件: 1、实例(Instance):运行数据库的软件环境,包括内存结构(SGA)和进程结构(Background Processes and User Proces…