llama-factory 系列教程 (六),linux shell 脚本自动实现批量大模型的训练、部署与评估

news2024/9/22 13:33:17

背景

最近在做大模型微调训练的评估,每次都要手动训练大模型,手动评估。
发现这样太浪费时间了,于是就尝试着使用linux shell 脚本,利用 for 循环自动实现大模型的训练、部署与评估。

实验:在不同的文本分类数据集尺寸上微调大模型

在这次实验中,我们分别使用了100、500、1000和2000条数据对大模型进行了微调。我们的目标是评估不同大小的数据集对大模型表现的影响。

项目开源地址:https://github.com/JieShenAI/csdn/blob/main/24/07/few_shot_sft/readme.md

实验方法

为了高效地完成微调任务,我们使用了Linux shell脚本 自动化运行。具体步骤如下:

  1. 数据准备:将不同大小的数据集准备好。
  2. 批量微调:利用Linux shell脚本批量化地微调大模型,自动保存微调后的模型权重。
  3. 自动评估:微调完成后,脚本会自动调用评估程序,对模型在测试集上的表现进行评估。

这种方法极大地提高了工作效率。若不使用自动化脚本,我们需要手动逐个训练模型,然后手动运行评估程序,这不仅耗时,而且容易出错。

优势

  • 时间节省:利用自动化脚本,我们可以在夜间让计算机自行完成微调和评估工作,第二天早上起床后即可查看结果。
  • 减少人工干预:整个过程无需过多人工干预,减少了人工的时间与精力。

通过这种方式,我们能够得出不同大小数据集对大模型表现的影响,为进一步的研究提供了宝贵的数据支持。

项目文件介绍

  • build_llm_data.ipynb
    从训练集中随机筛选并转换为Alpaca样式的数据集格式
    在大模型的微调过程中,从训练集中随机抽取不同规模的数据样本,以便进行模型的测试和优化。本文从训练集中随机筛选100、500、1000和2000条数据,并将这些数据转换为Alpaca样式的微调数据集格式,最后将筛选后的数据保存在data文件夹下。
    本文在文本分类数据集上进行模型训练。
    下述是转化为大模型微调的数据集样例:
    [
      {
        "instruction": "You are a document classifier. When given the text, you classify the text into one of the following categories:\n\n\"Human Necessities\"\n\"Performing Operations; Transporting\"\n\"Chemistry; Metallurgy\"\n\"Textiles; Paper\"\n\"Fixed Constructions\"\n\"Mechanical Engineering; Lightning; Heating; Weapons; Blasting\"\n\"Physics\"\n\"Electricity\"\n\"General tagging of new or cross-sectional technology\"\n\"Unknown\"\n\nYour output should only contain one of the categories and no explanation or any other text.",
        "input": "Classify the document:\nan image sensor device may include a dual - gated charge storage region within a substrate . the dual - gated charge storage region includes first and second diodes within a common charge generating region . this charge generating region is configured to receive light incident on a surface of the image sensor device . the first and second diodes include respective first conductivity type regions responsive to first and second gate signals , respectively . these first and second gate signals are active during non - overlapping time intervals .",
        "output": "Electricity"
      },
      ...
    ]
    
  • train.sh
    在开始训练之前,需要在 LLaMA-Factory/data/dataset_info.json 文件中注册 data 目录下的数据集。接下来,从 LLaMA-Factory 的可视化界面获取 LoRA 微调的命令行。train.sh 脚本实现了批量化训练,并在训练完成后保存 LoRA 的权重。
    # 对所有切分后的数据集进行训练
    cd LLaMA-Factory
    data_files=(llm_train_100 llm_train_500 llm_train_1000 llm_train_2000)
    echo ${data_files[@]}
    
    for data_file in ${data_files[@]}; do
        echo ${data_file}
        llamafactory-cli train \
            --stage sft \
            --do_train True \
            --model_name_or_path ZhipuAI/glm-4-9b-chat \
            --preprocessing_num_workers 16 \
            --finetuning_type lora \
            --template glm4 \
            --flash_attn auto \
            --dataset_dir data \
            --dataset ${data_file} \
            --cutoff_len 1024 \
            --learning_rate 5e-05 \
            --num_train_epochs 3.0 \
            --max_samples 100000 \
            --per_device_train_batch_size 2 \
            --gradient_accumulation_steps 4 \
            --lr_scheduler_type cosine \
            --max_grad_norm 1.0 \
            --logging_steps 5 \
            --save_steps 100 \
            --warmup_steps 0 \
            --optim adamw_torch \
            --packing False \
            --report_to none \
            --output_dir saves/GLM-4-9B-Chat/lora/240731-${data_file} \
            --fp16 True \
            --plot_loss True \
            --ddp_timeout 180000000 \
            --include_num_input_tokens_seen True \
            --lora_rank 8 \
            --lora_alpha 16 \
            --lora_dropout 0 \
            --lora_target all
    done
    
    # nohup bash train.sh > train.log 2>&1 &
    
  • eval.sh
    在训练完成后,使用 VLLM 部署训练完成的 LoRA 模型,并将其部署成 API 接口,便于通过 infer_eval.py 进行评估。eval.sh 脚本实现了对训练模型的批量部署与评估,自动化地逐个部署和推理。在评估完成一个大模型后,脚本会杀死正在部署的进程,开始部署下一个大模型,并进行新的评估。
    # conda activate llm
    cd LLaMA-Factory
    
    # kw_arr=(llm_train_100 llm_train_500 llm_train_1000 llm_train_2000)
    kw_arr=(llm_train_100 llm_train_500 llm_train_1000)
    
    
    for kw in "${kw_arr[@]}"; do
        echo $kw
        CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \
            --model_name_or_path /home/jie/.cache/modelscope/hub/ZhipuAI/glm-4-9b-chat \
            --adapter_name_or_path ./saves/GLM-4-9B-Chat/lora/240731-${kw} \
            --template glm4 \
            --finetuning_type lora \
            --infer_backend vllm \
            --vllm_enforce_eager &
            
        # 模型预测推理脚本,便于后续评估
        python ../infer_eval.py ${kw} > ../logs/${kw}.log 2>&1
        # 杀掉服务进程
        pkill -f llamafactory
        echo "Stopped llamafactory"
    done
    
    # nohup bash eval.sh > eval.log 2>&1 &
    
  • infer_eval.py
    利用在线部署的大模型,结合 LangChain 工具,在测试集上逐个进行评估。
    import os
    import json
    import random
    import logging
    import argparse
    import pickle
    import evaluate
    from tqdm import tqdm
    from datasets import load_dataset
    from dataclasses import dataclass, field
    from langchain_openai import ChatOpenAI
    from langchain_core.messages import HumanMessage, SystemMessage
    from langchain_core.output_parsers import StrOutputParser
    
    
    os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
    os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
    
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.FileHandler('../eval.log')],
        level=logging.INFO
    )
    
    
    @dataclass
    class EvalData:
        name : str
        in_cnt : int = 0
        not_in_cnt : int = 0
        preds : list = field(default_factory=list)
        labels : list = field(default_factory=list)
        not_in_texts : list = field(default_factory=list)
        eval : dict = field(default_factory=dict)
    
    def save_obj(obj, name):  
        """  
        将对象保存到文件  
        :param obj: 要保存的对象  
        :param name: 文件的名称(包括路径)  
        """  
        with open(name, 'wb') as f:  
            pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
    
    
    def load_obj(name):  
        """  
        从文件加载对象  
        :param name: 文件的名称(包括路径)  
        :return: 反序列化后的对象  
        """  
        with open(name, 'rb') as f:  
            return pickle.load(f)
    
    
    LABELS_DICT = {
        0: "Human Necessities",
        1: "Performing Operations; Transporting",
        2: "Chemistry; Metallurgy",
        3: "Textiles; Paper",
        4: "Fixed Constructions",
        5: "Mechanical Engineering; Lightning; Heating; Weapons; Blasting",
        6: "Physics",
        7: "Electricity",
        8: "General tagging of new or cross-sectional technology",
    }
    
    
    LABELS_NAME = [
        LABELS_DICT[i]
        for i in range(9)
    ]
    
    LABELS_2_IDS = {
        v : k
        for k, v in LABELS_DICT.items()
    }
    
    
    def compute_metrics(pred, label):
        res = {}
        accuracy = evaluate.load("accuracy")
        res.update(accuracy.compute(
                predictions=pred, 
                references=label
            ))
    
        precision = evaluate.load("precision")
        res.update(precision.compute(
                predictions=pred, 
                references=label,
                average="macro"
            ))
    
        recall = evaluate.load("recall")
        res.update(recall.compute(
                predictions=pred, 
                references=label,
                average="macro"
            ))
    
        f1 = evaluate.load("f1")
        res.update(f1.compute(
                predictions=pred, 
                references=label,
                average="macro"
            ))
        return res
    
    
    def eval(kw):
        eval_data = EvalData(name=kw)
        model = ChatOpenAI(
            api_key="0",
            base_url="http://localhost:8000/v1",
            temperature=0
        )
    
        valid_dataset = load_dataset(
            "json",
            data_files="../data/llm_valid.json"
        )["train"]
        # labels = valid_dataset["output"][:50]
        labels = valid_dataset["output"]
        
        eval_data.labels = labels
        
        parser = StrOutputParser()
        preds = []
        cnt = 0
        for item in tqdm(valid_dataset):
            cnt += 1
            messages = [
                SystemMessage(content=item['instruction']),
                HumanMessage(content=item['input']),
            ]
            chain = model | parser
            pred = chain.invoke(messages).strip()
            preds.append(pred)
            # if cnt == 50:
            #     break
        
        eval_data.preds = preds
    
        not_in_texts = []
        in_cnt = 0
        not_in_cnt = 0
    
        for pred in preds:
            if pred in LABELS_NAME:
                in_cnt += 1
            else:
                not_in_cnt += 1
                not_in_texts.append(pred)
        
        eval_data.in_cnt = in_cnt
        eval_data.not_in_cnt = not_in_cnt
        eval_data.not_in_texts = not_in_texts
        
        pred_num = [
            LABELS_2_IDS[pred] if pred in LABELS_NAME else random.choice(range(9))
            for pred in preds
        ]
        label_num = [
            LABELS_2_IDS[label]
            for label in labels
        ]
        
        eval_data.eval = compute_metrics(pred=pred_num, label=label_num)
        
        logging.info(f"in_cnt: {in_cnt}, not_in_cnt: {not_in_cnt}")
        logging.info(f"eval: {eval_data.eval}")
        
        # 推理结果保存
        save_obj(
                eval_data,
                f"../objs/{kw}.pkl"
            )
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="输入大模型名,开始推理")
        parser.add_argument("kw", help="目前部署的大模型名字")
        args = parser.parse_args()
        logging.info(args.kw)
        eval(args.kw)
    
  • see_result.ipynb
    导入保存到objs文件夹中的预测结果,并进行结果的渲染
    最后结果如下图所示,数据集量越大效果越好:

result.png

各位读者在看完,训练脚本 train.sh, 部署和推理脚本 eval.sh,应该已经明白本项目大致流程。

一言以蔽之,就是在shell脚本中,使用 for 循环实现训练、部署、评估流程。

若大家想复现本文实验,本项目已经在Github开源,项目开源地址:https://github.com/JieShenAI/csdn/blob/main/24/07/few_shot_sft/readme.md

本文主要是为大家展示,使用linux shell 脚本,自动化处理的流程,故在项目的具体细节没有过多的解释。

应该与Bert文本分类进行对比,就可以明显看出大模型的few-shot能力,有读者感兴趣可以实现一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1964670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

记录两道关于编码解码的问题

环境&#xff1a;php环境即可&#xff0c;也可使用phpstudy。 参考文章: 深入理解浏览器解析机制和XSS向量编码-CSDN博客(很重要) HTML 字符编码&#xff08;自我复习&#xff09;-CSDN博客 例题1&#xff1a; <?php header("X-XSS-Protection: 0"); $xss …

Shell编程——简介和基础语法(1)

文章目录 Shell简介什么是ShellShell环境第一个Shell脚本Shell脚本的运行方法 Shell基础语法Shell变量Shell传递参数Shell字符串Shell字符串截取Shell数组Shell运算符 Shell简介 什么是Shell Shell是一种程序设计语言。作为命令语言&#xff0c;它交互式解释和执行用户输入的命…

【算法】一文带你搞懂0-1背包问题!(实战篇)

在【算法】一文带你搞懂0-1背包问题&#xff01;&#xff08;理论篇&#xff09;中&#xff0c;讲解了纯0-1背包问题及其原理&#xff0c;本篇文章中进入LeetCode中0-1背包问题应用的实战环节&#xff0c;主要难点其实在于看出是0-1背包问题、转换成0-1背包问题。 416. 分割等和…

动态卷积(轻量级卷积)替代多头自注意力

动态卷积&#xff0c;它比自注意力更简单、更有效。我们仅基于当前时间步长预测单独的卷积核&#xff0c;以确定上下文元素的重要性。这种方法所需的操作数量随输入长度呈线性增长&#xff0c;而自注意力是二次的。在大规模机器翻译、语言建模和抽象摘要上的实验表明&#xff0…

【论文阅读笔记 + 思考 + 总结】MoMask: Generative Masked Modeling of 3D Human Motions

创新点&#xff1a; VQ-VAE &#x1f449; Residual VQ-VAE&#xff0c;对每个 motion sequence 输出一组 base motion tokens 和 v 组 residual motion tokensbidirectional 的 Masked transformer 用来生成 base motion tokensResidual Transformer 对 residual motion toke…

机器学习 | 分类算法原理——似然函数

Hi&#xff0c;大家好&#xff0c;我是半亩花海。接着上次的逻辑回归继续更新《白话机器学习的数学》这本书的学习笔记&#xff0c;在此分享似然函数这一分类算法原理。本章的分类算法原理基于《基于图像大小进行分类》项目&#xff0c;欢迎大家交流学习&#xff01; 目录 一、…

个性化你的生产力工具:待办事项App定制指南

国内外主流的10款待办事项软件对比&#xff1a;PingCode、Worktile、滴答清单、番茄ToDo、Teambition、Todoist、Microsoft To Do、TickTick、Any.do、Trello。 在寻找合适的待办事项软件时&#xff0c;你是否感到选择众多、难以决断&#xff1f;一个好的待办事项工具可以大大提…

stl-algorithm【1】

#include《algorithm》 交换两数swap&#xff08;x&#xff0c;y&#xff09; 不只可以交换两个“数”&#xff08;数据类型&#xff09; 翻转【借助迭代器】reverse(it1,it2) 仍是左闭右开

国产开源夜莺部署

使用二进制方式部署夜莺 - 快猫星云 (flashcat.cloud) # install mysql yum -y install mariadb* systemctl enable mariadb systemctl restart mariadb mysql -e "SET PASSWORD FOR rootlocalhost PASSWORD(1234);"# install redis yum install -y redis systemctl…

navicat 17 下载安装

百度网盘 通过网盘分享的文件&#xff1a;Navicat17 链接: https://pan.baidu.com/s/1nFFQzWhjxRUM_X6bVlWNGw?pwd8888 提取码: 8888 1.双击运行安装包 2.点击下一步 2.勾选我同意&#xff0c;点击下一步 3.自定义安装路径&#xff0c;点击下一步 4.注意勾选桌面快捷方式&a…

编程新手指南:从入门到精通

编程小白如何成为大神&#xff1f;大学新生的最佳入门攻略 编程已成为当代大学生的必备技能&#xff0c;但面对众多编程语言和学习资源&#xff0c;新生们常常感到迷茫。如何选择适合自己的编程语言&#xff1f;如何制定有效的学习计划&#xff1f;如何避免常见的学习陷阱&…

基于YOLOv8的高压输电线路异物检测系统

基于YOLOv8的高压输电线路异物检测系统 (价格88) 包含 【“鸟窝”&#xff0c;“风筝”&#xff0c;“气球”&#xff0c;“垃圾”】 4个类 通过PYQT构建UI界面&#xff0c;包含图片检测&#xff0c;视频检测&#xff0c;摄像头实时检测。 &#xff08;该系统可以根据数…

众人帮蚂蚁帮任务平台修复版源码,含搭建教程。

全修复运营版本的任务平台&#xff0c;支持垂直领域细分&#xff0c;定向导流&#xff0c;带有排行榜功能&#xff0c;任务发布上传审核&#xff0c;用户信用等级&#xff0c;充值接口等等均完美可用。支付对接Z支付免签接口&#xff0c;环境配置及安装教程都已经打包。 搭建环…

ARM学习(31)编译器对overlay方式的支持

ARM学习&#xff08;31&#xff09;编译器对overlay方式的支持 1、overlay介绍 overlay&#xff1a;重叠得意思&#xff0c;就是可以重复利用得空间&#xff0c;一般在内存上使用这种空间。比如以Windows操作系统为例&#xff0c;其存储空间&#xff08;ROM/FLASH&#xff09;…

springboot垂钓服务系统-计算机毕业设计源码17434

摘要 本文旨在针对垂钓爱好者的需求&#xff0c;基于微信小程序平台&#xff0c;设计并实现一套垂钓服务系统。首先&#xff0c;通过对用户需求进行调研和分析&#xff0c;确定了系统的基本功能模块&#xff0c;包括垂钓点信息展示、用户预约和支付、钓具租赁信息等。接着&…

WebView加载数据的几种方式

之前客户端加载H5时遇到了一些问题&#xff0c;我为了方便解决问题&#xff0c;所以将对应场景复刻到了Demo中&#xff0c;从之前的网络加载模拟为了本地加载Html的方式&#xff0c;但是没想到无意被一个基础知识点卡了一些时间&#xff0c;翻看往昔笔记发现未曾记录这种基础场…

【MATLAB源码】机器视觉与图像识别技术(7)续---BP神经网络

系列文章目录在最后面&#xff0c;各位同仁感兴趣可以看看&#xff01; BP神经网络 第一节、BP网络定义第二节、BP网络结构及其特点第三节、信息传播方式 信息的正向传播&#xff1a;实质是计算网络的输出误差的反向传播&#xff1a;实质是学习过程第四节、 BP网络的算法流程…

python:plotly 网页交互式数据可视化工具

pip install plotly plotly-5.22.0-py3-none-any.whl pip install plotly_express 包含&#xff1a;GDP数据、餐厅的订单流水数据、鸢尾花 Iris数据集 等等 pip show plotly Name: plotly Version: 5.22.0 Summary: An open-source, interactive data visualization librar…

每日OJ_牛客HJ60 查找组成一个偶数最接近的两个素数

目录 牛客HJ60 查找组成一个偶数最接近的两个素数 解析代码 牛客HJ60 查找组成一个偶数最接近的两个素数 查找组成一个偶数最接近的两个素数_牛客题霸_牛客网 解析代码 首先需要判断素数&#xff0c;素数表示除过1和本身&#xff0c;不能被其它数整除。通过循环遍历来判断一…

飞致云开源社区月度动态报告(2024年7月)

自2023年6月起&#xff0c;中国领先的开源软件公司FIT2CLOUD飞致云以月度为单位发布《飞致云开源社区月度动态报告》&#xff0c;旨在向广大社区用户同步飞致云旗下系列开源软件的发展情况&#xff0c;以及当月主要的产品新版本发布、社区运营成果等相关信息。 飞致云开源大屏…