vllm 部署GLM4模型进行 Zero-Shot 文本分类实验,让大模型给出分类原因,准确率可提高6%

news2024/9/26 3:20:31

文章目录

    • 简介
    • 数据集
    • 实验设置
    • 数据集转换
    • 模型推理
    • 评估

简介

本文记录了使用 vllm 部署 GLM4-9B-Chat 模型进行 Zero-Shot 文本分类的实验过程与结果。通过对 AG_News 数据集的测试,研究发现大模型在直接进行分类时的准确率为 77%。然而,让模型给出分类原因描述(reason)后,准确率显著提升至 83%,提升幅度达 6%。这一结果验证了引入 reasoning 机制的有效性。文中详细介绍了实验数据、提示词设计、模型推理方法及评估手段。

复现自这篇论文:Text Classification via Large Language Models. https://arxiv.org/abs/2305.08377 让大模型使用reason。

该项目的文件结构如下所示:

├── cls_vllm.log
├── cls_vllm.py
├── data
│   ├── basic_llm.csv
│   └── reason_llm.csv
├── data_processon.ipynb
├── eval.ipynb
├── output
│   ├── basic_vllm.pkl
│   └── reason_vllm.pkl
├── settings.py
└── utils.py

数据集

现在要找一个数据集做实验,进入 https://paperswithcode.com/。
找到 文本分类,看目前的 SOTA 是在哪些数据集上做的,文本分类. https://paperswithcode.com/task/text-classification

在这里插入图片描述

实验使用了 AG_News 数据集。若您对数据集操作技巧感兴趣,可以参考这篇文章:

datasets库一些基本方法:filter、map、select等. https://blog.csdn.net/sjxgghg/article/details/141384131

实验设置

settings.py 文件中,我们定义了一些实验中使用的提示词:

LABEL_NAMES = ['World', 'Sports', 'Business', 'Science | Technology']

BASIC_CLS_PROMPT = """
你是文本分类专家,请你给下述文本分类,把它分到下述类别中:
* World
* Sports
* Business
* Science | Technology

text是待分类的文本。请你一步一步思考,在label中给出最终的分类结果:
text: {text}
label: 
"""

REASON_CLS_PROMPT = """
你是文本分类专家,请你给下述文本分类,把它分到下述类别中:
* World
* Sports
* Business
* Science | Technology

text是待分类的文本。请你一步一步思考,首先在reason中说明你的判断理由,然后在label中给出最终的分类结果:
text: {text}
reason: 
label: 
""".lstrip()

data_files = [
    "data/basic_llm.csv",
    "data/reason_llm.csv"
]

output_dirs = [
    "output/basic_vllm.pkl",
    "output/reason_vllm.pkl"
]

这两个数据文件用于存储不同提示词的大模型推理数据:

  • data/basic_llm.csv
  • data/reason_llm.csv

数据集转换

为了让模型能够执行文本分类任务,我们需要对原始数据集进行转换,添加提示词。

原始的数据集样式,要经过提示词转换后,才能让模型做文本分类。

代码如下:

data_processon.ipynb

from datasets import load_dataset

from settings import LABEL_NAMES, BASIC_CLS_PROMPT, REASON_CLS_PROMPT, data_files

import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

# 加载 AG_News 数据集的测试集,只使用test的数据去预测
ds = load_dataset("fancyzhx/ag_news")

# 转换为 basic 提示词格式
def trans2llm(item):
    item["text"] = BASIC_CLS_PROMPT.format(text=item["text"])
    return item
ds["test"].map(trans2llm).to_csv(data_files[0], index=False)

# 转换为 reason 提示词格式
def trans2llm(item):
    item["text"] = REASON_CLS_PROMPT.format(text=item["text"])
    return item
ds["test"].map(trans2llm).to_csv(data_files[1], index=False)

上述代码实现的功能就是把数据集的文本,放入到提示词的{text} 里面。

模型推理

本文使用 ZhipuAI/glm-4-9b-chat. https://www.modelscope.cn/models/zhipuai/glm-4-9b-chat 智谱9B的chat模型,进行VLLM推理。

为了简化模型调用,我们编写了一些实用工具:

utils.py

import pickle
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from modelscope import snapshot_download


def save_obj(obj, name):
    """
    将对象保存到文件
    :param obj: 要保存的对象
    :param name: 文件的名称(包括路径)
    """
    with open(name, "wb") as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    """
    从文件加载对象
    :param name: 文件的名称(包括路径)
    :return: 反序列化后的对象
    """
    with open(name, "rb") as f:
        return pickle.load(f)
    


def glm4_vllm(prompts, output_dir, temperature=0, max_tokens=1024):
    # GLM-4-9B-Chat-1M
    max_model_len, tp_size = 131072, 1
    model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat')

    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    llm = LLM(
        model=model_dir,
        tensor_parallel_size=tp_size,
        max_model_len=max_model_len,
        trust_remote_code=True,
        enforce_eager=True,
    )
    stop_token_ids = [151329, 151336, 151338]
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, stop_token_ids=stop_token_ids)

    inputs = tokenizer.apply_chat_template(prompts, tokenize=False, add_generation_prompt=True)
    outputs = llm.generate(prompts=inputs, sampling_params=sampling_params)

    save_obj(outputs, output_dir)

glm4_vllm :

  • 参考自 https://www.modelscope.cn/models/zhipuai/glm-4-9b-chat

    给大家封装好了,以后有任务,直接调用函数

save_obj:

  • 把python对象,序列化保存到本地;

    在本项目中,用来保存 vllm 推理的结果;

模型推理代码
cls_vllm.py

from datasets import load_dataset

from utils import glm4_vllm
from settings import data_files, output_dirs


# basic 预测
basic_dataset = load_dataset(
    "csv",
    data_files=data_files[0],
    split="train",
)
prompts = []
for item in basic_dataset:
    prompts.append([{"role": "user", "content": item["text"]}])
glm4_vllm(prompts, output_dirs[0])


# reason 预测,添加了原因说明
reason_dataset = load_dataset(
    "csv",
    data_files=data_files[1],
    split="train",
)
prompts = []
for item in reason_dataset:
    prompts.append([{"role": "user", "content": item["text"]}])
glm4_vllm(prompts, output_dirs[1])


# nohup python cls_vllm.py > cls_vllm.log 2>&1 &

在推理过程中,我们使用了 glm4_vllm 函数进行模型推理,并将结果保存到指定路径。

output_dirs: 最终推理完成的结果输出路径;

评估

在获得模型推理结果后,我们需要对其进行评估,以衡量分类的准确性。

eval.ipynb

from settings import LABEL_NAMES
from utils import load_obj

from datasets import load_dataset
from settings import data_files, output_dirs

import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

ds = load_dataset("fancyzhx/ag_news")
def eval(raw_dataset, vllm_predict):
    
    right = 0 # 预测正确的数量
    multi_label = 0 # 预测多标签的数量
    
    for data, output in zip(raw_dataset, vllm_predict):
        true_label = LABEL_NAMES[data['label']]
        
        output_text = output.outputs[0].text
        pred_label = output_text.split("label")[-1]
        
        tmp_pred = []
        for label in LABEL_NAMES:
            if label in pred_label:
                tmp_pred.append(label)
        
        if len(tmp_pred) > 1:
            multi_label += 1
        
        if " ".join(tmp_pred) == true_label:
            right += 1
    
    return right, multi_label

我们分别对 basic 和 reason 预测结果进行了评估。

basic 预测结果的评估 :

dataset = load_dataset(
    'csv', 
    data_files=data_files[0], 
    split='train'
    )
output = load_obj(output_dirs[0])

eval(dataset, output)

输出结果:

(5845, 143)

加了reason 预测结果评估:

dataset = load_dataset(
    'csv', 
    data_files=data_files[1], 
    split='train'
    )
output = load_obj(output_dirs[1])

eval(dataset, output)

输出结果:

(6293, 14)

评估结果如下:

  • basic: 直接分类准确率为 77%(5845/7600),误分类为多标签的样本有 143 个。
  • reason: 在输出原因后分类准确率提高至 83%(6293/7600),多标签误分类样本减少至 14 个。

误分类多标签: 这是单分类问题,大模型应该只输出一个类别,但是它输出了多个类别;

可以发现,让大模型输出reason,不仅分类准确率提升了5%,而且在误分类多标签的数量也有所下降。
原先误分类多标签有143条数据,使用reason后,多标签误分类的数量降低到了14条。

这些结果表明,让模型输出 reason的过程,确实能够有效提升分类准确性,还能减少误分类多个标签。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2077520.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

地理空间数据分析技巧:优化的热点分析与异常值分析的应用

热点分析作为一种常用的空间统计方法,能够帮助我们识别地理空间中的热点和冷点区域,即那些高值或低值集中出现的地方。而优化的热点分析进一步简化了这一过程,使用户无需手动调整参数即可获得可靠的结果。此外,异常值分析则专注于…

wooyu漏洞库YYDS!!!入门之道:重现乌云漏洞库

目录 wooyun乌云漏洞库搭建 1、搭建资料 文件结构分析: ​编辑2、搭建过程 2.1、搭建wooyun网站 2.2、配置数据库 2.2.1、修改数据库配置文件conn.php 2.2.2、创建wooyun数据库,并解压数据库文件 2.2.3、连接数据库(数据库默认连接密…

信号与系统——定义与分类(1)

一、信号与系统 信号:信号是信息的表现形式或传送载体,例如电磁波。信号可以用一个函数 yx (t) 来表示。 系统:是指若干相互关联的事物组合而成,具有特定功能的整体。换句话说就是,系统就是对输入信号进行加工和处理&#xff0c…

Nginx: 反向代理和动静分离概述

反向代理 反向代理服务器介于用户和真实服务器之间,提供请求和响应的中转服务对于用户而言,访问反向代理服务器就是访问真实服务器反向代理可以有效降低服务器的负载消耗,提升效率 1 )反向代理的模型 现在我们有一个用户和真实服…

新版cubemx生成CMake工程浮点数打印问题记录

问题现象 解决方案 set(CMAKE_C_LINK_FLAGS "${CMAKE_C_LINK_FLAGS} --specsnano.specs -u _printf_float")参考Cortex-M4权威指南 重新下载和测试

[Algorithm][综合训练][打怪][判断是不是平衡二叉树][最大子矩阵]详细讲解

目录 1.打怪1.题目链接2.算法原理详解 && 代码实现 2.判断是不是平衡二叉树1.题目链接2.算法原理详解 && 代码实现 3.最大子矩阵1.题目链接2.算法原理详解 && 代码实现 1.打怪 1.题目链接 打怪 2.算法原理详解 && 代码实现 自己的版本&…

C++ 设计模式——代理模式

C 设计模式——代理模式 C 设计模式——代理模式1. 主要组成成分2. 逐步构建代理模式2.1 抽象主题类定义2.2 真实主题类实现2.3 代理类实现2.4 主函数 3. 代理模式 UML 图代理模式 UML 图解析 4. 代理模式的优点5. 代理模式的缺点6. 代理模式的分类7. 代理模式和装饰者模式比较…

MybatisPlus:实现分页效果并解决错误:cant found IPage for args

我们在做开发使用mybatisplus 做分页查询的时候遇到了个问题: 继承 IPage拦截没有作用会默认分页,这个时候报了cant found IPage for args 错误~~~ 我们分析了下,其实这个问题很简单,是因为没有给默认值赋值,因为查询…

日撸Java三百行(day35:图的m着色问题)

目录 一、问题描述 二、思路分析 三、代码实现 总结 一、问题描述 在高中学习排列组合的时候,有一个非常经典的问题,就是涂色问题,即用m种颜色给n块区域涂色,要求每块区域只能涂同一种颜色且相邻区域的颜色不能相同&#xff…

pyinstaller将python程序打包成exe文件

将python代码打包成exe文件可以在不安装python环境的情况下直接运行python代码,譬如自己在自己的电脑上写好了代码,想发给其他人使用,可以用下述方法将python程序打包成exe文件,其他人直接执行exe文件即可使用该程序。 1.安装pyi…

二叉搜索树:数据结构之美

目录 引言基础知识 定义性质操作详解 插入节点删除节点查找节点遍历 前序遍历中序遍历后序遍历高级主题 平衡问题AVL树简介应用案例总结 引言 二叉搜索树(Binary Search Tree, BST)是一种特殊的二叉树,它的每个节点具有以下性质:左子树上的所有节点的键…

Python数据采集与网络爬虫技术实训室解决方案

在大数据与人工智能时代,数据采集与分析已成为企业决策、市场洞察、产品创新等领域不可或缺的一环。而Python,作为一门高效、易学的编程语言,凭借其强大的库支持和广泛的应用场景,在数据采集与网络爬虫领域展现出了非凡的潜力。唯…

Mysql重要参数

1、是否开启慢SQL日志 show VARIABLES like slow_query_log%; 2、慢SQL日志保存位置 show VARIABLES like slow_query_log_file%; 3、慢SQL的阈值,超过则是慢SQL,单位秒,默认10s show VARIABLES like long_query_time%;

小阿轩yx-Kubernetes存储入门

小阿轩yx-Kubernetes存储入门 前言 数据是一个企业的发展核心,它涉及到数据存储和数据交换的内容。在生产环境中尤为重要的一部分在 Kubernetes 中另一个重要的概念就是数据持久化 Volume。 Volume 的概念 对多数项目而言 数据文件的存储是非常常见的 在 Kube…

计算机的错误计算(七十四 )

摘要 回复网友的疑问:用错数解释计算机的错误计算(六十四)中的错误计算原因。 计算机的错误计算(六十四)到(六十九),以及(七十一)与(七十三&…

攻防世界 1000次点击

做题笔记。 下载解压 查壳。 32位ida打开。 查找字符串。 winmain函数写的,程序运行如下: 一开始思路是想着分析找到关键代码然后去od进行调试。 后来,额,不想看代码了。吐了。 尝试去字符串搜索flag样式,确实一发现…

高效恢复,无忧存储:2024年数据恢复工具大搜罗

不知道你是否了解过电子存储设备,我们的设备往往都存储在一个小小的芯片里,它为我们提供了数据携带的便捷性,当然也为我们带来了数据意外丢失的风险。为了我们的数据安全,我们来探讨一下有什么数据恢复工具能为我们的资料保驾护航…

Ruo-Yi 前后端分离如何不使用注解@DataSource的方式而是使用Mybatis插件技术实现多数据源的切换【可以根据配置文件进行开启/关闭】

Ruo-Yi 前后端分离如何不使用注解DataSource的方式而是使用Mybatis插件技术实现多数据源的切换【可以根据配置文件进行开启/关闭】 1、首先 配置文件: # 数据源配置 spring:datasource:type: com.alibaba.druid.pool.DruidDataSourcedriverClassName: com.mysql.c…

ZooKeeper--基于Kubernetes部署ZooKeeper

ZooKeeper 服务 服务类型: 无头服务(clusterIP: None),这是 StatefulSet(有状态集)必需的配置。 端口: 2181 (客户端): 用于客户端连接。 2888 (跟随者): 用于 ZooKeeper 服务器之间的连接。 3888 (领导者): 用于领导者…

邮政快递批量查询解决方案:提升业务运营效率

邮政快递批量查询:固乔快递查询助手的高效体验 在电商行业日益繁荣的今天,快递物流成为了连接商家与消费者的关键纽带。而对于需要处理大量订单的电商企业或个人而言,如何高效、准确地查询和跟踪快递物流信息显得尤为重要。幸运的是&#xf…