使用LlamaFactory进行模型微调

news2024/10/12 1:40:58

使用LlamaFactory进行模型微调

简介

论文地址:https://arxiv.org/pdf/2403.13372

仓库地址:https://github.com/hiyouga/LLaMA-Factory/tree/main

名词解释

1. 预训练 (Pre-training, PT)

预训练是指模型在大规模无监督数据集上进行初步训练的过程。数据通常由互联网上的文本构成,没有明确的标签。模型通过预测下一个词来学习语言的语法、语义和常识知识。

  • 目标:通过大量的文本数据学习词与词之间的关系,获得广泛的语言理解能力。
  • 训练方式:无监督学习。模型根据已知上下文预测下一个词,并通过这个预测结果不断调整自身的参数。
  • 结果:预训练后的模型具备了基本的语言生成和理解能力,但缺乏完成特定任务的精细调整。

2. 指令微调 (Supervised Fine-tuning, SFT)

指令微调是在预训练的基础上,使用有监督数据对模型进行精细调整的过程。这里的数据是特定任务的数据集,通常包含了人类标注的输入-输出对,指导模型如何在特定指令下工作。

  • 目标:提升模型在特定任务上的表现,让模型更好地理解和执行特定指令。
  • 训练方式:有监督学习。模型根据给定的输入和预期的输出调整参数,以便能够在类似场景下更准确地执行任务。
  • 结果:模型在特定任务(如回答问题、翻译、摘要等)上具有更好的性能和准确性。

3. 基于人工反馈的对齐 (Reinforcement Learning with Human Feedback, RLHF)

RLHF 是通过引入人工反馈,进一步调整模型行为,使其生成的内容更加符合人类期望的一个阶段。具体来说,人类会对模型的输出进行评估,模型根据这些反馈进行强化学习调整。

  • 目标:让模型的输出更符合人类偏好或社会价值观,并减少不当或有害的输出。
  • 训练方式:强化学习。模型会生成多个候选答案,并由人类评估这些答案的质量。基于这些评价,模型会通过强化学习算法(如Proximal Policy Optimization, PPO)调整生成策略。
  • 结果:模型不仅能够理解指令,还能产生与人类偏好对齐的答案。RLHF 在模型的安全性和伦理方面也有很大的作用。

背景

开源大模型如LLaMA,Qwen,Baichuan等主要都是使用通用数据进行训练而来,其对于不同下游的使用场景和垂直领域的效果有待进一步提升,衍生出了微调训练相关的需求,包含预训练(pt),指令微调(sft),基于人工反馈的对齐(rlhf)等全链路。但大模型训练对于显存和算力的要求较高,同时也需要下游开发者对大模型本身的技术有一定了解,具有一定的门槛。

LLaMA-Factory整合主流的各种高效训练微调技术,适配市场主流开源模型,形成一个功能丰富,适配性好的训练框架。项目提供了多个高层次抽象的调用接口,包含多阶段训练,推理测试,benchmark评测,API Server等,使开发者开箱即用。同时提供了基于gradio的网页版工作台,方便初学者可以迅速上手操作,开发出自己的第一个模型。

环境准备

colab

使用google的colab可以快速且免费地使用LlamaFactory进行模型微调。

https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
在这里插入图片描述

本地机器

  1. 本地运行需要必要的cuda和pytorch,如果你在容器里运行,可以直接使用nvidia的pytorch镜像:nvcr.io/nvidia/pytorch:24.05-py3
  2. 安装LlamaFactory
export PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
export USE_MODELSCOPE_HUB=1
export MODELSCOPE_CACHE=/data/modelscope
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git

cd LLaMA-Factory
pip install -e ".[torch,metrics,modelscope]"

llamafactory-cli webui

数据集

LLamaFactory内置了一些常用的数据集,可以直接在现有的数据集上修改,也可以自己扩充新的数据集,详细参考:https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README_zh.md

这里直接用身份验证数据集,修改模型的身份去验证微调效果:

import json

cd /content/LLaMA-Factory/

NAME = "测试AI"
AUTHOR = "测试科技"

with open("data/identity.json", "r", encoding="utf-8") as f:
  dataset = json.load(f)

for sample in dataset:
  sample["output"] = sample["output"].replace("{{"+ "name" + "}}", NAME).replace("{{"+ "author" + "}}", AUTHOR)

with open("data/identity.json", "w", encoding="utf-8") as f:
  json.dump(dataset, f, indent=2, ensure_ascii=False)

目前LLamaFactory支持两种数据集:Alpaca 格式、Sharegpt 格式。相较于 alpaca 格式的数据集,sharegpt 格式支持更多的角色种类,例如 human、gpt、observation、function 等等。

Alpaca格式

指令微调 (Supervised Fine-tuning, SFT) 格式

在指令监督微调时,instruction 列对应的内容会与 input 列对应的内容拼接后作为人类指令,即人类指令为 instruction\ninput。而 output 列对应的内容为模型回答。

如果指定,system 列对应的内容将被作为系统提示词。

history 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容也会被用于模型学习

[
  {
    "instruction": "人类指令(必填)",
    "input": "人类输入(选填)",
    "output": "模型回答(必填)",
    "system": "系统提示词(选填)",
    "history": [
      ["第一轮指令(选填)", "第一轮回答(选填)"],
      ["第二轮指令(选填)", "第二轮回答(选填)"]
    ]
  }
]

预训练 (Pre-training, PT) 格式

在预训练时,只有 text 列中的内容会用于模型学习。

[
  {"text": "document"},
  {"text": "document"}
]

偏好数据集

偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。

它需要在 chosen 列中提供更优的回答,并在 rejected 列中提供更差的回答。

[
  {
    "instruction": "人类指令(必填)",
    "input": "人类输入(选填)",
    "chosen": "优质回答(必填)",
    "rejected": "劣质回答(必填)"
  }
]

Sharegpt格式

指令微调 (Supervised Fine-tuning, SFT) 格式

相比 alpaca 格式的数据集,sharegpt 格式支持更多的角色种类,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 conversations 列中。

注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。

[
  {
    "conversations": [
      {
        "from": "human",
        "value": "人类指令"
      },
      {
        "from": "function_call",
        "value": "工具参数"
      },
      {
        "from": "observation",
        "value": "工具结果"
      },
      {
        "from": "gpt",
        "value": "模型回答"
      }
    ],
    "system": "系统提示词(选填)",
    "tools": "工具描述(选填)"
  }
]

偏好数据集

[
  {
    "conversations": [
      {
        "from": "human",
        "value": "人类指令"
      },
      {
        "from": "gpt",
        "value": "模型回答"
      },
      {
        "from": "human",
        "value": "人类指令"
      }
    ],
    "chosen": {
      "from": "gpt",
      "value": "优质回答"
    },
    "rejected": {
      "from": "gpt",
      "value": "劣质回答"
    }
  }
]

微调

开始微调

可以通过gradio页面直接微调

# 启动gradio,使用7860端口访问
llamafactory-cli webui

选择微调模型和数据集

在这里插入图片描述

开始微调
在这里插入图片描述

微调完成后的文件存在项目目录的saves文件夹下。

判断微调效果

可以通过看损失曲线是否收敛来判断本次训练的效果,如果收敛不好,则需要考虑调小学习率、减少梯度累积、增加训练轮数等方式。以下是一些损失较大的图例(需要修改参数重新微调)。

在这里插入图片描述
在这里插入图片描述

验证微调效果

在这里插入图片描述

微调参数解释

梯度累积 (gradient_accumulation_steps)

模型会在每次处理一个小批次后,不立即更新权重,而是先将每次的小批次梯度累积起来,直到累积到设定的次数为止(例如,4或8),然后再一次性执行权重更新。这样做的主要原因是为了减少显存占用。如果你的显存较为充裕,想要快速更新权重,可以选择较小的值,比如 4,否则可以选择较大的值,比如 8

学习率 (learning_rate)

更大的学习率意味着模型每次更新的幅度更大,收敛速度会较快。然而,过大的学习率可能会导致模型在损失函数的表面上跳动过大,甚至错过全局最优解,导致模型发散或收敛到局部最优解。

5e-5:适合在模型刚开始训练时使用,或者用于一些简单任务和数据集较小的场景。较大的学习率可以帮助模型快速进入一个较好的区域,减少训练时间。

5e-6:适合在微调阶段使用,特别是当你在较大的预训练模型上进行微调时,小的学习率更为合适。因为大模型通常已经学习到了大量的知识,微调阶段只需要进行小幅度的调整,过大的学习率反而可能破坏模型原有的良好性能。

训练轮次 (training_epochs)

在深度学习中,单次迭代(即单个epoch)通常不足以让模型学到足够的信息。模型通过多轮次的训练,不断调整权重,逐渐优化自身,降低误差并提高在验证集上的性能。每次新的epoch开始,模型会使用更新的权重重新开始遍历数据,从而逐步学习数据中的模式和特征。

如果训练轮次设置得太少,模型可能无法充分学习数据中的模式,表现为欠拟合(underfitting),即训练不足,模型在训练集和验证集上的表现都较差。

果训练轮次设置得太多,模型可能过度拟合(overfitting)训练集,即模型对训练数据的表现非常好,但在验证集或测试集上的表现可能会变差,因为模型“记住”了训练数据中的噪声,而不是学习到泛化的模式。

合并LoRA 权重

使用页面的Export功能进行LoRA的权重合并
在这里插入图片描述

合并完成后,可以用以下代码进行测试

import transformers
import torch

model_id='/data/finetuned/llama3-8b-instruct-merged'
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

messages = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "你是谁?"},
]

prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

print("欢迎使用聊天模型!输入 'exit' 退出对话。\n")
while True:
  input_text = input("你: ")

  if input_text.lower() == "exit":
    print("对话结束。")
    break
  outputs = pipeline(
    prompt,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
  )
  print(outputs[0]["generated_text"][len(prompt):])

在这里插入图片描述

问题解决

KeyError: ‘WEBP’

使用图形化界面微调模型时,页面会展示训练损失图,可能会报这个错误,这个错误表明在使用 matplotlib 保存图片为 WEBP 格式时,出现了 KeyError: 'WEBP',即 PIL 不支持 WEBP 格式。这通常是由于 Pillow 库没有正确安装 WEBP 格式的支持引起的。

apt-get update && apt-get install libwebp-dev
pip uninstall pillow
pip install --no-cache-dir pillow

使用python验证是否解决,期待返回True

from PIL import features
print(features.check('webp'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2206517.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka-初识

一、Kafka是什么&#xff1f; Kafka是一个高度可扩展、弹性、容错和安全的分布式流处理平台&#xff0c;由服务器和客户端组成&#xff0c;通过高性能TCP网络协议进行通信。它可以像消息队列一样生产和消费数据。可以部署在裸机硬件、虚拟机和容器上&#xff0c;也可以部署在本…

使用3080ti运行blip2的

使用3080ti运行blip2的案例 注意&#xff01;blip2很吃显存&#xff0c;需要大于80GB显存的卡。我最后安装的所有包的版本信息&#xff08;python 3.9 &#xff09;以供参考&#xff1a; 首先&#xff0c;我在运行blip2的demo的时候显存用了80G以上&#xff0c;所以大家卡的显存…

VS Code最新版本Retome远程ssh不兼容旧服务器问题

✨✨欢迎来到T_X_Parallel的博客&#xff01;&#xff01;       &#x1f6f0;️博客主页&#xff1a;T_X_Parallel       &#x1f6f0;️欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 目录 问题&#xff1a;无法正常使用vscode-remote插件远…

深度优先搜索 - 岛屿最大面积

题目描述 给定一个由 0 和 1 组成的非空二维数组 grid &#xff0c;用来表示海洋岛屿地图。 一个 岛屿 是由一些相邻的 1 (代表土地) 构成的组合&#xff0c;这里的「相邻」要求两个 1 必须在水平或者竖直方向上相邻。你可以假设 grid 的四个边缘都被 0&#xff08;代表水&…

从零开始搭建UVM平台(十二)-加入sequence机制

书接上回&#xff1a; 从零开始搭建UVM平台&#xff08;一&#xff09;-只有uvm_driver的验证平台 从零开始搭建UVM平台&#xff08;二&#xff09;-加入factory机制 从零开始搭建UVM平台&#xff08;三&#xff09;-加入objection机制 从零开始搭建UVM平台&#xff08;四&…

邮件系统国产化改造: 保障信息安全、提升效率的最佳选择

在当前数字化转型的大背景下&#xff0c;我国政府提出了构建网络强国和数字强国的宏伟蓝图。这一战略的实施&#xff0c;不仅为数字政府的建设提供了坚实的基础&#xff0c;也为政府和企业的数字化升级指明了方向。在这一进程中&#xff0c;邮件系统的国产化改造就显得尤为重要…

功能安全测试安全渗透测试,一文讲清楚

本文我们将以围绕系统安全质量提升为目标&#xff0c;讲述在功能安全测试&安全渗透测试上实践过程。 希望通过此篇文章&#xff0c;帮助大家更深入、透彻地了解安全测试。 安全渗透测试实践 安全前置扫描主要是识别白盒漏洞、黑盒漏洞问题&#xff0c;针对JSRC类问题&am…

pycharm里debug时如何看到数据的维度

使用表达式计算&#xff08;Evaluate Expression&#xff09; 调试时&#xff0c;使用 PyCharm 的 “Evaluate Expression” 功能可以动态查看或修改数据。具体步骤如下&#xff1a; 在调试模式中按 Alt F8&#xff08;Windows&#xff09;或 Option F8&#xff08;Mac&…

ARC学习(4)基本编程模型认识(四)----寄存器以及异常数据读取

笔者来聊一下ARC寄存器的获取 在介绍了ARC编程模型的知识点之后,来看一些具体的编程操作,比如如何获取寄存器,如何编写汇编语言实现特定功能? 1、获取寄存器 可以使用内联汇编来实现寄存器的获取,具体格式如下: _Asm:汇编宏标识符,指示内联汇编代码_Save_all_regs:…

第十二章 RabbitMQ之失败消息处理策略

目录 一、引言 二、RepublishMessageRecoverer 实现 2.1. 实现步骤 2.2. 实现代码 2.2.1. 异常交换机队列回收期配置类 2.2.2. 常规交换机队列配置类 2.2.3. 消费者代码 2.2.4. 消费者yml配置 2.2.5. 生产者代码 2.2.6. 生产者yml配置 2.2.7. 运行效果 一、引言 …

【瑞萨RA8D1 CPK开发板】串口的使用和STDOUT输出重定向

串口 本次串口的使用关于时钟导致串口的波特率不对&#xff0c;坑了我很久的时间 使能时钟 串口发现一个问题就是&#xff0c;只能使用下边的时钟配置&#xff0c;修改时钟源和分频系数都会导致串口波特率不正常&#xff0c;这种问题出现在mdkrasc的使用场景之下&#xff1b…

bclinux安装minio和mc及从服务器上下载文件

下载MinIO服务器二进制文件 访问MinIO的官方网站或使用wget、curl等工具直接从MinIO的官方GitHub存储库下载最新版本的MinIO服务器二进制文件。例如&#xff0c;使用以下命令&#xff1a; 下载命令&#xff1a;wget https://dl.min.io/server/minio/release/linux-amd64/ 授…

Hadoop三大组件的工作原理

Hadoop三大组件的工作原理 一、引言 Hadoop是一个开源的分布式计算框架&#xff0c;在大数据处理领域具有举足轻重的地位。其核心组件包括HDFS&#xff08;分布式文件系统&#xff09;、MapReduce&#xff08;分布式计算框架&#xff09;和YARN&#xff08;资源管理系统&…

Vue3 ECharts看板

获取 ECharts - 入门篇 - 使用手册 - Apache ECharts npm install echarts <template><div id"main" style"height:400px;"></div> </template><script lang"ts" setup> import { ref, onMounted } from "…

AcWing 905:区间选点 ← 贪心算法

【题目来源】https://www.acwing.com/problem/content/907/【题目描述】 给定 N 个闭区间 [ai,bi]&#xff0c;请你在数轴上选择尽量少的点&#xff0c;使得每个区间内至少包含一个选出的点。 输出选择的点的最小数量。 位于区间端点上的点也算作区间内。【输入格式】 第一行包…

【论文阅读笔记】End-to-End Object Detection with Transformers

代码地址&#xff1a;https://github.com/facebookresearch/detr 论文小结 本文是Transformer结构应用于目标检测&#xff08;OD&#xff09;任务的开山之作。方法名DETE&#xff0c;取自Detection Transformer。   作为2020年的论文&#xff0c;其表现精度在当时也不算高的…

Linux:信号保存与处理

使用kill -l命令查看信号&#xff1a; 信号量和信号确实一点关系没有 信号是操作系统发出的进程与进程之间的通知于中断&#xff0c;是进程之间时间异步通知的一种方式 先了解同步通信&#xff1a;同步通信是一种比特同步通信技术&#xff0c;要求发收双方具有同频同相的同步…

学以致用 SAP HCM 顾问excel函数实战系列

EXCEL函数&#xff1a;在上学的时候&#xff0c;对word、excel、PPT感觉都很简单&#xff0c;稀里糊涂的学&#xff0c;稀里糊涂的忘&#xff0c;然后走向工作岗位的时候&#xff0c;突然发现这三大宝剑无比锋利&#xff0c;可惜自己太菜&#xff0c;曾经努力学习&#xff0c;但…

前端 | Uncaught (in promise) undefined

前端 | Uncaught (in promise) undefined 最近开发运行前端项目时&#xff0c;经常预计控制台报错 &#xff0c;如下图&#xff1a; 这里我总结下&#xff0c;这种报错的场景和原因&#xff0c;并通过实际代码案例帮助小伙伴更好理解下 。 文章目录 前端 | Uncaught (in promi…

数据丢失的终极克星来了!EasyRecovery17数据恢复软件

数据丢失的终极克星来了&#xff01; 各位亲爱的朋友们&#xff0c;你们有没有经历过那种“哎呀妈呀&#xff0c;重要文件找不到了&#xff01;”的绝望时刻&#xff1f;别急&#xff0c;今天我要向你们安利一款神器——EasyRecovery17数据恢复软件&#xff0c;简直是我们这些“…