Bert模型实现中文新闻文本分类

news2024/12/28 19:11:34

        Bert基于Transformer架构是解决自然语言处理的深度学习模型,常使用在文本分类、情感分析、词性标注等场合。

        本文将使用Bert模型对中文文本进行分类,其中训练集数据18W条,验证集数据1W条,包含10个类别的文本数据,数据可以自己从Kaggel上下载。

        

中文新闻标题类别标签类别名
锌价难续去年辉煌0金融
金科西府 名墅天成1房地产
同步A股首秀:港股缩量回调2经济
状元心经:考前一周重点是回顾和整理3教育
一年网事扫荡10年纷扰开心网李鬼之争和平落幕4科技
60年铁树开花形状似玉米芯(组图)5社会
发改委治理涉企收费每年为企业减负超百亿6国际
布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击7体育
体验2D巅峰 倚天屠龙记十大创新概览8游戏
Rain入伍前最后开唱 本周六“雨”润京城(图)9娱乐

分类模型的结构比较简单,示意图如下:

Dataset是我们用的数据集的库,是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示。其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数。

DataLoader是PyTorch提供的一个数据加载器,它可以将数据分成小批次进行加载,并自动完成数据的批量加载、随机洗牌、并发预取等操作。在神经网络的训练过程中,我们通常需要处理大量的数据。如果一次性将所有数据加载到内存中,不仅会消耗大量的内存资源,还可能导致程序运行缓慢甚至崩溃。因此,我们需要一种机制来将数据分成小批次进行加载,而DataLoader正是为了满足这一需求而诞生的。

#首先导入需要用到的数据包

from transformers import BertModel, BertTokenizer
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
import os

class BertClassifier(nn.Module):
    def __init__(self, bert_model, output_size):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(bert_model.config.hidden_size, output_size)

    def forward(self, input_ids, attention_mask):
        # 获取BERT模型的CLS输出
        text_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)  
        #得到线性层的结果
        logits=self.classifier(text_output.pooler_output)
        return logits


#读取数据
class data_load(Dataset):
    def __init__(self,path):
        self.data=list()
        file=open(path,'r',encoding='utf-8')
        for line in file:
            text,label=line.strip().split('\t')
            self.data.append((text,int(label)))
        file.close()
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index]


#用于dataloader,对于每个小批量的数据,进行分词和填充
def collate_fn(batch,tokenizer):
    texts=[text[0] for text in batch]
    labels=[text[1] for text in batch]
    labels=torch.tensor(labels,dtype=torch.long)
    tokens=tokenizer(
                texts,
                add_special_tokens=True,
                max_length=512,
                padding=True,
                truncation=True,
                return_tensors='pt',
                )
    return tokens['input_ids'],tokens['attention_mask'],labels

if __name__=="__main__":
    dataset=data_load('./train.txt')

    print(len(dataset))
    
    output :180000
            
 
    #加载模型,生成分词器
    tokenizer=BertTokenizer.from_pretrained('bert-base-chinese')
    bert_model = BertModel.from_pretrained('bert-base-chinese')

    #dataset:要加载的数据集对象,必须是实现了len()和getitem()方法的对象
    data_loader=DataLoader(dataset,
                           batch_size=128,
                           shuffle=True,
                           collate_fn=lambda x:collate_fn(x,tokenizer))
    
    # 指定机器
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #打印分词器支持的最大长度,输入的中文数据不能超过512
    #如果进行长文本分类,需要进行文本截断或分块处理
    # print(tokenizer.model_max_length)
    

    #定义bertclassifier模型为10分类
    model=BertClassifier(bert_model,output_size=10).to(device)
    model.train()
    #优化器
    optimizer=optim.AdamW(model.parameters(),lr=5e-5)
    #交叉熵损失误差
    criterion=nn.CrossEntropyLoss()
    #存放模型
    os.makedirs('output_models',exist_ok=True)
    epoch_n=10
    for epoch in range(1,epoch_n+1):
        for batch_index,data in enumerate(data_loader):
            input_ids=data[0].to(device)
            attention_mask=data[1].to(device)
            label=data[2].to(device)
            #清空梯度
            optimizer.zero_grad()
            #前向传播
            output=model(input_ids,attention_mask)
            loss=criterion(output,label)
            loss.backward() #计算梯度
            optimizer.step()  #更新模型参数
            
            #计算正确率,用于观察模型结果
            predict=torch.argmax(output,dim=1)
            correct=(predict==label).sum().item()
            acc=correct/output.size(0)
            print(f"Epoch {epoch}/{epoch_n}") #迭代轮数
            print(f"Batch {batch_index+1}/{len(data_loader)}") 
            print(f"Loss: {loss.item():.4f}") #损失
            print((f"Acc {correct}/{output.size(0)}=={acc:.3f}")) #正确率
            #每一次迭代都保存一次模型结果
            model_name=f'./output_models/chinese_news_classify{epoch}.pth'
            print("saved model: %s" % (model_name))
            torch.save(model.state_dict(),model_name)

可以看到随着训练的进行,模型的准确率越来越高。由于数据量和机器内存原因,训练的时间比较长,就没有全部跑完。

Epoch 1/10
Batch 59/1407
Loss: 0.4286
Acc 113/128==0.883
saved model: ./output_models/chinese_news_classify1.pth
Epoch 1/10
Batch 60/1407
Loss: 0.4399
Acc 114/128==0.891
saved model: ./output_models/chinese_news_classify1.pth
Epoch 1/10
Batch 61/1407
Loss: 0.5028
Acc 109/128==0.852
saved model: ./output_models/chinese_news_classify1.pth
Epoch 1/10
Batch 62/1407
Loss: 0.3180
Acc 120/128==0.938

使用训练好的模型预测中文文本

from kaggel_chinese_text import  BertClassifier
from transformers import BertModel, BertTokenizer
import torch

test_text='铁血铸辉煌 天骄3公会战唤起新激情'
bert_model = BertModel.from_pretrained('bert-base-chinese')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model=BertClassifier(bert_model,10)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('./output_models/chinese_news_classify1.pth',map_location=device))
model.to(torch.device(device))
model.eval()
inputs = tokenizer.encode_plus(
    test_text,
    add_special_tokens=True,
    max_length=128,
    padding='max_length',
    truncation=True,
    return_tensors='pt'
)
input_ids = inputs['input_ids']
# print("shape of inut_ids:",input_ids.shape)
attention_mask = inputs['attention_mask']
with torch.no_grad():
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    outputs = model(input_ids,attention_mask)
    _, predicted = torch.max(outputs, 1)
print(predicted.item())

#能正确预测文本属于游戏类型
output: 8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1842948.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大润发超市购物卡怎么用?

收到大润发超市的礼品卡以后,我才发现,最近的大润发也得十来公里 为了100块的大润发打车也太不划算了 叫外送也不在配送范围内 最后没办法,在收卡云上出掉了,还好最近价格不错,也不亏,收卡云的到账速度也…

leetcode:557. 反转字符串中的单词 III(python3解法)

难度:简单 给定一个字符串 s ,你需要反转字符串中每个单词的字符顺序,同时仍保留空格和单词的初始顺序。 示例 1: 输入:s "Lets take LeetCode contest" 输出:"steL ekat edoCteeL tsetnoc…

使用飞书多维表格实现推送邮件

一、为什么用飞书? 在当今竞争激烈的商业环境中,选择一款高效、智能的办公工具至关重要。了解飞书的朋友应该都知道,飞书的集成能力是很强大的,能够与各种主流的办公软件无缝衔接,实现数据交互,提升工作效…

恒创科技:云主机上的数据安全如何保证?(实用性技巧分享)

云主机上的数据安全如何保证?答案很简单,虽很多用户却不能完全做到,但我们可以了解一些安全措施予以防范。以下是云主机数据保护的几个实用技巧,希望对您有所帮助! 1.避免将敏感信息存储在云中 网络上的许多建议听起来都像这样:“…

业余时间做跨境电商实现经济自由,我是怎么做的?

在知乎问答上翻阅大家非常感兴趣的问题,解答一些疑惑的同时,发现大家对跨境电商还是很感兴趣的,类似“小白如何入局跨境电商?2024跨境电商平台,哪些值得做?现在电商哪个平台好做?”等的这些主观问题&#…

ubuntu访问windows共享文件夹

方法: Ubuntu访问Windows共享文件夹的方法-CSDN博客 基于交换机的PC端网络通信_服务器交换机pc端-CSDN博客 补充说明: 在这里面输入: smb://192.168.0.30/WindowsShareToLinux

虚拟机Ping不通主机

1.问题描述 虚拟机IP: 192.168.3.133 主机ip:192.168.3.137 虚拟机Ping不通主机 主机可以ping通虚拟机 2.解决方案 设置桥接模式 控制面板找到网络和Internet设置 3.问题解决

Leetcode - 周赛401

目录 一,3178. 找出 K 秒后拿着球的孩子 二,3179. K 秒后第 N 个元素的值 三,3180. 执行操作可获得的最大总奖励 I 四,3181. 执行操作可获得的最大总奖励 II 一,3178. 找出 K 秒后拿着球的孩子 本题可以直接模拟&a…

CesiumJS整合ThreeJS插件封装

最近做项目有一个三维需求使用CesiumJS比较难以实现,发现THREEJS中效果比较合适,于是准备将THREEJS整合到CesiumJS中 为实现效果所需我们找到官方Integrating Cesium with Three.js博客,于是根据该博客提供的思路去实现整合 文章目录 一、创…

VMware虚拟机三种网络模式设置 - NAT(网络地址转换模式)

一、前言 在前一篇《Bridged(桥接模式)》中,我详细介绍了虚拟机网络模式设置中的桥接模式。今天详细讲解一下NAT(网络地址转换模式)。 在虚拟机(VM)中,NAT(Network Addre…

微信小程序navigateTo异常(APP-SERVICE-SDK:Unknown URL)

背景 在开发小程序时,可能会用到banner,通过banner跳转至各种子页面。但是因为小程序自身的因素,有些是不允许的,比如通过banner跳转一个http/https链接。如果使用 wx.navigateTo完成跳转时,就会发生异常。 navigate…

Latex添加参考文献的两种方案

Latex添加参考文献的两种方案 方案1:一般插入法方案2:使用BibTex 方案1:一般插入法 此方案在latex结尾直接插入参考文献,一般从IEEE官网下载的模板好像默认都是这样的!下面为参考格式: 这种方案比较容易操…

产品心理学:曝光效应

曝光效应(the exposure effect or the mere exposure effect):又谓多看效应、(简单、单纯)暴露效应、(纯粹)接触效应等等。 它是一种心理现象,指的是我们会偏好自己熟悉的事物&#…

JVM中的垃圾回收机制

文章目录 什么是垃圾为什么需要垃圾回收早期垃圾回收Java的垃圾回收机制垃圾回收主要关注的区域垃圾判定算法引用计数算法可达性分析算法 垃圾收集算法标记清除算法复制算法标记整理算法分代收集思想增量收集算法分区算法 什么是垃圾 垃圾回收(Garbage Collection&…

2024-06月 | 维信金科 | 风控数据岗位推荐,高收入岗位来袭!

今日推荐岗位:策略分析经理/分析专家、贷前、中策略分析、风控模型分析。 风控部门是金融业务的核心部门,而从事风控行业的人即称之为风险管理者。是大脑,是最最最重要的部门之一。今日推荐岗位的核心技能分布如下: 简历发送方式…

磁盘未格式化:深度解析、恢复策略与预防措施

一、磁盘未格式化的定义与现象 在计算机存储领域,磁盘未格式化通常指的是磁盘分区或整个磁盘的文件系统信息出现丢失或损坏的情况,导致操作系统无法正确读取和识别磁盘上的数据。当尝试访问这样的磁盘时,系统往往会弹出一个警告框&#xff0…

001 Spring介绍

文章目录 特点1.方便解耦,简化开发2.AOP编程的支持3.声明式事务的支持4.方便程序的测试5.方便集成各种优秀框架6.降低Java EE API的使用难度7.Java源码是经典学习范例 好处什么是耦合和内聚耦合性,也叫耦合度,是对模块间关联程度的度量内聚标…

蓝鹏测控公司全长直线度算法项目多部门现场组织验收

关键字:全场直线度算法,直线度测量仪,直线度检测,直线度测量设备, 6月18日上午,蓝鹏测控公司全长直线度算法项目顺利通过多部门现场验收。该项目由公司技术部、开发部、生产部等多个部门共同参与,旨在提高直线度测量精度,满足高精度制造领域需…

ppt转换word文档怎么操作?6个软件让你自己轻松转换文件

ppt转换word文档怎么操作?6个软件让你自己轻松转换文件 将PPT文件转换为Word文档是一项常见的任务,可以通过多种软件和在线工具来实现。以下是六款常用的软件和工具,它们可以帮助您轻松地将PPT文件转换为Word文档: 1.迅捷PDF转换…

Windows上PyTorch3D安装踩坑记录

直入正题,打开命令行,直接通过 pip 安装 PyTorch3D : (python11) F:\study\2021-07\python>pip install pytorch3d Looking in indexes: http://mirrors.aliyun.com/pypi/simple/ ERROR: Could not find a version that satisfies the requirement p…