使用 LlamaFactory 结合开源大语言模型实现文本分类:从数据集构建到 LoRA 微调与推理评估

news2024/12/12 19:37:05

文章目录

    • 背景介绍
      • 文本分类数据集
      • Lora 微调
      • 模型部署与推理
        • 期待模型的输出结果
    • 文本分类评估代码

背景介绍

本文将一步一步地,介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验,以 LoRA微调 qwen/Qwen2.5-7B-Instruct 为例。

文本分类数据集

按照 alpaca 样式构建数据集,并在将其添加到 LLaMA-Factory/data/dataset_info.json 文件中。如此方便直接根据自定义数据集的名字,获取到数据集的数据。

[
  {
    "instruction": "",
    "input": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:\n\n要求}}\nreason: \nlabel:",
    "output": "reason: 该文本主要讨论的是xxx。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
  },
  ...
]

Lora 微调

llamafactory 框架支持网页端训练,但本文选择在终端使用命令行微调模型。

模型微调训练的参数较多,将模型训练的参数都存储在 yaml 文件中。

qwen_train_cls.yaml 的文件内容如下:

### model
model_name_or_path: qwen/Qwen2.5-7B-Instruct

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
# dataset_dir: data
dataset_dir: LLaMA-Factory/data/ 填写相应路径
dataset: 数据集名 
template: qwen
cutoff_len: 2048
# max_samples: 1000 若数据集较大,可随机筛选一部分数据微调模型
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: output/qwen2.5-7B/cls_epoch2 训练的LoRA权重输出路径
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

使用下述命令启动模型训练:

nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &

命令分解介绍:
nohup, 全称为 “no hangup”(不要挂起)。它的作用是让命令在退出终端后仍然运行,防止因关闭终端或会话中断导致进程被终止。
默认情况下,nohup 会将输出重定向到 nohup.out 文件,但这里已经显式指定了输出位置。
llamafactory-cli train qwen_train_cls.yaml 运行 llamafactory-cli 工具,用于执行训练任务。
train 是子命令,表示进行训练。
qwen_train_cls.yaml 是一个配置文件,包含训练所需的超参数、数据路径、模型结构等。
qwen_train_cls.log
将标准输出 (stdout) 重定向到 qwen_train_cls.log 文件中。
即运行过程中的正常日志信息会被记录到这个文件。
2>&1: 将标准错误输出 (stderr) 重定向到标准输出 (stdout)。
这样,所有错误信息也会被写入到 qwen_train_cls.log 文件中。
&: 表示将整个命令放到后台运行。终端会立即返回,您可以继续进行其他操作,而不用等待命令完成。

模型部署与推理

模型训练完成后得到 Lora 权重。相关微调模型部署与推理,请浏览下述两篇文章,相比llamafactory原本的模型推理速度更快。

  • 基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
  • 基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

目前llamafactory已经支持 vllm_infer 推理,这个PR是笔者提交的:

  • llamafactory vllm.https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py
期待模型的输出结果

下述是使用 llamafactory 推理出的数据格式,建议大家在做推理评估时,也做成这个样式,方便统一评估。

{
	"prompt": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:...", 
	"predict": "\nreason: 该文本主要讨论了改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。这些内容与社会管理和经济管理密切相关,但更侧重于行政管理和社会治理的改革,因此更符合“社会管理”这一类别。\n\nlabel: 社会管理", 		 
	"label": "reason: 该文本主要讨论的是改革创新、行政区划调整、体制机制障碍的破除以及行政管理体制等与政府治理和社会管理相关的内容,强调了与高质量发展和生活品质的关系。这些内容显示出对社会管理和行政管理的关注,尤其是在推动城乡一体化和适应高质量发展要求方面。因此,该文本最符合“社会管理”这一类别。\n\nlabel: 社会管理"
}

文本分类评估代码

import os
import re
import json


from sklearn.metrics import classification_report, confusion_matrix

# 文本类别
CLASS_NAME = [
    "产业相关",
    ...
    "法律法规与行政事务",
    "其他",
]


def load_jsonl(file_path):
    """
    加载指定路径的 JSON 文件并返回解析后的数据。

    :param file_path: JSON 文件的路径
    :return: 解析后的数据(通常是字典或列表)
    :raises FileNotFoundError: 如果文件未找到
    :raises json.JSONDecodeError: 如果 JSON 格式不正确
    """
    data = []
    try:
        with open(file_path, "r", encoding="utf-8") as file:
            for line in file:
                tmp = json.loads(line)
                data.append(tmp)
    except FileNotFoundError as e:
        print(f"文件未找到:{file_path}")
        raise e
    except json.JSONDecodeError as e:
        print(f"JSON 格式错误:{e}")
        raise e
    return data


def parser_label(text: str):
    pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"
    matches = re.findall(pattern, text, re.DOTALL)
    if len(matches) == 1:
        return matches[0]
    return None


def trans2num(item):
    predict = parser_label(item["predict"])
    label = parser_label(item["label"])

    predict_idx = -1
    label_idx = -1
    for idx, cls_name in enumerate(CLASS_NAME):
        if predict == cls_name:
            predict_idx = idx

        if label == cls_name:
            label_idx = idx

    return predict_idx, label_idx

def cls_eval(input_file):
    data = load_jsonl(file_path=input_file)
    predicts = []
    labels = []

    for item in data:
        predict, label = trans2num(item)
        if label == -1:
            continue

        predicts.append(predict)
        labels.append(label)

    return classification_report(predicts, labels, output_dict=False)

本文使用了大模型生成式预测文本类别,我没有使用结构化输出的方式,大家可以使用结构化的json格式输出,这样在提取大模型预测结果的时候会方便很多。

大家按照自己模型的输出结果,修改parser_label 函数,这个函数用于从大模型的输出结果提取label。

cls_eval("xxx/generated_predictions.jsonl")

就会得到下述的输出结果:

-1 代表模型预测的类别不在给定的类别中。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2256872.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI大模型赋能医学诊疗与药学服务——课题基金申请辅导项目成功举办

2024年11月23日,北京整合医学学会在线上成功举办了“AI大模型赋能医学诊疗与药学服务——课题基金申请辅导项目”。此次会议吸引了来自全国各地的医学、药学及人工智能领域的专家学者和科研人员积极参与,共同探讨AI大模型在医学诊疗与药学服务中的应用&a…

Java8 CompletableFuture异步编程

文章目录 CompletableFuturede介绍CompletableFuturede使用场景常用异步编程实现方案- Thread- ExecutorService- CountDownLatch- CyclicBarrier- ForkJoinPool- CompletableFuture各种实现方案总结 CompletableFuturede结构结构梳理- Future接口- CompletionStage接口常用方法…

el-thee懒加载删除某条数据 ,el-thee懒加载重置,el-thee刷新某个节点

一、懒加载的tree已经全部展开&#xff0c;外部点击删除的时候不需要重新展开点击获取下一层数据 <template> <el-treeref"tree":data"treeData":props"defaultProps"render-after-expandhighlight-currentlazy:expand-on-click-node&q…

计算机网络-IPSec VPN工作原理

一、IPSec VPN工作原理 昨天我们大致了解了IPSec是什么&#xff0c;今天来学习下它的工作原理。 IPsec的基本工作流程如下&#xff1a; 通过IKE协商第一阶段协商出IKE SA。 使用IKE SA加密IKE协商第二阶段的报文&#xff0c;即IPsec SA。 使用IPsec SA加密数据。 IPsec基本工作…

国际荐酒师Peter助力第六届地博会,推动地理标志产品国际化发展

国际荐酒师Peter Lisicky助力第六届知交会暨地博会&#xff0c;推动地理标志产品国际化发展 第六届粤港澳大湾区知识产权交易博览会暨国际地理标志产品交易博览会于2024年12月9日至11日在中新广州知识城盛大举行&#xff0c;吸引了全球众多行业专家、企业代表及相关机构齐聚一…

Android显示系统(05)- OpenGL ES - Shader绘制三角形(使用glsl文件)

Android显示系统&#xff08;02&#xff09;- OpenGL ES - 概述 Android显示系统&#xff08;03&#xff09;- OpenGL ES - GLSurfaceView的使用 Android显示系统&#xff08;04&#xff09;- OpenGL ES - Shader绘制三角形 Android显示系统&#xff08;05&#xff09;- OpenGL…

【Golang】Go语言编程思想(六):Channel,第一节,介绍Channel

Channel 下面的几个例子将会展示如何定义一个 channel&#xff1a; func chanDemo() {var c chan int // chan int 的含义是, c 是一个 channel, 里面的内容是 int// 上面的声明语句将会创建一个 nil channel, c nil, 它的作用将在 select 当// 中体现 }创建一个非 nil 的 c…

怎么获取Java高并发经验与系统设计技能?

如何获得高并发经验&#xff1f; 这是系统邀请我回答的一个问题&#xff0c;由此也引发了我的一些思考&#xff1a;为什么人人都想要获得高并发经验&#xff1b;想拥有高并发系统设计技能&#xff1f; 其原因LZ认为主要有以下三点&#xff1a; 涨薪&#xff1a;有高并发系统设…

Spherical Harmonics (SH)球谐函数的原理及应用【3DGS】

Spherical Harmonics &#xff08;SH&#xff09;球谐函数的原理及应用【3DGS】 前言球谐函数&#xff08;Spherical Harmonics, SH&#xff09;球谐函数不同阶的表达式以及有什么不同&#xff1f;具体介绍球谐函数基函数球谐函数 前言 高斯泼溅Gaussian Splatting (GS) GS 模…

Java版-图论-拓扑排序与有向无环图

拓扑排序 拓扑排序说明 对一个有向无环图(Directed Acyclic Graph简称DAG)G进行拓扑排序,是将G中所有顶点排成一个线性序列,使得图中任意一对顶点u和v,若边<u,v>∈E(G),则u在线性序列中出现在v之前。通常,这样的线性序列称为满足拓扑次序(Topological Order)的序列…

如何在 Odoo18 视图中添加关联数据看板按钮 | 免费开源ERP实施诀窍

文 / 开源智造 Odoo亚太金牌服务 引言 关联数据看板按钮乃是 Odoo 当中的一项强效功能&#xff0c;它容许用户顺遂地访问相关记录&#xff0c;或者直接从模型的表单视图施行特定操作。它们为用户给予了对重要信息的疾速访问途径&#xff0c;并简化了工作流程&#xff0c;由此…

TCP客户端服务器端通信(线程池版)

1、什么是监听套接字&#xff0c;和UDP相比&#xff0c;TCP为什么文件描述符变多了&#xff1f; 在网络编程中&#xff0c;TCP和UDP是两种常见的传输协议&#xff0c;它们之间最大的不同之一在于连接的管理方式。为了更好地理解这个区别&#xff0c;我们可以用一个生动的比喻来…

【Linux】通过crond服务设置定时执行shell脚本,实际执行时间却延迟了8小时

一、问题描述 通过使用crond服务设置定时任务&#xff0c;在每天凌晨的2:00执行脚本&#xff0c;但检查结果时发现&#xff0c;实际执行时间却在上午10点。 检查shell脚本执行结果发现&#xff0c;实际执行脚本时间在上午10:00&#xff0c;延迟了8小时。 检查系统时间&#xf…

Git基础笔记

目录 1.Git 常用命令 2.Git 分支操作 3.远程仓库操作 Git 概述 Git 是一个免费的、开源的 分布式版本控制系统 &#xff0c;可以快速高效地处理从小型到大型的各种 项目 1.Git 常用命令 1.设置用户签名 git config --global user.name 用户名 2.设置用户签名 git config…

PADS系列:绘制RTL8306原理图的过程

大家好&#xff0c;我是山羊君Goat。 在所有相关的元件都被创建到了原理图库之后&#xff0c;就可以正式开始原理图的绘制了。不过绘制过程中也是会按照一定的顺序来进行的&#xff0c;这样可以达到事半功倍的效果。 首先就是主芯片的放置&#xff0c;这里有三个主芯片&#x…

GCP Case:MountKirk Games

游戏后端 根据游戏活动动态放大或缩小。 连接到托管的nos0l数据库服务。 运行定制的linux发行版。 游戏分析平台 根据游戏活动来扩大或缩小规模直接处理来自游戏服务器的传入数据。 处理由于移动网络缓慢而迟到的数据。 通过sql查询来访问至少10tb的历史数据 处理由用户…

OpenCV相机标定与3D重建(10)眼标定函数calibrateHandEye()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 计算手眼标定&#xff1a; g T c _{}^{g}\textrm{T}_c g​Tc​ cv::calibrateHandEye 是 OpenCV 中用于手眼标定的函数。该函数通过已知的机器人…

【CSS in Depth 2 精译_072】第 12 章 CSS 排版与间距概述 + 12.1 间距设置(上):究竟该用 em 还是 px

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第四部分 视觉增强技术 ✔️【第 12 章 CSS 排版与间距】 ✔️ 12.1 间距设置 ✔️ 12.1.1 使用 em 还是 px ✔️12.1.2 对行高的深入思考12.1.3 行内元素的间距设置 文章目录 第 12 章 排版与间距…

数据结构代码归纳

1.线性表 线性表的顺序表示 定义与初始化 typedef struct SqList{ElemType data[MaxSize];//ElemType *data 开动态数组 int length; }Sqlist; void InitList(SqList &L){L.length0;//若静态数组//若动态数组 //L.data(ElemType*)malloc(sizeof(ElemType)*MaxSize); }…

数据结构 (36)各种排序方法的综合比较

一、常见排序方法分类 插入排序类 直接插入排序&#xff1a;通过构建有序序列&#xff0c;对于未排序数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。希尔排序&#xff1a;是插入排序的一种改进版本&#xff0c;先将整个待排序的记录序列分割成为…