Open-R1 项目代码文件的详细剖析

news2025/2/20 13:41:27

目录

1. configs.py

功能概述

关键代码与细节

2. evaluate.py

功能概述

关键代码与细节

3. generate.py

功能概述

关键代码与细节

4. grpo.py

功能概述

关键代码与细节

5. rewards.py

功能概述

关键代码与细节

6. sft.py

功能概述

关键代码与细节

安装

训练模型

评估模型

复现DeepSeek的评估结果

MATH-500

GPQA Diamond

数据生成流程


技术实现与细节

以下是对提供的代码文件的详细剖析,结合代码内容和项目背景,分析其功能、实现细节和应用场景。

1. configs.py

功能概述

configs.py 文件定义了两种配置类:GRPOConfigSFTConfig,分别用于 GRPO(Group Relative Policy Optimization)训练和 SFT(Supervised Fine-Tuning)训练。这些配置类继承自 trl(Transformers Reinforcement Learning)库中的基础配置类,并添加了一些额外的参数。

关键代码与细节
  • GRPOConfig 和 SFTConfig

    @dataclass
    class GRPOConfig(trl.GRPOConfig):
        benchmarks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
        )
        callbacks: list[str] = field(
            default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
        )
        system_prompt: Optional[str] = field(
            default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
        )
        hub_model_revision: Optional[str] = field(
            default="main", metadata={"help": "The Hub model branch to push the model to."}
        )
        overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
        push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    • 继承关系GRPOConfigSFTConfig 继承自 trl.GRPOConfigtrl.SFTConfig,扩展了这些类的功能。

    • 新增参数

      • benchmarks:训练后运行的基准测试列表。

      • callbacks:训练过程中运行的回调函数列表。

      • system_prompt:用于基准测试的系统提示。

      • hub_model_revision:推送模型到 Hugging Face Hub 的分支。

      • overwrite_hub_revisionpush_to_hub_revision:控制是否覆盖或推送模型版本。

  • 应用场景

    • 这些配置类用于定义训练和评估的参数,支持用户自定义训练流程中的各种设置,如基准测试、回调函数和模型版本管理。

2. evaluate.py

功能概述

evaluate.py 文件定义了自定义的评估任务,用于在 LightEval 框架中评估模型的性能。这些任务包括数学推理、问答等。

关键代码与细节
  • 评估指标

    latex_gold_metric = multilingual_extractive_match_metric(
        language=Language.ENGLISH,
        fallback_mode="first_match",
        precision=5,
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
        aggregation_function=max,
    )
    • multilingual_extractive_match_metric:一个多语言的提取匹配指标,用于评估模型生成的内容是否与参考答案匹配。

    • gold_extraction_targetpred_extraction_target:定义了从参考答案和模型生成内容中提取信息的配置。

  • 提示函数

    def prompt_fn(line, task_name: str = None):
        return Doc(
            task_name=task_name,
            query=line["problem"],
            choices=[line["solution"]],
            gold_index=0,
        )
    • prompt_fn:生成评估任务的提示,用于数学推理任务。

    • aime_prompt_fngpqa_prompt_fn:分别为 AIME 和 GPQA 任务生成提示。

  • 任务定义

    aime24 = LightevalTaskConfig(
        name="aime24",
        suite=["custom"],
        prompt_function=aime_prompt_fn,
        hf_repo="HuggingFaceH4/aime_2024",
        hf_subset="default",
        hf_avail_splits=["train"],
        evaluation_splits=["train"],
        few_shots_split=None,
        few_shots_select=None,
        generation_size=32768,
        metric=[expr_gold_metric],
        version=1,
    )
    • LightevalTaskConfig:定义了一个评估任务的配置,包括任务名称、提示函数、数据集、评估指标等。

    • TASKS_TABLE:将所有定义的任务存储在一个列表中,便于管理和运行。

  • 应用场景

    • 该文件用于定义和运行模型的评估任务,支持多种数学推理和问答任务,帮助用户评估模型在不同领域的性能。

3. generate.py

功能概述

generate.py 文件定义了一个用于生成数据的管道,使用 distilabel 工具从模型中生成合成数据。

关键代码与细节
  • 构建管道

    def build_distilabel_pipeline(
        model: str,
        base_url: str = "http://localhost:8000/v1",
        prompt_column: Optional[str] = None,
        prompt_template: str = "{{ instruction }}",
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        max_new_tokens: int = 8192,
        num_generations: int = 1,
        input_batch_size: int = 64,
        client_replicas: int = 1,
        timeout: int = 900,
        retries: int = 0,
    ) -> Pipeline:
        ...
    • build_distilabel_pipeline:构建一个 distilabel 管道,用于生成数据。

    • 参数

      • model:用于生成数据的模型名称。

      • base_url:模型服务器的 URL。

      • prompt_columnprompt_template:定义提示的列和模板。

      • temperaturetop_p:生成的温度和核采样参数。

      • max_new_tokens:生成的最大新 token 数量。

      • num_generations:每个输入生成的样本数量。

      • input_batch_size:输入的批量大小。

      • client_replicas:客户端副本数量,用于并行处理。

      • timeoutretries:请求超时和重试次数。

  • 主函数

    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
        ...
        args = parser.parse_args()
        ...
        pipeline = build_distilabel_pipeline(
            model=args.model,
            base_url=args.vllm_server_url,
            prompt_template=args.prompt_template,
            prompt_column=args.prompt_column,
            temperature=args.temperature,
            top_p=args.top_p,
            max_new_tokens=args.max_new_tokens,
            num_generations=args.num_generations,
            input_batch_size=args.input_batch_size,
            client_replicas=args.client_replicas,
            timeout=args.timeout,
            retries=args.retries,
        )
        ...
        distiset = pipeline.run(
            dataset=dataset,
            dataset_batch_size=args.input_batch_size * 1000,
            use_cache=False,
        )
        ...
    • 命令行参数:通过 argparse 解析命令行参数,支持用户自定义生成数据的配置。

    • 数据加载:使用 datasets 加载数据集。

    • 管道运行:运行生成管道,生成合成数据并保存到 Hugging Face Hub。

  • 应用场景

    • 该文件用于生成合成数据,支持用户自定义生成配置,适用于模型训练和数据增强。

4. grpo.py

功能概述

grpo.py 文件实现了 GRPO(Group Relative Policy Optimization)训练流程,用于优化模型的策略。

关键代码与细节
  • GRPOScriptArguments

    @dataclass
    class GRPOScriptArguments(ScriptArguments):
        reward_funcs: list[str] = field(
            default_factory=lambda: ["accuracy", "format"],
            metadata={
                "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
            },
        )
        cosine_min_value_wrong: float = field(
            default=0.0,
            metadata={"help": "Minimum reward for wrong answers"},
        )
        ...
    • reward_funcs:定义奖励函数列表,支持多种奖励函数,如准确率、格式、推理步骤、余弦缩放和重复惩罚。

    • cosine_min_value_wrong 等参数:定义余弦缩放奖励的参数。

  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
        ...
        trainer = GRPOTrainer(
            model=model_args.model_name_or_path,
            reward_funcs=reward_funcs,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • 奖励函数:根据用户指定的奖励函数,加载相应的函数。

    • GRPOTrainer:初始化 GRPO 训练器,设置模型、奖励函数、训练参数等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 GRPO 训练,支持多种奖励函数和训练配置,适用于优化模型的策略。

5. rewards.py

功能概述

rewards.py 文件定义了多种奖励函数,用于在 GRPO 训练中评估模型生成的内容。

关键代码与细节
  • 奖励函数

    def accuracy_reward(completions, solution, **kwargs):
        ...
        reward = float(verify(answer_parsed, gold_parsed))
        ...
    • accuracy_reward:检查模型生成的内容是否与参考答案一致,返回 1 或 0。

    • format_reward:检查生成内容是否符合特定格式。

    • reasoning_steps_reward:检查生成内容是否包含清晰的推理步骤。

    • cosine_scaled_reward:基于生成内容长度的余弦缩放奖励。

    • repetition_penalty_reward:基于重复 n-gram 的惩罚奖励。

  • 应用场景

    • 这些奖励函数用于 GRPO 训练,帮助模型生成更准确、更符合格式、更具推理性和更少重复的内容。

6. sft.py

功能概述

sft.py 文件实现了 SFT(Supervised Fine-Tuning)训练流程,用于对模型进行有监督微调。

关键代码与细节
  • 主函数

    def main(script_args, training_args, model_args):
        ...
        dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
        ...
        trainer = SFTTrainer(
            model=model_args.model_name_or_path,
            args=training_args,
            train_dataset=dataset[script_args.dataset_train_split],
            eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
            processing_class=tokenizer,
            peft_config=get_peft_config(model_args),
            callbacks=get_callbacks(training_args, model_args),
        )
        ...
    • 数据加载:使用 datasets 加载训练和评估数据集。

    • SFTTrainer:初始化 SFT 训练器,设置模型、训练参数、分词器等。

    • 训练循环:运行训练循环,支持从断点恢复训练。

  • 应用场景

    • 该文件用于 SFT 训练,支持多种训练配置和回调函数,适用于对模型进行有监督微调。

功能模块

  • 模型训练

    • SFT(Supervised Fine-Tuning):对预训练模型进行微调,使其更好地适应特定任务。例如,在指令微调中,将小样本数据集用于微调,使模型生成更符合人类常识的对话内容。

    • GRPO(Group-Relative Policy Optimization):使用 GRPO 方法对模型进行 RL(强化学习)培训。该方法基于代理与环境之间的交互,通过最大化累积奖励信号来训练策略模型。

  • 模型评估

    • 使用 lighteval 对模型进行评估,lighteval 是一种轻量级的评估工具,支持多种评估任务。例如,在 AIME 2024、MATH-500 和 GPQA Diamond 等任务上对模型进行测试,得到准确率等评估指标,以评估模型的性能。

  • 数据生成

    • 从 smol 蒸馏 R1 模型生成数据:使用轻量级的蒸馏 R1 模型生成数据。该模块通过 Distilabel 来生成合成数据,为模型训练提供更多样化的数据。

    • 从 DeepSeek-R1 生成数据:使用更大的 DeepSeek-R1 模型生成数据。这需要更多的计算资源,但可以生成更高质量的合成数据,以支持更复杂的模型训练和测试。

安装

[!CAUTION]
相关库依赖于CUDA 12.4。如果您看到与段错误相关的错误,请使用nvcc --version仔细检查您的系统正在运行的CUDA版本。

要运行这个项目中的代码,首先,使用例如uv创建一个Python虚拟环境。
要安装uv,请参考UV安装指南。

uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip --link-mode=copy

接下来,安装vLLM:

uv pip install vllm==0.7.1 --link-mode=copy

这也会安装PyTorch v2.5.1,使用这个版本非常重要,因为vLLM的二进制文件是针对该版本编译的。然后,您可以通过pip install -e .[LIST OF MODES]安装特定用例的其余依赖项。对于大多数贡献者,我们建议:

GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" --link-mode=copy

接下来,按如下方式登录您的Hugging Face和Weights and Biases账户:

 

训练模型

我们支持使用数据并行分布式训练(DDP)或DeepSpeed(ZeRO-2和ZeRO-3)来训练模型。例如,要在从DeepSeek-R1提炼的带有推理痕迹的数据集(如Bespoke-Stratos-17k)上运行监督微调(SFT),请运行以下命令:

# 通过命令行进行训练
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
    --learning_rate 2.0e-5 \
    --num_train_epochs 1 \
    --packing \
    --max_seq_length 4096 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --bf16 \
    --output_dir data/Qwen2.5-1.5B-Open-R1-Distill

# 通过YAML配置文件进行训练
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
    --config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml

目前,支持以下任务:

评估模型

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8

要使用张量并行:

make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8

复现DeepSeek的评估结果

MATH-500

我们能够在约1 - 3个标准差范围内复现DeepSeek在MATH-500基准测试上报告的结果:

模型MATH-500(🤗 LightEval)MATH-500(DeepSeek报告值)
DeepSeek-R1-Distill-Qwen-1.5B81.283.9
DeepSeek-R1-Distill-Qwen-7B91.892.8
DeepSeek-R1-Distill-Qwen-14B94.293.9
DeepSeek-R1-Distill-Qwen-32B95.094.3
DeepSeek-R1-Distill-Llama-8B85.489.1
DeepSeek-R1-Distill-Llama-70B93.494.5

要复现这些结果,请使用以下命令:

NUM_GPUS=1 # 对于32B和70B模型,设置为8
MODEL=deepseek-ai/{model_name}
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS"
OUTPUT_DIR=data/evals/$MODEL

lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR

GPQA Diamond

lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
    --custom-tasks src/open_r1/evaluate.py \
    --use-chat-template \
    --output-dir $OUTPUT_DIR
python scripts/run_benchmarks.py --model-id={model_id}  --benchmarks gpqa
数据生成流程
  • 小模型蒸馏数据生成

    • 使用轻量级蒸馏 R1 模型生成数据。通过 Distilabel 工具,从预定义的提示模板和数据集出发,生成合成数据。

    • 例如,使用 DeepSeek-R1 的蒸馏 Qwen-7B 模型生成数学推理数据,将数据保存到远程数据集中,并可通过华为 MindSpore 加载该数据集以用于训练。

  • 大模型数据生成

    • 使用更大的 DeepSeek-R1 模型生成数据,需要更多的计算资源。通过 Slurm 脚本(如 slurm/generate.slurm)在集群上运行生成任务,可以高效地生成大规模合成数据。

    • 生成过程中,可以通过设置温度(如 0.6)、提示列(如 “problem”)等参数来控制生成数据的质量和多样性。

更多可参照GitHub - huggingface/open-r1: Fully open reproduction of DeepSeek-R1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2297874.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1)

Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1) import android.graphics.Bitmap import android.graphics.BitmapFactory import android.graphics.HardwareRenderer import android.graphics.PixelFormat import android.graphic…

区块链+隐私计算:长安链多方计算合约标准协议(CMMPC-1)发布

建设背景 长安链与隐私计算的深度融合是构建分布式数据与价值流通网络的关键基石,可以在有效连接多元参与主体的同时确保数据的分布式、可追溯、可计算,以及隐私性与安全性。在长安链与隐私计算的融合实践中,开源社区提炼并抽象出多方计算场…

#渗透测试#批量漏洞挖掘#Crocus系统—Download 文件读取

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

LabVIEW用户界面设计原则

在LabVIEW开发中,用户界面(UI)设计不仅仅是为了美观,它直接关系到用户的操作效率和体验。一个直观、简洁、易于使用的界面能够大大提升软件的可用性,尤其是在复杂的实验或工业应用中。设计良好的UI能够减少操作错误&am…

MySQL8.0 innodb Cluster 高可用集群部署(MySQL、MySQL Shell、MySQL Router安装)

简介 MySQL InnoDB集群(Cluster)提供了一个集成的,本地的,HA解决方案。Mysq Innodb Cluster是利用组复制的 pxos 协议,保障数据一致性,组复制支持单主模式和多主模式。 InnoDB Cluster组件: …

Effective Objective-C 2.0 读书笔记——内存管理(上)

Effective Objective-C 2.0 读书笔记——内存管理(上) 文章目录 Effective Objective-C 2.0 读书笔记——内存管理(上)引用计数属性存取方法中的内存管理autorelease保留环 ARCARC必须遵循的方法命名原则ARC 的自动优化&#xff1…

软件测试覆盖率详解

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 一、覆盖率概念 覆盖率是用来度量测试完整性的一个手段,是测试技术有效性的一个度量。分为:白盒覆盖、灰盒覆盖和黑盒覆盖;测…

控制玉米株高基因 PHR1 的基因克隆

https://zwxb.chinacrops.org/CN/10.3724/SP.J.1006.2024.33011

windows10本地的JMeter+Influxdb+Grafana压测性能测试,【亲测,避坑】

一、环境,以下软件需要解压、安装到电脑上。 windows10 apache-jmeter-5.6.3 jdk-17.0.13 influxdb2-2.7.11 grafana-enterprise-11.5.1二、配置Influxdb,安装完默认连接http://localhost:8086/。打开连接,配置如下。 开启Influxdb&#xf…

如何在 Java 后端接口中提取请求头中的 Cookie 和 Token

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] 📱个人微信&a…

【Python网络爬虫】爬取网站图片实战

【Python网络爬虫】爬取网站图片实战 Scrapying Images on Website in Action By Jackson@ML *声明:本文简要介绍如何利用Python爬取网站数据图片,仅供学习交流。如涉及敏感图片或者违禁事项,请注意规避;笔者不承担相关责任。 1. 创建Python项目 1) 获取和安装最新版…

SAP ABAP VA05增强

SE18 输入增强的BADI名称:BADI_SDOC_WRAPPER 进入后,点击Interface。 进入后,点击显示对象清单。 双击增强类,下面有之前做好的增强类,没有的可以自己创建一个。 IF_BADI_SDOC_WRAPPER~ADAPT_RESULT_COMP 代码 METHOD if_badi_sdoc_wrapper~adapt_result_comp."…

八大排序——简单选择排序

目录 1.1基本操作: 1.2动态图: 1.3代码: 代码解释 1. main 方法 2. selectSort 方法 示例运行过程 初始数组 每轮排序后的数组 最终排序结果 代码总结 1.1基本操作: 选择排序(select sorting)也…

【清晰教程】本地部署DeepSeek-r1模型

【清晰教程】通过Docker为本地DeepSeek-r1部署WebUI界面-CSDN博客 目录 Ollama 安装Ollama DeepSeek-r1模型 安装DeepSeek-r1模型 Ollama Ollama 是一个开源工具,专注于简化大型语言模型(LLMs)的本地部署和管理。它允许用户在本地计算机…

【matlab优化算法-17期】基于DBO算法的微电网多目标优化调度

基于蜣螂DBO算法的微电网多目标优化调度 一、前言 微电网作为智能电网的重要组成部分,其优化调度对于降低能耗、减少环境污染具有重要意义。本文介绍了一个基于Dung Beetle Optimizer(DBO)算法的微电网多目标优化调度项目,旨在通…

如何使用qt开发一个xml发票浏览器,实现按发票样式显示

使用Qt开发一个按发票样式显示的XML发票浏览器,如下图所示样式: 一、需求: 1、按税务发票样式显示。 2、拖入即可显示。 3、正确解析xml文件。 二、实现 可以按照以下步骤进行: 1. 创建Qt项目 打开Qt Creator,创…

解析 JavaScript 面试题:`index | 0` 确保数组索引为整数

文章目录 一、JavaScript 中的数字类型二、按位或运算符 | 的作用(一)对于整数(二)对于小数(三)对于非数字值 三、用于数组索引的意义 在 JavaScript 面试中,常常会涉及到一些看似简单却蕴含着深…

46 map与set

目录 一、序列式容器和关联式容器 二、set系列的使用 (一)set和mutilset参考文档链接 (二)set类模板介绍 1、set类声明 2、set的构造和迭代器 3、set的增删查 (三)multiset类模板 1、multiset和se…

RAGFlow和Dify对比

‌ RAGFlow和Dify都是基于大语言模型(LLM)的应用开发平台,具有相似的功能和应用场景,但它们在技术架构、部署要求和用户体验上存在一些差异。‌‌ RAGFlow和Dify对比 2025-02-13 22.08 RAGFlow‌ ‌技术栈‌:RAGFlow…

Dart 3.5语法 14-16

017自定代码段让变量有默认值 List下标访问和2种for循环遍历_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1RZ421p7BL?spm_id_from333.788.videopod.episodes&vd_source68aea1c1d33b45ca3285a52d4ef7365f&p42原作者链接,此为修订补充版本 014main…