【大模型】微调实战—使用 ORPO 微调 Llama 3

news2024/9/20 18:38:33

ORPO 是一种新颖微调(fine-tuning)技术,它将传统的监督微调(supervised fine-tuning)和偏好对齐(preference alignment)阶段合并为一个过程。这减少了训练所需的计算资源和时间。此外,实证结果表明,ORPO 在各种模型规模和基准测试(benchmarks)上优于其他对齐方法。
在本文中,我们将使用 ORPO 和 TRL 库对新的 Llama 3 8B 模型进行微调。

ORPO

指令微调(instruction tuning)和偏好对齐(preference alignment)是使LLM适应特定任务的基本技术。传统上,这涉及一个多阶段的过程:1/ 在指令上进行监督微调(Supervised Fine-Tuning, SFT),以使模型适应目标领域,然后 2/ 使用偏好对齐方法,如基于人类反馈的强化学习(Reinforcement Learning with Human Feedback, RLHF)或直接偏好优化(Direct Preference Optimization, DPO),以增加生成首选响应而非被拒绝响应的可能性。
在这里插入图片描述

然而,研究人员发现了这种方法的局限性。虽然 SFT 有效地使模型适应所需的领域,但它无意中增加了在首选答案的同时生成不需要的答案的可能性。这就是为什么偏好调整阶段对于扩大首选输出和拒绝输出的可能性之间的差距是必要的。
ORPO 由 Hong 和 Lee (2024) 提出,通过将指令调整和偏好对齐结合到一个单一的整体训练过程中,为这个问题提供了一个优雅的解决方案。 ORPO 修改了标准语言建模目标,将负对数似然损失与优势比 (OR) 项相结合。这种 OR 损失对被拒绝的响应进行弱惩罚,同时对首选响应进行强烈奖励,从而使模型能够同时学习目标任务并与人类偏好保持一致。
在这里插入图片描述
ORPO 已在主要微调库中实现,如 TRL、Axolotl 和 LLaMA-Factory。在下一节中,我们将了解如何与 TRL 一起使用。

使用 ORPO 微调 Llama 3

Llama 3 是Meta开发的最新大型语言模型(LLM)家族。该模型在一个包含15万亿个标记的数据集上进行了训练(相比之下,Llama 2 的训练数据集为2万亿个标记)。目前已经发布了两种模型尺寸:一个是拥有70B参数的模型,另一个是较小的8B参数模型。70B参数的模型已经展示了令人印象深刻的性能,在MMLU基准测试中得分为82,在HumanEval基准测试中得分为81.7。
Llama 3 模型还将上下文长度增加到了8,192个标记(相比之下,Llama 2 为4,096个标记),并且有可能通过RoPE扩展到32k。此外,这些模型使用了一种新的分词器,具有128K标记的词汇量,从而减少了编码文本所需的标记数量15%。这种词汇量的增加也解释了参数从70亿增加到80亿。
ORPO 需要一个偏好数据集,包括提示、选择的答案和拒绝的答案。在此示例中,我们将使用 mlabonne/orpo-dpo-mix-40k ,它是以下高质量 DPO 数据集的组合:

  • argilla/distilabel-capybara-dpo-7k-binarized: highly scored chosen answers >=5 (2,882 samples)
  • argilla/distilabel-intel-orca-dpo-pairs: highly scored chosen answers>=9, not in GSM8K (2,299 samples)
  • argilla/ultrafeedback-binarized-preferences-cleaned: highly scoredchosen answers >=5 (22,799 samples)
  • argilla/distilabel-math-preference-dpo: highly scored chosen answers>=9 (2,181 samples)
  • unalignment/toxic-dpo-v0.2 (541 samples)
  • M4-ai/prm_dpo_pairs_cleaned (7,958 samples)
  • jondurbin/truthy-dpo-v0.1 (1,016 samples)

首先安装所需的库:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们可以导入必要的库并登录W&B(可选)

import gc
import os

import torch
import wandb
from datasets import load_dataset
# from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

# wb_token = userdata.get('wandb')
# wandb.login(key=wb_token)

如果您有最新的 GPU,还应该能够使用 Flash Attention 库将默认的 eager Attention 实现替换为更高效的实现。

if torch.cuda.get_device_capability()[0] >= 8:
    #!pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16

接下来,我们将借助bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 QLoRA 的 PEFT 设置 LoRA 配置。我还使用方便的 setup_chat_format() 函数来修改模型和标记生成器以支持 ChatML。它会自动应用此聊天模板,添加特殊标记,并调整模型嵌入层的大小以匹配新的词汇表大小。
请注意,您需要提交访问 meta-llama/Meta-Llama-3-8B 的请求并登录您的 Hugging Face 帐户。或者,您可以加载模型的非门控副本,例如 NousResearch/Meta–Llama-3-8B。(我选择手动从NousResearch/Meta–Llama-3-8B下载)

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

现在模型已准备好进行训练,我们可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将“chosen”和“rejected”列转换为 ChatML 格式。请注意,我仅使用 1,00 个样本,而不是整个数据集,因为运行时间太长。(我选择手动下载)

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(100))

def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置一些超参数: * learning_rate :与传统的 SFT 甚至 DPO 相比,ORPO 使用非常低的学习率。 8e-6这个值来自原始论文,大致对应于SFT学习率1e-5和DPO学习率5e-6。我建议将其增加到 1e-6 左右以进行真正的微调。 * beta :即论文中的 𝜆 参数,默认值为0.1。原始论文的附录显示了如何通过消融研究选择它。 * 其他参数,如 max_length 和批量大小设置为使用尽可能多的可用 VRAM(此配置中约为 20 GB)。理想情况下,我们会训练模型 3-5 个 epoch,但这里我们坚持使用 1 个 epoch。
最后,我们可以使用 ORPOTrainer 来训练模型,它充当包装器。

orpo_args = ORPOConfig(
    learning_rate=8e-6,
    beta=0.1,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results/",
)

trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

中间需要选择是否使用W&B,不会使用,我选择不使用
在这里插入图片描述
完成了 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B
在这里插入图片描述

生成目录:
在这里插入图片描述

合并完整模型到本地:

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

# Save the merged model and tokenizer to local directory
local_save_directory = "new_model"
model.save_pretrained(local_save_directory)
tokenizer.save_pretrained(local_save_directory)

得到和初始模型一样结构的微调模型;
在这里插入图片描述
完整教程:https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html
本文使用代码对原代码改了一部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1916583.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机毕业设计】012基于微信小程序的科创微应用平台

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

华为ensp实现防火墙的区域管理与用户认证

实验环境 基于该总公司内网,实现图片所在要求 后文配置请以本图为准 接口配置与网卡配置 1、创建vlan 2、防火墙g0/0/0与云页面登录 登录admin,密码Admin123,自行更改新密码 更改g0/0/0口ip,敲下命令service-manage all permit 网卡配置…

彩虹小插画:成都亚恒丰创教育科技有限公司

彩虹小插画:色彩斑斓的梦幻世界 在繁忙的生活节奏中,总有一抹温柔的色彩能悄然触动心弦,那就是彩虹小插画带来的梦幻与宁静。彩虹,这一自然界的奇迹,被艺术家们巧妙地融入小巧精致的插画之中,不仅捕捉了瞬…

3D线上展示技术如何应用到汽车营销中?有哪些优势?

传统的汽车销售主要是通过实体店面展示汽车,但这样的展示方式成本高昂,而且还有空间限制。近年来,随着互联网的不断发展,线上看车逐渐成为当下年轻消费群体的看车新选择,并且线上看车正在从2D平面转向3D立体体验。 一、…

three完全开源扩展案例01-三角形渐变

演示地址 import * as THREE from three import { OrbitControls } from three/examples/jsm/controls/OrbitControls.jsconst box document.getElementById(box)const scene new THREE.Scene()const camera new THREE.PerspectiveCamera(75, box.clientWidth / box.client…

VirtualBox NAT网络模式

设置网络模式 右键网络设置 查看此时IP SSH连接 端口转发设置 ssh连接 samba文件共享 虚拟机上samba服务启动运行了,但由于windows无法连接虚拟机IP,即samba访问的入口堵了,无法像访问本地磁盘一样通过samba通道访问虚拟机 替代方案——多…

自定义在线活动报名表单小程序源码系统 源代码+搭建部署教程 可二次定制开发

系统概述 在数字化时代,线上活动成为连接用户与组织的重要桥梁。为了高效地管理活动报名流程,一款灵活、易用的在线活动报名表单小程序显得尤为重要。本文旨在为开发者提供一套全面的解决方案,包括自定义在线活动报名表单小程序的源代码分析…

YOLOv10改进 | 损失函数篇 | SlideLoss、FocalLoss、VFLoss分类损失函数助力细节涨点(全网最全)

一、本文介绍 本文给大家带来的是分类损失 SlideLoss、VFLoss、FocalLoss损失函数,我们之前看那的那些IoU都是边界框回归损失,和本文的修改内容并不冲突,所以大家可以知道损失函数分为两种一种是分类损失另一种是边界框回归损失,…

推荐算法——MRR

定义: MRR计算的是第一个正确答案的排名的倒数,并对所有查询取平均值。它衡量了模型在排序结果中快速找到正确答案的能力。 其中: Q 是查询的总数。ranki​ 是第 i 个查询中第一个正确答案的排名(位置)。如果第一个正…

jdk中自带的并发类

1、seamplore 信号量 countDownLaunch:等待所有线程都完成,主线程在执行 CyclicBarrirer 内存屏障 exchanger 线程之间交换数据 phaser 阶段协同器 阻塞队列

C语言 | Leetcode C语言题解之第227题基本计算题II

题目&#xff1a; 题解&#xff1a; int calculate(char* s) {int n strlen(s);int stk[n], top 0;char preSign ;int num 0;for (int i 0; i < n; i) {if (isdigit(s[i])) {num num * 10 (int)(s[i] - 0);}if (!isdigit(s[i]) && s[i] ! || i n - 1) {s…

Apache Dubbo与Nacos整合过程

Dubbo服务发现 Dubbo 提供的是一种 Client-Based 的服务发现机制&#xff0c;依赖第三方注册中心组件来协调服务发现过程&#xff0c;支持常用的注册中心如 Nacos、Consul、Zookeeper 等。 以下是 Dubbo 服务发现机制的基本工作原理图&#xff1a; 服务发现包含提供者、消费者…

快速测试electron环境是否安装成功

快速测试electron环境是否安装成功 测试代码正确运行的效果运行错误的效果v22.4.1 版本无法使用v20.15.1版本无法使用v18.20.4 版本无法使用 终极解决办法 测试代码 1.npx create-electron-app my-electron-app 2.cd my-electron-app 3.npm start 正确运行的效果 环境没问题…

如何给ubuntu虚拟机扩容

虚拟机设置 鼠标点击硬盘&#xff0c;弹出对话框后&#xff0c;点击扩展&#xff0c;输入扩展后的硬盘大小&#xff0c;我这里扩展到100G 安装工具 sudo apt-get install gparted 重新分区

边框插画:成都亚恒丰创教育科技有限公司

边框插画&#xff1a;艺术与生活的精致边界 在视觉艺术的广阔天地里&#xff0c;边框插画以其独特的魅力和细腻的表达方式&#xff0c;成为连接艺术与生活的一道精致边界。成都亚恒丰创教育科技有限公司它不仅仅是图像的外框装饰&#xff0c;更是情感、故事与创意的延伸&#…

Vue使用Echarts(入门级)

最终效果&#xff1a; npm install echarts --save // 先安装echarts<template><!-- 创建一个dom区域用于挂载echarts图表 --><div id"chart" style"width: 600px;height:500px;"/> </template> <script> import * as ech…

nginx安装配置视频频服务器-windows

编译安装nginx 1、安装perl 安装地址: https://strawberryperl.com&#xff0c;选择msi安装程序即可 2、安装sed for windows 下载地址&#xff1a;https://sourceforge.net/projects/gnuwin32/files/sed/&#xff0c;执行安装程序结束后&#xff0c;将安装包bin目录配置到…

如何在 C 语言中进行选择排序?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&#xff0c;看过的人都说好。 文章目…

阿里云操作系统智能助手OS Copilot实验测评报告

简介&#xff1a;作为一名学生&#xff0c;阿里云操作系统智能助手OS Copilot对学生的帮助主要体现在提高学习效率、简化操作流程和优化系统管理等方面。通过其丰富的功能&#xff0c;从系统信息的快速获取到复杂的系统运维管理&#xff0c;OS Copilot都能为学生提供极大的便利…

计算机毕业设计Python深度学习游戏推荐系统 Django PySpark游戏可视化 游戏数据分析 游戏爬虫 Scrapy 机器学习 人工智能 大数据毕设

本论文的主要研究内容如下&#xff1a; 了解基于Spark的TapTap游戏数据分析系统的基本架构&#xff0c;掌握系统的开发方法&#xff0c;包括系统开发基本流程、开发环境的搭建、测试与运行等。 主要功能如下&#xff1a; &#xff08;1&#xff09;用户管理模块&#xff1a…