欺诈文本分类检测(十七):支持分类原因训练

news2024/9/28 16:38:58

1. 引言

前文数据校正与增强进行了数据增强,本文将使用增强后的数据对模型进行进一步训练,以便得到能同时预测出分类标签、欺诈者、分类原因多个信息的模型。

为此,我们需要对整个训练过程进行调整,包括:

  1. 交叉训练逻辑封装
  2. 数据序列化的改造
  3. 评测方法改造

2. 交叉训练封装

首先,我们将前文 交叉训练验证的代码封装为一个脚本trainer_cross.py,方便复用。内容如下:

import glob
import gc
import numpy as np
from datasets import Dataset, concatenate_datasets
from sklearn.model_selection import KFold
from trainer import *

def find_last_checkpoint(output_dir):
    checkpoint_dirs = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
    last_checkpoint_dir = max(checkpoint_dirs, key=os.path.getctime)
    return last_checkpoint_dir

def load_model_v2(model_path, checkpoint_path='', device='cuda'):
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)
    # 加载lora权重
    if checkpoint_path: 
        model = PeftModel.from_pretrained(model, model_id=checkpoint_path).to(device)
    # 将基础模型的参数设置为不可训练
    for param in model.base_model.parameters():
        param.requires_grad = False
    
    # 将 LoRA 插入模块的参数设置为可训练
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
    return model

def build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset):
    # 开启梯度检查点时,要执行该方法
    if train_args.gradient_checkpointing:
        model.enable_input_require_grads()
    return Trainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],  # 早停回调
    )

def train_kfold(model_path, output_base_path, datasets, build_args_func, fold_num=5, device='cuda', last_checkpoint_path=''):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    kf = KFold(n_splits=fold_num, shuffle=True)
    results = []
    
    for fold, (train_index, val_index) in enumerate(kf.split(np.arange(len(datasets)))):
        print(f"fold={fold} start, train_index={train_index}, val_index={val_index}")
        train_dataset = datasets.select(train_index)
        eval_dataset = datasets.select(val_index)
        print(f"train data: {len(train_dataset)}, eval: {len(eval_dataset)}")
    
        output_path = f'{output_base_path}_{fold}'
        train_args, lora_config = build_args_func(output_path)
        if last_checkpoint_path:
            model = load_model_v2(model_path, last_checkpoint_path, device)
            print(f"fold={fold}, load model from checkpoint: {last_checkpoint_path}")
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(device)
            model = get_peft_model(model, lora_config)
    
        model.print_trainable_parameters()
        trainer = build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset)
        train_result = trainer.train()
        print(f"fold={fold}, result = {train_result}")
        results.append(train_result)
        
        last_checkpoint_path = find_last_checkpoint(output_path)
        
    return results

其中,各个方法的作用释义如下:

  • find_last_checkpoint:用于从一个目录下查找最新的checkpoint。
  • load_model_v2:加载模型和微调的checkpoint,并将lora权重设置为可训练,非lora权重设置为不可训练。
  • build_trainer_v2:构造训练器
  • train_kfold:封装K折交叉训练验证的主循环逻辑,循环的每个批次为不同的数据集

train_kfold是此脚本最终对外公开的方法,它开放了如下参数以便灵活调整训练过程:

  • model_path:基座模型路径;
  • output_base_path:输出模型的基础路径,K折交叉训练会以此路径为基础,来构造每一折数据的输出路径;
  • datasets:经过预处理后的数据集;
  • build_args_func:构造训练参数的方法,根据output_path来构造训练参数和Lora参数;
  • fold_num: 数据集要分割的折数;
  • device: 训练的GPU设备;
  • last_checkpoint_path: 最近一次训练的checkpoint路径,当接着上一次的训练结果继续训练时传此参数。

3. 数据加载改造

当输出数据改变后,模型的预期输出不再仅仅是一个分类标签,还需要包括欺诈者和分类原因。因此,我们加载数据和数据序列化的方式需要作相应调整。

改造数据预处理函数,扩展with_reason参数,参数值定义:

  • true:表示预期结果除了is_fraud字段外,还包含fraud_speaker和reason字段。
  • false:表示预期结果不包含fraud_speaker和reason字段。

代码如下(有变化的仅仅是if with_reason的判断分支)。

def preprocess(item, tokenizer, with_reason=False, max_length=2048):
    system_message = "You are a helpful assistant."
    user_message = item['instruction'] + '\n' + item['input']
    if with_reason: 
        output = {"is_fraud":item["label"], "fraud_speaker":item["fraud_speaker"], "reason":item["reason"]}
    else:
        output = {"is_fraud":item["label"]}
        
    assistant_message = json.dumps(output, ensure_ascii=False)
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False)  
    response = tokenizer(assistant_message, add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  
    # -100是一个特殊的标记,用于指示指令部分的token不应参与损失计算
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    
    # 对输入长度做一个限制保护,超出截断
    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "labels": labels[:max_length]
    }

相应对外的load_dataset方法也扩展with_reason参数,目的兼容之前的单独分类标签训练,支持带原因和不带原因两种加载数据的模式。

def load_one_dataset(data_path, tokenizer, with_reason:bool):
    df = load_jsonl(data_path)
    ds = Dataset.from_pandas(df)
    return ds.map(
        lambda x: preprocess(x, tokenizer, with_reason=with_reason),
        remove_columns=ds.column_names)

def load_dataset(train_path, eval_path, tokenizer, with_reason=False):
    train_dataset = load_one_dataset(train_path, tokenizer, with_reason)
    eval_dataset = load_one_dataset(eval_path, tokenizer, with_reason)
    return train_dataset, eval_dataset

4. 开始训练

4.1 初始化

初始化改为引入新封装的脚本trainer_cross.py

%run trainer_cross.py

数据路径、模型路径、设备定义基本和之前保持一致。

traindata_path = '/data2/anti_fraud/dataset/train0902.jsonl'
evaldata_path = '/data2/anti_fraud/dataset/eval0902.jsonl'
model_path = '/data2/anti_fraud/models/modelscope/hub/Qwen/Qwen2-1___5B-Instruct'
output_base_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913'
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'

加载数据集,使用concatenate_datasets方法将训练集和验证集合并为一个数据集。

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
train_dataset, eval_dataset = load_dataset(traindata_path, evaldata_path, tokenizer, with_reason=True, lazy=False)
datasets = concatenate_datasets([train_dataset, eval_dataset])

在这里插入图片描述

4.2 训练

定义参数构造的方法,用于构造训练参数和Lora参数,具体参数值保持与之前相同。

def build_arguments(output_path):
    train_args = build_train_arguments(output_path)
    train_args.eval_strategy='epoch'
    train_args.save_strategy='epoch'
    train_args.num_train_epochs = 2
    train_args.per_device_train_batch_size = 8
    
    lora_config = build_loraconfig()
    lora_config.lora_dropout = 0.2  
    lora_config.r = 16
    lora_config.lora_alpha = 32
    return train_args, lora_config

调用train_kfold方法开始训练:

results = train_kfold(model_path, output_base_path, datasets, build_args_func=build_arguments, fold_num=5, last_checkpoint_path=last_checkpoint_path)

总共进行了5折数据10轮训练,每折数据进行了两轮训练,相应的训练损失和验证损失数据如下:

EpochTraining LossValidation Loss
10.7801000.825167
20.6967000.813522
30.7854000.738886
40.6662000.731676
50.6794000.619393
60.5589000.610776
70.5821000.503672
80.4297000.490893
90.4833000.394778
100.3080000.372799
4.3 评测

根据验证损失数据,基于前文支持分类原因评测改造的脚本,采用微调效果最好的最后一轮checkpoint进行评测。

%run evaluate_v2.py
testdata_path = '/data2/anti_fraud/dataset/test0902.jsonl'
checkpoint_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913_4/checkpoint-5454'
evaluate_v2(model_path, checkpoint_path, testdata_path, device, debug=True)

三个字段的评测指标分别如下:

字段指标
is_fraudprecision: 0.9422, recall: 0.9434, accuracy: 0.9419
fraud_speakeraccuracy: 0.9175
reasonprecision: 0.3596, recall: 0.3708, f1-score: 0.3571

经过训练后,三个字段的指标都有不同程度的提高,分别为:

  • is_fraud: precision从0.6232提升到0.9422,表明模型在欺诈文本分类任务上的精确率有明显提高,这能减少欺诈文本误报的次数;
  • fraud_speaker: accuracy从0.6327提升到0.9175,表明模型能有效的识别哪些说话者可能涉及欺诈;
  • reason: 召回率从0.2324提升到0.3708,f1-score从0.2638提升到0.3571

可以看到,is_fraud和fraud_speaker两个字段的准确率提升是比较明显的,而reason字段的召回率也有一定程度的提升,但分数没有那么高。

猜想原因可能在于训练不充分,因为从上面训练的损失数据中能看到一个现象:整个10轮训练下来,不论是训练损失还是验证损失都还在持续下降,这说明训练还未完成。

5. 再次训练

调整输出目录,并定义最近一次训练结构的checkpoint路径:

output_base_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0924'
last_checkpoint_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913_4/checkpoint-5454'

将折子数调整为10, 其它都和上面相同,基于指定的checkpoint继续训练:

results = train_kfold(model_path, output_base_path, datasets, build_args_func=build_arguments, fold_num=10, last_checkpoint_path=last_checkpoint_path)

总共进行了10折数据20轮训练,每折数据进行了两轮训练,相应的训练损失和验证损失数据如下:

EpochTraining LossValidation Loss
10.4393000.365142
20.2041000.351331
30.3275000.268683
40.2024000.246474
50.2539000.192165
60.1332000.165728
70.16770.1145
80.15550.1024
90.14310.08527
100.13290.07053
110.12310.06223
120.11390.0571
130.10660.05086
140.10080.04274
150.094840.04154
160.0709000.038615
170.0697000.033552
180.0688000.029461
190.0593000.026729
200.0682000.023861

从训练结果来看,损失数据一直在不断的下降,这里用最后第20轮的checkpoint进行评测(代码省略),三个字段的评测指标分别如下:

字段指标
is_fraud0.9372, recall: 0.9414, accuracy: 0.9383
fraud_speakeraccuracy: 0.9152
reasonprecision: 0.3664, recall: 0.3705, f1-score: 0.3601

结果是有些失望的,虽然损失在不断下降,但各项评测指标基本都没有什么改善。

原因猜测可能有以下几个:

  1. 数据量不足,2-3万条训练数据相对于文本生成任务来说还是太少,可能还不足以明显改善生成文本的相似度。
  2. Lora参数矩阵(r=16)较小,不足以储存足够的信息特征。

补充:后续尝试过将Lora矩阵的秩r调到64,并训练了10折20轮,但测评结果依然与上面的结果相似,没有明显改善。

由于分类原因的相似度指标并不是我们必需的,这里受精力和数据量的限制,暂时没继续往下训练,但欺诈者字段的准确率已经达到预期,带欺诈者和分类原因的response格式已经能够正常生成,如下所示:

{"is_fraud": true, "fraud_speaker": "小灿", "reason": "小灿要求吴某某登录一个网站,并根据其指示进行国际黄金的操作,这种行为很可能是典型的投资诈骗手段。"}
{"is_fraud": true, "fraud_speaker": "李小龙", "reason": "李小龙要求支付一笔费用以帮助其亲戚获得释放,但没有提供任何具体的细节或证明其关系。使用 relative/convicted等模糊词汇来联系律师,并要求立即支付费用,这种行为具有明显的诈骗特征。"}
{"is_fraud": true, "fraud_speaker": "朱立", "reason": "朱立提供的投资方案中,明显存在通过吸引投资者大量投入资金来获取高额回报的行为。这种高回报承诺通常是不现实且具有欺骗性的,极有可能构成经济诈骗。具体表现为:- 98元、598元和998元的投资计划承诺回报远远超过正常市场回报率,这违反了普遍的投资原则和常识。\- 投资98元的方案回报5万到10万元,这样的回报率过高,容易让人怀疑其真实性。\- 同样,其他投资方案也承诺极高的回报,如598元投资回报50万元,998元投资回报200万元,这些回报率远超正常市场水平,极易诱导投资者上当。\n"}

小结:本文通过对数据加载的改造和交叉训练过程的封装,完成了一次针对带分类原因的欺诈文本分类任务的训练,并通过评测方法的改造实现了对不同类型字段的结果评测。从损失数据和评测结果来看,要改善生成文本的精确率和召回率,可能还需要更多更丰富的数据,后续腾出时间再研究。

参考阅读

  • Lora单卡二次调优
  • 交叉训练验证
  • 数据校正与增强
  • 支持分类原因评测改造

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2174104.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果端侧AI布局深度分析

苹果 - 国际巨头的端侧 AI布局 深度分析 1.1.1 苹果AI&#xff1a;模型侧&#xff1a;MM1 3月&#xff0c;苹果发布多模态大模型MM1&#xff0c;拥有高达300亿参数。MM1融合密集模型与MoE变体&#xff0c;涵盖300亿、70亿、30亿参数版。MM1预训练指标领先&#xff0c;在多个多…

ubuntu 安装k8s

#关闭 Swap 内存&#xff0c;配置完成建议重启一下 nano /etc/fstab #注释下面相似的一行 #/swapfile none swap sw 0 0 #重启 reboot#部属k8s apt update && apt install -y apt-transport-https 下载 gpg 密钥 curl https://mi…

基于SpringBoot+Vue的高校实习管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

一次眼睛受损然后恢复的过程

由于多年没有社交比较宅,多年长期盯着电脑和手机,没有保护好眼睛 之前早上醒来有一段时间我眼睛老是分泌各种乱起八遭的东西,导致我眼睛看不清, 2023年3月有天的早上,我又不小心眼睛揉出血了,出门上班路上的时候才知道有这个问题,第二天早上就挂了去了眼科,医生给我开了如下的药…

单细胞转录组|scATAC-seq 数据整合

引言 本文在此展示了如何将多个源自人类外周血单核细胞的单细胞染色质数据集进行整合。其中一个数据集是通过10x Genomics的多组学技术获得的&#xff0c;它涵盖了每个细胞的DNA可及性和基因表达数据。另一个数据集则是通过10x Genomics的单细胞ATAC测序(scATAC-seq)技术得到的…

Mybatis-Plus自动填充时间的配置类

引言&#xff1a;在现代软件开发中&#xff0c;数据库操作是不可或缺的一部分。为了确保数据的准确性和完整性&#xff0c;我们常常需要在数据库记录中添加时间戳&#xff0c;例如创建时间和更新时间。MyBatis-Plus作为一个流行的持久层框架&#xff0c;提供了灵活的机制来实现…

官方权威解读|CNAS-CL01计量溯源性部分解读,供CNAS软件测试实验室参考

CNAS-CL01《检测和校准实验室能力认可准则》是软件测试实验室申请CNAS资质&#xff0c;建设符合CNAS要求的实验室质量管理体系时必须要参考的一部强制性准则。CNAS-CL01一共分为五大核心部分&#xff1a;通用要求、结构要求、资源要求、过程要求和管理体系要求。前面的文章中我…

【零散技术】微信支付保姆教程(一)

序言:时间是我们最宝贵的财富,珍惜手上的每个时分 微信支付十余年&#xff0c;早已成为国内必不可少的支付工具。但是开发对接中 繁杂的各类参数与文档&#xff0c;以及各种证书的申请&#xff0c;着实也成了不少开发者的噩梦&#xff0c;那么今天我们来看看&#xff0c;如何申…

3-3 AUTOSAR RTE 对SR Port的作用

返回总目录->返回总目录<- 一、前言 RTE作为SWC和BSW之间的通信机构,支持Sender-Receiver方式实现ECU内及ECU间的通信。 对于Sender-Receiver Port支持三种模式: 显式访问:若运行实体采用显示模式的S/R通信方式,数据读写是即时的;隐式访问:当多个运行实体需要读取…

小阿轩yx-案例:代码管理系统简介与部署

小阿轩yx-案例&#xff1a;代码管理系统简介与部署 前言 开发一个项目时&#xff0c;如果只有几十行代码或几百行代码时维护还算简单&#xff0c;但是代码数量达到一定程度或两三个人共同开发一个项目时&#xff0c;就很容易会出现代码混乱、冲突、排错难等问题。代码编写完成…

vue3中< keep-alive >页面实现缓存及遇到的问题

vue3中< keep-alive >页面实现缓存及遇到的问题 实现原理&#xff1a;keep-alive 是 Vue 的内置组件&#xff0c;当它包裹动态组件时&#xff0c;会缓存不活动的组件实例&#xff0c;而不是销毁它们。实现不同路由是否缓存只需要设置对应路由参数keepAlive为true&#xf…

Excel里的 $ 是什么意思,绝对引用用法详解来了

大家好&#xff0c;这里是效率办公指南&#xff01; &#x1f511; 在Excel中&#xff0c;$符号是一个功能强大的工具&#xff0c;它用于实现单元格引用的绝对引用和混合引用。了解它的用法对于编写公式和处理数据至关重要。今天&#xff0c;我们将详细介绍$符号的用法和一些实…

【C++】设计用户级缓冲区

目录 缓冲区功能分析 缓冲区空间分配策略分析 数据设计和函数介绍 完整代码 接口介绍 个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 缓冲区功能分析 1.可以向缓冲区写入数据 2.可用从缓冲区读取数据 3.可用窥探数据——把数据拷贝给上层&#xff0c;但缓冲区数据不减少 …

巧用枚举消除条件判断

shigen坚持更新文章的博客写手&#xff0c;记录成长&#xff0c;分享认知&#xff0c;留住感动。个人IP&#xff1a;shigen 在上一篇的文章结合HashMap与Java 8的Function和Optional消除ifelse判断中有讲到如何结合HashMap与Java 8的Function和Optional消除ifelse判断&#xff…

校园二手交易平台的小程序+ssm(lw+演示+源码+运行)

摘 要 随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…

Transformer: Attention is all you need

Transformer于2017年提出&#xff0c;最开始应用于NLP领域&#xff0c;随着Transformer的快速发展&#xff0c;在视觉领域中也越来越多的论文或应用用到了Transformer&#xff0c;这里记录一下自己学习的一些知识点。 PDF&#xff1a; 《Attention Is All You Need》 Code: att…

【HTML5】html5开篇基础(3)

1.❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; Hello, Hello~ 亲爱的朋友们&#x1f44b;&#x1f44b;&#xff0c;这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章&#xff0c;请别吝啬你的点赞❤️❤️和收藏&#x1f4d6;&#x1f4d6;。如果你对我的…

带您了解《人工智能机器视觉应用工程师》

人工智能机器视觉应用是指利用人工智能技术和机器视觉技术相结合&#xff0c;使机器能够像人类一样通过视觉感知和理解环境&#xff0c;从而实现各种应用。随着人工智能技术的不断发展&#xff0c;机器视觉应用在各个领域得到了广泛应用。 在工业制造领域&#xff0c;人工智能机…

Unity3D入门(三) : Android和Unity3D交互 - Android调用Unity

1. 前言 上篇文章&#xff0c;我们讲了如何在Unity3D中过渡地切换相机视角。这篇文章&#xff0c;我们来讲一下Unity3D怎么与Android交互。 1.1 unity和Android的三种通信方式 Unity官方提供的接口 : 有一个弊端&#xff0c;它有一个传输内容量的一个限制&#xff0c;传输内…

Java---异常及处理

一.异常 1.概念 程序的非正常执行。高级语言都有异常处理机制&#xff08;C&#xff0c;Java&#xff09; 2.一般处理异常的方法 Scanner sc new Scanner(System.in);System.out.println("请输入一个数字:");String s sc.nextLine();if (s.matches("[0-9]&qu…