Bert中文文本分类

news2024/12/29 7:35:58

这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main

我使用的训练环境是

pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;

一、准备训练数据

1.1 准备中文文本分类任务的训练数据

这里Demo数据如下:

各银行信用卡挂失费迥异 北京银行收费最高    0
莫泰酒店流拍 大摩叫价或降至6亿美元 4
乌兹别克斯坦议会立法院主席获连任   6
德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人    7
辉立证券给予广汽集团持有评级 2
图文-业余希望赛海南站第二轮 球场的菠萝蜜  7
陆毅鲍蕾:近乎完美的爱情(组图)(2)    9
7000亿美元救市方案将成期市毒药  0
保诚启动210亿美元配股交易以融资收购AIG部门   2

分类class类别文件:

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

1.2 数据读取和截断,使满足BERT模型输入

读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。

import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

from transformers import AutoTokenizer, AutoModelForMaskedLM

def read_data(file):
    # 读取文件
    all_data = open(file, "r", encoding="utf-8").read().split("\n")
    # 得到所有文本、所有标签、句子的最大长度
    texts, labels, max_length = [], [], []
    for data in all_data:
        if data:
            text, label = data.split("\t")
            max_length.append(len(text))
            texts.append(text)
            labels.append(label)
    # 根据不同的数据集返回不同的内容
    if os.path.split(file)[1] == "train.txt":
        max_len = max(max_length)
        return texts, labels, max_len
    return texts, labels,


class MyDataset(Dataset):
    def __init__(self, texts, labels, max_length):
        self.all_text = texts
        self.all_label = labels
        self.max_len = max_length
        self.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

    def __getitem__(self, index):
        # 取出一条数据并截断长度
        text = self.all_text[index][:self.max_len]
        label = self.all_label[index]

        # 分词
        text_id = self.tokenizer.tokenize(text)
        # 加上起始标志
        text_id = ["[CLS]"] + text_id

        # 编码
        token_id = self.tokenizer.convert_tokens_to_ids(text_id)
        # 掩码  -》
        mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))
        # 编码后  -》长度一致
        token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))
        # str -》 int
        label = int(label)

        # 转化成tensor
        token_ids = torch.tensor(token_ids)
        mask = torch.tensor(mask)
        label = torch.tensor(label)

        return (token_ids, mask), label

    def __len__(self):
        # 得到文本的长度
        return len(self.all_text)

将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。

二、微调BERT模型

我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。

至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。

pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。

简单调用下BERT模型,打印出来最后一层看下:

import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM

def process_text(text, bert_pred):
    tokenizer = BertTokenizer.from_pretrained(bert_pred)
    token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))
    mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))
    token_ids = token_id + [0] * (38 + 2 - len(token_id))
    token_ids = torch.tensor(token_ids).unsqueeze(0)
    mask = torch.tensor(mask).unsqueeze(0)
    x = torch.stack([token_ids, mask])
    return x

device = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:
    x = process_text(text, './bert-base-chinese/')
    input_ids, attention_mask = x[0].to(device), x[1].to(device)
    hidden_out = bert(input_ids, attention_mask=attention_mask,
                               output_hidden_states=False) 
    print(hidden_out)

 输出结果:

2.1 文本分类任务,选择使用pooler_output作为线性层的输入。

import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torch

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.args = parsers()
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  
        self.bert = BertModel.from_pretrained(self.args.bert_pred) 
        # bert 模型进行微调
        for param in self.bert.parameters():
            param.requires_grad = True
        # 一个全连接层
        self.linear = nn.Linear(self.args.num_filters, self.args.class_num)

    def forward(self, x):
        input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)
        hidden_out = self.bert(input_ids, attention_mask=attention_mask,
                               output_hidden_states=False)  # 是否输出所有encoder层的结果
        # shape (batch_size, hidden_size)  pooler_output -->  hidden_out[0]
        pred = self.linear(hidden_out.pooler_output)
        # 返回预测结果
        return pred

2.2 优化器使用Adam、损失函数使用交叉熵损失函数

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()

三、训练模型

3.1 参数配置

def parsers():
    parser = argparse.ArgumentParser(description="Bert model of argparse")
    parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期
    parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))
    parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))
    parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))
    parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))
    parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")
    parser.add_argument("--class_num", type=int, default=12)
    parser.add_argument("--max_len", type=int, default=38)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--learn_rate", type=float, default=1e-5)
    parser.add_argument("--num_filters", type=int, default=768)
    parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))
    parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))
    args = parser.parse_args()
    return args

3.2 模型训练

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import time


if __name__ == "__main__":
    start = time.time()
    args = parsers()
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("device:", device)
    train_text, train_label, max_len = read_data(args.train_file)
    dev_text, dev_label = read_data(args.dev_file)
    args.max_len = max_len

    train_dataset = MyDataset(train_text, train_label, args.max_len)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    dev_dataset = MyDataset(dev_text, dev_label, args.max_len)
    dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)

    model = MyModel().to(device)
    opt = AdamW(model.parameters(), lr=args.learn_rate)
    loss_fn = nn.CrossEntropyLoss()

    acc_max = float("-inf")
    for epoch in range(args.epochs):
        loss_sum, count = 0, 0
        model.train()
        for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):
            batch_label = batch_label.to(device)
            pred = model(batch_text)

            loss = loss_fn(pred, batch_label)
            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_sum += loss
            count += 1

            # 打印内容
            if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:
                msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
                print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
                loss_sum, count = 0.0, 0

            if batch_index % 1000 == 999:
                msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
                print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
                loss_sum, count = 0.0, 0

        model.eval()
        all_pred, all_true = [], []
        with torch.no_grad():
            for batch_text, batch_label in dev_dataloader:
                batch_label = batch_label.to(device)
                pred = model(batch_text)

                pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()
                label = batch_label.cpu().numpy().tolist()

                all_pred.extend(pred)
                all_true.extend(label)

        acc = accuracy_score(all_pred, all_true)
        print(f"dev acc:{acc:.4f}")
        if acc > acc_max:
            print(acc, acc_max)
            acc_max = acc
            torch.save(model.state_dict(), args.save_model_best)
            print(f"以保存最佳模型")

    torch.save(model.state_dict(), args.save_model_last)

    end = time.time()
    print(f"运行时间:{(end-start)/60%60:.4f} min")

模型保存为:

-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth

四、模型推理预测

准备预测文本文件,加载模型,进行文本的类别预测。


def text_class_name(pred):
    result = torch.argmax(pred, dim=1)
    print(torch.argmax(pred, dim=1).cpu().numpy().tolist())
    result = result.cpu().numpy().tolist()
    classification = open(args.classification, "r", encoding="utf-8").read().split("\n")
    classification_dict = dict(zip(range(len(classification)), classification))
    print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")
    
    
if __name__ == "__main__":
    start = time.time()
    args = parsers()
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model = load_model(device, args.save_model_best)

    texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]
    print("模型预测结果:")
    for text in texts:
        x = process_text(text, args.bert_pred)
        with torch.no_grad():
            pred = model(x)
        text_class_name(pred)
    end = time.time()
    print(f"耗时为:{end - start} s")

以上,基本流程完成。当然模型还需要调优来改进预测效果的。

代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:

 myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Done

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

shardingsphere分库分表项目实践5-自己用java写一个sql解析器+完整项目源码

前1节我们介绍了 shardingsphere 分表分库的sql解析与重写&#xff1a; shardingsphere分库分表项目实践4-sql解析&重写-CSDN博客 那么shardingsphere sql 解析底层究竟是怎么实现的呢&#xff0c;其实它直接用了著名的开源软件 antlr . antlr 介绍&#xff1a; ANTLR&a…

10分钟掌握项目管理核心工具:WBS、甘特图、关键路径法全解析

一、引言 在项目管理的广阔天地里&#xff0c;犹如一场精心编排的交响乐演奏&#xff0c;每个乐器、每个音符都需精准配合才能奏响美妙乐章。而 WBS&#xff08;工作分解结构&#xff09;、甘特图、关键路径法无疑是这场交响乐中的关键乐章&#xff0c;它们从不同维度为项目管…

【LLM】OpenAI 的DAY12汇总和o3介绍

note o3 体现出的编程和数学能力&#xff0c;不仅达到了 AGI 的门槛&#xff0c;甚至摸到了 ASI&#xff08;超级人工智能&#xff09;的边。 Day 1&#xff1a;o1完全版&#xff0c;开场即巅峰 12天发布会的开场即是“炸场级”更新——o1完全版。相比此前的预览版本&#x…

使用Kubernetes部署MySQL+WordPress

目录 前提条件 部署MySQL和WordPress 编写yaml文件 应用yaml文件 存在问题及解决方案 创建PV(持久化卷) 创建一个PVC(持久化卷声明) 部署添加PVC 查看PV对应的主机存储 删除资源 查看资源 删除deployment和service 查看主机数据 删除PVC和PV 删除主机数据 前提条…

RabbitMQ中的异步Confirm模式:提升消息可靠性的利器

在现代分布式系统中&#xff0c;消息队列&#xff08;Message Queue&#xff09;扮演着至关重要的角色&#xff0c;它能够解耦系统组件、提高系统的可扩展性和可靠性。RabbitMQ作为一款广泛使用的消息队列中间件&#xff0c;提供了多种机制来确保消息的可靠传递。其中&#xff…

sentinel限流+其他

quick-start | Sentinel sentinel 作用 限流 熔断降级 1&#xff0c;限制什么 QPS 并发线程数 2&#xff0c;限制什么 资源&#xff0c;什么资源 服务&#xff0c;方法&#xff0c;接口&#xff0c;或者一段代码 3&#xff0c;实现方式 配置规则 注解 其他 Java常见5种限流…

Ubuntu 中安装 RabbitMQ 教程

简介 RabbitMq作为一款消息队列产品&#xff0c;它由Erlang语言开发&#xff0c;实现AMQP&#xff08;高级消息队列协议&#xff09;的开源消息中间件。 应用场景 异步处理 场景说明&#xff1a;用户注册后&#xff0c;注册信息写入数据库&#xff0c;再发邮件、短信通知。 …

Spark生态圈

Spark 主要用于替代Hadoop中的 MapReduce 计算模型。存储依然可以使用 HDFS&#xff0c;但是中间结果可以存放在内存中&#xff1b;调度可以使用 Spark 内置的&#xff0c;也可以使用更成熟的调度系统 YARN 等。 Spark有完善的生态圈&#xff1a; Spark Core&#xff1a;实现了…

AT24C02学习笔记

看手册&#xff1a; AT24Cxx xx代表能写入xxK bit(xx K)/8 byte 内部写周期很关键&#xff0c;代表每一次页写或字节写结束后时间要大于5ms&#xff08;延时5ms确保完成写周期&#xff09;&#xff0c;否则时序会出错。 页写&#xff1a;型不同号每一页可能写入不同大小的…

119.【C语言】数据结构之快速排序(调用库函数)

目录 1.C语言快速排序的库函数 1.使用qsort函数前先包含头文件 2.qsort的四个参数 3.qsort函数使用 对int类型的数据排序 运行结果 对char类型的数据排序 运行结果 对浮点型数据排序 运行结果 2.题外话:函数名的本质 1.C语言快速排序的库函数 cplusplus网的介绍 ht…

五模型对比!Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量时间序列预测

目录 预测效果基本介绍程序设计参考资料 预测效果 基本介绍 光伏功率预测&#xff01;五模型对比&#xff01;Transformer-GRU、Transformer、CNN-GRU、GRU、CNN五模型多变量时间序列预测(Matlab2023b 多输入单输出) 1.程序已经调试好&#xff0c;替换数据集后&#xff0c;仅运…

利用Dockerfile构建自定义镜像

当一个系统开发完成&#xff0c;需要将系统打包为一个镜像文件&#xff0c;让docker能够运行该镜像&#xff0c;成为一个可以被访问的容器。 上述操作可以通过自定义镜像的方式来实现&#xff0c;本文章基于VMware虚拟机中安装的Centos7操作系统来完成。前面的操作步骤&#x…

喜报 | 擎创科技入围上海市优秀信创解决方案

近日&#xff0c;由上海市经信委组织的“2024年上海市优秀信创解决方案”征集遴选活动圆满落幕&#xff0c;擎创科技凭借实践经验优秀的《擎创夏洛克智能预警与应急处置解决方案》成功入选“2024年上海市优秀信创解决方案”名单。 为激发创新活力&#xff0c;发挥标杆作用&…

基于aspose.words组件的word bytes转pdf bytes,去除水印和解决linux中文乱码问题

详情见 https://preferdoor.top/archives/ji-yu-aspose.wordszu-jian-de-word-byteszhuan-pdf-bytes

快速排序学习优化

首先&#xff0c;上图。 ‘’’ cpp int partSort(int *a ,int left,int right) {int keyi left; //做左侧基准while(left<right){while(left<right && a[right]>a[keyi]){right--;}while(left<right && a[left]<a[keyi]){left;}swap(a[left…

搭建vue项目

一、环境准备 1、安装node node官网&#xff1a;https://nodejs.org/zh-cn 1.1、打开官网&#xff0c;选择“下载”。 1.2、选择版本号&#xff0c;选择系统&#xff0c;根据需要自行选择&#xff0c;上面是命令安装方式&#xff0c;下载是下载安装包。 1.3、检查node安装…

华为管理变革之道:管理制度创新

目录 华为崛起两大因素&#xff1a;管理制度创新和组织文化。 管理是科学&#xff0c;150年来管理史上最伟大的创新是流程 为什么要变革&#xff1f; 向世界标杆学习&#xff0c;是变革第一方法论 体系之一&#xff1a;华为的DSTE战略管理体系&#xff08;解决&#xff1a…

ASP-CMS漏洞

打开aspcms靶场 账号&#xff1a;admin 密码&#xff1a;123456 去保存抓包 在slideTextStatus1后面写上%25><%25eval(request(chr(65)))%25><%25 我们在去访问这个文件config/AspCms_Config.asp再去蚁剑连接&#xff0c;连接成功

pyqt和pycharm环境搭建

安装 python安装&#xff1a; https://www.python.org/downloads/release/python-3913/ python3.9.13 64位(记得勾选Path环境变量) pycharm安装&#xff1a; https://www.jetbrains.com/pycharm/download/?sectionwindows community免费版 换源&#xff1a; pip config se…

微服务-1 认识微服务

目录​​​​​​​ 1 认识微服务 1.1 单体架构 1.2 微服务 1.3 SpringCloud 2 服务拆分原则 2.1 什么时候拆 2.2 怎么拆 2.3 服务调用 3. 服务注册与发现 3.1 注册中心原理 3.2 Nacos注册中心 3.3 服务注册 3.3.1 添加依赖 3.3.2 配置Nacos 3.3.3 启动服务实例 …