大模型rlhf 相关博客

news2024/9/27 15:25:25

想学习第一篇博客:

https://huggingface.co/blog/zh/rlhf

RLHF 技术分解

RLHF 是一项涉及多个模型和不同训练阶段的复杂概念,这里我们按三个步骤分解:

  1. 预训练一个语言模型 (LM) ;
  2. 聚合问答数据并训练一个奖励模型 (Reward Model,RM) ;
  3. 用强化学习 (RL) 方式微调 LM。

  细化:

  1. 没啥好说的

  2.关于训练奖励数值方面,这里需要人工对 LM 生成的回答进行排名。起初我们可能会认为应该直接对文本标注分数来训练 RM,但是由于标注者的价值观不同导致这些分数未经过校准并且充满噪音。通过排名可以比较多个模型的输出并构建更好的规范数据集。这个过程中一个有趣的产物是目前成功的 RLHF 系统使用了和生成模型具有 不同 大小的 LM (例如 OpenAI 使用了 175B 的 LM 和 6B 的 RM,Anthropic 使用的 LM 和 RM 从 10B 到 52B 大小不等,DeepMind 使用了 70B 的 Chinchilla 模型分别作为 LM 和 RM) 。一种直觉是,偏好模型和生成模型需要具有类似的能力来理解提供给它们的文本。看来lm和rm可以不同模型.

   3. 用强化学习微调

  这是最核心的部分.长期以来出于工程和算法原因,人们认为用强化学习训练 LM 是不可能的。而目前多个组织找到的可行方案是使用策略梯度强化学习 (Policy Gradient RL) 算法、近端策略优化 (Proximal Policy Optimization,PPO) 微调初始 LM 的部分或全部参数。

  让我们首先将微调任务表述为 RL 问题。首先,该 策略 (policy) 是一个接受提示并返回一系列文本 (或文本的概率分布) 的 LM。这个策略的 行动空间 (action space) 是 LM 的词表对应的所有词元 (一般在 50k 数量级) ,观察空间 (observation space) 是可能的输入词元序列,也比较大 (词汇量 ^ 输入标记的数量) 。奖励函数 是偏好模型和策略转变约束 (Policy shift constraint) 的结合。

  PPO 算法确定的奖励函数具体计算如下:

  将提示 x 输入初始 LM 和当前微调的 LM,分别得到了输出文本 y1y2,将来自当前策略的文本y2传递给 RM 得到一个标量的奖励 $r_θ$

将两个模型的生成文本进行比较计算差异的惩罚项,在来自 OpenAI、Anthropic 和 DeepMind 的多篇论文中设计为输出词分布序列之间的 Kullback–Leibler (KL) divergence 散度的缩放,即 rθλrKL 。

   这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值。此外,OpenAI 在 InstructGPT 上实验了在 PPO 添加新的预训练梯度,可以预见到奖励函数的公式会随着 RLHF 研究的进展而继续进化。我感觉加入梯度是要让训练可控,保证逐渐收敛,而不是跳跃.

  也就是新旧模型参数变化不要太大.

  最后根据 PPO 算法,我们按当前批次数据的奖励指标进行优化 (来自 PPO 算法 on-policy 的特性) 。PPO 算法是一种信赖域优化 (Trust Region Optimization,TRO) 算法,它使用梯度约束确保更新步骤不会破坏学习过程的稳定性。DeepMind 对 Gopher 使用了类似的奖励设置,但是使用 A2C (synchronous advantage actor-critic) 算法来优化梯度。

最后我们再来理解一下第三步的图片流程:

  图片里面的J函数就是rθλrKL​. 我理解交替更新两个语言模型的参数.所以这里面写上对于tuned lm 做了参数freeze

训练 RM 需要的奖励标签规模大概是 50k 左右,所以并不那么昂贵 (当然远超了学术实验室的预算) 。

还有一些其他实现.我们继续读blogs.

第二篇:

https://huggingface.co/blog/zh/stackllama

现实中,我看到的rlhf代码跟上面博客里面的不同.因为我们可能无法同时开2个模型进行训练,并且再加上评估模型就更慢了. 还有我们数据集都是成对和打分的. 是现有数据集而不是上一篇博客数据集都是lm给出的. 所以我们还需要看其他的实现, 如何落地. 目前用过的就是trl库包.

这个博客里面的数据集是:https://huggingface.co/datasets/lvwerra/stack-exchange-paired

高效训练策略

即使是最小 LLaMA 模型的训练,都需要大量内存。估算一下: 以 bf16 半精度,每个参数用 2 个字节 (以 fp32 精度四字节的标准),训练时需要 8 个字节 (例如 Adam 优化器,参见 Tramsformers 的 性能文档)。可见 7B 参数量的模型将用 (2+8)* 7B = 70 GB 的内存,并且还可能需要更多用于计算诸如注意力分数的中间值。所以很难在一张 80GB 显存的 A100 上训练。或许你可以使用一些技巧,比如用更高效的半精度训练的优化器来压缩内存,但溢出是迟早的。

另外的可能是 参数高效的微调(Parameter-Efficient Fine-Tuning, PEFT) 技术,比如 peft 库,它可以对使用 8-bit 加载的模型做 低秩优化(Low-Rank Adaptation,LoRA)。

训练时需要 8 个字节 (例如 Adam 优化器). 

我们来证明这个结论:

论文:8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION: 第三页:

 因为每一次梯度传导, 上一层的梯度算完, 乘以当前层的矩阵即可得到. 所以梯度是不用每一个都存储的. 也就是g0到gt不用存.

所以对于Momentum算法. 我们只需要存mt, mt 就是所有参数的梯度. 所以对于32bit的状态, 我们会用4bit 来存储 . 同理Adam 存mt 和rt即可. 所以是8bite.

1GB=1000*1000*1000bite=10^9/4 float      所以4GB=10^9 float =1B model. 所以momentum 4GB可以训练1B模型.  Adam 8GB可以训练1B模型.

所以我们常用的llama7B  Adam 训练需要 56GB显存来做全参数训练.

 后续我会看trl 库的代码实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1073976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构和算法(10):B-树

B-树:大数据 现代电子计算机发展速度空前,就存储能力而言,情况似乎也是如此:如今容量以TB计的硬盘也不过数百元,内存的常规容量也已达到GB量级。 然而从实际应用的需求来看,问题规模的膨胀却远远快于存储能…

10.9作业

设计一个Per类&#xff0c;类中包含私有成员:姓名、年龄、指针成员身高、体重&#xff0c;再设计一个Stu类&#xff0c;类中包含私有成员:成绩、Per类对象p1&#xff0c;设计这两个类的构造函数、析构函数和拷贝构造函数。 #include <iostream>using namespace std;clas…

Java实现哈希表

1.哈希表定义 哈希表&#xff08;hash table&#xff0c;也叫散列表&#xff09;&#xff0c;是根据关键码值&#xff08;key value&#xff09;而直接进行访问的数据结构。也就是说&#xff0c;它通过把关键码值映射到表中一个位置来访问记录&#xff0c;以加快查找的速度。这…

【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入必要的工具包 1. Residual&#xff08;残差连接&#xff09; __init__&#xff08;初始化&#xff09; forward&#xff08;前向传播&#xff09; 2. resnet_block&#xff08;残…

9+代谢+分型,基于代谢通路对肝癌进行分型从而开展实验。

今天给同学们分享一篇代谢分型的生信文章“Bulk and single-cell transcriptome profiling reveal extracellular matrix mechanical regulation of lipid metabolism reprograming through YAP/TEAD4/ACADL axis in hepatocellular carcinoma”&#xff0c;这篇文章于2023年04…

【Linux 下 MySQL5.7 中文编码设置】

前言 原本要使用 Sqoop 把我 MySQL 的数据导入到 HBase 中&#xff0c;习惯了使用 windows 下的 MySQL 8.0 版本&#xff0c;但是用 Sqoop 从windows 传到 linux 下有点复杂&#xff0c;就索性用我自己之前没用过的 linux 下的 MySQL 5.7&#xff0c;结果果然一堆问题&#xff…

爱国者的润学日记-十月

首先需要科学的准备面试和润。如何进行科学的准备工作呢&#xff1f; 高效的按照面试考察内容进行针对性训练&#xff0c;按 Machine-learning-interview准备保证处于专注的心态&#xff0c;如今互联网娱乐发达&#xff0c;之前即使比赛时我也是一边比赛一边看视频。之后准备面…

MySQL:读写分离-amoeba(7)

环境介绍 mysql主服务器 192.168.254.1 mysql从服务器&#xff08;1&#xff09;192.168.254.2 mysql从服务器&#xff08;2&#xff09;192.168.254.3 amoeba代理服务器 192.168.254.4 测试服务器 192.168.254.5 此技术搭配主从复制&#xff0c;我的主服务器和从服务器都…

TS类中属性的封装

我们在如下的代码中&#xff0c;我们在类中设置属性&#xff0c;创建的对象可以随意修改自身的属性&#xff0c;对象中的属性可以任意被修改导致对象中的数据非常不安全。 // 创建一个Person类 class Person {name: string;age: number;constructor(name: string, age: number…

通道剪枝channel pruning

1、相关定义 过参数化&#xff1a;主要是指在训练阶段&#xff0c;在数学上需要进行大量的微分求解&#xff0c;去捕捉数据中微小的变化信息&#xff0c;一旦完成迭代式的训练之后&#xff0c;网络模型在推理的时候就不需要这么多参数。剪枝算法&#xff1a;核心思想就是减少网…

【【萌新的SOC学习之小水文系列】】

萌新的SOC学习之小水文系列 SD卡读写TXT文本实验 SD 卡共有 9 个引脚线&#xff0c;可工作在 SDIO 模式或者 SPI 模式。在 SDIO 模式下&#xff0c;共用到 CLK、CMD、DAT[3:0]六根信号线&#xff1b;在 SPI 模式下&#xff0c;共用到 CS&#xff08;SDIO_DAT[3]&#xff09;、…

栅形状的影响及可靠性的优化

栅形状的影响 VD-MOSFET单元结构采用平面栅极拓扑结构&#xff0c;栅极电极位于半导体的平坦上表面。虽然在这种结构中&#xff0c;在平面结处会发生电场增强&#xff0c;但在栅极电极处不会发生电场增强&#xff0c;因为栅极电极的边缘与高度掺杂的N源区重叠。栅极电极的边缘被…

新能源+低代码:百数服务商新领域,跨行业结合所碰撞出的新火花

新能源行业的兴起主要是在最近几年&#xff0c;特别是“双碳”目标提出后&#xff0c;中国的新能源行业迎来了快速发展的阶段。在政策支持和资本加持下&#xff0c;各种新能源和绿色发展基金设立&#xff0c;以新能源为主体的新型电力系统也得到了深化改革&#xff0c;大力推动…

Qt中QTimer定时器的用法

Qt中提供了两种定时器的方式一种是使用Qt中的事件处理函数&#xff0c;另一种就是Qt中的定时器类QTimer。 使用QTimer类&#xff0c;需要创建一个QTimer类对象&#xff0c;然后调用其start()方法开启定时器&#xff0c;此后QTimer对象就会周期性的发出timeout()信号。 1.QTimer…

十五、异常(6)

本章概要 Try-With-Resources 用法 揭示细节 异常匹配 Try-With-Resources 用法 在考虑所有可能失败的方法时&#xff0c;找出放置所有 try-catch-finally 块的位置变得令人生畏。确保没有任何故障路径&#xff0c;使系统远离不稳定状态&#xff0c;这非常具有挑战性。 Inp…

Unity ToLua热更框架使用教程(1)

从本篇开始将为大家讲解ToLua在unity当中的使用教程。 Tolua的框架叫LuaFramework&#xff0c;首先附上下载链接&#xff1a; https://github.com/jarjin/LuaFramework_UGUI_V2 这个地址的是UGUI的。 下载完之后导入项目&#xff0c;首先&#xff0c;我们要先让这个项目跑起…

老卫带你学---Datagrip连接clickhouse

Datagrip连接clickhouse Datagrip是一个DB可视化特别方便的软件&#xff0c;因为一些业务需要采用clickhouse&#xff0c;然而在download相关driver的时候出现各种问题&#xff0c;于是整理一下方案 1.需要下载clickhouse-jdbc的jar包&#xff0c;可以直接在sonatype上去下载…

C# 人像卡通化

效果 项目 代码 using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Threading.Tasks; using System.Windows.Forms;nam…

图像分割-Segment Anything实践

一、模型介绍 Segment Anything 模型是一种新的图像分割模型&#xff0c;它可以在不需要大量标注数据的情况下&#xff0c;对图像中的任何物体进行分割。这种方法可以帮助计算机视觉领域的研究人员和开发人员更轻松地训练模型&#xff0c;从而提高计算机视觉应用程序的性能。该…

超前预告 | 云原生?大模型?这届乌镇双态IT大会亮点有点多

石道旁的水面&#xff0c;轻轻泛着微光&#xff0c;几片墨绿缓缓飘下&#xff0c;荡起柔和的波纹&#xff0c;向对岸游去。这儿不似北方秋阳如火的躁动&#xff0c;这儿的秋色是安静的&#xff0c;里便是江南水乡乌镇…… 2023年&#xff0c;第六届双态IT乌镇用户大会将于10月…