百度ERNIE 3.0——中文情感分析实战

news2024/11/24 17:44:10

目录

  • 前言
  • 一、百度ERNIE 3.0
  • 二、使用ERNIE 3.0中文预训练模型进行句子级别的情感分析
    • 2-1、环境
    • 2-2、数据集加载
    • 2-3、加载预训练模型和分词器
    • 2-4、基于预训练模型的数据处理
    • 2-5、数据训练和评估
    • 2-6、模型验证
    • 2-7、情感分析结果的预测以及保存
  • 三、自定义个人案例
    • 3-1、如何自定义数据集
  • 总结


前言

ERNIE(Enhanced Representation through kNowledge IntEgration)是百度研发的一种基于深度学习的预训练语言模型。它通过大规模的无监督学习从大量文本数据中学习语义和知识表示。

一、百度ERNIE 3.0

百度与鹏城自然语言处理联合实验室重磅发布鹏城-百度·文心(模型版本号:ERNIE 3.0 Titan),该模型是全球首个知识增强的千亿AI大模型,也是目前为止全球最大的中文单体模型。

基于业界领先的鹏城实验室算力系统“鹏城云脑Ⅱ”和百度飞桨深度学习平台强强练手,鹏城-百度·文心模型参数规模超越GPT-3达到2600亿,致力于解决传统AI模型泛化性差、强依赖于昂贵的人工标注数据、落地成本高等应用难题,降低AI开发与应用门槛。目前该模型在60多项任务取得最好效果,并大幅刷新小样本学习任务基准。

鹏城-百度·文心基于百度知识增强大模型ERNIE 3.0全新升级,模型参数规模达到2600亿,相对GPT-3的参数量提升50%。

在算法框架上,该模型沿袭了ERNIE 3.0的海量无监督文本与大规模知识图谱的平行预训练算法,模型结构上使用兼顾语言理解与语言生成的统一预训练框架。为提升模型语言理解与生成能力,研究团队进一步设计了可控和可信学习算法。

在训练上,结合百度飞桨自适应大规模分布式训练技术和“鹏城云脑Ⅱ”算力系统,解决了超大模型训练中多个公认的技术难题。在应用上,首创大模型在线蒸馏技术,大幅降低了大模型落地成本。

以下为百度-文心模型结构图
在这里插入图片描述

二、使用ERNIE 3.0中文预训练模型进行句子级别的情感分析

2-1、环境

pip install paddle -i https://mirror.baidu.com/pypi/simple
pip install paddlenlp -i https://mirror.baidu.com/pypi/simple

2-2、数据集加载

chnsenticorp: ChnSentiCorp是中文句子级情感分类数据集,包含酒店、笔记本电脑和书籍的网购评论

import os
import paddle
import paddlenlp

#加载中文评论情感分析语料数据集ChnSentiCorp
from paddlenlp.datasets import load_dataset

# 分割训练集验证集和测试集
train_ds, dev_ds, test_ds = load_dataset("chnsenticorp", splits=["train", "dev", "test"])
print("训练集样例:", train_ds[0])

输出
训练集样例: {‘text’: ‘选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般’, ‘label’: 1, ‘qid’: ‘’}

2-3、加载预训练模型和分词器

加载预训练模型和分词器: PaddleNLP中Auto模块(包括AutoModel, AutoTokenizer及各种下游任务类)提供了方便易用的接口,无需指定模型类别,即可调用不同网络结构的预训练模型。PaddleNLP的预训练模型可以很容易地通过from_pretrained()方法加载,Transformer预训练模型汇总包含了40多个主流预训练模型,500多个模型权重。

AutoModelForSequenceClassification可用于句子级情感分析和目标级情感分析任务,通过预训练模型获取输入文本的表示,之后将文本表示进行分类。PaddleNLP已经实现了ERNIE 3.0预训练模型,可以通过一行代码实现ERNIE 3.0预训练模型和分词器的加载。

from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "ernie-3.0-medium-zh"
# 预训练模型加载
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=len(train_ds.label_list))
# 分词器加载
tokenizer = AutoTokenizer.from_pretrained(model_name)

2-4、基于预训练模型的数据处理

Dataset中通常为原始数据,需要经过一定的数据处理并进行采样组batch。

  • 通过Dataset的map函数,使用分词器将数据集从原始文本处理成模型的输入。
  • 定义paddle.io.BatchSampler和collate_fn构建 paddle.io.DataLoader。

实际训练中,根据显存大小调整批大小batch_size和文本最大长度max_seq_length。

import functools
import numpy as np

from paddle.io import DataLoader, BatchSampler
from paddlenlp.data import DataCollatorWithPadding

# 数据预处理函数,利用分词器将文本转化为整数序列
def preprocess_function(examples, tokenizer, max_seq_length, is_test=False):
	# 使用分词器处理训练集,给定最大长度。
    result = tokenizer(text=examples["text"], max_seq_len=max_seq_length)
    if not is_test:
    # 如果不是测试集的话,赋予标签,否则不给标签。
        result["labels"] = examples["label"]
    return result

# 映射到训练集和验证集上。
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128)
train_ds = train_ds.map(trans_func)
dev_ds = dev_ds.map(trans_func)

# collate_fn函数构造,将不同长度序列充到批中数据的最大长度,再将数据堆叠
collate_fn = DataCollatorWithPadding(tokenizer)

# 定义BatchSampler,选择批大小和是否随机乱序,进行DataLoader
train_batch_sampler = BatchSampler(train_ds, batch_size=32, shuffle=True)
dev_batch_sampler = BatchSampler(dev_ds, batch_size=64, shuffle=False)
train_data_loader = DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=collate_fn)
dev_data_loader = DataLoader(dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=collate_fn)

2-5、数据训练和评估

数据训练和评估: 定义训练所需的优化器、损失函数、评论指标等,就可以开始进行预训练模型的微调任务。

# Adam优化器、交叉熵损失函数、accuracy评价指标
optimizer = paddle.optimizer.AdamW(learning_rate=2e-5, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()

开始训练

# 开始训练
import time
import paddle.nn.functional as F

from eval import evaluate

epochs = 5 # 训练轮次
ckpt_dir = "ernie_ckpt" #训练过程中保存模型参数的文件夹
best_acc = 0
best_step = 0
global_step = 0 #迭代次数
tic_train = time.time()
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        input_ids, token_type_ids, labels = batch['input_ids'], batch['token_type_ids'], batch['labels']

        # 计算模型输出、损失函数值、分类概率值、准确率
        logits = model(input_ids, token_type_ids)
        loss = criterion(logits, labels)
        probs = F.softmax(logits, axis=1)
        correct = metric.compute(probs, labels)
        metric.update(correct)
        acc = metric.accumulate()

        # 每迭代10次,打印损失函数值、准确率、计算速度
        global_step += 1
        if global_step % 10 == 0:
            print(
                "global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"
                % (global_step, epoch, step, loss, acc,
                    10 / (time.time() - tic_train)))
            tic_train = time.time()
        
        # 反向梯度回传,更新参数
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()

        # 每迭代100次,评估当前训练的模型、保存当前模型参数和分词器的词表等
        if global_step % 100 == 0:
            save_dir = ckpt_dir
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            print(global_step, end=' ')
            acc_eval = evaluate(model, criterion, metric, dev_data_loader)
            if acc_eval > best_acc:
                best_acc = acc_eval
                best_step = global_step

                model.save_pretrained(save_dir)
                tokenizer.save_pretrained(save_dir)

2-6、模型验证

from eval import evaluate

# 加载ERNIR 3.0最佳模型参数
params_path = 'ernie_ckpt/model_state.pdparams'
state_dict = paddle.load(params_path)
model.set_dict(state_dict)

# 也可以选择加载预先训练好的模型参数结果查看模型训练结果
# model.set_dict(paddle.load('ernie_ckpt_trained/model_state.pdparams'))

print('ERNIE 3.0-Medium 在ChnSentiCorp的dev集表现', end=' ')
eval_acc = evaluate(model, criterion, metric, dev_data_loader)

2-7、情感分析结果的预测以及保存

测试集数据预处理

# 测试集数据预处理,利用分词器将文本转化为整数序列
trans_func_test = functools.partial(preprocess_function, tokenizer=tokenizer, max_seq_length=128, is_test=True)
test_ds_trans = test_ds.map(trans_func_test)

# 进行采样组batch
collate_fn_test = DataCollatorWithPadding(tokenizer)
test_batch_sampler = BatchSampler(test_ds_trans, batch_size=32, shuffle=False)
test_data_loader = DataLoader(dataset=test_ds_trans, batch_sampler=test_batch_sampler, collate_fn=collate_fn_test)

模型预测分类结果

# 模型预测分类结果
import paddle.nn.functional as F

label_map = {0: '负面', 1: '正面'}
results = []
model.eval()
for batch in test_data_loader:
    input_ids, token_type_ids = batch['input_ids'], batch['token_type_ids']
    logits = model(batch['input_ids'], batch['token_type_ids'])
    probs = F.softmax(logits, axis=-1)
    idx = paddle.argmax(probs, axis=1).numpy()
    idx = idx.tolist()
    preds = [label_map[i] for i in idx]
    results.extend(preds)

预测结果写入excel

# 存储ChnSentiCorp预测结果  
test_ds = load_dataset("chnsenticorp", splits=["test"]) 

res_dir = "./results"
if not os.path.exists(res_dir):
    os.makedirs(res_dir)
with open(os.path.join(res_dir, "ChnSentiCorp.tsv"), 'w', encoding="utf8") as f:
    f.write("qid\ttext\tprediction\n")
    for i, pred in enumerate(results):
        f.write(test_ds[i]['qid']+"\t"+test_ds[i]['text']+"\t"+pred+"\n")

三、自定义个人案例

3-1、如何自定义数据集

众所周知:加载入模型的数据集格式为<class ‘paddlenlp.datasets.dataset.MapDataset’>,在导入个人数据集时,需要首先转换一下数据格式。

from paddlenlp.datasets import load_dataset

# 读取数据,将数据集拆解、重组。
def read(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        # 跳过列名
        next(f)
        for line in f: 
            words, labels = line.strip('\n').split('\t')
            words = words.split('\002')
            labels = labels.split('\002')
            yield {'tokens': words, 'labels': labels}

# data_path为read()方法的参数,构建成 MapDataset格式,MapDataset 在绝大多数时候都可以满足要求。
map_ds = load_dataset(read, data_path='train.txt',lazy=False)

# 之后使用train_test_split分割数据集以及使用MapDataset来转换数据格式

# lazy=True参数将数据构建成IterDataset格式,一般只有在数据集过于庞大无法一次性加载进内存的时候我们才考虑使用 IterDataset 。
iter_ds = load_dataset(read, data_path='train.txt',lazy=True)

参考文章:
解析全球最大中文单体模型鹏城-百度·文心技术细节.
源代码地址.
如何自定义数据集.

总结

导入到本地的过程中,发现其中的一个包from eval import evaluate无法导入,盲猜这是paddle的本地类,如果是这样的话,那么只能在百度的平台训练,之后导出模型到本地进行部署了。🤣🤣

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/752395.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库语句

文章目录 数据库语句SQL语言分类MySQL中6种常见的约束1.DDL1.1 创建新的数据库1.2 创建新的表1.3 删除指定的数据表 2.DML管理表2.1 插入数据2.2 修改&#xff08;更新数据&#xff09;2.3 在数据表中删除指定的数据 3.DQL查询数据记录4.DCL4.1 修改表名和表结构4.2 扩展表结构…

Node.js详解(四):连接MongoDB

文章目录 一、安装MongoDB访问驱动二、连接数据库三、添加数据四、添加多条数据五、修改数据六、查询数据1、查询单条记录2、查询多条记录 七、删除数据八、完整示例代码1、路由 Api 接口&#xff1a;2、运行结果&#xff1a; MongoDB 对许多平台都提供驱动可以访问数据库&…

前端vue入门(纯代码)31_route-link的repalce属性

如果夜里十二点我还回你消息&#xff0c;那么意味着什么&#xff0c;意味着我是真的很喜欢玩手机。 【29.Vue Router--router-link的replace属性】 <router-link>的replace属性 replace属性的作用是&#xff1a;控制路由跳转时操作浏览器历史记录的模式。【当我们从一个…

城市内涝监测设备-内涝监测终端

随着我国城市化发展迅速、全球极端天气现象频发带来的暴雨天气增多&#xff0c;汛期暴雨引发道路低洼处、立交桥底、隧道、涵洞等城市 内涝时有发生&#xff0c;甚至开启城市看海模式&#xff0c;对交通、电力、通讯等造成了严重的影响和破坏&#xff0c;严重时造成人民生命、财…

放弃使用Merge,开心拥抱Rebase!

1. 引言 大家好&#xff0c;我是比特桃。Git 作为现在最流行的版本管理工具&#xff0c;想必大家在开发过程中都会使用。由于 Git 中很多操作默认是采用 Merge 进行的&#xff0c;并且相对也不容易出错&#xff0c;所以很多人都会使用 Merge 来进行合并代码。但Rebase 作为 Gi…

官宣!菁英实习生计划启动,百度大模型团队诚邀你的加入

大模型风起&#xff0c;人才需求涌 在这个充满变革的时代&#xff0c;我们见证了AI的快速发展。从“阿尔法狗”击败世界围棋冠军&#xff0c;到生成式大模型以势不可挡的浪潮席卷全球&#xff0c;掀起人类社会一场眩晕式变革。新技术、新工具、新的生产力正在改变经济活动各环…

小红书运营推广

大家好&#xff0c;我是权知星球&#xff0c;今天给大家分享一下小红手运营推广的一些经验&#xff0c;希望能给大家运营小红书带来一些帮助。 这篇文章虽然是基于小红书的运营写的&#xff0c;但新媒体的东西都是相通的&#xff0c;相信这篇文章对运营其他媒体的同学也会有所…

抓包工具Fiddler:fiddler的介绍及安装

Fiddler简介 Fiddler是比较好用的web代理调试工具之一&#xff0c;它能记录并检查所有客户端与服务端的HTTP/HTTPS请求&#xff0c;能够设置断点&#xff0c;篡改及伪造Request/Response的数据&#xff0c;修改hosts&#xff0c;限制网速&#xff0c;http请求性能统计&#xff…

MyBatis源码分析_Executor组件及3个火枪手(6)

目录 1. 前提 2. Executor执行器 3. 总结 4. 三个火枪手 5. StatementHandler生成Statement 6. ParameterHandler 参数解析 7. BoundSql的数据结构 8. 总结 1. 前提 在Mybatis源码分析_事务管理器 &#xff08;5&#xff09;_chen_yao_kerr的博客-CSDN博客一文中&…

网关微服务简单配置

导入一下网关的基本依赖 <dependencies><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId></dependency><dependency><groupId>com.alibaba.cloud<…

直播商城系统源码的威力:开启直播购物新时代

随着科技的不断进步和人们对互动性购物体验的追求&#xff0c;直播购物正成为电商行业的热门趋势。直播商城系统源码的威力在这一潮流中愈发显现&#xff0c;为商家和消费者提供了无限的机会和便利。 下面是一个简单的示例&#xff0c;展示了如何利用直播商城系统源码创建一个…

深度强化学习:深度解析 MADDPG

深度强化学习:深度解析 MADDPG 学习强化学习,码代码的能力必须要出众,要快速入门强化学习 搞清楚其中真正的原理,读源码是一个最简单的最直接的方式。最近创建了一系列该类型文章,希望对大家有多帮助。 另外,我会将所有的文章及所做的一些简单项目,放在 1.MADDPG 原理…

JS脚本 - 批量给所有指定标签追加Class属性

JS脚本 - 批量给所有指定标签追加Class属性 前言一. 脚本二. 测试运行 前言 公司里我们有个应用引入了UBT埋点&#xff0c;记录了页面上所有的点击操作以及对应的点击按钮。但是我们看下来发现&#xff0c;我们需要给每个按钮加一个唯一标识做区分&#xff0c;并且这个ID是给U…

选读SQL经典实例笔记07_日期处理(下)

1. 一个季度的开始日期和结束日期 1.1. 以yyyyq格式&#xff08;前面4位是年份&#xff0c;最后1位是季度序号&#xff09;给出了年份和季度序号 1.2. DB2 1.2.1. sql select (q_end-2 month) q_start,(q_end1 month)-1 day q_endfrom (select date(substr(cast(yrq as c…

Linux系统编程(信号处理 sigacation函数和sigqueue函数 )

文章目录 前言一、sigaction二、sigqueue函数三、代码示例总结 前言 本篇文章我们来介绍一下sigacation函数和sigqueue函数。 一、sigaction sigaction 是一个用于设置和检查信号处理程序的函数。它允许我们指定信号的处理方式&#xff0c;包括指定一个函数作为信号处理程序…

AsyncImage, BackgroundMaterials, TextSelection, ButtonStyles 的使用

1. AsyncImage 异步加载图片 1.1 实现 /*case empty -> No image is loaded.case success(Image) -> An image succesfully loaded.case failure(Error) -> An image failed to load with an error.*/ /// iOS 15 开始的 API 新特性示例 /// 异步加载图片 struct As…

Ae 效果:CC Plastic

风格化/CC Plastic Stylize/CC Plastic CC Plastic&#xff08;CC 塑料&#xff09;效果用于创建具有塑料质感的图像或视频效果&#xff0c;它模拟了塑料材质的外观特性&#xff0c;包括光照反射、表面凹凸以及光泽效果等。 ◆ ◆ ◆ 效果属性说明 Surface Bump 表面凹凸 通过…

IoT 场景下 TDengine 与老牌时序数据库怎么选?看看这份TSBS报告

上周一&#xff0c;TDengine 正式发布了 IoT 场景下基于 TSBS 的时序数据库&#xff08;Time Series Database&#xff0c;TSDB&#xff09;性能基准测试报告。该报告模拟虚拟货运公司车队中一组卡车的时序数据&#xff0c;预设了五种卡车规模场景&#xff0c;在相同的 AWS 云环…

[Lesson 01] TiDB数据库架构概述

目录 一 章节目标 二 TiDB 体系结构 1 TiDB Server 2.1 TiKV 2.2 TiFlash 3 PD 参考 一 章节目标 理解TiDB数据库整体架构了解TiDB Server TiKV tiFlash 和 PD的主要功能 二 TiDB 体系结构 了解这些体系结构是如何实现TiDB的核心功能的 1 TiDB Server TiDB Serve…

记录--你知道Vue中的Scoped css原理么?

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 追忆Scoped 偶然想起了一次面试&#xff0c;二面整体都聊完了&#xff0c;该做的算法题都做出来了&#xff0c;该背的八股文也背的差不多了&#xff0c;面试官频频点头&#xff0c;似乎对我的基础和项…