最实战的GLM4微调入门:从文本分类开始

news2024/10/6 14:26:26

GLM4是清华智谱团队最近开源的大语言模型。

以GLM4作为基座大模型,通过指令微调的方式做高精度文本分类,是学习LLM微调的入门任务。

在这里插入图片描述

使用的9B模型,显存要求相对较高,需要40GB左右。

在本文中,我们会使用 GLM4-9b-Chat 模型在 复旦中文新闻 数据集上做指令微调训练,同时使用SwanLab监控训练过程、评估模型效果。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

相关文章:Qwen2指令微调

知识点:什么是指令微调?

大模型指令微调(Instruction Tuning)是一种针对大型预训练语言模型的微调技术,其核心目的是增强模型理解和执行特定指令的能力,使模型能够根据用户提供的自然语言指令准确、恰当地生成相应的输出或执行相关任务。

指令微调特别关注于提升模型在遵循指令方面的一致性和准确性,从而拓宽模型在各种应用场景中的泛化能力和实用性。

在实际应用中,我的理解是,指令微调更多把LLM看作一个更智能、更强大的传统NLP模型(比如Bert),来实现更高精度的文本预测任务。所以这类任务的应用场景覆盖了以往NLP模型的场景,甚至很多团队拿它来标注互联网数据

下面是实战正片:

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerate
pandas
tiktoken

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas accelerate tiktoken

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.10、tiktokn==0.7.0,更多环境细节可以查看这里

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载GLM4-9b-Chat模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq

# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallback

swanlab_callback = SwanLabCallback(...)

trainer = Trainer(
    ...
    callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlab


def dataset_jsonl_transfer(origin_path, new_path):
    """
    将原始数据集转换为大模型微调所需数据格式的新数据集
    """
    messages = []

    # 读取旧的JSONL文件
    with open(origin_path, "r") as file:
        for line in file:
            # 解析每一行的json数据
            data = json.loads(line)
            context = data["text"]
            catagory = data["category"]
            label = data["output"]
            message = {
                "instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",
                "input": f"文本:{context},类型选型:{catagory}",
                "output": label,
            }
            messages.append(message)

    # 保存重构后的JSONL文件
    with open(new_path, "w", encoding="utf-8") as file:
        for message in messages:
            file.write(json.dumps(message, ensure_ascii=False) + "\n")
            
            
def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = 384 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|system|>\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   


def predict(messages, model, tokenizer):
    device = "cuda"
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print(response)
     
    return response
    
# 在modelscope上下载GLM模型到本地目录下
model_dir = snapshot_download("ZhipuAI/glm-4-9b-chat", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"

train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"

if not os.path.exists(train_jsonl_new_path):
    dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):
    dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)

# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"],
    inference_mode=False,  # 训练模式
    r=8,  # Lora 秩
    lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1,  # Dropout 比例
)

model = get_peft_model(model, config)

args = TrainingArguments(
    output_dir="./output/GLM4-9b",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

swanlab_callback = SwanLabCallback(
    project="GLM4-fintune",
    experiment_name="GLM4-9B-Chat",
    description="使用智谱GLM4-9B-Chat模型在zh_cls_fudan-news数据集上微调。",
    config={
        "model": "ZhipuAI/glm-4-9b-chat",
        "dataset": "huangjintao/zh_cls_fudan-news",
    },
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]

test_text_list = []
for index, row in test_df.iterrows():
    instruction = row['instruction']
    input_value = row['input']
    
    messages = [
        {"role": "system", "content": f"{instruction}"},
        {"role": "user", "content": f"{input_value}"}
    ]

    response = predict(messages, model, tokenizer)
    messages.append({"role": "assistant", "content": f"{response}"})
    result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"
    test_text_list.append(swanlab.Text(result_text, caption=response))
    
swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始,这些loss、grad_norm等信息会到一定的step时打印出来:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的glm4的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的glm4能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了glm4指令微调的训练!

7. 模型推理

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def predict(messages, model, tokenizer):
    device = "cuda"

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return response


# 加载原下载路径的tokenizer和model
tokenizer = AutoTokenizer.from_pretrained("./ZhipuAI/glm-4-9b-chat/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./ZhipuAI/glm-4-9b-chat/", device_map="auto", torch_dtype=torch.bfloat16)

# 加载训练好的Lora模型,将下面的checkpointXXX替换为实际的checkpoint文件名名称
model = PeftModel.from_pretrained(model, model_id="./output/GLM4-9b/checkpoint-XXX")

test_texts = {
    'instruction': "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",
    'input': "文本:航空动力学报JOURNAL OF AEROSPACE POWER1998年 第4期 No.4 1998科技期刊管路系统敷设的并行工程模型研究*陈志英* * 马 枚北京航空航天大学【摘要】 提出了一种应用于并行工程模型转换研究的标号法,该法是将现行串行设计过程(As-is)转换为并行设计过程(To-be)。本文应用该法将发动机外部管路系统敷设过程模型进行了串并行转换,应用并行工程过程重构的手段,得到了管路敷设并行过程模型。"
}

instruction = test_texts['instruction']
input_value = test_texts['input']

messages = [
    {"role": "system", "content": f"{instruction}"},
    {"role": "user", "content": f"{input_value}"}
]

response = predict(messages, model, tokenizer)
print(response)

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:GLM4-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1842081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式开发二十:定时器之基本定时器

定时器是微控制器中的关键外设&#xff0c;用于精确控制时间和事件。通过配置时钟源、预分频器、计数周期和比较值&#xff0c;可以实现各种时间控制任务&#xff0c;如定时中断、PWM生成和时间测量。理解定时器的工作原理和配置方法是嵌入式系统开发中的基本技能。 STM32F407 …

转让神州开头的无区域科技公司需要多少钱

您好&#xff0c;我公司现有2家无区域神州名称的公司转让。所谓无区域名称是公司名称中不带有行政区划、及行业特点的公司名称&#xff0c;都是需要在工商总,局核准名称的&#xff0c;对于民营企业来说也比较喜欢这种名称名称很大气&#xff0c;现在重核更严格了&#xff0c;所…

期货止损口诀需牢记

实战操作难免错&#xff0c;心中不必一团火&#xff1b; 出错认输是常事&#xff0c;亏损不止闯大祸。 止损纪律要定死&#xff0c;价格不能差豪丝&#xff1b; 触及止损要出局&#xff0c;管它价格怎放肆。 强势多空价放宽&#xff0c;价格波动要空间&#xff1b; 大势不改…

OpenAI 前首席科学家 Ilya Sutskever 成立新人工智能公司

OpenAI 联合创始人之一 Ilya Sutskever 在正式离开 OpenAI 一个月后&#xff0c;成立了一家新公司 Safe Superintelligence Inc. (SSI)。Sutskever 是 OpenAI 的长期首席科学家&#xff0c;他与前 Y Combinator 合伙人 Daniel Gross 以及前 OpenAI 工程师 Daniel Levy 共同创立…

[SAP ABAP] MESSAGE消息处理

常用的MESSAGE命令的字符 信息类型描述EError 出现错误消息&#xff0c;应用程序在当前点暂停 WWarning 出现警告消息&#xff0c;用户必须按Enter键才能继续应用程序 IInformation 将打开一个弹出窗口&#xff0c;其中包含消息文本&#xff0c;用户必须按Enter键才能继续 SSu…

watcher学习小结

架构 主要是watcher-api&#xff0c;watcher-applier&#xff0c;watcher-decision-engine watcher-applier watcher-decision-engine 将DecisionEngineManager和DecisionEngineSchedulingService封装到oslo_service&#xff0c;然后调service的launch_service&#xff0c;实…

渗透测试基础(四) MS08-067 漏洞攻击

1. 漏洞介绍 漏洞描述 Microsoft Windows Server服务RPC请求缓冲区溢出漏洞Windows的Server服务在处理特质RPC请求时存在缓冲区溢出漏洞&#xff0c;远程攻击者可以通过发送恶意的RPC请求触发这个溢出&#xff0c;导致完全入侵用户系统&#xff0c;以SYSTEM权限执行任意指令。…

为什么3D渲染让客户无法抗拒?7个重要原因

客户通常对工程、建筑、复杂的室内外设计知之甚少&#xff0c;展示草图只会让他们感到难以理解。不过&#xff0c;现代设计师和建筑师不再需要为此烦恼。 通过使用逼真且沉浸式的3D渲染&#xff0c;他们可以让能够轻松地向客户传达信息和沟通想法。它对赢得客户至关重要。接下…

PyCharm新手入门

前言 在之前《Python集成开发工具的选择》一文中介绍了python初学者可以使用Jupyter Notebook&#xff0c;Jupyter Notebook简单易用&#xff0c;可以用来练习代码编写&#xff0c;但是实际生产开发环境使用这个工具是远远不够用的&#xff0c;因为实际软件开发中需要软件调试…

API接口对接的步骤流程?有哪些注意事项?

API接口对接自动化的实现方法&#xff1f;如何调试API接口发信&#xff1f; 在现代软件开发中&#xff0c;API接口对接已成为各个系统和应用之间进行通信和数据交换的关键技术。AokSend将详细介绍API接口对接的步骤流程&#xff0c;帮助开发者更好地理解和实现这一过程。 API…

在webstorm配置nodejs(从零开始)

在webstorm配置nodejs之前&#xff0c;需要先下载node.js和webStorm。 按下winr&#xff0c;输入cmd打开命令行 输入node -v和npm -v会出现相应的版本&#xff0c;如果报错则需要去下载node.js。 打开webStorm&#xff0c;File—settings 搜索node 选择node.exe安装位置 重启…

大模型日报|8 篇必读的大模型论文

大家好&#xff0c;今日必读的大模型论文来啦&#xff01; 1.Pandora&#xff1a;自回归-扩散混合通用世界模型 世界模型模拟世界在不同行动下的未来状态&#xff0c;它们有助于创建交互式内容&#xff0c;并为有依据的长远推理提供基础。然而&#xff0c;目前的基础模型并不…

【stm32单片机应用】基于I2C协议的OLED显示(利用U82G库)

一、U8g2库 &#xff08;一&#xff09;U8g2简介 U8g2 是一个用于单色和彩色显示的嵌入式图形库&#xff0c;特别适用于单色OLED、LCD显示屏的驱动。它是对早期U8g库的扩展和改进&#xff0c;提供了更多功能和更广泛的硬件支持。U8g2作为一款强大而灵活的嵌入式图形库&#x…

为什么你不能下载哨兵遥感影像?Sentinel-1 和 Sentinel-2(解决)

​ 点击下方全系列课程学习 点击学习—>ArcGIS全系列实战视频教程——9个单一课程组合系列直播回放 点击学习——>遥感影像综合处理4大遥感软件ArcGISENVIErdaseCognition 今天的文章来介绍一下如何下载欧空局哨兵数据&#xff0c;哨兵数据是目前我们可以免费下载的全球…

海康威视-下载的录像视频浏览器播放问题

目录 1、播放异常比对 2、视频编码检查 2.1、正常视频解析 2.2、海康视频解析 2.3、比对工具 3、转码 3.1、maven依赖 3.2、实现代码 4、验证 在前面的文章&#xff08;海康威视-按时间下载录像文件_海康威视 sdk 下载录像 大小0-CSDN博客&#xff09;中&#xff0c;通…

吊打Unity的角色动画重定向专业版工具FPS手臂武器动画动物动画角色动作微调烘焙20240620

今天发现一款关注已久的Unity插件上架商店了&#xff0c;可以将动画从一个通用/人形角色重新定位到另一个通用角色。 吊打Unity的角色动画重定向专业版工具FPS手臂武器动画动物动画角色动作微调烘焙202406201103 Unity 中任何通用角色的终极解决方案。它没有 Humanoid 系统的限…

分析师:是什么导致山寨币在本轮周期表现不佳?

在加密货币领域&#xff0c;山寨币的过度分散化问题逐渐凸显&#xff0c;成为本轮周期内其表现疲软的核心因素。经过深入研究&#xff0c;我发现这种分散化对加密货币市场的整体健康造成了严重威胁。然而&#xff0c;令人遗憾的是&#xff0c;目前看来&#xff0c;我们尚未找到…

ECharts 雷达图案例001-自定义节点动画

ECharts 雷达图案例001-自定义节点动画 引言 在数据可视化的领域中&#xff0c;ECharts 提供了一种强大的工具来展示多维数据。本文将介绍如何使用 ECharts 创建一个自定义节点样式的雷达图&#xff0c;让数据展示更加生动和个性化。 效果预览 通过自定义节点样式&#xff…

数据结构_二叉树

目录 一、树型结构 二、二叉树 2.1 概念 2.2 特殊的二叉树 2.3 二叉树的性质 2.4 二叉树的存储 2.5 遍历二叉树 2.6 操作二叉树 总结 一、树型结构 树是一种非线性的数据结构&#xff0c;它是由 n(n>0) 个有限结点组成一个具有层次关系的集合&#xff0c;一棵 n 个…

CatBoost算法详解

CatBoost算法详解 CatBoost&#xff08;Categorical Boosting&#xff09;是由Yandex开发的一种基于梯度提升决策树&#xff08;GBDT&#xff09;的机器学习算法&#xff0c;特别擅长处理包含类别特征的数据集。它不仅在精度和速度上表现出色&#xff0c;还对类别特征有天然的…