Qwen2大模型微调入门实战(完整代码)

news2024/11/25 16:22:30

Qwen2是通义千问团队的开源大语言模型,由阿里云通义实验室研发。以Qwen2作为基座大模型,通过指令微调的方式实现高准确率的文本分类,是学习大语言模型微调的入门任务。

在这里插入图片描述

指令微调是一种通过在由(指令,输出)对组成的数据集上进一步训练LLMs的过程。 其中,指令代表模型的人类指令,输出代表遵循指令的期望输出。 这个过程有助于弥合LLMs的下一个词预测目标与用户让LLMs遵循人类指令的目标之间的差距。

在这个任务中我们会使用Qwen2-1.5b-Instruct模型在zh_cls_fudan_news数据集上进行指令微调任务,同时使用SwanLab进行监控和可视化。

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本教程参考了这篇文章。

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python,并且有一张英伟达显卡(显存要求并不高,大概10GB左右就可以跑)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
modelscope
transformers
datasets
peft
accelerat
pandas

一键安装命令:

pip install swanlab modelscope transformers datasets peft pandas

本案例测试于modelscope1.14.0、transformers4.41.2、datasets2.18.0、peft0.11.1、accelerate0.30.1、swanlab0.3.9

2.准备数据集

本案例使用的是zh_cls_fudan-news数据集,该数据集主要被用于训练文本分类模型。

zh_cls_fudan-news由几千条数据,每条数据包含text、category、output三列:

  • text 是训练语料,内容是书籍或新闻的文本内容
  • category 是text的多个备选类型组成的列表
  • output 则是text唯一真实的类型

在这里插入图片描述

数据集例子如下:

"""
[PROMPT]Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:[OUTPUT]Sports
"""

我们的训练任务,便是希望微调后的大模型能够根据Text和Category组成的提示词,预测出正确的Output。


我们将数据集下载到本地目录下。下载方式是前往zh_cls_fudan-news - 魔搭社区 ,将train.jsonltest.jsonl下载到本地根目录下即可:

在这里插入图片描述

3. 加载模型

这里我们使用modelscope下载Qwen2-1.5B-Instruct模型(modelscope在国内,所以下载不用担心速度和稳定性问题),然后把它加载到Transformers中进行训练:

from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq

# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

这里直接使用SwanLab和Transformers的集成来实现:

from swanlab.integration.huggingface import SwanLabCallback


swanlab_callback = SwanLabCallback(...)

trainer = Trainer(
    ...
    callbacks=[swanlab_callback],
)

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5. 完整代码

开始训练时的目录结构:

|--- train.py
|--- train.jsonl
|--- test.jsonl

train.py:

import json
import pandas as pd
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.huggingface import SwanLabCallback
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import os
import swanlab


def dataset_jsonl_transfer(origin_path, new_path):
    """
    将原始数据集转换为大模型微调所需数据格式的新数据集
    """
    messages = []

    # 读取旧的JSONL文件
    with open(origin_path, "r") as file:
        for line in file:
            # 解析每一行的json数据
            data = json.loads(line)
            context = data["text"]
            catagory = data["category"]
            label = data["output"]
            message = {
                "instruction": "你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型",
                "input": f"文本:{context},类型选型:{catagory}",
                "output": label,
            }
            messages.append(message)

    # 保存重构后的JSONL文件
    with open(new_path, "w", encoding="utf-8") as file:
        for message in messages:
            file.write(json.dumps(message, ensure_ascii=False) + "\n")
            
            
def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = 384 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|im_start|>system\n你是一个文本分类领域的专家,你会接收到一段文本和几个潜在的分类选项,请输出文本内容的正确类型<|im_end|>\n<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}   


def predict(messages, model, tokenizer):
    device = "cuda"
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print(response)
     
    return response
    
# 在modelscope上下载Qwen模型到本地目录下
model_dir = snapshot_download("qwen/Qwen2-1.5B-Instruct", cache_dir="./", revision="master")

# Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./qwen/Qwen2-1___5B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 加载、处理数据集和测试集
train_dataset_path = "train.jsonl"
test_dataset_path = "test.jsonl"

train_jsonl_new_path = "new_train.jsonl"
test_jsonl_new_path = "new_test.jsonl"

if not os.path.exists(train_jsonl_new_path):
    dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)
if not os.path.exists(test_jsonl_new_path):
    dataset_jsonl_transfer(test_dataset_path, test_jsonl_new_path)

# 得到训练集
train_df = pd.read_json(train_jsonl_new_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,  # 训练模式
    r=8,  # Lora 秩
    lora_alpha=32,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1,  # Dropout 比例
)

model = get_peft_model(model, config)

args = TrainingArguments(
    output_dir="./output/Qwen1.5",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

swanlab_callback = SwanLabCallback(
    project="Qwen2-fintune",
    experiment_name="Qwen2-1.5B-Instruct",
    description="使用通义千问Qwen2-1.5B-Instruct模型在zh_cls_fudan-news数据集上微调。",
    config={
        "model": "qwen/Qwen2-1.5B-Instruct",
        "dataset": "huangjintao/zh_cls_fudan-news",
    }
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

# 用测试集的前10条,测试模型
test_df = pd.read_json(test_jsonl_new_path, lines=True)[:10]

test_text_list = []
for index, row in test_df.iterrows():
    instruction = row['instruction']
    input_value = row['input']
    
    messages = [
        {"role": "system", "content": f"{instruction}"},
        {"role": "user", "content": f"{input_value}"}
    ]

    response = predict(messages, model, tokenizer)
    messages.append({"role": "assistant", "content": f"{response}"})
    result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"
    test_text_list.append(swanlab.Text(result_text, caption=response))
    
swanlab.log({"Prediction": test_text_list})
swanlab.finish()

看到下面的进度条即代表训练开始:

在这里插入图片描述

6.训练结果演示

在SwanLab上查看最终的训练结果:

可以看到在2个epoch之后,微调后的qwen2的loss降低到了不错的水平——当然对于大模型来说,真正的效果评估还得看主观效果。

在这里插入图片描述

可以看到在一些测试样例上,微调后的qwen2能够给出准确的文本类型:

在这里插入图片描述

至此,你已经完成了qwen2指令微调的训练!

相关链接

  • 代码:完整代码直接看本文第5节
  • 实验日志过程:Qwen2-1.5B-Fintune - SwanLab
  • 模型:Modelscope
  • 数据集:zh_cls_fudan_news
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1805932.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解读下/etc/network/interfaces配置文件

/etc/network/interfaces 是一个常见的网络配置文件&#xff0c;通常在 Debian 及其衍生版本的 Linux 发行版中使用。该文件用于配置网络接口和网络连接参数&#xff0c;允许用户手动设置网络连接的属性&#xff0c;包括 IP 地址、子网掩码、网关、DNS 服务器等。 以下是一个可…

LeetCode热题100—链表(二)

19.删除链表的倒数第N个节点 题目 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5] 示例 2&#xff1a; 输入&#xff1a;head [1], n 1 …

人工智能程序员应该有什么职业素养?

人工智能程序员应该有什么职业素养&#xff1f; 面向企业需求去学习AI必备技能实战能力实战能力提升策略 面向企业需求去学习 如果想要应聘AI相关的岗位&#xff0c;就需要知道HR和管理层在招聘时需要考察些什么&#xff0c;面向招聘的需求去学习就能具备AI程序员该有的职业素…

Golang发送邮件如何验证身份?有哪些限制?

Golang发送邮件需要哪些库&#xff1f;怎么设置邮件发送的参数&#xff1f; 对于开发者而言&#xff0c;使用Golang发送邮件是一种常见需求。然而&#xff0c;在发送邮件的过程中&#xff0c;验证身份是一个至关重要的环节&#xff0c;它确保了邮件的可靠性和安全性。A将探讨G…

linux业务代码性能优化点

planning优化的一些改动----------> 减少值传递&#xff0c;多用引用来传递 <---------- // ----------> 减少值传递&#xff0c;多用引用来传递 <---------- // 例1&#xff1a; class A{}; std::vector<A> v; // for(auto elem : v) {} // 不建议&#xff…

FreeRTOS实时系统 在任务中增加数组等相关操作 导致单片机起不来或者挂掉

在调试串口任务中增加如下代码&#xff0c;发现可以用keil进行仿真&#xff0c;但是烧录程序后&#xff0c;调试串口没有打印&#xff0c;状态灯也不闪烁&#xff0c;单片机完全起不来 博主就纳了闷了&#xff0c;究竟是什么原因&#xff0c;这段代码可是公司永流传的老代码了&…

【小白专用24.6.8】c#异步方法 async task调用及 await运行机制

await是C#中用于等待异步操作完成的关键字。它通常用于异步方法内部&#xff0c;使得在等待异步操作期间&#xff0c;线程可以继续执行其他操作&#xff0c;从而保持程序的响应性。 在使用await时&#xff0c;需要注意以下几点&#xff1a; 1. async修饰符&#xff1a; 使用…

Mac环境下,简单反编译APK

一、下载jadx包 https://github.com/skylot/jadx/releases/tag/v1.4.7 下载里面的这个&#xff1a;下载后&#xff0c;找个干净的目录解压&#xff0c;我是放在Downloads下面 二、安装及启动 下载和解压 jadx&#xff1a; 下载 jadx-1.4.7.zip 压缩包。将其解压到你希望的目…

C++ 史上首次超越 C,跃至榜二

TIOBE 公布了 2024 年 6 月的编程语言排行榜。 C在本月的TIOBE指数中成功超越了C&#xff0c;成为新的第二名。它是一种被广泛应用于嵌入式系统、游戏开发和金融交易软件等领域的编程语言。这次的排名是C在TIOBE指数中的历史最高位&#xff0c;同时也是C语言的历史最低位。 T…

排序-读取数据流并实时返回中位数

目录 一、问题描述 二、解题思路 1.顺序表排序法 2.使用大根堆、小根堆 三、代码实现 1.顺序表排序法实现 2.大根堆、小根堆法实现 四、刷题链接 一、问题描述 二、解题思路 1.顺序表排序法 &#xff08;1&#xff09;每次读取一个数就对列表排一次序&#xff0c;对排…

使用Python批量处理Excel的内容

正文共&#xff1a;1500 字 10 图&#xff0c;预估阅读时间&#xff1a;1 分钟 在前面的文章中&#xff08;如何使用Python提取Excel中固定单元格的内容&#xff09;&#xff0c;我们介绍了如何安装Python环境和PyCharm工具&#xff0c;还利用搭好的环境简单测试了一下ChatGPT提…

CPP入门:CPP的内存管理模式

一.new和delete操作自定义类型 1.1C语言的内存管理 在传统的C语言中&#xff0c;malloc和free是无法调用类的构造和析构函数的 #include <iostream> using namespace std; class kuzi { public:kuzi(int a, int b):_a(a), _b(b){cout << "我构造啦" &…

卷积的计算过程

卷积的计算过程 flyfish 包括手动计算&#xff0c;可视化使用torch.nn.Conv2d实现 示例 import torch import torch.nn as nn# 定义输入图像 input_image torch.tensor([[1, 2, 3, 0, 1],[0, 1, 2, 3, 4],[2, 3, 0, 1, 2],[1, 2, 3, 4, 0],[0, 1, 2, 3, 4] ], dtypetorch.f…

20240609如何查询淘宝的历史价格

20240609如何查询淘宝的历史价格 2024/6/9 18:39 百度&#xff1a;淘宝历史价格 淘宝历史价格查询网站 https://zhuanlan.zhihu.com/p/670972171 30秒学会淘宝商品历史价格查询&#xff01; https://item.taobao.com/item.htm?id693104421622&pidmm_29415502_2422500430_1…

AI菜鸟向前飞 — LangChain系列之十七 - 剖析AgentExecutor

AgentExecutor 顾名思义&#xff0c;Agent执行器&#xff0c;本篇先简单看看LangChain是如何实现的。 先回顾 AI菜鸟向前飞 — LangChain系列之十四 - Agent系列&#xff1a;从现象看机制&#xff08;上篇&#xff09; AI菜鸟向前飞 — LangChain系列之十五 - Agent系列&#…

sqlilabs靶场安装

05-sqllabs靶场安装 1 安装 1 把靶场sqli-labs-master.zip上传到 /opt/lampp/htdocs 目录下 2 解压缩 unzip sqli-labs-master.zip3 数据库配置 找到配置文件,修改数据库配置信息 用户名密码&#xff0c;修改为你lampp下mysql的用户名密码&#xff0c;root/123456host:la…

GraphQL(6):认证与中间件

下面用简单来讲述GraphQL的认证示例 1 实现代码 在代码中添加过滤器&#xff1a; 完整代码如下&#xff1a; const express require(express); const {buildSchema} require(graphql); const grapqlHTTP require(express-graphql).graphqlHTTP; // 定义schema&#xff0c;…

计算机网络:数据链路层 - 扩展的以太网

计算机网络&#xff1a;数据链路层 - 扩展的以太网 集线器交换机自学习算法单点故障 集线器 这是以前常见的总线型以太网&#xff0c;他最初使用粗铜轴电缆作为传输媒体&#xff0c;后来演进到使用价格相对便宜的细铜轴电缆。 后来&#xff0c;以太网发展出来了一种使用大规模…

启动xv6遇坑记录

我是在VMware上的Ubuntu22.04.4搭建的&#xff0c;启动xv6遇到超多bug&#xff0c;搞了好几天&#xff0c;所以记录一下。 目录 git push的时候报错 make qemu缺少包 运行make qemu时卡住 可能有影响的主机设置 git push的时候报错 remote: Support for password authent…

kafka-生产者事务-数据传递语义事务介绍事务消息发送(SpringBoot整合Kafka)

文章目录 1、kafka数据传递语义2、kafka生产者事务3、事务消息发送3.1、application.yml配置3.2、创建生产者监听器3.3、创建生产者拦截器3.4、发送消息测试3.5、使用Java代码创建主题分区副本3.6、屏蔽 kafka debug 日志 logback.xml3.7、引入spring-kafka依赖3.8、控制台日志…