RLHF中的PPO算法原理及其实现

news2024/11/23 8:27:12

RLHF中的PPO算法原理及其实现

ChatGPT是基于InstructGPT实现的多轮对话生成式大模型。ChatGPT主要涉及到的技术包括:

  • 指令微调(Instruction-tuning);
  • 因果推断(Causal Language Modeling);
  • 人类对齐(Human Alignment)

博主在之前的文章中已经介绍过关于指令微调以及相关Prompting技术的原理(可以详见:Prompt-Tuning——深度解读一种新的微调范式)以及关于GPT等因果语言模型的相关介绍:【预训练语言模型】GPT: Improving Language Understanding by Generative Pre-Training。那么除了如何训练一个基本的生成式模型外,大模型还需要关注于如何让生成式大模型更加符合人类价值观

在之前的文章InstructGPT原理讲解及ChatGPT类开源项目中已经介绍了ChatGPT以及最近开源的一些类ChatGPT模型是如何实现对齐的,这里我们也详细介绍一下InstructGPT中进行人类对齐的核心算法——RLHF(人类对齐的强化学习)PPO算法。

本篇文章主要参考下面两个参考资料:
【1】强化学习极简入门:通俗理解MDP、DP MC TC和Q学习、策略梯度、PPO
【2】基于DeepSpeed训练ChatGPT


一、RLHF PPO 算法原理

PPO算法是一种具体的Actor-Critic算法实现,比如在对话机器人中,输入的prompt是state,输出的response是action,想要得到的策略就是怎么从prompt生成action能够得到最大的reward,也就是拟合人类的偏好。

PPO算法涉及到两个策略:

  • 近端策略优化惩罚(PPO-penalty);
  • 近端策略优化裁剪PPO-clip。

重要性采样

因为在Actor-Critic训练时,策略函数参数进行优化后,上一轮策略采样的动作-状态序列就不能用了,因此需要进行重要性采样,来避免每次更新策略函数后的重复采样问题。当不能在分布p中采样数据,而只能从另外一个分布q中去采样数据时(q可以是任何分布)。

重要性采样的原理:在这里插入图片描述

KL散度约束:

重要性采样中,p和q分布不能查得太远,所以需要有KL散度施加约束。

Advantage:

Actor-Critic算法中,需要定义advantage,最简单的就是定义Reward-baseline,也可以定义为。其中 V π ( s ) V_{\pi}(s) Vπ(s)可以理解为当前状态 s s s下所有动作执行后得到的奖励的期望,而 Q π ( s , a ) Q_{\pi}(s, a) Qπ(s,a)表示当前状态 s s s下指定某一个动作 a a a得到的奖励。所以如果 A π ( s , a ) > 0 A_{\pi}(s, a)>0 Aπ(s,a)>0,则说明当前动作 a a a所获的奖励是大于整体期望的,所以应该极大化这个动作的概率。

总的来说,Advantage旨在通过正负值来告诉策略什么动作可以是可以得到正反馈,避免仅有Reward作为绝对值时所带来的高方差问题。

Advantage+重要性采样:

Advantage可以认为是重要性采样中的 f ( x ) f(x) f(x)。由于其在优化过程中参数是在变的,所以需要进行重要性采样,因此优化目标变为:

J θ ′ = E s t , a t ∼ π θ ′ [ p θ ( a t ∣ s t ) p θ ′ ( a t , s t ) A θ ′ ( s t , a t ) ] J^{\theta'}=\mathbb{E}_{s_t, a_t}\sim\pi_{\theta'}\bigg[\frac{p_{\theta}(a_t|s_t)}{p_{\theta'}(a_t, s_t)}A^{\theta'}(s_t, a_t)\bigg] Jθ=Est,atπθ[pθ(at,st)pθ(atst)Aθ(st,at)]

近端策略优化惩罚(PPO-penalty)

PPO算法之近端策略优化惩罚的原理如下图所示:
在这里插入图片描述

近端策略优化裁剪PPO-clip

优化目标改为下面:
在这里插入图片描述

公式的理解:
在这里插入图片描述

所以说,clip本质上也是约束两个分布不要差的太远,其相比KL散度来说,KL散度是在两个分布的输出logits上进行约束,而clip方法则是直接在概率比上做约束。


二、RLHF PPO算法实现

(1)首先初始化RLHF类和PPOTrainer

rlhf_engine = DeepSpeedRLHFEngine(
        actor_model_name_or_path=args.actor_model_name_or_path,
        critic_model_name_or_path=args.critic_model_name_or_path,
        tokenizer=tokenizer,
        num_total_iters=num_total_iters,
        args=args)
ppo_trainer =  DeepSpeedPPOTrainer
trainer = ppo_trainer(rlhf_engine, args)

初始化过程中,加载4个模型,包括Actor、SFT、Critic和Reward:

代码中的self.ref其实就是SFT模型

class DeepSpeedRLHFEngine():
    def __init__(self, actor_model_name_or_path, critic_model_name_or_path,
                 tokenizer, args, num_total_iters):
        self.args = args
        self.num_total_iters = num_total_iters
        self.tokenizer = tokenizer
        # 用训练好的SFT模型初始化Actor模型
        self.actor = self._init_actor(
            actor_model_name_or_path=actor_model_name_or_path)
        # 用训练好的SFT模型初始化SFT模型
        self.ref = self._init_ref(
            actor_model_name_or_path=actor_model_name_or_path)
        self.actor_ema = None
        if self.args.enable_ema:
            self.actor_ema = self._init_ema(
                actor_model_name_or_path=actor_model_name_or_path)
        # 用训练好的RW初始化Critic模型
        self.critic = self._init_critic(
            critic_model_name_or_path=critic_model_name_or_path)
        # 用训练好的RW初始化reward模型
        self.reward = self._init_reward(
            critic_model_name_or_path=critic_model_name_or_path)
        if self.args.critic_gradient_checkpointing:
            self.critic.gradient_checkpointing_enable()

(2)加载用于RLHF的训练数据

prompt_train_dataloader, num_total_iters = create_datasets(
        args=args, tokenizer=tokenizer, train_phase=3)

(3)RLHF整体训练过程如下
具体的流程可以详看代码注释,总的来说,主要流程包括:

  • 遍历每个epoch,每个epoch里遍历每个batch;
  • 对于每个batch,先采样一堆经验数据;
  • 根据经验数据,训练Actor和Critic模型
# 训练的总Epoch数
for epoch in range(args.num_train_epochs):
    # 遍历每一个Batch
    for step, (batch_prompt) in enumerate(prompt_train_dataloader):
        batch_prompt = to_device(batch_prompt, device)
        prompts = batch_prompt['prompt'] # prompt
        length = prompts.size(-1)
    	# 进行采样,并加入到经验池,详见(3.1)
        out = trainer.generate_experience(prompts)
        exp_dataset = exp_mini_dataset.add(out)

        if exp_dataset is not None:
            inner_iter = 0
            critic_loss, actor_loss = 0, 0
            average_reward = 0

            if args.actor_gradient_checkpointing:
                rlhf_engine.actor.gradient_checkpointing_enable()
        	# 从经验池中进行学习Epoch轮
            for ppo_ep in range(args.ppo_epochs):
                for i, (exp_data) in enumerate(exp_dataset):
                    # 得到actor和critic loss,详见(3.2)
                    actor_loss, critic_loss = trainer.train_rlhf(exp_data)
                    critic_loss += actor_loss.item()
                    actor_loss += critic_loss.item()
                    average_reward += exp_data["rewards"].mean()

                    inner_iter += 1
                    if args.enable_ema:
                        moving_average(rlhf_engine.actor,
                                       rlhf_engine.actor_ema,
                                       zero_stage=args.actor_zero_stage)
                # 每一轮结束后打乱经验池
                random.shuffle(exp_dataset)
            average_reward = get_all_reduce_mean(average_reward).item()
        if args.actor_gradient_checkpointing:
            rlhf_engine.actor.gradient_checkpointing_disable()

这个训练过程主要包括两个核心步骤:

  • 采样Experience数据;
  • 根据采样的数据训练Actor和Critic模型。

下面详细分析一下这两个核心步骤,理解了这两个核心步骤也就差不多理解了RLHF PPO算法了。

Experience采样

图来自这里。

实现细节详见代码及注释:

def generate_experience(self, prompts):
    self.eval() # 开启eval模式
    # 输入instruct prompt,由Actor生成seq,上图中红色步骤(1),seq由instruct和response组成
    seq = self._generate_sequence(prompts)
    self.train() # 恢复训练模型
    pad_token_id = self.tokenizer.pad_token_id
    attention_mask = seq.not_equal(pad_token_id).long()
    with torch.no_grad():
        # 将seq喂入actor中得到action_logits,上图中棕色步骤(2)
        output = self.actor_model(seq, attention_mask=attention_mask)
        # 将seq喂入SFT中得到sft_logits,上图中黑色步骤(5)
        output_ref = self.ref_model(seq, attention_mask=attention_mask)
        # 将seq喂入reward模型中打分,得到r(x,  y),上图绿色步骤(4)
        reward_score = self.reward_model.forward_value(
            seq, attention_mask,
            prompt_length=self.prompt_length)['chosen_end_scores'].detach(
            )
        # 将seq喂入critic,获得critic的value,上图蓝色步骤(3)
        values = self.critic_model.forward_value(
            seq, attention_mask, return_value_only=True).detach()[:, :-1]

    logits = output.logits
    logits_ref = output_ref.logits
	# 获得经验数据
    return {
        'prompts': prompts,
        'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
        'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:, 1:]),
        'value': values,
        'rewards': reward_score,
        'input_ids': seq,
        "attention_mask": attention_mask
    }

获得Advantage,并更新Actor和Critic参数

在这里插入图片描述

def train_rlhf(self, inputs):
	# 当前RLHF轮次最初采样的经验池中采样一批数据
    prompts = inputs['prompts'] # instruct prompt
    log_probs = inputs['logprobs'] # actor模型生成response对应的action_logist
    ref_log_probs = inputs['ref_logprobs'] # SFT模型生成response对应的sft_logits
    reward_score = inputs['rewards'] # reward模型预测的奖励r(x, y)
    values = inputs['value'] # critic模型预测的奖励
    attention_mask = inputs['attention_mask']
    seq = inputs['input_ids']

    start = prompts.size()[-1] - 1
    action_mask = attention_mask[:, 1:]
	### 根据经验数据,接下来计算相应的reward和advantage
    old_values = values # RLHF训练之前的critic给的value值
    with torch.no_grad():
        # 根据SFT的sft_logits和Actor的action_logist,计算KL散度;
        # 并根据KL散度与reward模型预测的奖励r(x, y),获得最终奖励
        # 上图中红色步骤(1)
        old_rewards = self.compute_rewards(prompts, log_probs,
                                           ref_log_probs, reward_score,
                                           action_mask)
        """
    	    def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                        action_mask):
                # self.kl_ctl = 0.02
                kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
                rewards = kl_divergence_estimate
                start = prompts.shape[1] - 1
                ends = start + action_mask[:, start:].sum(1)
                reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                          self.clip_reward_value)
                batch_size = log_probs.shape[0]
                for j in range(batch_size):
                    rewards[j, start:ends[j]][-1] += reward_clip[j]

		"""
        # 由critic或的的value与前面根据KL散度和r(x, y)得到的reward,从而计算得到advantage
        # 上图蓝色步骤(2)
        advantages, returns = self.get_advantages_and_returns(
            old_values, old_rewards, start)
    	"""
        def get_advantages_and_returns(self, values, rewards, start):
            # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
            lastgaelam = 0
            advantages_reversed = []
            length = rewards.size()[-1]
            for t in reversed(range(start, length)):
                nextvalues = values[:, t + 1] if t < length - 1 else 0.0
                delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
                lastgaelam = delta + self.gamma * self.lam * lastgaelam
                advantages_reversed.append(lastgaelam)
            advantages = torch.stack(advantages_reversed[::-1], dim=1)
            returns = advantages + values[:, start:]
            return advantages.detach(), returns

        """
	### 根据经验数据以及得到的advatage,下面开始获得一系列的loss
    batch = {'input_ids': seq, "attention_mask": attention_mask}
    # 将这一批经验数据的seq(instruct prompt+response)再一次喂入actor得到logits
    # 因为现在是在更新actor和critic,而经验数据所采用的actor和critic早已经是之前的了,所以
    # 现在正在更新的actor和critic与当时进行经验采样时的actor、critic的参数已经有差异了;
    # 所以需要重新获得当前最新的actor输出的logits
    # 上图中棕色步骤(3)
    actor_prob = self.actor_model(**batch, use_cache=False).logits
    actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
                                      inputs['input_ids'][:, 1:])
    # 根据新的actor logits以及经验数据中的logits,以及advantage,计算actor loss
    # 上图中绿色步骤(4)
    actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
                                    log_probs[:, start:], advantages,
                                    action_mask[:, start:])
    """
    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio)
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss
    """
    # 更新actor模型参数
    self.actor_model.backward(actor_loss)
    self.actor_model.step()
    # 经验数据中的seq(instruct prompt+response)再一次喂入critic得到value
    # 同理,由于当前的critic和当初进行经验数据采样时的critic相差很远;所以需要重新获得value
    # 上图中黑色步骤(5)
    value = self.critic_model.forward_value(**batch,
                                            return_value_only=True,
                                            use_cache=False)[:, :-1]
    # 根据最新的critic的value,经验数据的old_value,以及advatage,计算得到critic loss
    critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
                                                                   start:],
                                      returns, action_mask[:, start:])
    """
	def critic_loss_fn(self, values, old_values, returns, mask):
        ## value loss
        values_clipped = torch.clamp(
            values,
            old_values - self.cliprange_value,
            old_values + self.cliprange_value,
        )
        vf_loss1 = (values - returns)**2
        vf_loss2 = (values_clipped - returns)**2
        vf_loss = 0.5 * torch.sum(
            torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
        return vf_loss
    """
    # 更新critic参数
    self.critic_model.backward(critic_loss)
    self.critic_model.step()
    return actor_loss, critic_loss

博主会不断更新关于大模型方面更多技术,相关文章请见:

【1】详谈大模型训练和推理优化技术
【2】Prompt-Tuning——深度解读一种新的微调范式
【3】InstructGPT原理讲解及ChatGPT类开源项目
【4】基于DeepSpeed训练ChatGPT
【5】【HuggingFace轻松上手】基于Wikipedia的知识增强预训练
【6】Pytorch单机多卡GPU的实现(原理概述、基本框架、常见报错)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/555845.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始Vue3+Element Plus后台管理系统(十五)——多语言国际化vue I18n

i18n国际化的内容比较多&#xff0c;写文章的时间也用得比较长&#xff0c;从上周五开始到本周一&#xff0c;断断续续完成了。 虽然实际工作中很多项目都不需要国际化&#xff0c;但是了解国际化的用法还是很有必要的。 i18n Vue I18n 是 Vue.js 的国际化插件。它可以轻松地…

PFC-FLAC3D Coupling Examples

目录 PFC-FLAC3D Coupling Examples Punch Indentation of a Bonded Material Sleeved Triaxial Test of a Bonded Material 命令流 结果 PFC-FLAC3D Coupling Examples Punch Indentation of a Bonded Material 这个例子展示了一个粘合颗粒模型&#xff08;BPM&#xff0…

项目经历该如何写?

大家好&#xff0c;我是帅地。 这不春招来了吗&#xff0c;帮训练营的帅友们修改了很多简&#xff0c;其中问题最多的就是项目经历 专业技能这块了&#xff0c;特别是项目经历这块&#xff0c;很多人写了一大堆描述功能描述&#xff0c;但是自己具体干了什么却没怎么写&#…

研发工程师玩转Kubernetes——使用Deployment进行多副本维护

多副本维护是指&#xff0c;对一组在任何时候都处于运行状态的 Pod 副本的稳定集合进行维护。说的直白点&#xff0c;就是保证某种的Pod数量会被自动维持——增加了该类Pod会自动删除多余的&#xff0c;减少了该类Pod会自动新增以弥补&#xff0c;以保证Pod数量不变。 Kubernet…

day37_Tomcat_Maven

今日内容 一、Maven 二、Tomcat 一、Maven 1.1 引言 项目管理问题 项目中jar包资源越来越多&#xff0c;jar包的管理越来越沉重。 繁琐 要为每个项目手动导入所需的jar&#xff0c;需要搜集全部jar 复杂 项目中的jar如果需要版本升级&#xff0c;就需要再重新搜集jar 冗余 相…

基于Spring-动态调整线程池阻塞队列长度

最近在做一个动态线程池的组件&#xff0c;遇到了关于阻塞队列长度刷新的问题,所以记录下来&#xff0c;很有意思 我们都知道常用线程池分为二类&#xff0c;Spring-ThreadPoolTaskExecutor和JDK-ThreadPoolExecutor的&#xff0c;当然了Spring也是基于JDK做一步封装&#xff0…

​数据库原理及应用上机(实验四 SQL连接查询)

✨作者&#xff1a;命运之光 ✨专栏&#xff1a;数据库原理及应用上机实验 目录 ✨一、实验目的和要求 ✨二、实验内容及步骤 ✨三&#xff0e;实验结果 ✨四、实验总结 &#x1f353;&#x1f353;前言&#xff1a; 数据库原理及应用上机实验报告的一个简单整理后期还会不…

Zerto 10.0 发布 - 勒索软件防护、灾难恢复和多云移动性的统一解决方案

Zerto 10.0 发布 - 勒索软件防护、灾难恢复和多云移动性的统一解决方案 请访问原文链接&#xff1a;https://sysin.org/blog/zerto-10/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 携手 ZERTO 提升勒索软件保护与灾难恢复水…

Python异常处理

1. 异常概述 在程序运行过程中&#xff0c;经常会遇到各种错误&#xff0c;这些错误称为“异常”。这些异常有的是由于开发者一时疏忽将关键字敲错导致的&#xff0c;这类错误多数产生的是SyntaxError:invalid syntax&#xff08;无效的语法&#xff09;&#xff0c;这将直接导…

JVM笔记

Java中对象一定分配在堆空间上吗&#xff1f;判断一个对象是否还活着GCgc频繁 Java中对象一定分配在堆空间上吗&#xff1f; 逃逸分析&#xff1a;分析对象动态作用域&#xff0c;当一个对象在方法中被定义后&#xff0c;它可能被外部方法所引用&#xff0c;例如作为调用参数传…

Redis6.2.5安装布隆过滤器BloomFilter

最近学习需要用到布隆过滤器&#xff0c;所以去RedisLabsModules下载RedisBloom插件&#xff0c;简单介绍一下安装的过程&#xff0c;首先需要先安装好Redis&#xff0c;建议使用Redis6以上版本&#xff0c;Redis安装教程查看https://smilenicky.blog.csdn.net/article/details…

什么是客户自助服务门户及其搭建方法

随着信息技术的快速发展&#xff0c;越来越多的企业开始转向以客户为中心的服务模式&#xff0c;而客户自助服务门户&#xff08;Customer Self-Service Portal&#xff09;则成为了重要的服务方式。它可以让客户在不需要人工干预的情况下&#xff0c;自行解决问题&#xff0c;…

chatgpt赋能Python-python_ai建模

用Python构建AI模型&#xff1a;一步步解析 随着人工智能技术的发展和普及&#xff0c;越来越多的企业开始寻找高效可靠的AI建模技术来提高业务水平和竞争力。Python作为一种强大的编程语言和开发工具&#xff0c;在AI建模领域也扮演着重要的角色。本文将介绍Python AI建模的基…

chatgpt赋能Python-python_ai下载

Python AI 下载&#xff1a;实现自动化数据处理的利器 介绍 Python作为一种脚本语言&#xff0c;凭借其简洁灵活的语法、强大的库支持和生态系统&#xff0c;成为了今天最流行的编程语言之一。在人工智能领域&#xff0c;Python也是最常用的语言之一&#xff0c;因为它的开发…

18-04 数据库分布式架构

分布式ID UUID 优点&#xff1a; 使用简单无需引入额外组件 缺点 无序&#xff0c;无法实现范围查询插入操作比自增ID性能差不少&#xff08;大概四倍&#xff09;建议用自增ID&#xff08;表的主键&#xff09; UUID&#xff08;唯一标识&#xff09; Redis Incr指令优点…

Hudi系列23:常见问题

文章目录 一. 存储一直看不到数据二. 数据有重复三. NoSuchMethodError3.1 问题描述3.2 解决方案3.2.1 查看源码3.2.2 avro版本问题3.2.3 hudi-flink1.14-bundle jar包的问题 四. Merge On Read 写只有 log 文件4.1 问题描述4.2 解决方案1(测试未通过)4.2 解决方案2(测试通过:)…

[LitCTF 2023]Flag点击就送!(cookie伪造)

随便输一个名字 尝试admin 但是我们在cookie里找到了一些东西 session&#xff1a;"eyJuYW1lIjoiYWRtaW4ifQ.ZGs1vw.7ikpuOhUtXxyB2UV-FH7UGIZkaE" 想到session伪造 先说一下session的作用&#xff1a; 由于http协议是一个无状态的协议&#xff0c;也就是说同一个用…

chatgpt赋能Python-pythonseries访问元素

Python Series: 访问元素 在Python中&#xff0c;我们可以使用列表&#xff08;List&#xff09;、元组&#xff08;Tuple&#xff09;和字典&#xff08;Dictionary&#xff09;等可迭代对象存储和处理数据。在处理这些可迭代对象时&#xff0c;我们经常需要对它们的元素进行…

【WSN覆盖】基于麻雀搜索算法的二维混合无线传感器网络覆盖优化 WSN覆盖空洞修复【Matlab代码#24】

文章目录 【可更换其他算法&#xff0c;获取资源请见文章第6节&#xff1a;资源获取】1. SSA算法2. WSN节点感知模型3. 混合WSN覆盖优化4. 部分代码展示5. 仿真结果展示6. 资源获取 【可更换其他算法&#xff0c;获取资源请见文章第6节&#xff1a;资源获取】 1. SSA算法 网上…

数字逻辑(计科专业)

半加器 用与非门实现 全加器 编码器 编码就是将信息装换成独特的代码或信号输出的电路 普通编码器&#xff1a;任何时候只允许输入一个有效编码信号&#xff0c;否则输出就会发生混乱。 优先编码器&#xff1a;允许同时输入两个以上的有效编码信号。当同时输入几个有效编码信…