[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练

news2025/1/12 12:11:47

当前关于LLM的共识

大型语言模型(LLM)使 NLP 中微调模型的过程变得更加复杂。最初,当 ChatGPT 等模型首次出现时,最主要的方法是先训练奖励模型,然后优化 LLM 策略。从人类反馈中强化学习(RLHF)极大地推动了NLP的发展,并将NLP中许多长期面临的挑战抛在了一边。基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。

然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

现在主流的LLM,比如chatglm、chinese-alpaca,主要进行了三步操作:

Step1:知识学习,CLM,大规模语料库上的预训练,本步的模型拥有续写的功能

Step2:知识表达,指令微调,在指令数据上进行微调,本步骤可以使用Lora等节省显存的方式,本模型可以听懂人类指令并进行回答的功能

Step3:偏好学习,RLHF或本文所提的DPO,可以让模型的输出更符合人类偏好,通俗说就是同样一句话,得调教的让模型输出人类喜欢的表达方式,好比高情商的人说话让人舒服

第二步,还是多多少少学习了一点知识,第三步则几乎不学知识,只学表达方式了。

RLHF太耗时耗力了,得提前训练好RewardModel,然后PPO阶段,得加载4个模型,2个推理,2个训练,实在是太不友好了。

下图是SFT+RLHF的过程,对应上文的Step2和Step3,主要包括指令微调模型、训练奖励模型和PPO优化。

现在大多数目前开源的LLM模型都只做了前2步:预训练和指令微调。

而其中原因就是第3步人类反馈强化学习(RLHF)实现起来很困难:

1.需要人类反馈数据(很难收集)
2.奖励模型训练(很难训练)
3. PPO强化学习微调(不仅很耗资源,而且也很难训练)

但是能不能不要最后一步呢,一般来说还是有RLHF比较好,有主要有以下几个原因:

  1. 提高安全性和可控性;
  2. 改进交互性;
  3. 克服数据集偏差;
  4. 提供个性化体验;
  5. 符合道德规范;
  6. 持续优化和改进。

RLHF使得ChatGPT这样的大型对话模型既具备强大能力,又能够接受人类价值观的指导,生成更智能、安全、有益的对话回复。这是未来可信赖和可解释AI的重要发展方向。

所以这一步还是非常重要。那如何解决人类反馈强化学习(RLHF)训练这个难题呢?

DPO (Differentiable Policy Optimization) 算法

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

DPO 是为实现对 LLM 的精确控制而引入的一种方法。从人类反馈强化学习(RLHF)的基础是训练奖励模型,然后使用近端策略优化(PPO)使语言模型的输出与人类的偏好相一致。这种方法虽然有效,但既复杂又不稳定。DPO 将受限奖励最大化问题视为人类偏好数据的分类问题。这种方法稳定、高效、计算量小。它无需进行奖励模型拟合、大量采样和超参数调整。

DPO(Direct Preference Optimization)是一种直接偏好优化算法,它与PPO(Proximal Policy Optimization)优化的目标相同。主要思路是:

1.定义policy模型(策略模型)和reference模型(参考模型),Policy模型是需要训练的对话生成模型,reference模型是给定的预训练模型或人工构建的模型。

2.对于给定prompt,计算两模型对正样本和负样本的概率,正样本是人类选择的回复,负样本是被拒绝的回复。

3.通过两个模型概率的差值构建DPO损失函数,惩罚policy模型对正样本概率的下降和负样本概率的上升。通过最小化DPO损失进行模型训练。

相比之下DPO就很友好,只需要加载2个模型,其中一个推理,另外一个训练,直接在偏好数据上进行训练即可:

DPO 拒绝有害问题 实战部分

数据集

数据集其实就是标准的RLHF奖励模型的训练集,下载地址在这

Anthropic/hh-rlhf · Datasets at Hugging Face

dikw/hh_rlhf_cn · Datasets at Hugging Face

其样式就是:一个context,一个选择的正样本,一个拒绝的负样本。希望这些样本能够让LLM 尽可能生成用户选择的无害回复,而不要生成有害的回复。

微调代码
下方这段代码实现了基于DPO (Differentiable Policy Optimization) 的对话模型微调。主要步骤包括:

  1. 加载预训练语言模型(这里使用llama-2-7b)并准备量化训练,采用int4量化的+少量lora 参数。
  2. 定义参考模型(int4量化的模型),也使用同样的预训练模型。
  3. 加载Helpful/Harmless数据集,并转换成所需格式。
  4. 定义DPO训练参数,包括batch size,学习率等。
  5. 定义DPO训练器,传入policy模型,参考模型,训练参数等。
  6. 进行DPO微调训练。
  7. 保存微调后的模型,只保存量lora 参数。

关键点:

1. 使用DPO损失函数实现安全性约束的模型训练。不需要额外在训练一个奖励模型。
2. 这也导致整个训练过程只需要策略模型和参考模型 2个LLM模型,不需要额外的显存去加载奖励模型。
3. 整个训练过程策略模型和参考模型可以进行4int的模型量化 + 少量的lora 参数

综上,这段代码对预训练语言模型进行DPO微调,以实现安全可控的对话生成

#!/usr/bin/env python
# coding: utf-8

from typing import Dict

import torch
from datasets import Dataset, load_dataset
from trl import DPOTrainer
import bitsandbytes as bnb

from transformers import TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)

output_dir1 = "./dpo_output_dir1"
output_dir2 = "./dpo_output_dir2"

base_model = "/home/work/llama-2-7b"

###准备训练数据
dataset = load_dataset("json", data_files="./dpo_dataset/harmless_base_cn_train.jsonl")
train_val = dataset["train"].train_test_split(
    test_size=2000, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]


def extract_anthropic_prompt(prompt_and_response):
    final = ""
    for sample in prompt_and_response:
        final += sample["role"] + "\n" + sample["text"]
    final += "\n"
    return final


def get_hh(dataset, split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = dataset
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["context"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"]["role"] + "\n" + sample["chosen"]["text"],
            "rejected": sample["rejected"]["role"] + "\n" + sample["rejected"]["text"],
        }

    return dataset.map(split_prompt_and_responses)


train_dataset = get_hh(train_data, "train", sanity_check=True)
eval_dataset = get_hh(val_data, "test", sanity_check=True)


def find_all_linear_names(model):
    # cls = bnb.nn.Linear8bitLt
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}"
    )


tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

bnb_4bit_compute_dtype = "float16"
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_4bit_quant_type = "nf4"
use_nested_quant = False

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

model = AutoModelForCausalLM.from_pretrained(base_model,
                                             trust_remote_code=True,
                                             quantization_config=bnb_config,
                                             device_map="auto")
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)

modules = find_all_linear_names(model)
config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=modules,
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

###定义参考模型
model_ref = AutoModelForCausalLM.from_pretrained(base_model,
                                                 trust_remote_code=True,
                                                 quantization_config=bnb_config,
                                                 device_map="auto")
###定义dpo训练参数
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    max_steps=100,
    remove_unused_columns=False,
    gradient_accumulation_steps=2,
    learning_rate=3e-4,
    evaluation_strategy="steps",
    output_dir="./test",
)

###定义dpo训练器
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
###训练
dpo_trainer.train()
###模型保存
dpo_trainer.save_model(output_dir1)

dpo_trainer.model.save_pretrained(output_dir2)
tokenizer.save_pretrained(output_dir2)

训练过程

其中看出加载了2遍int4量化的模型到显存中,需要训练的策略模型只有一部分lora参数,而参考模型就是原始模型本身.

模型保存

保存下来的参数也就是lora参数,这部分lora 参数就学会了如何拒绝回答有害问题。

至此,我们就学会了如何利用使用DPO +Qlora 实现在完成RLHF的实战。

使用场景

核心原则:偏好数据集中的good/bad response都是和SFT model的训练数据同分布的,也可以说模型是可以生成good/bad response的。

场景1

已有一个SFT model,为了让它更好,对它的output进行偏好标注,然后使用DPO进行训练,这是最正常的使用场景,但是偏好数据集确实避免不了的

场景2

场景1的改进版本,偏好标注不由人来做,而是让gpt4或者一个reward model来标注好坏,至于reward model怎么来,就各凭本事吧

场景3

没有SFT model只有偏好数据集,那就先在偏好数据即中的 (�,��) 进行训练,然后在进行DPO的训练。先SFT就是为了符合上文的核心原则

DPO: Direct Preference Optimization 论文解读及代码实践 - 知乎 (zhihu.com)GitHub - mzbac/llama2-fine-tune: Scripts for fine-tuning Llama2 via SFT and DPO.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017020.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[字符串和内存函数]错误信息报告函数strerror详解

strerror介绍 strerror是一个C库函数&#xff0c;用于将错误代码转换为对的错误信息字符串。它接受一个整数参数errno&#xff0c;返回一个指向错误信息字符串的指针。 errno是一个全局变量&#xff0c;可以直接使用。它在C语言中用于表示发生错误时的错误码。它是一个整数&…

笔记1.4 计算机网络性能

1. 速率 速率即数据率&#xff08;data rate&#xff09;或称数据传输速率或比特率 单位时间&#xff08;秒&#xff09;传输信息&#xff08;比特&#xff09;量 计算机网络中最重要的一个性能指标 单位&#xff1a;bps、kbps、Mbps k 10^3、M 10^6、G 10^9 速率往往…

autosar 诊断入门

AUTOSAR (汽车开放系统架构) 是一个国际汽车行业的开放和标准化的软件架构。它的主要目标是为了创建一种独立于硬件的软件架构&#xff0c;以提高汽车电子系统的模块化和可重用性。 AUTOSAR架构主要分为两个部分&#xff1a;AUTOSAR Runtime Environment (RTE) 和 AUTOSAR Soft…

HTTP各版本差异

HTTP1.0 无法复用连接 HTTP1.0为每个请求单独新开一个TCP连接 #mermaid-svg-9N3exXRS4VvT4bWF {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-9N3exXRS4VvT4bWF .error-icon{fill:#552222;}#mermaid-svg-9N3exXRS…

集成电路运算放大器[23-9-16]

目录 1、结构组成&#xff1a;差分放大电路、电压放大电路、功率放大电路。 2、同相放大器&#xff1a; 3、反相放大器&#xff1a;一个正电压放大并变为负电压。 4、差分放大电路&#xff1a;输入两个不同的电压&#xff0c;两者的差值乘以放大系数得到输 出电压。 1、结构组…

2.策略模式

UML图 代码 main.cpp #include "Strategy.h" #include "Context.h"void test() {Context* pContext nullptr;/* StrategyA */pContext new Context(new StrategyA());pContext->contextInterface();/* StrategyB */pContext new Context(new Strat…

MybatisPlus(5)

前言&#x1f36d; ❤️❤️❤️SSM专栏更新中&#xff0c;各位大佬觉得写得不错&#xff0c;支持一下&#xff0c;感谢了&#xff01;❤️❤️❤️ Spring Spring MVC MyBatis_冷兮雪的博客-CSDN博客 上篇讲了增删的操作&#xff0c;这篇讲修改操作中的一个问题以及它对应的…

002-第一代硬件系统架构确立及产品选型

第一代硬件系统架构确立及产品选型 文章目录 第一代硬件系统架构确立及产品选型项目介绍摘要硬件架构硬件结构选型及设计单片机选型上位机选型扯点别的 关键字&#xff1a; Qt、 Qml、 信号采集机、 数据处理、 上位机 项目介绍 欢迎来到我们的 QML & C 项目&#xff…

Centos8下载安装JDK8

安装JDK8 一、下载 官网&#xff1a;https://www.oracle.com/java/technologies/downloads/#java8-linux 二、存放到opt目录下 三、解压 tar -zxvf jdk-8u361-linux-x64.tar.gz -C /usr/local如果下载的是tar格式&#xff0c;则安装如下命令解压 tar -zvf jdk-8u361-linu…

Flash的学习

Flash的学习 1 概述 2 特性 STM32 的内部FLASH 包含主存储器、系统存储器以及选项字节区域。 2.1 主存储器 主存储器分为256 页&#xff0c;每页大小为2KB&#xff0c;共512KB。这个分页的概念&#xff0c;实质就是FLASH 存储器 的扇区&#xff0c;与其它FLASH 一样&…

【C++】动态内存管理(79分钟写的文章哪里看不懂了,快来学)

动态内存管理目录&#xff1a; 一、C/C内存分布 在学习了C/C内存区域的划分后&#xff0c;我们来做几道题巩固一下&#xff1a; 1. 选择题&#xff1a;选项 : A.栈 B.堆 C.数据段(静态区) D.代码段(常量区)globalVar在哪里&#xff1f;____ staticGlobalVar在哪里&#x…

【2023知乎评论爬虫】我用Python爬虫爬了2386条知乎评论!

文章目录 一、爬取目标二、展示爬取结果三、爬虫代码讲解3.1 分析知乎页面3.2 爬虫代码 四、同步视频五、完整源码 您好&#xff0c;我是 马哥python说&#xff0c;一枚10年程序猿。 一、爬取目标 前些天我分享过一篇微博的爬虫&#xff1a;https://blog.csdn.net/solo_msk/a…

Spring 的注入

目录 一、注入&#xff08;Injection&#xff09; 1、什么是注入 &#xff08;1&#xff09;为什么需要注入 &#xff08;2&#xff09;如何进行注入 2、Spring 注入原理分析&#xff08;简易版&#xff09; 二、Set 注入详解 1、JDK 内置类型 &#xff08;1&#xff09…

【修复版】2023新版塔罗 算八字测运易理风水 取名 源码平台 搭建教程

全线修复&#xff0c;欢迎你拿其他家的来比&#xff0c;叫他们发测试连接去测试对比&#xff0c;眼睛骗不了人 很多说自家的是修复版&#xff0c;能像我这样把修复的列出来嘛&#xff0c;顾客朋友也可以直接去问他&#xff1a;你到底修复了个啥&#xff1f;&#xff1f;&#…

深入探索智能问答:从检索到生成的技术之旅

目录 一、智能问答概述1. **语义理解**2. **知识库和数据库**3. **上下文感知**4. **动态学习和自适应** 二、发展历程1. **基于规则的系统**2. **统计方法的兴起**3. **深度学习和神经网络的突破**4. **预训练模型** 三、智能问答系统的主要类型四、基于知识库的问答系统五、基…

如何使用记事本制作一个简陋的小网页(1)

1、创建一个记事本 2、将记事本的尾缀进行修改&#xff0c;变为html html本质上是一种标签并不是一种语言&#xff0c;其最主要的功能就是对页面设置和页面的文本样式进行修改和修缮。 3、右键被修改了尾缀的文本&#xff0c;选择使用记事本的打开方式。 打开完毕后&#xf…

什么是JavaScript的事件驱动编程(event-driven programming)?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 事件&#xff08;Event&#xff09;⭐ 事件监听器&#xff08;Event Listener&#xff09;⭐ 回调函数&#xff08;Callback Function&#xff09;⭐ 非阻塞和异步⭐ 事件循环&#xff08;Event Loop&#xff09;⭐ 触发事件&#xff08;…

「C++程序设计 (面向对象进阶)」学习笔记・二

0、引言 本专栏的系列文章是在学习 北京邮电大学 崔毅东 老师的《C程序设计 (面向对象进阶)》课程过程中整理的。欢迎前往专栏了解更多相关内容~ &#x1f600; 有关于现代 C 的基本介绍&#xff0c;请前往《现代C基本介绍》&#xff01; &#x1f514; 先决条件 本专栏的系列…

【测试开发】答疑篇 · 什么是软件测试

【测试开发】答疑篇 文章目录 【测试开发】答疑篇1. 生活中的测试2. 什么是软件测试3. 为什么要有测试/没有测试行不行4. 软件测试和软件开发的区别5. 软件测试和软件调试之间的区别6. 软件测试的岗位7. 优秀测试人员具备的素质 【测试开发】答疑篇 软件不一定是桌面应用&#…

公众号迁移线上公证如何办?

公众号账号迁移的作用是什么&#xff1f;只能变更主体吗&#xff1f;微信公众平台的帐号迁移功能可将原公众号的粉丝、文章素材、违规记录、留言功能、名称等迁移至新的公众号。通过迁移可以实现公众号的公司主体变更、粉丝转移、开通留言功能、服务号转为订阅号等作用。因此不…