欺诈文本分类检测(十三):交叉训练验证

news2024/12/23 2:03:07

1. 引言

交叉验证主要讨论的是数据集的划分问题。

通常情况下,我们会采用均匀随机抽样的方式将数据集划分成3个部分——训练集、验证集和测试集,这三个集合不能有交集,常见的比例是8:1:1(如同前文我们所作的划分)。这三个数据集的用途分别是:

  • 训练集:用来训练模型,去学习模型的权重和偏置这些参数,这些参数可称为学习参数。
  • 验证集:用于在训练过程中选择超参数,比如批量大小、学习率、迭代次数等,它并不参与梯度下降,也不参与学习参数的确定。
  • 测试集:用于训练完成后评价最终的模型时使用,它既不参与学习参数的确定,也不参数超参数的选择,而仅仅使用于模型的评价。

注:千万不能在训练过程中使用测试集,不论是用于训练还是用于超参数的选择,这会将测试数据无意中提前透露给模型,相当于作弊,使得模型测试时准确率虚高。

而交叉验证与上述不同的地方在于:在手动划分时只分出训练集和测试集,到真正训练时才从训练集中动态抽取一定比例作为验证集,并且在多轮训练中会循环提取不同的训练集和验证集。数据集划分大概如下图:
在这里插入图片描述

  • 第一轮训练时,将训练集平均分成5份,选1份作为验证集,其余4份作为训练集。
  • 第二轮训练时,取另外的1份作为验证集,剩余4份作为训练集。
  • ……
  • 如此循环,直到每份数据都参与过训练和验证。

这样做的好处在于:模型能更充分的利用数据,更全面的学习到数据的整体特征,减少过拟合风险。

2. 训练过程

2.1 初始化

这一部分同前文训练的预设一样,基本没有什么改变。

%run trainer.py
traindata_path = '/data2/anti_fraud/dataset/train0819.jsonl'
evaldata_path = '/data2/anti_fraud/dataset/eval0819.jsonl'
model_path = '/data2/anti_fraud/models/modelscope/hub/Qwen/Qwen2-1___5B-Instruct'
output_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_1'

声明要使用的GPU设备。

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'

加载模型和tokenizer分词器。

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
train_dataset, eval_dataset = load_dataset(traindata_path, evaldata_path, tokenizer)

在这里插入图片描述

2.2 数据处理

这一部分主要是将前文构造训练/测试数据集所构造的训练集和验证集合并,采用sklearn库中的KFold重新按折子进行数据集分割。

import glob
import gc
import numpy as np
from datasets import Dataset, concatenate_datasets
from sklearn.model_selection import KFold

拼接训练集和验证集作为一个数据集。

datasets = concatenate_datasets([train_dataset, eval_dataset])
len(datasets)
21135

创建KFold对象用于按折子划分数据集。

  • n_splits=5:表示将数据集划分为5份。
  • shuffle=True:表示调用kf.split划分数据集前先将顺序打乱。

KFold是由sklearn库提供的k折交叉验证方法,它通过将数据集分成k个相同大小的子集(称为折),每次迭代数据集时,使用其中一个作为验证集,其余4个作为训练集,并重复这个过程k次。

kf = KFold(n_splits=5, shuffle=True)
kf
KFold(n_splits=5, random_state=None, shuffle=True)

用kfold划分数据集时,实际拿到的是数据在数据集中的索引顺序,如下面示例的效果。

indexes = kf.split(np.arange(len(datasets)))
train_indexes, val_indexes = next(indexes)
train_indexes, val_indexes, len(train_indexes), len(val_indexes)
(array([    0,     2,     3, ..., 21129, 21131, 21134]),
 array([    1,     9,    12, ..., 21130, 21132, 21133]),
 16908,
 4227)

如上所示,训练集的数量16908和验证集的数量4227比例基本是4:1。

2.3 超参数定义

定义超参构造函数,包括训练参数和Lora微调参数。这里相对于之前作的调整在于:

  • 修改评估和保存模型的策略,由每100step改为每个epoch保存一次,原因是前者保存的checkpoint有太多冗余,节省一些磁盘空间。
  • 将num_train_epochs调整为2,表示每个折子的数据集训练2遍,k=5时数据总共会训练10遍。

注:当per_device_train_batch_size=16时训练过程中会意外发生OOM,所以临时将批次大小per_device_train_batch_size改为8.

def build_arguments(output_path):
    train_args = build_train_arguments(output_path)
    train_args.eval_strategy='epoch'
    train_args.save_strategy='epoch'
    train_args.num_train_epochs = 2
    train_args.per_device_train_batch_size = 8
    
    lora_config = build_loraconfig()
    lora_config.lora_dropout = 0.2   
    lora_config.r = 16
    lora_config.lora_alpha = 32
    return train_args, lora_config

Lora配置和前文最后一次训练的配置相同,秩采用16,dropout采用0.2.

2.4 重新定义模型加载

由于训练过程中需要迭代更换不同的训练集和验证集组合,而更换数据集就需要重新创建训练器,传入新的模型实例,相当于从头开始训练。

为了实现后一次训练能在前一次训练结果的基础上继续训练,就需要找到前一次训练的最新checkpoint。所以定义一个find_last_checkpoint方法,用于从一个目录中查找最新的checkpoint。

# 确定最后的checkpoint目录
def find_last_checkpoint(output_dir):
    checkpoint_dirs = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
    last_checkpoint_dir = max(checkpoint_dirs, key=os.path.getctime)
    return last_checkpoint_dir

find_last_checkpoint("/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0830_1")
'/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0830_1/checkpoint-3522'
  • glob.glob 函数可以在指定目录下查找所有匹配 checkpoint-* 模式的子目录
  • os.path.getctime 返回文件的创建时间(或最近修改时间)
  • max 函数根据这些时间找出最后创建的目录,也就是最新的checkpoint。

定义一个新的加载模型的方法,用于从基座模型和指定的checkpoint中加载最新训练的模型,并根据训练目标来设置参数的require_grad属性。

def load_model_with_checkpoint(model_path, checkpoint_path='', device='cuda'):
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)
    # 加载lora权重
    if checkpoint_path: 
        model = PeftModel.from_pretrained(model, model_id=checkpoint_path).to(device)
    
    # 将基础模型的参数设置为不可训练
    for param in model.base_model.parameters():
        param.requires_grad = False
    
    # 将 LoRA 插入模块的参数设置为可训练
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
    return model

如上代码逻辑所示,将来自lora的参数都设置为需要梯度requires_grad = True,其余原始基座模型的参数设置不可训练requires_grad = False

2.5 构建训练过程

在这个训练过程中,除了第一次训练是从0初始化的微调秩矩阵,后面几次训练则都是从指定checkpoint来初始化微调秩,这导致了原先定义的build_trainer方法不通用。所以定义一个新的训练器构建方法,将加载微调参数的逻辑移到外面。

def build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset):
    # 开启梯度检查点时,要执行该方法
    if train_args.gradient_checkpointing:
        model.enable_input_require_grads()
    return Trainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],  # 早停回调
    )

下面定义交叉训练的主循环。

results = []
last_checkpoint_path = ''

for fold, (train_index, val_index) in enumerate(kf.split(np.arange(len(datasets)))):
    train_dataset = datasets.select(train_index)
    eval_dataset = datasets.select(val_index)

    output_path = f'/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_{fold}'
    train_args, lora_config = build_arguments(output_path)
    # 第一次训练和后面几次训练所采用的模型加载方法不同
    if last_checkpoint_path:
        model = load_model_with_checkpoint(model_path, last_checkpoint_path, device)
    else:
        model = load_model(model_path, device)
        model = get_peft_model(model, load_config)

    model.print_trainable_parameters()
    trainer = build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset)
    train_result = trainer.train()
    results.append(train_result)
    
    last_checkpoint_path = find_last_checkpoint(output_path)

代码逻辑说明:

  • kf.split函数划分了5份数据索引,以这5份数据索引进行5次迭代。
  • 使用datasets.select基于索引在每次迭代时选择不同的数据作为训练集和验证集。
  • 为了避免前次迭代训练的结果被下次迭代的结果给覆盖,每次迭代训练通过fold来拼接不同的输出目录output_path。
  • 如果存在last_checkpoint_path,则从checkpoint来加载模型,如果不存在,则使用get_peft_model向模型中插入一个新的Lora微调秩。
  • 使用新的build_trainer_v2方法来构建训练器并开始训练。
  • 每次迭代完都找出此次训练中最新的checkpoint,作为下次训练的起点。
2.6 开始训练

运行上面的主循环开始训练。

最终可以收集到5次迭代训练的损失数据如下,每次迭代跑2轮数据集,共跑了10轮数据集。

EpochTraining LossValidation Loss
10.02330.02189
20.01380.01614
30.0088000.011420
40.0046000.013666
50.0032000.004718
60.0030000.004082
70.0072000.001999
80.0000000.000814
90.0049000.002273
100.0102000.002139

对比前面欺诈文本分类微调(七)—— lora单卡二次调优训练进行到2300步左右(大概两遍数据)就开始过拟合(主要现象是验证损失到0.0161就不再下降反而开始升高)。K折交叉训练直到第4次迭代(大概八遍数据)过后才达到损失最低点,第5次迭代才出现了略微的过拟合(相比于第4次),过拟合的现象得到了极大的缓解,验证损失也降到了一个更低的值0.000814,这说明数据相比之前训练来说得到了更充分的使用。

3. 评估测试

由于交叉训练中验证集和训练集都参与了模型学习参数的更新,所以用验证集进行评估已经没有意义。我们直接用测试集进行最后的评估。

第一轮迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_0/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:19<00:00, 11.75it/s]

tn:1135, fp:32, fn:128, tp:1054
precision: 0.9705340699815838, recall: 0.8917089678510999

第三轮迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_2/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:21<00:00, 11.64it/s]

tn:1133, fp:34, fn:64, tp:1118
precision: 0.9704861111111112, recall: 0.9458544839255499

第四次迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_3/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:21<00:00, 11.66it/s]

tn:1128, fp:39, fn:64, tp:1118
precision: 0.9662921348314607, recall: 0.9458544839255499

第五次迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_4/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:22<00:00, 11.58it/s]

tn:1124, fp:43, fn:50, tp:1132
precision: 0.963404255319149, recall: 0.9576988155668359

与之前单卡训练和多卡微调的结果相比,精确率有一点点下降(0.9953->0.9634),但召回率却有了一个比较大的提升(0.9129->0.9576),这个测评结果的数据变化与上面损失结果的数据变化基本是一致的。

小结:本文通过引入K折交叉验证方法,循环选择不同的训练集和验证集进行多次迭代训练,将损失降到了一个更低的值,也在很大程度上缓解了[前面每次训练]过程中都出现的过拟合现象。最终在从未见过的测试数据集上进行评测时,召回率指标也有了一个较大的提升。从这个结果来看,K折交叉验证这种方法确实能让模型对数据学习的更充分,有助于模型泛化能力的提升。

相关阅读

  • 欺诈文本分类检测(五):构建训练/测试集
  • 欺诈文本分类微调(七): lora单卡二次调优
  • 欺诈文本分类检测(十一):LLamaFactory多卡微调
  • 交叉验证方法汇总

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2112466.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

吉利星座03组卫星发射

北京时间2024年9月6日2时30分&#xff0c;在太原卫星发射中心使用长征六号运载火箭&#xff0c;成功将民营“中国星链”——吉利未来出行星座第三个轨道面、吉利星座03组卫星发射升空&#xff0c;10颗卫星顺利进入预定轨道&#xff0c;发射任务获得圆满成功。此次任务是长征系列…

飞思相机存储卡格式化数据如何恢复?提供全面指南

在数字摄影时代&#xff0c;‌飞思相机以其卓越的成像质量和专业的性能&#xff0c;‌赢得了众多摄影师的青睐。‌然而&#xff0c;‌即使是专业的设备也难免遭遇数据丢失的困境&#xff0c;‌尤其是当存储卡不幸被格式化时。‌面对这一突如其来的灾难&#xff0c;‌许多摄影师…

qt QGraphicsScene场景坐标和场景内GraphicsItem局部坐标的相互转换

为了更清晰地解释场景坐标与局部坐标之间的转换过程&#xff0c;我们可以通过一个简单的实例来演示如何赋值场景坐标&#xff0c;并将其转换为图形项的局部坐标。 实例步骤 假设我们有一个场景 QGraphicsScene 和一个矩形图形项 QGraphicsRectItem&#xff0c;矩形的大小为 1…

Redis进阶(六):缓存

1.缓存 速度快的设备可以作为速度慢的设备的缓存 缓存能够有意义&#xff1a;二八定律&#xff0c;20%的数据可以应对80%的请求 通常使用redis作为数据库的缓存&#xff08;mysql&#xff09; 数据库是非常重要的组件&#xff0c;mysql速度比较慢 因为mysql等数据库&#x…

【 C++ 】类和对象的学习(三)

前言&#xff1a; &#x1f618;我的主页&#xff1a;OMGmyhair-CSDN博客 目录 一、初始化列表 二、类型转换 三、static成员 四、友元 五、内部类 六、匿名对象 一、初始化列表 当我们之前在写构造函数时&#xff0c;我们通常在构造函数内对成员变量进行赋值。但其实还…

系统架构师考试学习笔记第三篇——架构设计高级知识(19)嵌入式系统架构设计理论与实践

本章考点&#xff1a; 第19课时主要学习嵌入式系统架构设计的理论和工作中的实践。根据新版考试大纲&#xff0c;本课时知识点会涉及案例分析题&#xff08;25分&#xff09;。在历年考试中&#xff0c;案例题对该部分内容都有固定考查&#xff0c;综合知识选择题目中有固定分值…

北大港中文腾讯提出ViewCrafter:一张图像就可以制作影视特效和游戏画面!

北大和港中文联合腾讯人工智能实验室提出了 ViewCrafter&#xff0c;这是一种利用视频扩散模型的先验从单个或稀疏图像合成一般场景的高保真新视图的新方法。 可以简单理解为将复杂的图像转换成新角度的图像版本。首先&#xff0c;它会使用特殊的算法来读取一张或几张图像&…

SpringBoot项目-实现简单的CRUD功能和分页查询

背景 本博文主要是创建了一个新的SpringBoot项目&#xff0c;实现基本的增删改查&#xff0c;分页查询&#xff0c;带条件的分页查询功能。是方便初学者学习后端项目的一个比较清晰明了的实践代码&#xff0c;读者可根据博文&#xff0c;从自己动手创建一个新的SpringBoot项目…

Scratch教师节 —— 感恩教师节

小虎鲸Scratch资源站-免费Scratch作品源码,素材,教程分享平台! Scratch教师节动画作品——感恩教师节 在这个特别的日子里&#xff0c;我们迎来了教师节。为了表达对老师们的感激之情&#xff0c;Scratch平台上的小朋友们用创意与热情制作了精彩的动画作品——“感恩教师节”。…

在国产芯片上实现YOLOv5/v8图像AI识别-【4.3】RK3588使用yolov8+bytetrack实现跟踪更多内容见视频

本专栏主要是提供一种国产化图像识别的解决方案&#xff0c;专栏中实现了YOLOv5/v8在国产化芯片上的使用部署&#xff0c;并可以实现网页端实时查看。根据自己的具体需求可以直接产品化部署使用。 B站配套视频&#xff1a;https://www.bilibili.com/video/BV1or421T74f 背景…

【Canvas与艺术】四叶花

【成图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>386.四叶花</title><style type"text/css">.c…

GDB watch starti i files

watch break starti 在程序的最初开始运行的位置处断下来 ​​ i files 查看程序及加载的 so 的 sections ​​

【Ubuntu】环境准备

平时不怎么接触运维。linux的东西都快忘完了&#xff0c;正好最近腾讯云优惠&#xff0c;38元一年&#xff0c;优惠拉满&#xff0c;拿下一个玩一玩&#xff0c;可以当小程序的服务器&#xff0c;记录一些常用的操作&#xff0c;省的每次用的时候都想不起来 1.有一个linux系统…

对接后端download接口报未知异常错误

你一定遇到过这种情况&#xff0c;在一个项目中下载功能明明好好的&#xff0c;下载接口调用方法与前端调用方法封装的好好的&#xff0c;可是换了一个接口&#xff0c;竟然搞罢工了&#xff0c;类似下面这样的&#xff0c;你会不会无从下手&#xff0c;不知道该怎么办呢&#…

2.C_数据结构_线性表

线性表的描述 线性表就是若干数据的一个线性序列。 数学表达式&#xff1a; L&#xff1a;表名 a0~an-1&#xff1a;数据元素 n&#xff1a;表长&#xff0c;n>0是为非空表 二元描述形式&#xff1a; D&#xff1a;数据元素D用 ai 表示&#xff0c;这个 i 范围是0~n-1 …

【C++从练气到飞升】21---再谈哈希算法:位图 | 布隆过滤器 | 哈希切分

&#x1f388;个人主页&#xff1a;库库的里昂 ✨收录专栏&#xff1a;C从练气到飞升 &#x1f389;鸟欲高飞先振翅&#xff0c;人求上进先读书&#x1f389; 目录 ⛳️推荐 一、位图 1.1 一道面试题 1.2 位图的概念 1.3 位图的模拟实现 1.4 位图的应用 1.4.1 给定100亿…

双项第一!鼎捷强势领跑PLM市场

近日&#xff0c;国际数据公司IDC发布了《中国PLM市场分析及厂商份额&#xff0c;2023&#xff1a;创新左移》 报告数据显示鼎捷PLM2023年收入增长率39.5%&#xff0c;收入增速市场第一 鼎捷在多个细分行业市场中保持领先&#xff0c;在装备制造PLM领域市场份额达到7.9%市占率…

基于 rt-thread的I2C操作EEPROM(AT24C02)

一、AT24C02 The AT24C01A/02/04/08A/16A provides 1024/2048/4096/8192/16384 bits of serial electrically erasable and programmable read-only memory (EEPROM) organized as 128/256/512/1024/2048 words of 8 bits each.AT24C01A/02/04/08A/16A提供1024/2048/4096/8192…

Redis进阶(三)--Redis高性能底层原理

文章目录 第三章、Redis高性能底层原理一、持久化1、RDB&#xff08;1&#xff09;给哪些内存数据做快照?&#xff08;2&#xff09;RDB文件的生成是否会阻塞主线程&#xff08;3&#xff09;bgsave执的行流程&#xff08;4&#xff09;RDB文件&#xff08;5&#xff09;RDB的…

ios免签H5

1、windows下载mobileconfig文件制作工具&#xff0c;可在csdn搜索iPhone_Mobileconfig_Tool下载安装&#xff1b;IOS 从APP Store 下载Apple Configurator 2 2、用申请的域名SSL证书给mobieconfig文件签名&#xff0c;最好下载Apache证书&#xff0c;里面包含 AE86211.crt…