【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析

news2025/2/24 18:43:30

【复现DeepSeek-R1之Open R1实战】系列博文链接:
【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
【复现DeepSeek-R1之Open R1实战】系列3:基础知识介绍
【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!
【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码结构解析
【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析
【复现DeepSeek-R1之Open R1实战】系列8:混合精度训练、DeepSpeed、vLLM和LightEval介绍
【复现DeepSeek-R1之Open R1实战】系列9:有趣的现象——GRPO训练过程Loss从0开始慢慢变大

目录

    • 4.6 GRPO训练过程
      • 4.6.1 GRPO原理
      • 4.6.2 设置参考模型
      • 4.6.3 从训练集中抽取问题
      • 4.6.4 旧策略模型生成G个输出
      • 4.6.5 对每个输出用奖励模型 RM 打分
      • 4.6.6 根据目标函数做梯度更新

4.6 GRPO训练过程

我们挑一些重点部分的代码来分析。

4.6.1 GRPO原理

在分析源码之前,我们再回顾一下GRPO的原理。

强化学习的介绍可以参考该博文:【DeepSeek-R1背后的技术】系列三:强化学习(Reinforcement Learning, RL)。

核心思想如下图所示:

GRPO

核心动机:在许多实际应用中,奖励只有在序列末端才给一个分数(称之为 Result/Oucome Supervision),或在每一步给一些局部分数(Process Supervision)。不管怎么样,这个奖励本身往往是离散且比较稀疏的,要让价值网络去学习每个token的价值,可能并不划算。而如果我们在同一个问题 q 上采样多份输出 o1, o2, … , oG,经过奖励模型Reward Model之后得到对应到奖励 r1, r2, … , rG,对它们进行奖励对比,就能更好地推断哪些输出更好。由此,就能对每个输出的所有 token 做相对评分,无须明确地学到一个价值函数。

在数理推理、数学解题等场景,这个技巧尤其管用,因为常常会基于同一个题目 q 生成多个候选输出,有对有错,或者优劣程度不同。那就把它们的奖励进行一个分组内的比较,以获取相对差异,然后把相对优势视为更新策略的依据。

关键点1:分组采样与相对奖励

GRPO 中,“分组”非常关键:我们会在一个问题 q 上,采样 GRPO 份输出 o1, o2, … , oG,然后把这组输出一起送进奖励模型(或规则),得到奖励分 r = {r1, r2, … , rG},先对r做归一化(减去均值除以标准差),从而得出分组内的相对水平,这样就形成了相对奖励 r’i,最后我们把这个相对奖励赋给该输出对应的所有 token 的优势函数。简单来说:多生成几份答案,一起比较,再根据排名或分数差更新,能更直接、简洁地反映同一问题下的优劣关系,而不需要用一个显式的价值网络去学习所有中间时刻的估计。

关键点2:无需价值网络的高效策略优化

因为不再需要在每个 token 上拟合一个价值函数,我们就能大幅节省内存,因为不必再维护和 Actor 同样大的 Critic 模型。这不仅是存储层面的解放,也是训练过程中的显著加速。当然,GRPO 也会引入一些新的代价:我们要为每个问题采样一组输出(不止一条),意味着推理时要多花点算力去生成候选答案。这种方法和“自洽性采样(Self-consistency)”思路也有点类似。

具体流程如下:

伪代码

分组相对奖励A’i,t的计算方法:

我们先把每个oi的奖励ri做归一化 r’i = ( ri - mean( r ) ) / std( r ),然后令A’i,t = r’i,也就是说,输出oi的所有 token 共享同一个分数r’i。它们的好坏相对于该分组内的平均水平来衡量,而不依赖外部价值网络去“拆分”或“插值”。这样我们就得到了一个无价值网络的优势函数,核心思路就是基于相互间的比较与排序。

如果用的是过程监督(process supervision),即在推理过程中的每个关键步骤都打分,那么就会略有不同。那时每个步骤都有一个局部奖励,就可以把它依时间序列累加或折算成与 token 对应的优势。

过程监督VS结果监督:过程奖励与末端奖励的对比

  • 结果监督(Outcome Supervision):只有输出序列结束才打一个奖励,如回答对/错、得分多少。GRPO 则把这个 r rr 同样分配给序列里每个 token。
  • 过程监督(Process Supervision):对中间推理步骤也有打分(比如计算正确一步就+1,错误一步就-1)。那就得收集多个时刻的奖励,然后累加到每个 token 或步骤上,再做分组相对化。

那么问题来了,batch内如何分组?在实际操作中,我们往往会在一个 batch 中包含若干个问题 q ,对每个问题生成 G 个答案。也就是说 batch 大小 = B,每个问题生成 G 个候选,那么一次前向推理要生成 B ∗ G 条候选。然后,每个候选都送进奖励模型得到分数ri。这样做推理开销不小,如果 G 较大,会显著地增加生成次数,但换来的好处是,我们不再需要价值网络了。

延伸:迭代式强化学习——奖励模型的更新与回放机制

在实际用 GRPO 的时候,如果奖励模型 RM 也是学习得来的,那么当策略模型变强时,RM 所得到的训练样本分布会越来越“难”,这时 RM 自身也需要更新。这样就会出现迭代强化学习流程:先用当前 RM 来指导一轮策略更新,然后再用新策略生成的数据来更新 RM。为了避免灾难性遗忘,可以保留一部分旧数据(回放机制 replay buffer),让 RM 每次都在新旧数据上共同训练,这样 RM 不会完全忘记之前的问题特征。

接下来,我们就按照上面的流程,详细解读源码。

4.6.2 设置参考模型

        # Reference model
        if is_deepspeed_zero3_enabled():
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif not is_peft_model(model):
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model)
        else:
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None

这段代码片段展示了如何根据不同的条件创建或配置一个参考模型(ref_model),主要用于深度学习中的模型训练和评估。以下是详细的解析:

  1. DeepSpeed ZeRO-3 启用时

    • is_deepspeed_zero3_enabled():检查是否启用了 DeepSpeed 的 ZeRO-3 零冗余优化器。
    • 如果启用,则从预训练模型中加载一个因果语言模型(Causal Language Model)作为参考模型,并使用 model_init_kwargs 中的参数进行初始化。
  2. PEFT 模型未启用时

    • is_peft_model(model):检查当前模型是否是 PEFT(Parameter-Efficient Fine-Tuning)模型。
    • 如果不是 PEFT 模型,则调用 create_reference_model(model) 创建一个基于初始模型的参考模型。
  3. PEFT 模型启用时

    • 如果是 PEFT 模型,则不需要创建参考模型,因为可以通过禁用适配器(adapter)来恢复到初始模型状态,因此将 self.ref_model 设为 None

4.6.3 从训练集中抽取问题

采样器是RepeatRandomSampler类,主要是通过_prepare_inputs函数准备输入数据的。

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
class RepeatRandomSampler(Sampler):
    """
    Sampler that repeats the indices of a dataset N times.

    Args:
        data_source (`Sized`):
            Dataset to sample from.
        repeat_count (`int`):
            Number of times to repeat each index.
        seed (`Optional[int]`):
            Random seed for reproducibility (only affects this sampler).

    Example:
    ```python
    >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
    >>> list(sampler)
    [2, 2, 0, 0, 3, 3, 1, 1]
    ```
    """

    def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
        self.data_source = data_source
        self.repeat_count = repeat_count
        self.num_samples = len(data_source)
        self.seed = seed
        self.generator = torch.Generator()  # Create a local random generator
        if seed is not None:
            self.generator.manual_seed(seed)

    def __iter__(self):
        indexes = [
            idx
            for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
            for _ in range(self.repeat_count)
        ]
        return iter(indexes)

    def __len__(self):
        return self.num_samples * self.repeat_count


    def _get_train_sampler(self) -> Sampler:
        # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
        # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
        # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
        # preventing discrepancies in group formation.
        return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)

    def _get_eval_sampler(self, eval_dataset) -> Sampler:
        # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
        # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
        # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
        # preventing discrepancies in group formation.
        return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)

4.6.4 旧策略模型生成G个输出

同样在_prepare_inputs函数函数里,通过self.llm.generate()函数生成了G个输出,并做了一系列后处理操作:

# Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts_text = gather_object(prompts_text)
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
                completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
            else:
                completion_ids = [None] * len(all_prompts_text)
            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts),
                (self.accelerator.process_index + 1) * len(prompts),
            )
            completion_ids = completion_ids[process_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            # Regular generation path
            with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                prompt_completion_ids = unwrapped_model.generate(
                    prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
                )

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

4.6.5 对每个输出用奖励模型 RM 打分

  1. Reward Model初始化
        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        for i, reward_func in enumerate(reward_funcs):
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
                    reward_func, num_labels=1, **model_init_kwargs
                )
        self.reward_funcs = reward_funcs

        # Reward weights
        if args.reward_weights is not None:
            if len(args.reward_weights) != len(reward_funcs):
                raise ValueError(
                    f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
                    f"functions ({len(reward_funcs)})"
                )
            self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
        else:
            self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
  1. 计算奖励分数

计算过程是在_prepare_inputs函数里实现的,主要功能模块如代码注释所示,整个计算过程和我们在上一节原理介绍里能一一对应上:


        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        for i, (reward_func, reward_processing_class) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes)
        ):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                if is_conversational(inputs[0]):
                    messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                    texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                else:
                    texts = [p + c for p, c in zip(prompts, completions)]
                reward_inputs = reward_processing_class(
                    texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                )
                reward_inputs = super()._prepare_inputs(reward_inputs)
                with torch.inference_mode():
                    rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
            else:
                # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
                reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
                output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
                rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        rewards_per_func = gather(rewards_per_func)

        # Apply weights to each reward function's output and sum
        rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        advantages = advantages[process_slice]

4.6.6 根据目标函数做梯度更新

在compute_loss函数里,根据每个生成的优势分数advantages计算对应的损失,并加上KL正则。

梯度更新在Trainer的主函数train()里实现了,可以参考前一篇博文介绍:【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析。

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        # Compute the per-token log probabilities for the model

        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        ref_per_token_logps = inputs["ref_per_token_logps"]
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

        # x - x.detach() allows for preserving gradients from x
        advantages = inputs["advantages"]
        per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

        # Log the metrics
        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)

        mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2304529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Qt】可爱的窗口关闭确认弹窗实现

文章目录 ​​​实现思路界面构建交互逻辑实现颜色渐变处理圆形部件绘制 代码在主窗口的构造函数中创建弹窗实例ExitConfirmDialog 类代码ColorCircleWidget 类代码 今天在Qt实现了这样一个可互动的窗口(上图由于录屏工具限制没有录制到鼠标) ​​​实现…

计算机毕业设计SpringBoot+Vue.jst网上购物商城系统(源码+LW文档+PPT+讲解)

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…

自制操作系统前置知识汇编学习

今天要做什么? 为了更好的理解书中内容,需要学习下进制分析和汇编。 汇编语言其实应该叫叫机器指令符号化语言,目前的汇编语言是学习操作系统的基础。 一:触发器 电路触发器的锁存命令默认是断开的,是控制电路触发器…

Unity制作游戏——前期准备:Unity2023和VS2022下载和安装配置——附安装包

1.Unity2023的下载和安装配置 (1)Unity官网下载地址(国际如果进不去,进国内的官网,下面以国内官网流程为例子) unity中国官网:Unity中国官网 - 实时内容开发平台 | 3D、2D、VR & AR可视化 …

深度学习(5)-卷积神经网络

我们将深入理解卷积神经网络的原理,以及它为什么在计算机视觉任务上如此成功。我们先来看一个简单的卷积神经网络示例,它用干对 MNIST数字进行分类。这个任务在第2章用密集连接网络做过,当时的测试精度约为 97.8%。虽然这个卷积神经网络很简单…

【HarmonyOS Next】拒绝权限二次申请授权处理

【HarmonyOS Next】拒绝权限二次申请授权处理 一、问题背景: 在鸿蒙系统中,对于用户权限的申请,会有三种用户选择方式: 1.单次使用允许 2.使用应用期间(长时)允许 3.不允许 当用户选择不允许后&#xff0…

跟着李沐老师学习深度学习(十四)

注意力机制(Attention) 引入 心理学角度 动物需要在复杂环境下有效关注值得注意的点心理学框架:人类根据随意线索和不随意线索选择注意力 注意力机制 之前所涉及到的卷积、全连接、池化层都只考虑不随意线索而注意力机制则显示的考虑随意…

基于YOLO11深度学习的半导体芯片缺陷检测系统【python源码+Pyqt5界面+数据集+训练代码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

Spring Boot3.x集成Flowable7.x(一)Spring Boot集成与设计、部署、发起、完成简单流程

一、Flowable简介 Flowable 是一个轻量级、开源的业务流程管理(BPM)和工作流引擎,旨在帮助开发者和企业实现业务流程的自动化。它支持 BPMN 2.0 标准,适用于各种规模的企业和项目。Flowable 的核心功能包括流程定义、流程执行、任…

网络安全-openssl工具

OpenSSl是一个开源项目,包括密码库和SSL/TLS工具集。它已是在安全领域的事实标准,并且拥有比较长的历史,现在几乎所有的服务器软件和很多客户端都在使用openssl,其中基于命令行的工具是进行加密、证书管理以及测试最常用到的软件。…

【Web开发】PythonAnyWhere免费部署Django项目

PythonAnyWhere免费部署Django项目 文章目录 PythonAnyWhere免费部署Django项目将项目上传到GitHub从GitHub下载Django项目创建Web应用配置静态文件将项目上传到GitHub 打开项目,输入以下命令,生成Django项目依赖包。pip list --format=freeze > requirements.txt打开Git …

视频的分片上传

分片上传需求分析: 项目中很多地方需要上传视频,如果视频很大,上传到服务器需要很多时间 ,这个时候体验就会很差。所以需要前端实现分片上传的功能。 要实现分片上传,需要对视频进行分割,分割成不同的大小…

Moonshot AI 新突破:MoBA 为大语言模型长文本处理提效论文速读

前言 在自然语言处理领域,随着大语言模型(LLMs)不断拓展其阅读、理解和生成文本的能力,如何高效处理长文本成为一项关键挑战。近日,Moonshot AI Research 联合清华大学、浙江大学的研究人员提出了一种创新方法 —— 混…

Deepseek首页实现 HTML

人工智能与未来:机遇与挑战 引言 在过去的几十年里,人工智能(AI)技术取得了突飞猛进的发展。从语音助手到自动驾驶汽车,AI 正在深刻地改变我们的生活方式、工作方式以及社会结构。然而,随着 AI 技术的普及…

VS2022配置FFMPEG库基础教程

1 简介 1.1 起源与发展历程 FFmpeg诞生于2000年,由法国工程师Fabrice Bellard主导开发,其名称源自"Fast Forward MPEG",初期定位为多媒体编解码工具。2004年后由Michael Niedermayer接任维护,逐步发展成为包含音视频采…

kafka基本知识

什么是 Kafka? Apache Kafka 是一个开源的分布式流处理平台,最初由 LinkedIn 开发,后来成为 Apache 软件基金会的一部分。Kafka 主要用于构建实时数据管道和流处理应用程序。它能够高效地处理大量的数据流,广泛应用于日志收集、数…

类型系统下的语言分类与类型系统基础

类型系统是一种根据计算值的种类对程序语法进行分类的方式,目的是自动检查是否有可能导致错误的行为。 —Benjamin.C.Pierce,《类型与编程语言》(2002) 每当谈到编程语言时,人们常常会提到“静态类型”和“动态类型”。…

有没有使用wxpython开发的类似于visio或drawio的开源项目(AI生成)

有没有使用wxpython开发的类似于visio或drawio的开源项目 是的,有一些使用wxPython开发的类似于Microsoft Visio或draw.io(现为diagrams.net)的开源项目。wxPython 是一个跨平台的GUI工具包,它允许Python开发者创建桌面应用程序&…

【MySQL 一 数据库基础】深入解析 MySQL 的索引(3)

索引 索引操作 自动创建 当我们为一张表加主键约束(Primary key),外键约束(Foreign Key),唯一约束(Unique)时,MySQL会为对应的的列自动创建一个索引;如果表不指定任何约束时,MySQL会自动为每一列生成一个索引并用ROW_I…

【C++】优先级队列宝藏岛

> 🍃 本系列为初阶C的内容,如果感兴趣,欢迎订阅🚩 > 🎊个人主页:[小编的个人主页])小编的个人主页 > 🎀 🎉欢迎大家点赞👍收藏⭐文章 > ✌️ 🤞 &#x1…