bert新闻标题分类

news2024/11/18 11:46:16

使用 bert 完成文本分类任务,数据有 20w,来自https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch/tree/master/THUCNews

 下载即可:

模型使用 bert-base-chinese 下载参考:bert预训练模型下载-CSDN博客

实现了新闻分类,小编在这做个笔记,整个流程也就是对 bert 模型的应用,写了注释,方便学习查看,把代码放这里记录一下:

import os
import torch
from transformers import (
    get_linear_schedule_with_warmup,BertTokenizer,
    AdamW,
    AutoModelForSequenceClassification,
    AutoConfig
)
from torch.utils.data import DataLoader, dataset
import time
import numpy as np
from sklearn import metrics
from datetime import timedelta


data_dir = 'THUCNews/data'
# 在代码开始部分添加全局变量和函数
global_batch_size = 4  # 初始化为固定的batch_size
max_batch_size = 32  # 根据实际情况设置最大允许的batch_size


def get_optimal_batch_size():
    global global_batch_size
    # 这里是检查GPU可用内存并尝试增大批次大小的逻辑
    # 具体实现可能需要根据您的设备和任务进行调整
    # 以下仅为模拟示例,实际操作时请替换为正确的方法
    free_memory = torch.cuda.memory_allocated() / (1024**3)  # 获取当前GPU空闲显存(单位:GB)
    optimal_bs = min(max_batch_size, int(free_memory * 0.8))  # 按80%的空闲显存分配批次大小
    if optimal_bs > global_batch_size:
        global_batch_size = optimal_bs
    return global_batch_size


def read_file(path):
    with open(path, 'r', encoding="UTF-8") as file:
        docus = file.readlines()
        newDocus = []
        for data in docus:
            newDocus.append(data)
    return newDocus


class Label_Dataset(dataset.Dataset):  # 建立自定义数据集
    def __init__(self, data):
        self.data = data

    def __len__(self):  # 返回数据长度
        return len(self.data)

    def __getitem__(self, ind):
        onetext = self.data[ind]
        content, label = onetext.split('\t')
        label = torch.LongTensor([int(label)])
        return content, label


# 读取数据内容
trainContent = read_file(os.path.join(data_dir, "train.txt"))
testContent = read_file(os.path.join(data_dir, "test.txt"))
# 封成数据类型
traindataset = Label_Dataset(trainContent)
testdataset = Label_Dataset(testContent)
# 封装成数据加载器
testdataloder = DataLoader(testdataset, batch_size=1, shuffle=False)
batch_size = 1
traindataloder = DataLoader(traindataset, batch_size=get_optimal_batch_size(), shuffle=True)
# 加载器类别名称
class_list = [x.strip() for x in open(
    os.path.join(data_dir, "class.txt")).readlines()]

# 模型名称
pretrained_weights = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
config = AutoConfig.from_pretrained(pretrained_weights, num_labels=len(class_list))
# 单独指定config,在config中指定分类个数
# 因为是分类任务,用 AutoModelForSequenceClassification
nlp_classif = AutoModelForSequenceClassification.from_pretrained(pretrained_weights,
                                                                 config=config)
# 指定机器
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 可能 gpu 显存不够
device = torch.device("cpu")
nlp_classif = nlp_classif.to(device)


time_start = time.time() #开始时间
epochs = 2
gradient_accumulation_steps = 1
max_grad_norm =0.1  #梯度剪辑的阀值

require_improvement = 1000                 # 若超过1000batch效果还没提升,则提前结束训练
savedir = './myfinetun-bert_chinese/'
os.makedirs(savedir, exist_ok=True)


def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def train(model, traindataloder, testdataloder):
    """
    开始训练
    :param model:
    :param traindataloder:
    :param testdataloder:
    :return:
    """
    start_time = time.time()
    # 在训练模式下,模型会启用如dropout和batch normalization这样的正则化技术。
    model.train()
    # 获取模型中所有可训练参数及其名称。这样可以方便地对不同类型的参数应用不同的优化策略
    param_optimizer = list(model.named_parameters())
    # 不需要权重衰减(weight decay/L2正则化)的参数名称部分。通常包括偏置项(bias)和LayerNorm层中的偏差与权重参数
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # 创建了两个参数组,分别对需要和不需要权重衰减的参数应用不同的权重衰减率。第一组设置了0.01的权重衰减,第二组不进行权重衰减
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
    # :使用AdamW优化器初始化模型参数。AdamW是对Adam优化器的一个改进版本,它确保了权重衰减在梯度更新之前被正确应用。
    # 这里的lr表示学习率,设置为5e-5;eps是Adam算法中的一个稳定系数,设置为1e-8
    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)

    # 创建一个线性学习率调度器,并带有预热阶段(warmup)。这里没有设置预热步数(num_warmup_steps),
    # 意味着没有预热阶段;num_training_steps 设置为整个训练过程中迭代的总步数,
    # 即训练数据加载器循环次数乘以轮数(epochs)。随着训练的进行,学习率将按照预先设定的方式逐渐降低,从而有助于模型收敛并防止过拟合。
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, num_training_steps=len(traindataloder) * epochs)


    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升

    for epoch in range(epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, epochs))
        # sku_name代表文本序列,labels代表对应的类别标签。
        for i, (sku_name, labels) in enumerate(traindataloder):
            model.train()
            # 使用BERT分词器对文本序列进行编码,并根据需要补全至最大长度,同时将结果转换为PyTorch张量格式
            ids = tokenizer.batch_encode_plus(sku_name,
                                               #                max_length=model.config.max_position_embeddings,  #模型的配置文件中就是512,当有超过这个长度的会报错
                                               pad_to_max_length=True, return_tensors='pt')#没有return_tensors会返回list!!!!
            # 清零优化器中的梯度累计信息,准备进行新的反向传播过程
            optimizer.zero_grad()
            # 将标签数据从CPU转移到指定设备(如GPU)上,并去除可能存在的额外维度
            labels = labels.squeeze().to(device)
            # 将编码后的输入ID、标签和注意力掩码传入模型进行前向传播计算,得到损失和其他输出
            outputs = model(ids["input_ids"].to(device), labels=labels,
                            attention_mask=ids["attention_mask"].to(device))
            # 从模型返回的结果中提取损失值和logits(未归一化的预测概率)
            loss, logits = outputs[:2]
            # 如果设置了梯度累积步骤大于1,则需要将损失除以这个值,这样在多个小批次上累积更新梯度
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            # 计算梯度并反向传播到模型参数
            loss.backward()
            # 每经过gradient_accumulation_steps次迭代后
            if (i + 1) % gradient_accumulation_steps == 0:
                # 对模型所有参数的梯度进行裁剪,防止梯度过大导致训练不稳定
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()  # 应用梯度更新模型参数。
            scheduler.step()  # 更新学习率调度器,根据当前训练步数调整学习率
            model.zero_grad()  # 再次清零梯度,为下一批次的训练做准备

            '''评估模型在训练集和验证集上的性能,并根据验证集上的表现做出相应的决策:'''
            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                truelabel = labels.data.cpu()  # 真实类别
                predic = torch.argmax(logits,axis=1).data.cpu()  # 预测类别
                #                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(truelabel, predic)  # 比较
                dev_acc, dev_loss = evaluate(model, testdataloder)  # 计算验证集上的准确率和损失值
                if dev_loss < dev_best_loss:
                    '''比较当前验证集损失与历史最优验证集损失(dev_best_loss),如果当前损失更低,
                    则更新最优损失并保存模型至预设路径(savedir),同时记录最后一次改善的批次索引(last_improve)'''
                    dev_best_loss = dev_loss
                    model.save_pretrained(savedir)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                # 输出当前迭代次数、训练损失、训练准确率、验证损失、验证准确率以及已用时间
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                model.train()
            total_batch += 1  # 增加累计批次计数器
            '''如果自上次验证集损失下降以来已经过去了超过指定数量的批次(这里是1000批次),
            并且在这期间验证集损失未再降低,则自动停止训练,打印提示信息,并跳出循环。'''
            if total_batch - last_improve > require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break


def evaluate(model, testdataloder):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for sku_name, labels in testdataloder:
            ids = tokenizer.batch_encode_plus( sku_name,
                                               #                max_length=model.config.max_position_embeddings,  #模型的配置文件中就是512,当有超过这个长度的会报错
                                               pad_to_max_length=True,return_tensors='pt')#没有return_tensors会返回list!!!!

            labels = labels.squeeze().to(device)
            outputs = model(ids["input_ids"].to(device), labels=labels,
                            attention_mask =ids["attention_mask"].to(device) )

            loss, logits = outputs[:2]
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.argmax(logits, axis=1).data.cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)
    acc = metrics.accuracy_score(labels_all, predict_all)
    return acc, loss_total / len(testdataloder)


train(nlp_classif, traindataloder, testdataloder)

代码输出结果会生成一个文件夹:myfinetun-bert_chinese 里面存放的是模型,最后会生成一个 best 模型,我这里没跑完哈,所以结果不全

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1417754.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

无人机在三维空间中的转动问题

前提 这篇博客是对最近一个有关无人机拍摄图像项目中所学到的新知识的一个总结&#xff0c;比较杂乱&#xff0c;没有固定的写作顺序。 无人机坐标系旋转问题 上图是无人机坐标系&#xff0c;绕x轴是翻滚(Roll)&#xff0c;绕y轴是俯仰(Pitch)&#xff0c;绕z轴是偏航(Yaw)。…

微信小程序(二十三)获取页面栈及当前页面实例

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.页面栈的定义 2.获取当前页面实例 页面栈 当我们从A页面跳到B页面再跳到C页面时&#xff0c;页面栈则是由三个页面的实例组成的数组&#xff0c;A在下标为0的数组中&#xff0c;C在下标为2的数组中 当然&#…

【JaveWeb教程】(30)SpringBootWeb案例之《智能学习辅助系统》的详细实现步骤与代码示例(3)员工管理的实现

目录 SpringBootWeb案例033. 员工管理3.1 分页查询3.1.1 基础分页3.1.1.1 需求分析3.1.1.2 接口文档3.1.1.3 思路分析3.1.1.4 功能开发3.1.1.5 功能测试3.1.1.6 前后端联调 3.1.2 分页插件3.1.2.1 介绍3.1.2.2 代码实现3.1.2.3 测试 3.2 分页查询(带条件)3.2.1 需求3.2.2 思路分…

[答疑]张学友和Neal Ford的区别

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 albert 2024-1-22 16:43 如果base合适&#xff0c;不出差也可以演讲。比如base北京&#xff0c;每周在北京讲3到4场&#xff0c;一年150到200,十五二十年就累积够了。 我看新闻&…

【Sql Server】新手一分钟看懂在已有表基础上增加字段和说明

欢迎来到《小5讲堂》&#xff0c;大家好&#xff0c;我是全栈小5。 这是《Sql Server》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解&#xff0c; 特别是针对知识点的概念进行叙说&#xff0c;大部分文章将会对这些概念进行实际例子验证&#xff0c;以此达到加深对…

【时间安排】

最近刚刚回到家&#xff0c;到家就是会有各种事情干扰&#xff0c;心里变乱人变懒的&#xff0c;而要做的事情也要继续&#xff0c;写论文&#xff0c;改简历&#xff0c;学习新技能。。 明天后天两天写论文改简历 周一&#xff08;早上去城市书房&#xff0c;可能吵一点戴个耳…

linux 下scrcpy 手机投屏到电脑,QT+ffmpeg 获取视频流,处理等等

linux 下scrcpy 手机投屏到电脑,QT+ffmpeg 获取视频流,处理 1 安装 scrcpy 地址 https://github.com/Genymobile/scrcpy 转到 relese 下载 我这里下载的是linux系统 v2.3.1 版本 scrcpy-2.3.1.tar.gz 下载 scrcpy-server v2.3.1 版本 scrcpy-server-v2.3.1 解压scrcpy-2.3…

uni-app 开发

一、uni-app 简介 uni-app 是一个使用 Vue.js 开发所有前端应用的框架 。开发者编写一套代码&#xff0c;可发布到 iOS 、 Android 、 H5 、以及各种小程序&#xff08;微信 / 支付宝 / 百度 / 头条 /QQ/ 钉钉 / 淘宝&#xff09;、快应用等多个平台。 详细的 uni-app 官方…

mysql注入联合查询

环境搭建 下载复现漏洞的包 下载小皮面板 将下载好的文件解压在小皮面板的phpstudy_pro\WWW路径下 将这个文件phpstudy_pro\WWW\sqli-labs-php7-master\sql-connections\db-creds.inc 中的密码更改为小皮面板中的密码 选择php版本 在小皮中启动nginx和数据库 使用环回地址访…

python爬虫demo——爬取历史平均房价

简单爬取历史房价 需求 爬取的网站汇聚数据的城市房价 https://fangjia.gotohui.com/ 功能 选择城市 https://fangjia.gotohui.com/fjdata-3 需要爬取年份的数据&#xff0c;等等 https://fangjia.gotohui.com/years/3/2018/ 使用bs4模块 使用bs4模块快速定义需要爬取的…

c#之构值类型和引用类型

值类型:(整数/bool/struct/char/小数) 引用类型:(string/ 数组 / 自定义的类 / 内置的类) 值类型只需要一段单独的内存,用于存储实际的数据 引用类型需要两段内存(第一段存储实际的数据,他总是位于 堆中第二段是一个引用,指向数据在堆中的存放位置) 当使用引用类型赋值的时…

AV Foundation 视频播放中的可视拖拽进度条

引言 在视频播放软件中&#xff0c;通过拖拽进度条来调整播放进度几乎已成为不可或缺的功能。这一功能使用户能够精确指定视频播放的时间点。近年来&#xff0c;视频播放器在原有的拖拽进度条基础上进行了更加人性化的性能提升&#xff0c;引入了可视化拖拽条。这一创新为用户…

2023年CSDN年终总结:长风破浪会有时,风物长宜放眼量

目录 0 回首20231 打造垂类专栏2 个人技术成长3 首发SCI期刊4 生活中的美好5 新年新flag 0 回首2023 这是去年flag的完成情况&#xff0c;很惊喜地发现全部顺利完成了。 CSDN坚持垂类写作&#xff0c;完结机器学习和ROS机器人专栏&#xff0c;开启深度学习新篇章 粉丝数希望突…

TS学习笔记十二:项目配置

本节介绍ts项目配置相关内容&#xff0c;包括项目配置文件tsconfig.json的说明及编译选项的内容介绍。 讲解视频 TS学习笔记二十五&#xff1a;TS项目配置 B站视频 TS学习笔记二十五&#xff1a;TS项目配置 西瓜视频 https://www.ixigua.com/7327847796814709288 一、tsconf…

GUN/Linux时间同步服务之ntp配置管理

风险告知 本人及本篇博文不为任何人及任何行为的任何风险承担责任&#xff0c;图解仅供参考&#xff0c;请悉知&#xff01;相关配置操作是在一个全新的演示环境下进行的&#xff0c;演示环境中没有任何有价值的数据&#xff0c;但这并不代表摆在你面前的环境也是如此。生产环境…

超声波清洗机可以洗哪些东西?性价比比较高的超声波清洗机推荐

超声波清洗机现在作为一个相对来说比较方便快捷的清洁工具&#xff0c;其应用范围也是非常广泛的。无论是生活中的小物件&#xff0c;像眼镜还是耳钉这些&#xff0c;还是工业生产中的大型设备&#xff0c;超声波清洗机都能发挥出其独特的清洗效果&#xff0c;能够非常的省事的…

Docker容器引擎(4)

目录 一.搭建本地私有仓库 运行 registry 容器&#xff1a; Docker容器的重启策略如下&#xff1a; 为镜像打标签&#xff1a; 上传到私有仓库&#xff1a; 列出私有仓库的所有镜像&#xff1a; 列出私有仓库的 centos 镜像有哪些tag&#xff1a; 二.Docker--harbor私有…

【LeetCode: Z 字形变换 + 模拟】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

(28)Linux 信号保存 信号处理 不可重入函数

首先介绍几个新的概念&#xff1a; 信号递达(Delivery)&#xff1a;实际执行信号的处理动作。信号未决(Pending)&#xff1a;信号从产生到递达之间的状态。信号阻塞(Block)&#xff1a;被阻塞的信号产生时将保持在未决状态&#xff0c;直达解除对该信号的阻塞&#xff0c;才执…

Mac本上快速搭建redis服务指南

文章目录 前言1. 查看可用版本2.安装指定版本的redis3.添加redis到PATH3.1 按照执行brew install命令后输出的提示信息执行如下命令将redis添加到PATH3.2 执行命令要添加的redis环境信息生效: 4. 增加密码4.1 在文件中找到requirepass所在位置4.2 去掉注释并将requirepass值替换…