基于BERT对中文邮件内容分类

news2025/2/23 20:14:52

用BERT做中文邮件内容分类

    • 项目背景与意义
    • 项目思路
    • 数据集介绍
    • 环境配置
    • 数据加载与预处理
    • 自定义数据集
    • 模型训练
      • 加载BERT预训练模型
      • 开始训练
    • 预测效果

项目背景与意义

本文是《用BERT做中文邮件内容分类》系列的第二篇,该系列项目持续更新中。系列的起源是《使用PaddleNLP识别垃圾邮件》项目,旨在解决企业面临的垃圾邮件问题,通过深度学习方法探索多语言垃圾邮件的内容、标题提取与分类识别。

在本篇文章中,我们使用PaddleNLP的BERT预训练模型,根据提取的中文邮件内容判断邮件是否为垃圾邮件。该项目的思路在于基于前一篇项目的中文邮件内容提取,在98.5%的垃圾邮件分类器基线上,通过BERT的finetune进一步提升性能。
在这里插入图片描述

项目思路

在《使用PaddleNLP识别垃圾邮件(一)》项目的基础上,我们使用BERT进行finetune,力求在LSTM的98.5%的基线上进一步提升准确率。同时,文章中详细介绍了BERT模型的原理和PaddleNLP对BERT模型的应用,读者可以参考项目PaddleNLP2.0:BERT模型的应用进行更深入的了解。

本项目参考了陆平老师的项目应用BERT模型做短文本情绪分类(PaddleNLP 2.0),但由于PaddleNLP版本迭代的原因,进行了相应的调整和说明。

数据集介绍

我们使用了TREC 2006 Spam Track Public Corpora,这是一个公开的垃圾邮件语料库,包括英文数据集(trec06p)和中文数据集(trec06c)。在本项目中,我们仅使用了TREC 2006提供的中文数据集进行演示。数据集来源于真实邮件,保留了邮件的原有格式和内容。

除了TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集,但本项目仅使用了TREC 2006提供的中文数据集。数据集文件目录形式如下:

trec06c
│
├── data
│   │   000
│   │   001
│   │   ...
│   └───215
├── delay
│   │   index
└── full
    │   index

邮件内容样本示例:

负责人您好我是深圳金海实业有限公司...
GG非常好的朋友H在计划马上的西藏自助游...

环境配置

本项目基于Paddle 2.0编写,如果你的环境不是本版本,请先参考官网安装Paddle 2.0。以下是环境配置代码:

# 导入相关的模块
import re
import jieba
import os
import random
import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Pad, Tuple
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
import numpy as np
from functools import partial

数据加载与预处理

项目中使用了PaddleNLP的BertTokenizer进行数据处理,该tokenizer可以将原始输入文本转化成模型可接受的输入数据格式。以下是数据加载与预处理的代码:

# 解压数据集
!tar xvf data/data89631/trec06c.tgz

# 去掉非中文字符
def clean_str(string):
    string = re.sub(r"[^\u4e00-\u9fff]", " ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip()

# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path, save_path='all_email.txt'):
    email = ''
    f = open(original_path, 'r', encoding='gb2312', errors='ignore')
    for line in f:
        line = line.strip().strip('\n')
        line = clean_str(line)
        email += line
    f.close()
    return email[-200:]

# 读取标签文件信息
f = open('trec06c/full/index', 'r')
for line in f:
    str_list = line.split(" ")
    if str_list[0] == 'spam':
        label = '0'
    elif str_list[0] == 'ham':
        label = '1'
    text = get_data_in_a_file('trec06c/full/' + str(str_list[1].split("\n")[0]))
    with open("all_email.txt", "a+") as f:
        f.write(text + '\t' + label + '\n')

自定义数据集

在项目中,我们需要自定义数据集,并使其数据格式与使用ppnlp.datasets.ChnSentiCorp.get_datasets加载后完全一致。以下是自定义数据集的代码:

class SelfDefinedDataset(paddle.io.Dataset):
    def __init__(self, data):
        super(SelfDefinedDataset, self).__init__()
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def get_labels(self):
        return ["0", "1"]

def txt_to_list(file_name):
    res_list = []
    for line in open(file_name):
        res_list.append(line.strip().split('\t'))
    return res_list

trainlst = txt_to_list('train_list.txt')
devlst = txt_to_list('eval_list.txt')
testlst = txt_to_list('test_list.txt')

train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])

模型训练

加载BERT预训练模型

项目中使用了PaddleNLP提供的BertForSequenceClassification模型进行文本分类的Fine-tune。由于垃圾邮件识别是二分类问题,所以设置num_classes为2。

以下是加载BERT预训练模型的代码:

# 加载预训练模型
model = ppnlp.transformers.BertForSequenceClassification.from_pretrained("bert-base-chinese", num_classes=2)

开始训练

为了监控训练过程,引入了VisualDL记录训练log信息。以下是开始训练的代码:

# 设置训练超参数
learning_rate = 1e-5
epochs = 10
warmup_proption = 0.1
weight_decay = 0.01

num_training_steps = len(train_loader) * epochs
num_warmup_steps = int(warmup_proption * num_training_steps)

def get_lr_factor(current_step):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    else:
        return max(0.0,
                    float(num_training_steps - current_step) /
                    float(max(1, num_training_steps - num_warmup_steps)))

# 学习率调度器
lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step))

# 优化器
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ])

# 损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()

# 评估函数
metric = paddle.metric.Accuracy()

# 训练过程
global_step = 0
with LogWriter(logdir="./log") as writer:
    for epoch in range(1, epochs + 1):    
        for step, batch in enumerate(train_loader, start=1):
            input_ids, segment_ids, labels = batch
            logits = model(input_ids, segment_ids)
            loss = criterion(logits, labels)
            probs = F.softmax(logits, axis=1)
            correct = metric.compute(probs, labels)
            metric.update(correct)
            acc = metric.accumulate()

            global_step += 1
            if global_step % 50 == 0:
                print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
                writer.add_scalar(tag="train/loss", step=global_step, value=loss)
                writer.add_scalar(tag="train/acc", step=global_step, value=acc)
            
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
        
        eval_loss, eval_acc = evaluate(model, criterion, metric, dev_loader)
        writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)
        writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)

可以看到,在第2个epoch后验证集准确率已经达到99.4%以上,在第3个epoch就能达到99.6%以上。

预测效果

完成模型训练后,我们可以使用训练好的模型对测试集进行预测。以下是预测效果的代码:

data = ['您好我公司有多余的发票可以向外代开,国税,地税,运输,广告,海关缴款书如果贵公司,厂,有需要请来电洽谈,咨询联系电话,罗先生谢谢顺祝商祺']
label_map = {0: '垃圾邮件', 1: '正常邮件'}

predictions = predict(model, data, tokenizer, label_map, batch_size=32)
for idx, text in enumerate(data):
    print('预测内容: {} \n邮件标签: {}'.format(text, predictions[idx]))

预测效果良好,一个验证集准确率高达99.6%以上、基于BERT的中文邮件内容分类顺利完成!

以上是本文的全部内容,希望对读者理解如何使用BERT进行中文邮件内容分类有所帮助。欢迎交流指导!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1401725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【前端设计】card

欢迎来到前端设计专栏&#xff0c;本专栏收藏了一些好看且实用的前端作品&#xff0c;使用简单的html、css语法打造创意有趣的作品&#xff0c;为网站加入更多高级创意的元素。 html <!DOCTYPE html> <html lang"en"> <head><meta charset&quo…

SQL注入实战:http报文包讲解、http头注入

一&#xff1a;http报文包讲解 HTTP(超文本传输协议)是今天所有web应用程序使用的通信协议。最初HTTP只是一个为获取基于文本的静态资源而开发的简单协议&#xff0c;后来人们以各种形式扩展和利用它.使其能够支持如今常见的复杂分布式应用程序。HTTP使用一种用于消息的模型:客…

SpringBoot项目中集成Kaptcha

1.Kaptcha简介 Kaptcha是一个流行的Java库&#xff0c;用于生成验证码&#xff08;CAPTCHA&#xff09;图片。CAPTCHA是“Completely Automated Public Turing test to tell Computers and Humans Apart”的缩写&#xff0c;通常用于在线表单验证以防止机器人或自动化工具的滥用…

Meta 标签的力量:如何利用它们提高网站的可见性(上)

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

IOT pwn

已经过了填坑的黄金时期 环境搭建 交叉编译工具链 很多开源项目需要交叉编译到特定架构上&#xff0c;因此需要安装对应的交叉编译工具链。 sudo apt install gcc-arm-linux-gnueabi g-arm-linux-gnueabi -y sudo apt install gcc-aarch64-linux-gnu g-aarch64-linux-gnu -…

Dubbo的几个序列化方式

欢迎订阅专栏&#xff0c;会分享Dubbo里面相关的技术实现 这篇文章就不详细的介绍每种序列化方式的实现细节&#xff0c;大家可以自行去问度娘&#xff0c;我也会找一些资料。需要注意的是&#xff0c;这个先后顺序不表示性能优越 ObjectInput、ObjectOutput 这两是Dubbo序列…

看书标记【R语言数据分析项目精解:理论、方法、实战 9】

看书标记——R语言 Chapter 9 文本挖掘——点评数据展示策略9.1项目背景、目标和方案9.1.1项目背景9.1.2项目目标9.1.3项目方案1.建立评论文本质量量化指标2.建立用户相似度模型3.对用户评论进行情感性分析 9.2项目技术理论简介9.2.1评论文本质量量化指标模型1.主题覆盖量2.评论…

电脑可以连接wifi,甚至可以qq聊天,但就是不能用浏览器上网,一直显示未检测出入户网线的解决方案

今天回到家&#xff0c;准备办公却发现电脑可以连接wifi&#xff0c;甚至可以qq聊天&#xff0c;但就是不能用浏览器上网&#xff0c;一直显示未检测出入户网线的解决方案&#xff0c;小白也可以看懂 以下有几种解决方案&#xff0c;不妨都试试&#xff0c;估计可以解决95%的相…

lv14 内核定时器 11

一、时钟中断 硬件有一个时钟装置&#xff0c;该装置每隔一定时间发出一个时钟中断&#xff08;称为一次时钟嘀嗒-tick&#xff09;&#xff0c;对应的中断处理程序就将全局变量jiffies_64加1 jiffies_64 是一个全局64位整型, jiffies全局变量为其低32位的全局变量&#xff0…

web架构师编辑器内容-图层拖动排序功能的开发

新的学习方法 用手写简单方法实现一个功能然后用比较成熟的第三方解决方案即能学习原理又能学习第三方库的使用 从两个DEMO开始 Vue Draggable Next: Vue Draggable NextReact Sortable HOC: React Sortable HOC 列表排序的三个阶段 拖动开始&#xff08;dragstart&#x…

Spring-AOP入门案例

文章目录 Spring-AOP入门案例概念:通知(Advice)切入点(Pointcut )切面&#xff08;Aspect&#xff09; 目标对象(target)代理对象(Proxy)顾问&#xff08;Advisor)连接点(JoinPoint) 简单需求&#xff1a;在接口执行前输出当前系统时间Demo原始未添加aop前1 项目包结构2 创建相…

springCloud的ribbon和feign

ribbon方式调用 就是将原来的具体地址&#xff0c;改为了通过服务名去调用。注册中心中有多个服务&#xff0c;相同服务名&#xff0c;就会算作可以调用的服务。 首先得有一个注册中心&#xff0c;然后各种服务注册进去&#xff0c;然后利用ribbon或者feign去调用。 ribbon是直…

imgaug库图像增强指南(34):揭秘【iaa.Clouds】——打造梦幻般的云朵效果

引言 在深度学习和计算机视觉的世界里&#xff0c;数据是模型训练的基石&#xff0c;其质量与数量直接影响着模型的性能。然而&#xff0c;获取大量高质量的标注数据往往需要耗费大量的时间和资源。正因如此&#xff0c;数据增强技术应运而生&#xff0c;成为了解决这一问题的…

【数据结构与算法】归并排序详解:归并排序算法,归并排序非递归实现

一、归并排序 归并排序是一种经典的排序算法&#xff0c;它使用了分治法的思想。下面是归并排序的算法思想&#xff1a; 递归地将数组划分成较小的子数组&#xff0c;直到每个子数组的长度为1或者0。将相邻的子数组合并&#xff0c;形成更大的已排序的数组&#xff0c;直到最…

Python Timer定时器:控制函数在特定时间执行

Thread类有一个Timer子类&#xff0c;该子类可用于控制指定函数在特定时间内执行一次。例如如下程序&#xff1a; from threading import Timerdef hello():print("hello, world") # 指定10秒后执行hello函数 t Timer(10.0, hello) t.start() 上面程序使用 Timer …

MySQL-B-tree和B+tree区别

B-tree&#xff08;平衡树&#xff09;和Btree&#xff08;平衡树的一种变种&#xff09;是两种常见的树状数据结构&#xff0c;用于构建索引以提高数据库的查询性能。它们在一些方面有相似之处&#xff0c;但也有一些关键的区别。以下是B-tree和Btree的主要区别&#xff1a; …

Macos数据库管理软件:Navicat Premium for Mac 16.3.5中文版

Navicat Premium 16 for Mac是一款强大的数据库管理和开发工具&#xff0c;支持多种数据库系统&#xff0c;如MySQL、Oracle、SQL Server等。它提供了直观的用户界面和丰富的功能&#xff0c;使用户能够轻松地创建、管理和维护数据库。 软件下载&#xff1a;Navicat Premium fo…

【Unity学习笔记】Unity TestRunner使用

转载请注明出处&#xff1a;&#x1f517;https://blog.csdn.net/weixin_44013533/article/details/135733479 作者&#xff1a;CSDN|Ringleader| 参考&#xff1a; Input testingGetting started with Unity Test FrameworkHowToRunUnityUnitTest如果对Unity的newInputSystem感…

不同开发语言在进程、线程和协程的设计差异

不同开发语言在进程、线程和协程的设计差异 1. 进程、线程和协程上的差异1.1 进程、线程、协程的定义1.2 进程、线程、协程的差异1.3 进程、线程、协程的内存成本1.4 进程、线程、协程的切换成本 2. 线程、协程之间的通信和协作方式2.1 python如何实现线程通信&#xff1f;2.2 …

Leetcode3006. 找出数组中的美丽下标 I

Every day a Leetcode 题目来源&#xff1a;3006. 找出数组中的美丽下标 I 解法1&#xff1a;暴力 用 C 标准库中的字符串查找函数&#xff0c;找到 a 和 b 分别在 s 中的起始下标&#xff0c;将符合要求的下标插入答案。 代码&#xff1a; /** lc appleetcode.cn id3006…