StackLLaMA: A hands-on guide to train LLaMA with RLHF

news2025/1/6 18:51:31

Paper name

StackLLaMA: A hands-on guide to train LLaMA with RLHF

Paper Reading Note

Project URL: https://huggingface.co/blog/stackllama
Code URL: https://huggingface.co/docs/trl/index

TL;DR

  • Huggingface 公司开发的 RLHF 训练代码,已集成到 huggingface 的 trl 库中,在 Stack Exchange 数据集对 LLaMA 模型进行了微调。博客详细介绍了 SFT(有监督微调)、RM(奖励/偏好建模)和 RLHF(人类反馈的强化学习)的训练细节,并介绍了一些训练中可能遇到的问题及解决思路

Introduction

背景

  • ChatGPT、GPT-4 和 Claude 等模型是功能强大的语言模型,它们人类反馈强化学习 (RLHF) 的方法进行了微调,以使得它们的行为方式更好地符合我们的期望

本文方案

  • 在这篇博客文章中,我们展示了使用 SFT(有监督微调)、RM(奖励/偏好建模)和 RLHF(人类反馈的强化学习)相结合的方法,训练 LlaMa 模型回答 Stack Exchange (一个问答网站,每个答案有对应的用户点赞数目标注) 上的问题的所有步骤。
    在这里插入图片描述
  • 经过以上微调训练,本文训练了一个 StackLLaMA 模型,开源到了 Hub 上,整个训练流程也开源到了 trl

Dataset/Algorithm/Model/Experiment Detail

实现方式

LLaMA 模型

  • 在进行RLHF时,从一个有能力的模型开始非常重要:RLHF 步骤只是微调模型以使其与本文想要与其交互和期望其响应的方式相一致。因此,本文选择使用最近推出的性能出色的 LLaMA 模型。LLaMA 模型是由 Meta AI 开发的最新大型语言模型,大小从 7B 到 65B 参数不等,并在 1T 到 1.4T 个 token 数据集之间进行了训练,使其性能很强。本文使用 7B 模型作为所有后续步骤的基础

Stack Exchange 数据集

  • 收集人类反馈是一项复杂而昂贵的工作。为了引导这个例子的过程,同时仍然建立一个有用的模型,使用 Stack Exchange 数据集,数据集包括来自 StackExchange 平台的问题及其相应的答案(包括用于代码和许多其他主题的 StackOverflow)。这个数据集信息量很大,回复的答案与赞成票的数量和已接受答案的标签都有

  • 本文使用 A General Language Assistant as a Laboratory for Alignment 中提到的方法来给每个答案进行打分

    • score = round(log2 (1 + upvotes)) (注:这里用 log 的原因是人们一般优先看高赞回答,导致强者恒强,这里希望用 log 稍微拉低高赞回答的分数)
    • 被提问者接受的答案分数再加上 1
    • upvotes 为负的分数设置为 -1
  • 对于 reward model,每个问题需要两个回答用于对比。一些问题有几十个回复,导致有很多个匹配答案对,本文对每个问题最多采样 10 个答案对,以限制每个问题的数据数量。最后通过将 HTML 转换为 markdown 得到格式干净的数据,数据示例和处理脚本在:stack-exchange-paired

高效训练策略

  • 即便训练最小的 LLaMA 模型也需要大量的显存消耗,简单计算
    • 基于 bf16 进行参数存储,每个参数占用 2 bytes,Adam 优化器暂用 8 bytes,所以一个 7B 参数模型会消耗 (2+8)*7B=70GB 左右显存,计算注意力分数等中间值时可能需要更多显存
  • 本文使用 Parameter-Efficient Fine-Tuning (PEFT) 技巧,比如在 8 bit load 的模型上使用 LoRA
    • 以 8 bit 加载模型可显著减少显存占用,因为每个参数只需要一个 byte (例如 7B LlaMa 在显存中占用 7GB)
    • 在这种配置下,一般 1B 的参数需要 1.2~1.4Gb 的显存 (取决于批量大小和序列长度),80GB A100 一般可以训练 50-60B 的模型
  • 同时使用 dp 进行加速
    在这里插入图片描述

Supervised fine-tuning

  • 开始训练奖励模型和通过强化学习调整模型之前,如果模型在我们感兴趣的领域中表现良好,那么这会有所帮助。在本文的情况下,希望它能够回答问题,而对于其他用例,可能希望它能够遵循指令,这种情况下需要进行指令调整。实现这一点最简单的方法是使用来自该领域或任务的文本,继续使用语言建模目标对语言模型进行训练。StackExchange 数据集非常庞大(超过 1000 万条指令),因此可以轻松地在其中的一个子集上训练语言模型。
  • 利用与预训练阶段一样的 causal language modeling objective 损失来仅模型微调。为了有效地使用数据,本文使用了一种叫做“packing”的技术:不是在批次中每个样本都有一个文本,然后填充到模型的最长文本或最大上下文,而是将许多文本连接在一起,用 EOS token 分隔,并切割上下文大小的块来填充批次,无需任何填充。
    在这里插入图片描述
    采用这种方法,训练效率要高得多,因为每个通过模型的 token 都会被训练,而传统的数据读取方法会在损失计算中将填充的 token 排除掉。如果没有太多的数据,并且不希望有偶尔截断一些溢出上下文的 token 这种问题,也可以使用传统的数据加载器。上面描述的数据预处理方法在代码中是 ConstantLengthDataset 实现的
  • 模型使用 LoRA 方式进行训练,因为之后还需要使用不同的 loss 对模型进行训练,这里训练完成之后需要将 LoRA 的模型参数合入到原始模型中

Reward modeling and human preferences

  • 原则上,可以直接使用人类标注来进行 RLHF 微调模型。然而,这将需要在每次优化迭代之后向人类发送一些样本进行评分。由于收敛所需的训练样本数量较大以及人类阅读和标注速度的固有延迟,这是昂贵而缓慢的。一般是训练一个奖励模型 (reward model) 来代替人类标注。奖励模型的目标是模仿人类来评价一段文本。有几种可能的策略来构建奖励模型:最直接的方法是预测人类标注结果(例如评分分数或“好/坏”的二进制值)。在实践中,更好的方法是预测两个答案的排名,其中奖励模型输入为一个给定的 prompt x,以及两个基于 x 输入的回复 (yk, yj),奖励模型来预测哪一个会被人类注释者评价更高。奖励函数的 loss 设计为
    在这里插入图片描述
    其中 r 是模型的输出分时,yj 是两个回复中更好的回复,也即期望奖励模型对于更好的回复的打分需要尽量高,更差的回复的打分需要尽量低。loss 的代码实现如下
class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        rewards_j = model(input_ids=inputs["input_ids_j"],  attention_mask=inputs["attention_mask_j"])[0]
        rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
        loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
        if return_outputs:
            return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
        return loss
  • 实验配置
    • 训练数据使用了 100000 个候选对,评测使用了 50000 数据
    • batchsize 4,1 epoch
    • Adam,BF16
    • Lora rank 8,alpha 32
    • 8xA100 训练需要几个小时
  • 实验结果:67% 的准确率

Reinforcement Learning from Human Feedback

  • 基于前述的微调后的模型以及奖励模型进行强化学习训练,包含以下步骤
    • 基于 prompt 输入生成回复
    • 使用奖励模型对回复进行评级
    • 使用评级进行 reinforcement learning policy-optimization 更新

在这里插入图片描述

  • 查询和响应提示在被 token 化并传递给模型之前按如下方式模板化,该模板在 SFT,RM 和 RLHF 三个步骤中保持一致

    Question: <Query>
    Answer: <Response>
    
  • 使用 RL 训练语言模型的一个常见问题是,该模型可以通过生成完整的乱码来学习利用奖励模型,这会导致奖励模型分配高奖励。为了平衡这一点,在奖励中增加了一个惩罚:保留了一个没有训练的模型 (即 SFT 后的模型) 作为参考,并通过计算 KL-divergence 来对新模型的生成与参考模型的生成的相似性进行约束
    在这里插入图片描述

  • 整个 RLHF 的代码示例如下

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    question_tensors = batch["input_ids"]
        
    # sample from the policy and generate responses
    response_tensors = ppo_trainer.generate(
        question_tensors,
        return_prompt=False,
        length_sampler=output_length_sampler,
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

    # Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]

    # Run PPO step
    stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
    # Log stats to WandB
    ppo_trainer.log_stats(stats, batch, rewards
  • 实验配置
    • 3x8 A100-80GB 需要 20 h 的训练时间

实验结果

奖励模型训练

  • 准确率为 67%,作者的解释是任务比较难,人也不一定能做好

RL 模型训练

  • 训练过程中每个 batch 的 reward

在这里插入图片描述

  • 训练后的模型可以模仿人的回复,虽然不应该相信它关于 LLaMA 问题的建议,但答案看起来连贯,甚至提供了一个谷歌链接(这个直接在官网测试发现回复的字数会多很多)
    在这里插入图片描述

训练过程中的挑战

  • 高的 reward 不一定代表更好的性能
    在这里插入图片描述
    一般来说,在 RL 中希望获得最高的奖励。在 RLHF 中,因为使用了一个不完美的奖励模型,如果有机会,PPO 算法将利用这些不完美。这可能表现为奖励的突然增加,但是当查看策略生成的文本时,它们主要包含字符串 ( ```) 的重复,因为奖励模型发现 stack exchange 上包含代码块的答案通常比没有代码块的排名更高。这个可以通过 KL 惩罚来一定程度缓解

  • KL 在这里的实现不一定是正的值,因为本文采用了 KL 的估计值
    在这里插入图片描述
    可以看出来,当 policy 模型采样的 token 比 SFT 模型的概率低时,估计的 KL 值为负。但平均而言它将是正的,否则将无法从 policy 中正确抽样。然而一些生成策略会强制生成一些 token 或则强行抑制一些 token。例如,当批量生成时,完成的序列会被 pad,这时设置小的长度会导致 EOS token 被抑制。模型可以为那些导致负 KL 的 token 分配非常高或低的概率。由于 PPO 算法针对奖励进行优化,它会追逐这些负惩罚,导致不稳定
    在这里插入图片描述
    生成响应时需要小心,建议在求助于更复杂的生成方法之前始终先使用简单的采样策略

  • ppo 的 loss 有不稳定的现象暂时还没有解决
    在这里插入图片描述

Thoughts

  • 作者认为后续一些可以研究的点
    • 有了训练好的模型后可以与其他模型进行对比评测
    • 有了评测基建后可以尝试在数据集上做修改,比如过滤一些数据或增加一些数据
    • 不同模型架构和尺寸的对比

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/599663.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

产品设计-产品设计五要素

概念介绍 产品设计五要素分别是&#xff1a;战略层、范围层、结构层、框架层、表现层。自上而下的分析可用来分析已有的产品&#xff0c;自下而上分析则可以用来创造新的产品。下面是各个层级所包括的内容&#xff1a; 战略层&#xff1a;产品目标和用户需求&#xff08;做什…

【STL(2)】

STL&#xff08;2&#xff09; 知识点回顾函数对象函数对象理解系统的仿函数仿函数应用 容器适配器stackdeque queuepriority_queue mapmap使用插入访问下标访问的应用&#xff1a;计算文件中单词的个数 知识点回顾 在STL库中存在三个容器适配器&#xff0c;stack - queue - p…

西门子200系列PLC学习课程大纲(课程筹备中)

西门子200系列PLC学习课程大纲如下表所示&#xff0c;共106课&#xff0c;关注我&#xff0c;让你从菜鸟变大神。 第1课西门子200PLC概述S7-200 PLC新特性是什么第2课S7-200 PLC的CPU介绍第3课S7-200 PLC编程软件介绍第4课S7-200 PLC通信方式有哪些第5课S7-200 PLC显示面板介绍…

6.1——我在CSDN的创作纪念日

文章目录 ⭐前言⭐相遇CSDN⭐切换到编程赛道的契机&#x1f496; 好好的美工为什么切换编程赛道&#x1f496; 转换编程赛道的催化剂 ⭐写博客的目的——写给未来的自己&#x1f496; 初衷——为学习铺路&#x1f496; 博客是灯——照亮前行的路&#x1f496; 博客是路——互联…

wenet-基于预训练模型进行增量训练

1867-154075-0014 重中之重 run.sh脚本分析 wenet aishell脚本解析_weixin_43870390的博客-CSDN博客 一、准备工作 第一步&#xff1a;准备训练数据&#xff0c;拷贝到远程服务器 将准备好的数据文件0529_0531_dataset&#xff0c;上传到恒源云上的/hy-tmp/wenet/example…

数据结构与算法10:递归树、Trie树、B+树

目录 【递归树】 【Trie 树】 【B树】 【每日一练&#xff1a;最长公共前缀】 【递归树】 递归的思想是将大问题分解为小问题&#xff0c;然后再将小问题分解为更小的问题&#xff0c;直到问题的数据规模被分解得足够小&#xff0c;不用继续递归分解为止。如果把这个一层…

Effective第三版 中英 | 第2章 创建和销毁对象 | 用私有构造器或者枚举类型强化 Singleton 属性

文章目录 Effective第三版前言第二章 创建和销毁对象用私有构造器或者枚举类型强化 Singleton 属性 Effective第三版 前言 大家好&#xff0c;这里是 Rocky 编程日记 &#xff0c;喜欢后端架构及中间件源码&#xff0c;目前正在阅读 effective-java 书籍。同时也把自己学习该书…

如何在本地配置Github的项目--Python

如何在本地配置Github的项目 0. 引言1. 初步预览2. 配置环境2.1 环境已经给出2.2 环境未曾给出 3. 数据配置4. 依次调试5. 配置完成总结 0. 引言 Github上存在大量的代码。当下载下来后可能会存在疑惑&#xff1a;如何在本地配置对应的项目呢&#xff1f; 为了帮助新手解决这一…

【Android开发基础】购物车代码整理

文章目录 一、数据库设计二、Home界面三、购物车模块四、添加五、源代码 这个月总算忙完了&#xff0c;总算能够抽出时间来&#xff0c;认真写一下博客了。整理一下购物车的代码 一、数据库设计 基于SqLite简单设计一个数据存储逻辑 实体&#xff08;接收数据&#xff09; im…

【数据加密】古典密码Playfair

文章目录 一、引言1、主要任务2、分支3、密码体制分类4、攻击密码系统 二、普莱费厄体制1、构造字母表&#xff0c;设为密钥矩阵2、设立加密方法3、加密解密4、字典集合5、结果 一、引言 1、主要任务 解决信息的保密性和可认证问题&#xff0c;保证信息在生成、传递、处理、保…

Swin-Transformer详解

Swin-Transformer详解 0. 前言1. Swin-Transformer结构简介2. Swin-Transformer结构详解2.1 Patch Partition2.2 Patch Merging2.3 Swin Transformer Block2.3.1 W-MSA2.3.2 SW-MSA 3. 模型配置总结 0. 前言 Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章&#x…

数据的存储(浮点型)

目录 浮点型存储的规则 1.前面我们已经学过了整形在数据中的存储是以原码&#xff0c;反码&#xff0c;补码的形式在内存中存储的&#xff0c;那么浮点数是以什么样的形式存储的呢&#xff1f; 接下来我们通过一段代码来观察——> int main() {int n 9;float* p (float*…

String AOP的使用

面向切面编程&#xff0c;面向特定方法编程&#xff0c;以方法为对象&#xff0c;在不修改原方法的基础上&#xff0c;对方法进行操作扩展等&#xff0c;底层是通过动态代理实现的 使用开发步骤&#xff1a; 1、创建一个类&#xff0c;加上Aspect声明为一个AOP切面类&#xff…

2023 重新开始

感觉搞 IT 的日子最近都有点不太好过。 早上接到公司电话说今天是一个大日子。 为什么是大日子&#xff0c;相信所有人都是懂的。这次公司将会经历一次非常大的裁员&#xff0c;很不幸也在列表中。不过感觉这个好像也没有什么关系。 因为早就在意料之中的事情&#xff0c;经历…

c语言之结构体(初阶)

目录 1&#xff1a;结构体类型的声明 2&#xff1a;结构体初始化 3&#xff1a;结构体成员访问 4&#xff1a;结构体传参 1&#xff1a;结构体类型的声明 1&#xff1a;为啥要有结构体&#xff0c;因为当我们描述一个复杂对象的时候&#xff0c;可能平时我们的一个类型不能…

常见的五种排序

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C &#x1f525;座右铭&#xff1a;“不要等到什么都没有了&#xff0c;才下…

批量提取某音视频文案(二)

牙叔教程 简单易懂 之前写过一篇 批量提取某音视频文案 , 在之前的教程中, 我用的是微软的语音转文字功能, 今天我们换个方法, 使用 逗哥配音 的 文案提取 功能 准备工作 下载视频和音频 我在github找到的是这个仓库 https://github.com/Johnserf-Seed/TikTokDownload 注意一…

VLANIF虚接口案例实践

1&#xff09;拓扑 2&#xff09;需求&#xff1a; -所有PC能够ping通自己的网关 -实现vlan间互通&#xff0c;实现所有的PC互通 3&#xff09;配置步骤&#xff1a; 第一步&#xff1a;给pc配置IP地址 第二步&#xff1a;交换机创建vlan,做access和trunk -所有的交换机都配…

传统图形学对nerf的对比与应用落地

作者今年参加了China3DV的盛会&#xff0c;大会的发表、线下讨论、学者、工业界等等的交流着实对于Nerf有了更深的思考&#xff0c;以下是作者的抛砖引玉&#xff0c;如有不当之处敬请指出~ 传统图形学与nerf的简介&#xff1a; 传统图形学&#xff1a;显示表达几何表达方式&…

【CloudCompare教程】010:点云的裁剪功能(分段、裁剪、筛选)

本文讲解CloudCompare点云的裁剪功能(分段、裁剪、筛选)。 文章目录 一、点云的分段二、点云的裁剪三、点云的筛选一、点云的分段 加载案例点云数据,如下图所示: 选中图层点云,点击工具栏中的【分割】工具。 点击【激活线状选择】工具: 在需要裁剪的点云上绘制现状裁剪范…