11. DPO 微调示例:根据人类偏好优化LLM大语言模型

news2024/11/13 15:09:12

在部署大模型之后,我们必然要和微调打交道。现在大模型的微调有非常多的方法,过去的文章中提到的微调方法通常依赖于问题和答案对,标注成本较高。

2023 年所提出的 Direct Preference Optimization(DPO)为我们提供了一种无需标准标注答案的高效微调方法。DPO 依赖于人类对文本的偏好对(preference pairs),也就是说,数据集中只包含人类对两段文本中哪段更好的判断,而不是具体的正确答案。

在本文中,我们将利用 DPO 来微调一个模型让其按照偏好进行输出。这篇文章也为生成式人工智能导论课程中 HW6: LLM Values Alignment 提供中文引导。

代码文件下载 | 作业PDF

安装和导入一些必要的库

pip install bitsandbytes==0.43.1 datasets==2.19.0 peft==0.10.0 trl==0.8.6 accelerate==0.29.3
import os
import re
import json

import torch
import pandas as pd
from tqdm.auto import tqdm

from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, GenerationConfig
from trl import DPOTrainer

可能的问题:Keras 3 与 Transformers 不兼容

在导入时,你可能会看到以下报错:

RuntimeError: Failed to import trl.trainer.dpo_trainer because of the following error (look up to see its traceback):
Failed to import transformers.trainer because of the following error (look up to see its traceback):
Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_tf_utils because of the following error (look up to see its traceback):
Your currently installed version of Keras is Keras 3, but this is not yet supported in Transformers. Please install the backwards-compatible tf-keras package with pip install tf-keras.

transformers 库建议安装兼容的 tf-keras 包来解决这个兼容性问题。你可以通过以下命令安装:

pip install tf-keras

现在问题应该得到了解决。

加载数据集

我们将使用预先提供的数据集,包括带标签的偏好数据和测试提示数据。

这个数据集来自于生成式人工智能导论的HW6,处理的问题是:是否应该将动漫真人化?两个回答分别对应支持和不支持(由GPT生成),在后面的代码中你将选择支持的占比。

git clone https://github.com/Baiiiiiiiiii/GenAI_hw6_dataset.git
with open("./GenAI_hw6_dataset/labelled_data.json", 'r') as jsonfile:
    full_data = json.load(jsonfile)

with open("./GenAI_hw6_dataset/test_prompt.json", 'r') as jsonfile:
    test_data = json.load(jsonfile)

直观理解数据集:

full_data

image-20240919114655048

使用 HFD 下载模型

我们这里使用多线程的方法进行快速下载。

如果直接运行以下命令报错,根据 a. 使用 HFD 加快 Hugging Face 模型和数据集的下载 进行前置安装。

当然,你也可以取消我注释的部分,使用官方的命令进行安装,但是会很慢。

安装工具

sudo apt-get update
sudo apt-get install git wget curl aria2 git-lfs
git lfs install

下载 hfd 并修改权限

wget https://hf-mirror.com/hfd/hfd.sh
chmod a+x hfd.sh

多线程下载模型

export HF_ENDPOINT=https://hf-mirror.com
./hfd.sh 'MediaTek-Research/Breeze-7B-Instruct-v0_1' --tool aria2c -x 16

下载

加载模型

将使用MediaTek-Research/Breeze-7B-Instruct-v0_1模型进行微调。

model = AutoModelForCausalLM.from_pretrained(
    'MediaTek-Research/Breeze-7B-Instruct-v0_1',
    device_map='auto',
    trust_remote_code=True,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4'
    )
)

这里,我们采用了4位量化(4-bit quantization)来减少模型的内存占用,加快推理速度。

查看未经过微调的模型原始输出

在进行微调之前,我们首先查看一下原始模型的输出效果。首先,加载分词器:

tokenizer = AutoTokenizer.from_pretrained('MediaTek-Research/Breeze-7B-Instruct-v0_1')
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

定义一个数据处理函数,将数据格式化为模型可以接受的输入,我们这里的 prompt 延续原来的繁体(因为Breeze-7B-Instruct-v0_1更多使用繁体中文进行训练,你并不需要修改它):

def data_formulate(data):
    messages = [
        {"role": "system", "content": '回覆請少於20字'},
        {"role": "user", "content": data['prompt']},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

接下来,生成原始模型的响应:

original_model_response = []
for data in tqdm(test_data):
    id = data['id']
    print(f'Question {id}:\n'+data['prompt'])
    inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')
    generation_config=GenerationConfig(
            do_sample=False,
            max_new_tokens = 200,
            pad_token_id = tokenizer.pad_token_id
    )
    output = model.generate(**inputs, generation_config=generation_config)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]
    original_model_response.append(output)
    print('Response from original model:\n'+output+'\n')

这段代码将遍历测试数据集,生成并打印每个问题的原始模型响应。

image-20240919113918391

设置参数

你只需要修改这个模块,不需要改变其他的,除非你真的知道自己在做什么。

support_ratio 将反映你的偏好:

  • 0 表示完全不支持(反对)真人化
  • 1 表示完全支持真人化
  • 0.1 表示 10% 支持真人化
num_epoch = 1
data_size = 50
support_ratio = 0.1

准备训练数据

这里,我们将数据集分为支持(support)和反对(oppose)两部分,构建一个包含偏好对的训练数据集(是的,这里就是 DPO)。

# 选择部分数据用于训练
training_data = full_data[:data_size]

# 定义 support 数据集的大小
support_data_size = int(data_size * support_ratio)

# 为训练数据集准备数据
prompt_list = [data_formulate(data) for data in training_data]
chosen_list = [data['support'] for data in training_data[:support_data_size]] + [data['oppose'] for data in training_data[support_data_size:]]
rejected_list = [data['oppose'] for data in training_data[:support_data_size]] + [data['support'] for data in training_data[support_data_size:]]
position_list = ['support' for _ in range(support_data_size)] + ['oppose' for _ in range(data_size - support_data_size)]

# 创建训练数据集
train_dataset = Dataset.from_dict({'prompt': prompt_list, 'position': position_list, 'chosen': chosen_list, 'rejected': rejected_list})
pd.DataFrame(train_dataset).rename(columns={"chosen": "preferred", "rejected": "non-preferred"})

总共有 50 笔训练数据,当 support 设置为 0.1 时,前 50*0.1=5 笔训练资料的偏好将倾向于支持真人化,后 50-4=45 笔资料反对真人化。

image-20240919114949791

训练

现在,我们进入训练阶段。首先,设置训练参数:

training_args = TrainingArguments(
    output_dir='./',
    per_device_train_batch_size=1,
    num_train_epochs=num_epoch,
    gradient_accumulation_steps=8,
    gradient_checkpointing=False,
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    logging_steps = 1,
    warmup_ratio = 0.1,
    report_to = 'none'
)

接下来,配置PEFT(Parameter-Efficient Fine-Tuning):

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

然后,初始化DPO训练器:

dpo_trainer = DPOTrainer(
    model,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)

开始训练:

dpo_trainer.train()

image-20240919115410184

查看微调后的模型输出

训练完成后,我们需要查看微调后的模型效果。以下是生成训练后模型响应的代码:

trained_model_response = []
for data in tqdm(test_data):
    id = data['id']
    print(f'Question {id}:\n'+data['prompt'])
    inputs = tokenizer(data_formulate(data), return_tensors="pt").to('cuda')
    generation_config=GenerationConfig(
            do_sample=False,
            max_new_tokens = 200,
            pad_token_id = tokenizer.pad_token_id
    )
    output = model.generate(**inputs, generation_config=generation_config)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0].split('[/INST] ')[1]
    trained_model_response.append(output)
    print('Response from trained model:\n'+output+'\n')

这段代码与之前生成原始模型响应的代码类似,但这次生成的是经过微调后的模型响应:

image-20240919115643310

观察输出结果

最后,我们对比微调前后的模型响应,观察DPO方法带来的效果提升:

model_response = []
print(f'num_epoch: {num_epoch}\ndata_size: {data_size}\nsupport_ratio: {support_ratio}')
print()
for data in test_data:
    id = data['id']
    ref_output = original_model_response[id-1]
    output = trained_model_response[id-1]
    print(f'Question {id}:\n'+data['prompt'])
    print('Response from original model:\n'+ref_output)
    print('Response from trained model:\n'+output)
    print()
    model_response.append({'id':data['id'], 'prompt':data['prompt'], 'response_from_original_model':ref_output, 'response_from_trained_model':output})

image-20240919115708299

拓展

在使用 GPT 的时候你应该也见到过其同时生成两个回答让我们选择更倾向于哪个,这个和 Google 验证码有着异曲同工之妙。

进一步

12. Inseq 特征归因:可视化解释 LLM 的输出
李宏毅2024生成式人工智能导论 中文镜像版指导与作业

推荐阅读

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2152098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

卡牌抽卡机小程序:市场发展下的创新

今年以来,卡牌成为了行业中的黑马,在国内迅速流行,成为消费者的心头好。小小的卡牌创下了百亿的市场规模,发展前景巨大! 不过,随着卡牌市场的不断增长,市场发展也需要进行创新。线上抽卡机小程…

Yocto - 使用Yocto开发嵌入式Linux系统_02 认识 Yocto 项目

Meeting the Yocto Project 本章向你介绍 Yocto 项目。这里讨论的项目主要概念将贯穿全书。此外,我们还将简要讨论 Yocto 项目的历史、OpenEmbedded、Poky、BitBake、元数据和版本模式。系好安全带,欢迎加入我们的行列! This chapter introdu…

信息安全数学基础(19)同余式的基本概念及一次同余式

一、同余式概念 同余式是数论中的一个基本概念,用于描述两个数在除以某个数时所得的余数相同的情况。具体地,设m是一个正整数,a和b是两个整数,如果a和b除以m的余数相同,则称a和b模m同余,记作a≡b(mod m)。反…

C语言 | Leetcode C语言题解之第421题数组中两个数的最大异或值

题目: 题解: const int HIGH_BIT 30;struct Trie {// 左子树指向表示 0 的子节点struct Trie* left;// 右子树指向表示 1 的子节点struct Trie* right; };struct Trie* createTrie() {struct Trie* ret malloc(sizeof(struct Trie));ret->left re…

SpringBoot 数据库表结构文档生成

官方地址&#xff1a;https://github.com/pingfangushi/screw screw 螺丝钉&#xff0c;支持以下数据库 MySQL MariaDB TIDB Oracle SqlServer PostgreSQL Cache DB&#xff08;2016&#xff09; 生产文档支持 html word markdown 开始 添加依赖 <!-- 螺丝钉 --><…

CompletableFuture-详解使用及源码解析

背景 上一篇文章我们看了FutureTask&#xff0c;分析了他的问题&#xff0c;异步编程并不方便。 问题1&#xff1a; FutureTask获取执行结果前&#xff0c;主线程需要通过get()方法一直阻塞等待子线程执行完成call方法&#xff0c;才可以拿到返回结果问题2&#xff1a;如果不…

电竞显示器哪个牌子好

电竞显示器哪个好&#xff1f;你想成为电竞选手吗&#xff1f;显示器很关键&#xff0c;下面我就列举7款市面流行的电竞显示器给大家看看&#xff0c;总有一款适合你。 1.电竞显示器哪个好 - 蚂蚁电竞 ANT255VF电竞显示器 一、产品概述 蚂蚁电竞 ANT255VF电竞显示器是一款专为…

2024/9/21 leetcode 21.合并两个有序链表 2.两数相加

目录 21.合并两个有序链表 题目描述 题目链接 解题思路与代码 2.两数相加 题目描述 题目链接 解题思路与代码 --------------------------------------------------------------------------- 21.合并两个有序链表 题目描述 将两个升序链表合并为一个新的 升序 链表并返…

ChatCADChatCAD+:Towards a Universal and Reliable Interactive CAD using LLMs

ChatCAD&#xff08;论文链接&#xff1a;[2302.07257] ChatCAD: Interactive Computer-Aided Diagnosis on Medical Image using Large Language Models (arxiv.org)&#xff09; 网络流程图&#xff1a; 辅助阅读&#xff1a; 基于大型语言模型的医学图像交互式计算机辅助诊…

【运维自动化-作业平台】如何使用全局变量之字符串类型?

使用变量是脚本很常见的处理场景&#xff0c;作业平台中主要有全局变量和魔法变量两类&#xff0c;全局变量又区分了字符串、命名空间、主机列表、密文、数组5种类型。字符串类型变量 最简单、使用频率最高的全局变量类型&#xff0c;可以跨主机、跨步骤使用。目前在作业平台中…

uniApp微信小程序扫描普通二维码跳转到小程序指定页面操作方法

这篇文章主要给大家介绍了关于微信小程序扫描普通二维码跳转到小程序指定页面操作的相关资料,需要的朋友可以参考下 1、首先我们需要在微信公众平台的开发管理——>开发设置&#xff0c;找到&#xff08;扫普通链接二维码打开小程序&#xff09;&#xff0c;点击添加,根据提…

vue3-05-Element-plus中表单校验:校验对象中的对象的属性,校验对象中的数组中的对象的属性,校验嵌套对象

目录 一、校验对象中的普通属性二、校验对象中对象的属性三、校验对象中的数组中的对象的属性 这两天写vue3项目&#xff0c;用了element-plus库&#xff0c;到了表单规则验证的环节&#xff0c;我发现我只会校验对象中的普通属性&#xff0c;如果校验嵌套对象&#xff0c;我就…

ML 系列:多元线性回归 (MLR)(04)

图 1.多元线性回归与简单线性回归 一、说明 线性回归从一维推广到多维&#xff0c;这与单变量线性回归有很多不同&#xff0c;情况更加复杂&#xff0c;而在梯度优化也需要改成向量梯度&#xff0c;同时&#xff0c;数据预处理也成了必要步骤。 二、综述 多元线性回归是简单线性…

C++:分苹果【排列组合】

描述 把M个同样的苹果放到N个同样的盘子里&#xff0c;允许有的盘子空着不放&#xff0c;问共有多少种不同的分法&#xff1f;&#xff08;用K表示&#xff09;&#xff0c;5&#xff0c;1&#xff0c;1和1&#xff0c;5&#xff0c;1是同一种分法。 输入描述 两个整数M和N&…

C语言 | Leetcode C语言题解之第420题强密码检验器

题目&#xff1a; 题解&#xff1a; #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))int strongPasswordChecker(char * password) {int n strlen(password);bool has_lower false, has_upper false, has_digit false;for …

YOLOv9改进系列,YOLOv9主干网络替换为RepViT (CVPR 2024,清华提出,独家首发),助力涨点

摘要 轻量级视觉变换器(ViTs)在资源受限的移动设备上表现出优越的性能和较低的延迟,相比之下轻量级卷积神经网络(CNNs)稍显逊色。研究人员发现了许多轻量级 ViTs 和轻量级 CNNs 之间的结构联系。然而,它们在块结构、宏观和微观设计上的显著架构差异尚未得到充分研究。在…

【重磅发布】大模型在金融领域的价值、治理和生态进阶之路白皮书

引言 金融行业天然具备数据和信息密集型的特点,在数字化成熟度方面处于领先地位。此外,金融行业的数字化投入持续稳步增长,汇集了大量具备数字化技能的人才。这些优势使得金融行业在AI技术的应用和创新方面具备独特的条件,能够在推动技术革新和提升行业效率方面起到示范作…

NLP(二)-文本表示

One-hot One-hot&#xff08;独热&#xff09;编码是一种最简单的文本表示方式。如果有一个大小为V的词表&#xff0c;对于第i个词$w_i$&#xff0c;可以用一个长度为V的向量来表示&#xff0c;其中第i个元素为1&#xff0c;其它为0.例如&#xff1a; 减肥&#xff1a;[1, 0,…

59.【C语言】内存函数(memmove函数)

目录 2.memove函数 *简单使用 部分翻译 *模拟实现 方案1 方案2 1.有重叠 dest在src左侧 dest在src右侧 2.无重叠 代码 2.memove函数 *简单使用 memove:memory move cplusplus的介绍 点我跳转 对比第59篇的memcpy函数 对比memmcpy函数的介绍如下区别: 部分翻译 m…

【Verilog学习日常】—牛客网刷题—Verilog快速入门—VL59

根据RTL图编写Verilog程序 描述 根据以下RTL图&#xff0c;使用 Verilog HDL语言编写代码&#xff0c;实现相同的功能&#xff0c;并编写testbench验证功能。 输入描述&#xff1a; clk&#xff1a;系统时钟信号 rst_n&#xff1a;复位信号&#xff0c;低电平有效 data_in…