社区供稿 | RLHF 实践中的框架使用与一些坑 (TRL, LMFlow)

news2024/11/25 1:07:26

1 前言

之前看见文章总结了常见的一些 RLHF 框架的经验, 但是似乎没看见 Hugging Face 自己维护的 TRL 库的相关文章, 正好最近调 TRL 比较多, 就想写一个文章分享一下使用过程中踩到的坑,另外也介绍一下我们的全流程框架 LMFlow 。

43d4af0fefa2c2c04e98fe42f34c8b28.png

LMFlow 框架示意图。

我们主要用一个具体的例子展示如何在两个框架下做RLHF,并且记录下训练过程中我们踩到的主要的坑。这个例子包括完整的SFT,奖励建模和 RLHF, 其中RLHF包括通过 RAFT 算法(Reward rAnked FineTuning)或者TRL-PPO 对齐模型两个部分。为了方便用户,我们已经在 Hugging Face repo 中提供了一个基于 GPT-Neo-2.7B 的奖励模型,因此也可以先跳过奖励建模。

这个例子是基于仅适用于非商业用途的许可的 LLaMA 构建的, 为了使用LLaMA-7B 模型, 大家需要填写前面的 request form。测试的环境是 8 X A100 (40G)。

1.1 环境准备

LMFlow 的安装包中也包含了 TRL, 所以我们只需要按照官方的示例安装 LMFlow 即可。

git clone https://github.com/OptimalScale/LMFlow.git
cd LMFlow
conda create -n lmflow python=3.9 -y
conda activate lmflow
conda install mpi4py
pip install -e .

以上安装自动会把依赖的 PyTorch 等包也一起安装, 除此之外, 我们额外手动安装一下 matplotlib 这个包

1.2 数据集描述

我们使用 Dahoas/full-hh-rlhf 数据集作为例子,其中每个数据集样本包括一个提示和来自助手的两个回应。特别地,标记为 "chosen" 的回应相对于标记为 "rejected" 的回应更被人类所喜欢。数据集包括 112K 个训练样本和 12.5K 个测试样本。以下是数据集的一个示例样本:

" Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

为了便于训练,我们在字符开头添加 ``###'' 来重新构建提示,以便模型知道要回复。新样本的格式将是:

"###Human: What kind of noises did dinosaurs make? ###Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be ###Human: yes they did ###Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. ###Human: you cant read ###Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

我们在目录 ./data/hh_rlhf 中准备了所有需要使用的所有数据集,需要通过在 LMFlow目录下中运行以下命令获得:

cd data && ./download.sh hh_rlhf && cd -

2 RLHF 之前的准备

在这一节, 我们需要先完成SFT模型以及 reward model 的训练, 这部分我们使用LMFlow 完成。

2.1 SFT

这是数据集 /home/usrname/LMFlow/data/hh_rlhf/sft/hh_rlhf_sft.json 的一个示例。我们只使用首选回应,因此我们得到 112K 个训练样本。

{"type": "text_only", "instances": [{"text": "###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant: You’re welcome."}, {"text": "###Human: I'm trying to learn about the salam witch trials###Assistant: If you’re looking for books about witchcraft trials, I can recommend some books for you. But you probably mean you’re looking for more historical information about the Salem witch trials in 1692, and specifically about the salam witch trials in 1692?###Human: What are some good books to learn about the salam witch trials###Assistant: What would you like to learn?  If you’re interested in history, one of the best books is The Witch Persecutions in Colonial America: A History.  If you’re interested in witchcraft as a cultural phenomenon, you might enjoy two excellent books: Religion and the Decline of Magic: Studies in Popular Beliefs in Sixteenth- and Seventeenth-Century England by Keith Thomas and Magic, Witchcraft, and the Otherworld: An Anthropology of Superstition by Jack Goody.  If you’re interested in history specifically as it relates to religion, you might enjoy The Popish Plot, or Prelates' Plot: A History of the Popish Plot in England, by K. J. Everett."}]}

你可以编辑 /scripts/run_finetune.sh 并修改参数。我们在这里用 GPT-Neo-2.7B 作为一个例子, 你应当把它换成你获得的 llama-7b 模型的地址。

  • --model_name_or_path: EleutherAI/gpt-neo-2.7B

  • --dataset_path: ${project_dir}/data/hh_rlhf/sft

  • --output_dir: the path you want to store the sft model

  • --num_train_epochs: 1

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: 根据你的GPU资源调整。

  • exp_id: hh_rlhf_llama_sft

你可以编辑 /scripts/run_finetune.sh 并修改参数。我们在这里用 GPT-Neo-2.7B 作为一个例子。

然后,我们可以运行以下命令来执行 SFT。

./scripts/run_finetune.sh

你还可以通过以下命令使用 lora 训练,但还需要通过编辑 run_finetune_with_lora.sh 设置 model_name_or_path 和 dataset。

./scripts/run_finetune_with_lora.sh

下面这个损失图像示例中我们设了 epoch 为4, 但是提前停止并使用一个epoch结束的模型作为SFT模型, 此外我们的logging step 设置为了20, 所以整体看起来会比较平滑

1e95ba15734d0155186744920c142de2.png

SFT 模型训练曲线, 这个例子截取了1.6个epoch 的训练曲线。

在我的例子中, 得到的SFT模型存储在 /home/usrname/LMFlow/output_models/hh_rlhf_llama_sft/checkpoint-1271

2.2 Reward Modeling

我们首先按照 InstructGPT 论文的过程:https://arxiv.org/abs/2203.02155 使用 HH-RLHF 数据集训练一个奖励模型,其中包括:

  • 监督微调 (SFT);

  • 通过比较数据集进行奖励建模。

由于PPO有较大的内存压力, 后续实验证明在这个例子的设置里, TRL 的实现无法同时载入7B的RM与7B的训练模型, 因此我们选择使用GPT-Neo-2.7B作为我们的RM。其中监督微调与2.1节是完全一致的, 我们就略去了。

奖励建模涉及根据给定提示对回应进行排序,通常需要人类标注。幸运的是,在这个例子中,我们已经为每个提示提供了两个回应,其中一个优于另一个。为了创建一个新数据集,我们修改原始数据集,包括两个key: "positive" 和 "negative",其中前者是首选回应。

这是数据集 LMFlow/data/hh_rlhf/rm/hh_rlhf_rm_training.json 的一个示例,其中包含两个样本, 完整的数据集包含了Dahoas/full-hh-rlhf 数据集112K 的训练集和12.5K的测试集。

{"instances": [
    {"positive": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: Glad to be helpful.  I’m always happy to assist.", "negative": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: You’re welcome. I’m glad to help. Is there anything else you want to know?"}, 
    {"positive": "###Human: how do i teach my kids to make forts at home?###Assistant: Do you mean you want to teach them to build a simple structure they can hide in? Or something more elaborate?  I don’t know exactly what you mean by a “fort”, but maybe this gives you some ideas?", "negative": "###Human: how do i teach my kids to make forts at home?###Assistant: There are many ways to make forts!  The simplest might be to just pile up all the furniture in one room.  This works if you have a single room in your house, although it might still be a bit challenging to build a taller and sturdier fort this way.  The best way to build forts is to start with something easy, and then have fun exploring how you can improve it."}
]}

你可以编辑 ./scripts/run_reward_modeling.sh 脚本以从上一步的 sft 模型开始训练,并更新 "dataset_path" 以改变使用所需的数据集。我们默认使用LoRA进行训练, 因为我们在多个模型上进行了对比, LoRA由于优化上的优势能节省大量的时间, 并且full training 得到的模型并不能提供明显更有优势的RM, 我们用到的参数如下

  • --model_name_or_path: /home/usrname/LMFlow/output_models/hh_rlhf_rm_sft_gptneo_2_7B/checkpoint-1659

  • --dataset_path: ${project_dir}/data/hh_rlhf/rm/hh_rlhf_rm_training.json

  • --output_dir: the path you want to store the reward model

  • --num_train_epochs: 1

  • --learning_rate: 3e-5

  • --per_device_train_batch_size: adjust according to your GPU memory source.

  • --eval_steps: 400

  • --validation_split_percentage: 10

其中我们会自动使用数据集最后的百分之十样本对RM测试, 注意这里使用的数据集是原数据集中的training set + test set, 所以最后的一部分数据集并没有被模型见到过。在这个例子里, validation_split_percentage不应设大于15, 否则会有一部分SFT中用到的样本被使用进测试集 这些数据集的处理都实现在 /examples/run_reward_modeling.py 中, 如果你想使用你自己的数据集进行训练RM, 可以在这里根据你的需求进行修改。最后, 我们使用下面的代码进行训练

./scripts/run_reward_modeling.sh

下面是GPT-Neo-2.7B 与 LLaMA-7B 模型训练过程中的 evaluation loss 与 evaluation accuracy 图。

a2e12058253472d92b8a0105c5022d08.png

奖励模型训练中的evaluation曲线。

我们得到的一些RM 示例

ModelEval AccuracyRemarks
LLaMA-7B79.52%-
LLaMA-7B71.64%RM from LLaMA without SFT
GPT-NEO-2.7B69.24%-
GPT-NEO-1.3B65.58%Only trained on 10000 samples

可以看到一般来说, 更大的模型的准确率也要更高, 但是因为TRL-PPO会爆OOM的问题 (根据一个同学的反馈, 7B+7B 训练 trlx 的实现也一样是会爆OOM), 我们选择使用2.7B的模型。值得注意的是, 即使是LLaMA-7B模型的准确率也只能达到80%左右, 并且得到的RM很可能无法检测到一些我们所不希望有的pattern (例如重复)并仍然给一个比较高的reward。总而言之, 现在这种做分类得到的奖励模型, 仍然是有很大缺陷的。

最后, 因为我们得到的模型是low-rank 的 LoRA adapter, 我们需要使用*./examples/merge_lora.py* 来获得最终的RM模型。

3 RAFT Alignment

原始论文: RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

3.1 Algorithms Overview

RAFT想法的起源如下, 之前有很多研究都发现了如果训练RM的数据集直接做SFT, 效果不如先去训练RM, 再用RL进行reward learning。一个解释是后者能够有更多的数据进行训练, 但我们注意到前向产生数据本身并不仅仅是PPO专属的。此外, 当时我们花了很多的时间去调PPO, 发现PPO进行训练有容易OOM, 不稳定, 模型效果不确定的一些问题 (我们会在下一节记录中间踩的各种坑), 另外就是我们很多实验发现在垂直领域SFT可以稳定地给模型带来很大的性能提升, 一个自然的想法就是, reward learning 是否可以使用SFT。

具体而言, 我们每轮希望最终获取 b 个新样本进行训练,

  • 为此我们从prompt集合中选取 b x k 个prompt 并输入给当前的模型获得对应的输出;

  • 之后我们给b x k 个样本计算奖励;

  • 我们选取奖励最高的比例为1/k的样本进行SFT训练;

    • ''top'': 第一种方法是全部样本排序选取;

    • ''local'': 第二种方法是每个prompt 重复k 次, 并从这k个样本中选取最高奖励的样本;

    • 第一种会高效一些, 但是在一些场景 (例如这个例子里的实验) 下跨prompt的对比没有意义, 局部的排序会更加合理一些。

  • 新的一轮开始。

这里我们只使用了模型输出的一小部分数据进行训练, 这对forward 运算是坏的, 而对backward 运算是好的。我们观察到, 在我们基于deepspeed的实现下, forward 的batch size 可以开到 backward 的五倍左右, 所以我们认为一次推理的代价应该相对会小一些。

3.2 例子

我们使用之前得到的LLaMA-7B-SFT模型进行训练来作为一个例子, 我们希望记录一个具体的实验过程来说明其中的一些坑, 所以下面会有很多冗余和失败的尝试。

数据准备

我们的训练prompt集合就是Dahoas/full-hh-rlhf训练集中的112K样本去掉回复, 例如:

"###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant:"

我们额外从测试集里抽出2K用以测试。然而当我们使用这个prompt 集合进行 TRL-PPO的训练的时候 (所以后面为了fair comparison我们重做了实验, 泪目), 我们发现代码能够跑得起来, 但是在第二个epoch总是会爆OOM。Debug 良久之后发现原因是有一些prompt长度很长, 加上我们生成文本也比较长, TRL-PPO需要的memory和路径长度正相关, 因此我们只使用 token 数 < 256 的prompt, 最终得到82147个prompts。

测试LLaMA-7B-SFT

我们首先测试了SFT模型, 发现模型针对一个对话历史会回复多轮的自问自答, 为此我们将生成的回复用``###Human'' 进行截断:

def _clean_text(self, text):
    split_text = [x for x in text.split("###Human") if x]
    return split_text[0].strip().strip("#")

在LMFlow中, 使用的RM在*/LMFlow/examples/raft_align.py* 被指定, 如果你使用的奖励模型是按第二节的方法训练出, 你只给定它所在的本地地址或者 Hugging Face repo id:

reward_model_or_path: Optional[str] = field(
    default="weqweasdas/hh_rlhf_rm",
    metadata={
        "help": (
            "reward model name (huggingface) or its path"
        ),
    },
)

但是如果你的RM是一般性的, 例如 Hugging Face 上的一些分类器, 你可能还需要略微修改``get_reward_function'' 函数。

3.2.1 第一次训练

我们在LMFlow目录下, 使用如下的命令和参数进行训练:

./scripts/run_raft_align.sh
  • --model_name_or_path: /home/usrname/output_models/hh_rlhf_llama-sft (the model get from sft step, adjusted according your setup)

  • --dataset_path:${project_dir}/data/hh_rlhf/rlhf/rlhf_prompt

  • --output_dir: /home/usrname/output_models/hh_rlhf_raft_align

  • --num_train_epochs: 4

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: adjust according to your GPU memory source.

  • --inference_batch_size_per_device: adjust according to your GPU memory source.

  • --num_raft_iteration 20

  • --top_reward_percentage 0.125; (也就是1/8)

  • --raft_batch_size 1024 (每轮最终有1024个样本用来训练)

  • --output_min_length 126

实验运行地很顺利,训练奖励从约2.7提高到3.4,在我们的训练中, 我们监测了模型输出的一些多样性指标,我们注意到部分指标(例如distinct-2)在训练中显著下降,从0.39降至0.22。虽然有一些研究说明alignment tax 导致RLHF 模型的指标往往会变差 (作为human preference 上变好的代价), 但是这样大幅度的下降仍然是不同寻常的。为此, 我们检查了每个迭代时我们生成的样本,并发现如同SFT的测试, 在第一次迭代中,初始检查点的响应中偶尔会包含# (3%左右的样本),而我们的奖励函数无法检测到随机的#,这意味着包含#的响应也可能具有很高的奖励并被选入训练集。随后,情况变得越来越糟糕,最终有一半的响应包含嘈杂的#符号。

3.2.2 第二次训练

为了解决上述问题, 我们修改了代码并检测每个样本的回复是否含有冗余的#, 如果是, 则手动修改为一个低奖励。同时, 在当前的实现中, 我们会输出每一轮用以SFT的数据集用以监测整个训练过程。修改代码之后, 我们得到了如下的奖励曲线 (注意我们在测试的时候会使用比较低的temperature, 所以测试的奖励要高一些):

4124a779861809e90bd35340a163422f.png

RAFT的训练奖励曲线图, 横坐标表示一次 1) 数据生成 + 2) reward计算与样本排序 + 3) 一轮SFT。

其中横坐标代表的是一个raft的迭代, 包括 1) 数据生成 2) 数据排序 3) 以及在选出的数据集上进行一轮SFT。在我们的例子中, 每一轮会生成8192个样本, 并有1024个样本被使用去SFT。我们可以看到在训练的开始, 用以训练的数据集中的样本 (黄线)比我们模型自身的奖励要高得多, 而在这个小数据集上SFT之后, 模型的奖励开始上升 (绿线和蓝线), 而这反过来也改善了收集到的训练数据 (黄线也在上升)。在 8 x A100 (40G) 上进行如上训练大约需要三个小时。

最终获得的模型在奖励和多样性度量方面都表现良好,我们建议有兴趣的读者参考原始论文了解详细信息。然而,这更像是我们旅程的起点, 我们在最后一部分的讨论里对结果进行进一步的讨论, 在此之前, 我们先记录一下如何使用TRL-PPO进行实验。

4 TRL-PPO Alignment

LMFlow 安装过程中也会把TRL安装所以我们可以直接开始实验,在三个月之前想跑起来TRL需要手动修复几个小bug, 这几天拉了最新版本试验了一下似乎都已经修复了。

数据准备

我们首先修改 TRL-PPO 提供的script里的数据集准备, 注意我们将 TRL-PPO 的script 放在 LMFlow/examples中, 否则你需要稍微修改一下下面数据集的位置:

def build_dataset(config, tokenizer, dataset_name="./data/hh_rlhf/rlhf/rlhf_prompt/prompt.json"):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """

    ds = load_dataset("json", data_files=dataset_name, split="train")['instances'][0]
    texts = [sample['text'] for sample in ds]
    from datasets import Dataset
    ds = Dataset.from_dict({
        "text":texts,
    })
    
    
    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["text"])[:]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds = ds.filter(lambda x: len(x["input_ids"]) <= 256)
    ds.set_format(type="torch")
    print(len(ds))
    return ds

注意这里我们筛选了prompt 数据集, 只保留长度为256个token以内的, 否则过长的文本会导致OOM的错误。

超参数调整

PPO比较依赖于超参数, 不过我几个实验调下来的感觉是TRL默认的参数效果已经很不错了, 即使仔细调整学习率等等也很难获得很大的提升, 需要改的超参数包括:

  • batch_size: 1024/n_gpu, 在我们的设置下为128;

  • mini_batch_size: 一个有意思的发现是PPO的更新batch size 通常要比SFT小不少, 导致它会慢得多, 但不太确定是因为代码实现问题还是PPO本身需要的中间变量比较多的原因;

  • gradient_accumulation_steps: 1

除此之外, 比较关键的在于KL的权重的设置, 我最开始的想法就是简单的去搜, 结果从0.1, 0.05, 0.01 跑了好几轮都不能收敛 (reward 上升一阵后突然垮掉, 或者没有明显的reward 上升)。最后我的选择是先将KL的系数设为0, 然后去修改TRL的ppo_trainer 中的compute_rewards 函数, 打印出这个情况下的KL估计:

def compute_rewards(
        self,
        scores: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        ref_logprobs: torch.FloatTensor,
        masks: torch.LongTensor,
    ):
        """
        Compute per token rewards from scores and KL-penalty.

        Args:
            scores (`torch.FloatTensor`):
                Scores from the reward model, shape (`batch_size`)
            logprobs (`torch.FloatTensor`):
                Log probabilities of the model, shape (`batch_size`, `response_length`)
            ref_logprobs (`torch.FloatTensor`):
                Log probabilities of the reference model, shape (`batch_size`, `response_length`)
        """
        cnt = 0
        rewards, non_score_rewards = [], []
        for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
            # compute KL penalty (from difference in logprobs)
            kl = logprob - ref_logprob
            non_score_reward = -self.kl_ctl.value * kl
            non_score_rewards.append(non_score_reward)
            reward = non_score_reward.clone()
            last_non_masked_index = mask.nonzero()[-1]

            # reward is preference model score + KL penalty
            reward[last_non_masked_index] += score
            rewards.append(reward)
            if cnt < 20:
                print(torch.sum(kl))
                cnt += 1
        return torch.stack(rewards), torch.stack(non_score_rewards)

最终发现在reward曲线的后期, KL偏移最高能达到五六百之多, 最后决定设一个比较小的KL=0.001 (和paper [1] 一致)。在一些实验里我们有发现一个比较小的学习率在perplexity指标上会明显好一些。而值得注意的是[1]中设置的学习率要小得多, 文章中汇报的最大KL偏移也只有一两百左右, 我有尝试过5-e6的学习率, 结论是训练变得缓慢了很多 (需要一天多的时间进行训练), 但是并没有对KL偏移有明显改善,由于时间所限, 没有尝试更低的学习率了, 暂时不确定是超参数的设置问题还是TRL-PPO和 [1] 中实现的差异。我建议始终采样一些样本查看它们的KL估计以监测训练是否正常。

此外, 模型有时候回复会过短, 在ppo_trainer中有如下检查会报错, 一个办法是直接注释掉这个报错, 一个办法是对样本进行检测, 丢弃掉回复太短的样本, 两个方法我都试过似乎效果差不多。

def batched_forward_pass(
    ......
    
    if len(logprobs[j, start:end]) < 2:
    	raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
    
    ......

需要指出的是, 由于我们需要估计KL, 在TRL-PPO中, 我们不能随意调整生成的设置, 否则将很可能影响KL的估计:

generation_kwargs = {
    # "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000,
}

例如, 为了解决上面的回复太短的问题, 我们有尝试设置最短输出长度来强制模型输出更长的回复, 但是设置之后, 我们发现接近一半的KL估计都变为了负数。

训练

在PPO的训练中也会有模型自问自答生成多轮回复的问题, 并且在这个情况下是训不出来的, 所以我们也相应的去截断整个输出, 需要注意的是我们需要对应截断返回来的response_tensors:

output_min_length = 64
output_max_length = 128
output_length_sampler = LengthSampler(output_min_length, output_max_length)
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 1}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    with torch.no_grad():
        response_tensors = ppo_trainer.generate(
            query_tensors, 
            batch_size=1, ## adjust according to your memory source 
            return_prompt=False, 
            length_sampler=output_length_sampler, 
            **generation_kwargs)

    full_responses = tokenizer.batch_decode(response_tensors)
    clean_texts = [clean_text(tmp_text) for tmp_text in full_responses]
    clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
    lengths = [len(clean_tensor) for clean_tensor in clean_response_tensors]

    response_tensors = [response_tensors[i][:np.max([lengths[i]-2, 1])] for i in range(len(response_tensors))]

    batch["response"] = clean_texts

    texts_for_rewards = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts_for_rewards, **sent_kwargs)
    rewards = [output[0]["score"] for output in pipe_outputs]

在进行多番调参之后, 得到的PPO模型有一些奇怪的pattern, 首先PPO模型也会在输出里掺入大量随机的#, 因此需要和RAFT的训练一样加入一个检测来丢弃掉这些样本或者手动给予一个比较负面的奖励, 加入之后, PPO模型输出随机#的现象得到了缓解, 结果PPO开始复读 ``:) '' 这样一个颜表情了, 我试着再次惩罚这样一种在回复中加入大量 :) 的行为, 于是PPO开始复读 ;) 了。。。好在后面两个问题不算太严重,比例比较低,还能接受,由于DRL本身是比较黑箱的方法, 我们不太能直接得知模型倾向于生成这些颜表情的原因, 但我们猜测可能是RM对这类颜表情比较喜好, 使得PPO 利用了这种RM的缺陷。

TRL-PPO默认会使用一个随机的生成长度, 我们尝试了固定128输出长度和随机从[64, 128] 中抽取输出长度两种方式, 发现在其他设置合适的情况下都能学到比较好的reward, 但是后者似乎对于避免输出重复有一定帮助,最终得到的模型输出观感要更好一些。

PPO主要在调参上需要花费比较多的时间, 当参数合适时, 一次训练大概需要8~12个小时。

5 讨论

我们在下面展示一些随机抽样的例子,可以看到不管是 PPO 和 RAFT 都明显改变了模型回复的风格。整体而言, RAFT-aligned 模型通常倾向于用更多的细节回复,PPO 模型会更加礼貌而积极一些, 而 SFT 模型似乎不够 helpful, 很多时候没有按照指示给予建议。同时, 我们也观察到 PPO 会偶尔输出一些无意义的符号, RAFT 的回复有时候冗余的词有一些多。

我们认为这是因为奖励模型无法完全刻画一个回复的质量, 而 PPO 和 RAFT 都在某种程度上利用了奖励模型的这种不完美来获得高奖励。显然, 这只是 RLHF 探索的起始点, 我们还有许多改进的空间。为了进一步提高模型性能,例如, 我们可以改进奖励模型(例如使用 LLaMA-7B-RM), 我们也可以尝试一些更先进的生成策略来提升生成文本的质量 (例如 contrastive search, 见 https://zhuanlan.zhihu.com/p/629920420)。同时,请查看我们的 LMFlow 框架,以获取更多 LLMs 的乐趣:

OptimalScale/LMFlow: An Extensible Toolkit for Finetuning and Inference of Large Foundation Models. Large Model for All. (github.com)
https://github.com/OptimalScale/LMFlow

(以下图片由表格转换而来,为了显示方便,Prompt 中的 ### 替换成了换行,并以粗体呈现)

a490c93cf607b3b61deda0e3cf7e8144.png

5e2e006f15305c83ff71ac2b96460d7f.png

9ddae43a764aaeb5974e292c79289f50.png

4164f870eac7541d390e022f06190a00.png

527fd1fc4d36af7a44d6060cf8ee9529.png

[1] Training a helpful and harmless 326 assistant with reinforcement learning from human feedback



经作者授权,由 Hugging Face 账号在微信公众号平台标记原创发布,如需转载,请在本文下方留言。

作者知乎账号:「尊师重教章北海」,欢迎大家友好交流讨论。如果你有好的文章希望通过我们的平台分享给更多人,请通过这个链接与我们联系: 
https://huggingface.link/tougao

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/648358.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

架构思维的六要素

很多人都在私信问我随着ChatGPT等技术的兴起&#xff0c;自己的饭碗会不会没了。我的观点是&#xff1a;ChatGPT能做的工作十几年前早就没了。 十几年前还看到过招聘广告上写招程序员的&#xff0c;现在都是工程师起步&#xff0c;工程师是要有架构思维的。 像十几年前的黄金时…

手写SpringBoot启动器主要步骤

这里写目录标题 背景过程2.1自启动实现原理2.2手动实现SpringBoot自启动2.2.1宏观2.2.1微观2.2.1.1三个服务之间调用2.2.1.2自定义注解2.2.1.1业务组装2.2.1.3启动类 升华自定义注解&#xff1a;手动装配组件&#xff1a;简化启动过程&#xff1a;自动化注入依赖&#xff1a;简…

马原第三章复习1.

唯物史观和唯心史观过去出过小题 社会意识出小题 社会存在一般出大题 124 社会存在和社会意识(往往出大题) 社会历史性的基本问题 两种根本对立的历史史观 唯心主义: 至多考察人的思想动机 没有考虑思想背后的物质动因和经济根源 把社会史观看成人的思想,,不懂得…

关于酒店宾馆电气火灾隐患的预防与整改措施介绍 安科瑞 许敏

摘要&#xff1a;本文分析了酒店、宾馆电气火灾隐患的特点及产生的主要原因&#xff0c;并依此提出了消除火灾隐患的整改措施。 关键词:酒店宾馆&#xff1b;火灾&#xff1b;隐患&#xff1b;预防&#xff1b;整改&#xff1b;措施 1前言 随着旅居服务业的快速发展&#xf…

CVPR 2023 | 香港理工提出GrowSP:3D场景的无监督语义分割

点击下方卡片&#xff0c;关注“CVer”公众号 AI/CV重磅干货&#xff0c;第一时间送达 点击进入—>【目标检测和Transformer】交流群 GrowSP: Unsupervised Semantic Segmentation of 3D Point Clouds 论文链接&#xff1a;https://arxiv.org/abs/2305.16404 代码&#xff1…

2023/6/11总结

CSS Less嵌套 子元素的选择器可以直接写在父元素里面。 如果不是它的后代元素&#xff0c;比如你想写伪类选择器、交集选择器&#xff0c;需要在前面加&号。 Less运算&#xff1a; 加减乘除都可以&#xff0c;运算符必须用空格隔开。如果俩个元素都有单位&#xff0…

binfmt_misc

一&#xff1a;binfmt_misc是什么 binfmt_misc是内核中的一个功能&#xff0c;它能将非本机的二进制文件与特定的解析器自动匹配起来&#xff0c;进行二进制解析。 例如&#xff0c;在x86上解析arm64架构的二进制。 通过binfmt_misc可以注册解析器来处理指定二进制文件格式的请…

Qt|QDialog的创建及使用

文章目录 创建一个新的类继承QDialog设置标题去掉问号&#xff0c;只保留关闭使窗口在屏幕中心显示设置窗口大小QDialog没有任务栏窗口图标问题将窗口永远置于上层可见 不会被遮盖阻塞除当前窗口之外的所有窗口添加closeEvent hideEvent同理调用dialog类接收dialog返回状态&…

华为OD机试真题 JavaScript 实现【跳房子II】【2023 B卷 100分】,附详细解题思路

一、题目描述 跳房子&#xff0c;也叫跳飞机&#xff0c;是一种世界性的儿童游戏。 游戏参与者需要分多个回合按顺序跳到第1格直到房子的最后一格&#xff0c;然后获得一次选房子的机会&#xff0c;直到所有房子都被选完&#xff0c;房子最多的人获胜。 跳房子的过程中&…

XGBoost超参数调优指南

本文将详细解释XGBoost中十个最常用超参数的介绍&#xff0c;功能和值范围&#xff0c;及如何使用Optuna进行超参数调优。 对于XGBoost来说&#xff0c;默认的超参数是可以正常运行的&#xff0c;但是如果你想获得最佳的效果&#xff0c;那么就需要自行调整一些超参数来匹配你的…

jupyter lab升级或者安装插件后编译失败

错误回显 报错提示&#xff1a;please run ‘jupyter lab build’ on the server for full output&#xff0c;那么就进入prompt执行一下jupyter lab build 继续接着报错 If you dont already have a jupyter_config.py file, you can create one by adding a blank file of th…

线下实体衰落,真是电商惹祸?实是贪婪以及服务理念落后所致

网上时不时就有人指责电商的兴起导致了线下实体衰落&#xff0c;然而如果各位比较了线下实体与电商的差异&#xff0c;就会明白导致如此结果完全是咎由自取&#xff0c;因为线下实体太贪婪以及服务理念落后于时代。 笔者最近就购买了某款国产手机&#xff0c;在该国产手机品牌的…

集显独显并存,ubuntu安装显卡驱动的坑

一、安装和启动黑屏卡死 1、怎么办&#xff1f;显示器先接集显&#xff0c;完成驱动安装。 &#xff08;1&#xff09;屏蔽nouveau驱动 只要是安装过NVIDIA显卡驱动的&#xff0c;nouveau一般都被禁止了。可以通过命令&#xff1a; lsmod | grep nouveau 查看。如果没有任…

实现设备的延时控制

1. 引言 当搭建IoT管理后台后&#xff0c;APP、设备、云端三端就可以实现交互&#xff1b;当点击APP中的控制按钮&#xff0c;其控制指令就可以经过云端转发到设备执行&#xff0c;当设备执行后将设备的状态上报到云端&#xff0c;APP通过轮训可以取到设备此时的状态&#xff0…

Spring Boot 优雅集成 Spring Security 5.7(安全框架)

Spring Boot 集成 Spring Security &#xff08;安全框架&#xff09; 本章节将介绍 Spring Boot 集成 Spring Security 5.7&#xff08;安全框架&#xff09;。 &#x1f916; Spring Boot 2.x 实践案例&#xff08;代码仓库&#xff09; 介绍 Spring Security 是一个能够为基…

为行业变革注入新动能,行易道入选“高工智能汽车智驾榜单”

6月8日到9日&#xff0c;2023高工智能汽车开发者大会在上海成功举行。与会期间&#xff0c;北京行易道科技有限公司&#xff08;以下简称“行易道”&#xff09;营销副总裁袁泽雁带来了以“车载毫米波雷达进入成像时代”主题演讲&#xff0c;为大家分享了4D毫米波雷达如何以“新…

Java网络开发(Tomcat异步分页+增删改查)——从同步到异步 从jsp 到 js + axios + vue 实现 数据分页显示 数据增删改查

目录 引出一些固定的东西1.固定的响应格式2.name 变成 v-model 进行双向绑定3.下拉框选中--:value"type.id" v-model"companyDb.typeId"4.vue导包固定写法5.script固定写法6.axios的get请求7.axios的post请求---let params new URLSearchParams()8.前端美…

MYSQL 在优化器缺陷在次验证,与MYSQL 熄火了 还是 成熟了??

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;在新加的朋友会分到2群&#xff08;共…

聊聊我做 NeRF-3D重建性能优化经历

我们新推出大淘宝技术年度特刊《长期主义&#xff0c;往往从一些小事开始——工程师成长总结专题》&#xff0c;专题收录多位工程师真诚的心路历程与经验思考&#xff0c;覆盖终端、服务端、数据算法、技术质量等7大技术领域&#xff0c;欢迎一起沟通交流。 本文为此系列第四篇…

飞桨携手第二届GitLink开源夏令营,邀你参与顶尖开源项目!

想参与顶尖开源项目开发&#xff1f; 想熟悉开源社区参与流程&#xff1f; 想获得资深导师指导和丰厚现金奖励&#xff1f; 机会来啦&#xff01; 2016年9月&#xff0c;飞桨框架正式开源&#xff0c;其兼备易用性、高效性、灵活性和可扩展性等特点。如今&#xff0c;百度飞桨在…