[大模型]GLM4-9B-chat Lora 微调

news2025/1/16 21:10:48

本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境准备

在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。

在这里插入图片描述

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:

python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

pip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0
pip install tiktoken==0.7.0

MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

考虑到部分同学配置环境可能会遇到一些问题,我们在 AutoDL 平台准备了 GLM-4 的环境镜像,该镜像适用于本教程需要 GLM-4 的部署环境。点击下方链接并直接创建 AutoDL 示例即可。(vLLM 对 torch 版本要求较高,且越高的版本对模型的支持更全,效果更好,所以新建一个全新的镜像。) https://www.codewithgpu.com/i/datawhalechina/self-llm/GLM-4

在本节教程里,我们将微调数据集放置在根目录 /dataset。

模型下载

使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。

在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py 执行下载。

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os

model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat', cache_dir='/root/autodl-tmp/glm-4-9b-chat', revision='master')

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
  "instruction": "回答以下用户问题,仅输出答案。",
  "input": "1+1等于几?",
  "output": "2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
  "instruction": "你是谁?",
  "input": "",
  "output": "家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer((f"[gMASK]<sop><|system|>\n假设你是皇帝身边的女人--甄嬛。<|user|>\n"
                            f"{example['instruction']+example['input']}<|assistant|>\n"
                            ), 
                            add_special_tokens=False)
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

GLM4-9B-chat 采用的Prompt Template格式如下:

[gMASK]<sop><|system|> 
假设你是皇帝身边的女人--甄嬛。<|user|> 
小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——<|assistant|> 
嘘——都说许愿说破是不灵的。<|endoftext|>

加载 tokenizer 和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', use_fast=False, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True)

定义 LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是 4 倍。

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],  # 现存问题只微调部分演示即可
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(
    output_dir="./output/GLM4",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=50,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-5,
    save_on_each_node=True,
    gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

保存 lora 权重

lora_path='./GLM4'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

mode_path = '/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat'
lora_path = './GLM4_lora'

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path)

prompt = "你是谁?"
inputs = tokenizer.apply_chat_template([{"role": "user", "content": "假设你是皇帝身边的女人--甄嬛。"},{"role": "user", "content": prompt}],
                                       add_generation_prompt=True,
                                       tokenize=True,
                                       return_tensors="pt",
                                       return_dict=True
                                       ).to('cuda')


gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():
    outputs = model.generate(**inputs, **gen_kwargs)
    outputs = outputs[:, inputs['input_ids'].shape[1]:]
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

完整脚本参考:
https://github.com/datawhalechina/self-llm/blob/master/GLM-4/05-GLM-4-9B-chat%20Lora%20%E5%BE%AE%E8%B0%83.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1810296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python科研做图系列之时序图的绘制——对比折线图

参考知乎 折线图 我需要从两个不同的excel都读取第一列作为时间列,第二列作为编码列。 在同一张图上画出两条时间序列的折线图 横坐标是分钟,纵坐标是编码 帮我画的好看一些,记得解决中文乱码问题 英文版折线图 ,先搞个英文版,导师要求中文的话,再换成中文版 impor…

新技术前沿-2023-大模型学习根据个人数据集微调一个Transformer模型

参考如何根据自己的数据集微调一个 Transformer 模型 我们将通过NLP中最常见的文本分类任务来学习如何在自己的数据集上利用迁移学习(transfer learning)微调一个预训练的Transformer模型——DistilBERT。DistilBERT是BERT的一个衍生版本&#xff0c;它的优点在它的性能与BERT相…

区间预测 | Matlab实现GRU-ABKDE门控循环单元自适应带宽核密度估计多变量回归区间预测

区间预测 | Matlab实现GRU-ABKDE门控循环单元自适应带宽核密度估计多变量回归区间预测 目录 区间预测 | Matlab实现GRU-ABKDE门控循环单元自适应带宽核密度估计多变量回归区间预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现GRU-ABKDE门控循环单元自适应…

沐风老师3DMAX一键多孔结构建模插件Porous使用方法

​3DMAX一键多孔结构建模插件Porous使用教程 3dMax是大家熟知的3D建模软件之一&#xff0c;其功能非常的强大&#xff0c;在科研绘图领域有着非常广泛的应用&#xff0c;但是由于科研绘图的图形&#xff08;模型&#xff09;一般都属于异形结构&#xff0c;手工绘制建模&#x…

Docker | 入门:原理探究

Docker | 入门&#xff1a;原理探究 Run 的运行流程 Docker 底层原理 Docker 是怎么工作的&#xff1f; Docker 是一个 Client-Server 结构的系统&#xff0c;Docker 的守护进程运行在主机上&#xff0c;通过 Socket 从客户端访问。DockerServer 接受到 Docker-Client 的指令…

linux centos如何安装python3版本但不能影响默认python2版本

在CentOS上安装Python3而不影响系统默认的Python2,具有如何安装呢? 1. 更新系统包 首先,确保系统包是最新的: sudo yum update -y2. 安装EPEL存储库 EPEL(Extra Packages for Enterprise Linux)存储库包含许多额外的软件包,包括Python3: sudo yum install epel-rel…

Android Studio Jellyfish版本修改project使用特定jdk版本的步骤

android studio总是把这些东西改来改去让人十分恼火&#xff0c;IDE本身改来改去就让人无法上手就立即工作&#xff0c;很多时间浪费在IDE和gradle的配置和奇奇怪怪现象的斗智斗勇上&#xff0c;搞Android是真的有点浪费生命。一入此坑深不见底 jellyfish版安卓studio已经无法通…

[Vue3:axios]:实现登录跳转页面展示列表(查看教师所承担课程的学生选课情况)

文章目录 一&#xff1a;前置操作项目结构&#xff1a; 二&#xff1a;登录页面主要流程说明运行截图前端代码Login.vue 三&#xff1a;列表页面交互逻辑&#xff1a;涉及页面Page02.vue &#xff08;登录成功跳转学生选课页面&#xff09;运行截图 一&#xff1a;前置操作 ht…

【Linux】进程间通信之命名管道

&#x1f466;个人主页&#xff1a;Weraphael ✍&#x1f3fb;作者简介&#xff1a;目前正在学习c和算法 ✈️专栏&#xff1a;Linux &#x1f40b; 希望大家多多支持&#xff0c;咱一起进步&#xff01;&#x1f601; 如果文章有啥瑕疵&#xff0c;希望大佬指点一二 如果文章对…

【PB案例学习笔记】-18制作一个IP地址编辑框

写在前面 这是PB案例学习笔记系列文章的第18篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

46-4 等级保护 - 网络安全等级保护概述

一、网络安全等级保护概述 原文:没有网络安全就没有国家安全 二、网络安全法 - 安全立法 中华人民共和国主席令 第五十三号 《中华人民共和国网络安全法》已于2016年11月7日由中华人民共和国第十二届全国人民代表大会常务委员会第二十四次会议通过,并自2017年6月1日起正式…

LabVIEW图像采集处理项目中相机选择与应用

在LabVIEW图像采集处理项目中&#xff0c;选择合适的相机是确保项目成功的关键。本文将详细探讨相机选择时需要关注的参数、黑白相机与彩色相机的区别及其适用场合&#xff0c;帮助工程师和开发者做出明智的选择。 相机选择时需要关注的参数 1. 分辨率 定义&#xff1a;分辨率…

C++ - Clion安装Qt msvc2017版本教程,基础环境配置clion+ Qt5.12.12 msvc2017 + VS2019

背景&#xff1a;平时代码开发使用clion&#xff0c;但使用项目要制定mscv2017版本Qt。先装过mingw版本Qt无法运行&#xff0c;但msvc版本依赖装有Visual Studio&#xff0c;本地装的又是2019版。就出现了这个大坑&#xff0c;需要配置好clion Qt msvc2017 VS2019。 文章目录 …

立创EDA专业版设置位号居中并调整字体大小

选择某一个器件位号&#xff0c;右键->查找&#xff1a; 选择查找全部&#xff1a; 下面会显示查找结果&#xff1a; 查看&#xff0c;所有的位号都被选中了&#xff1a; 然后布局->属性位置&#xff1a; 属性位置选择中间&#xff1a; 然后位号就居中了 调整字体大小&a…

串口通信技术基础

1.0 串口通信基础 数据通信的两种常用形式&#xff1a; 1&#xff1a;并行通信 和 串行通信 并行方式&#xff1a;数据的各位使用多条数据线同时发送或同时接收 特点&#xff1a;传送速度快&#xff0c;但因需要多根传输线&#xff0c;曾经在近距离、高速率通信中使用 串行方式…

01——生产监控平台——WPF

生产监控平台—— 一、介绍 VS2022 .net core(net6版本&#xff09; 1、文件夹&#xff1a;MVVM /静态资源&#xff08;图片、字体等&#xff09; 、用户空间、资源字典等。 2、图片资源库&#xff1a; https://www.iconfont.cn/ ; 1.资源字典Dictionary 1、…

验证码识别接口、多种样式验证码识别接口、中英文验证码识别接口

验证码识别接口、多种样式验证码识别接口、中英文验证码识别接口 本文提供一个基于OCR和机器学习的验证码识别接口&#xff0c;能够识别较复杂的中文、英文验证码&#xff0c;在OCR的基础上针对验证码进行算法优化。本接口是收费的&#xff08;最低0.5分1次调用&#xff0c;试…

【最新鸿蒙应用开发】——总结ArkUI生命周期

鸿蒙ArkUI相关的生命周期都有哪些? 1. UIAbility生命周期 onCreate、onWindowStageCreate、onForeground、onBackground、onWindowStageDestroy、onDestroy。 onCreate&#xff1a;Create状态为在应用加载过程中&#xff0c;UIAbility实例创建完成时触发&#xff0c;系统会调…

python tushare股票量化数据处理:笔记

1、安装python和tushare及相关库 matplotlib pyplot pandas pandas_datareader >>> import matplotlib.pyplot as plt >>> import pandas as pd >>> import datetime as dt >>> import pandas_datareader.data as web 失败的尝试yf…

计蒜客:C10 第四部分:深度优先搜索基础 引爆炸弹

【C代码】 #include<bits/stdc.h> using namespace std; int n,m,ans0; char maze[501][501]; bool vis[501][501]; void dfs(int x,int y){vi…