基于Contiue来阅读open-r1中的GRPO训练代码

news2025/4/7 19:22:02

原创 快乐王子HP 快乐王子AI说 2025年04月03日 23:54 广东

前面安装了vscode[1]同时也安装了Coninue的相关插件[2],现在想用它们来阅读一下open-r1项目的代码[3]。

首先,从启动训练开始(以GRPO为例子)

第一步,使用TRL的vLLM后端

CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

第二步,启动GRPO

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \     accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7 \     src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml

查看vllm的服务启动帮助文档

usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST] [--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]                       [--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]

关于zero2.yaml文件

(https://github.com/huggingface/open-r1/blob/main/recipes/accelerate_configs/zero2.yaml)

0

    1.核心配置:    - 使用 DeepSpeed 的 Zero Stage 2 优化 (zero_stage: 2)    - 混合精度训练采用 bf16 (mixed_precision: bf16)    - 单机 8 GPU 训练 (num_machines: 1, num_processes: 8)2.Zero Stage 2 特点:    - 优化器状态分区,减少内存占用    - 没有启用参数或优化器卸载 (offload_optimizer_device: none, offload_param_device: none)    - 比 Stage 3 内存效率稍低,但通信开销更小3.硬件配置:    - 纯 GPU 训练 (use_cpu: false)    - 不涉及 TPU (tpu_* 相关配置均为 false)    - 适合具有 8 个 GPU 的单个节点4.使用场景:    - 中等规模模型训练    - 当 GPU 内存足够容纳模型参数和激活值时    - 需要比 Zero Stage 1 更高的内存效率,但不想承受 Stage 3 的通信开销5.性能考虑:    - bf16 混合精度可以在支持它的硬件上提供良好的训练速度和内存效率    - 8 个 GPU 的配置适合大多数单节点服务器这个配置文件适合在单个多 GPU 节点上训练中等规模模型,在内存效率和通信开销之间取得平衡。

    recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml文件分析

    (https://github.com/huggingface/open-r1/blob/main/recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml)

    1.模型架构:  - 基于1.5B参数的蒸馏版Qwen模型  - 使用Flash Attention 2优化注意力计算  - bfloat16混合精度训练2.训练策略:  - 采用GRPO(可能是一种强化学习优化算法)训练方法  - 结合三种奖励函数:准确性、格式正确性和标签计数  - 使用vLLM加速推理过程3.数据处理:  - 专门设计的复杂对话模板  - 数学领域专用数据集(OpenR1-Math-220k)  - 要求模型以和标签分步输出4.资源利用:  - 梯度检查点和梯度累积优化显存使用  - 适中的batch size(16)和上下文长度(512/2048)5.监控与部署:  - 完整的训练日志记录(W&B)  - 模型自动推送至HuggingFace Hub  - 严格的模型保存策略

    grpo.py文件

    (https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py)

    ```mermaidgraph TD    A[开始] --> B[设置随机种子]    B --> C[配置日志系统]    C --> D[检查检查点]    D --> E[初始化WandB]    E --> F[加载数据集]    F --> G[加载tokenizer]    G --> H[获取奖励函数]    H --> I[格式化对话数据]    I --> J[初始化模型参数]    J --> K[创建GRPOTrainer]    K --> L{是否有检查点?}    L -- 是 --> M[从检查点恢复训练]    L -- 否 --> N[开始新训练]    M --> O[训练模型]    N --> O    O --> P[保存模型和指标]    P --> Q{是否评估?}    Q -- 是 --> R[执行评估]    Q -- 否 --> S    R --> S[保存评估结果]    S --> T{是否推送至Hub?}    T -- 是 --> U[推送模型]    T -- 否 --> V[结束]    U --> V```

    rewards.py

    (https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)

    0

    结合医学场景来探索

    0

      def medical_accuracy_reward(response: str, golden_answer: str) -> float:    """评估医学准确性,需要与标准医学答案对比"""    # 这里可以集成医学知识库或NLP模型进行专业评估    medical_terms_score = calculate_medical_terms_match(response, golden_answer)    treatment_score = evaluate_treatment_correctness(response, golden_answer)    return 0.6 * medical_terms_score + 0.4 * treatment_scoredef safety_reward(response: str) -> float:    """安全性评估:检查是否有危险建议"""    dangerous_keywords = ["自行停药", "未经医生", "高剂量", "随意服用"]    for keyword in dangerous_keywords:        if keyword in response:            return 0.0  # 发现危险建议直接0分    return 1.0def citation_reward(response: str) -> float:    """参考文献引用评估"""    citation_formats = ["[1]", "(Smith et al., 2020)", "根据最新指南"]    return 1.0 if any(fmt in response for fmt in citation_formats) else 0.5def patient_language_reward(response: str) -> float:    """患者友好语言评估"""    complex_terms = ["病理学", "分子机制", "流行病学"]    simplified_explanations = ["简单说", "通俗理解", "换句话说"]        complex_count = sum(term in response for term in complex_terms)    simple_count = sum(term in response for term in simplified_explanations)        if complex_count == 0:         return 1.0    return simple_count / (complex_count + 1)  # 确保至少解释了部分复杂术语def empathy_reward(response: str) -> float:    """同理心评估"""    empathy_keywords = ["理解您", "不用担心", "建议咨询", "我们会帮助"]    return min(1.0, 0.2 * sum(kw in response for kw in empathy_keywords))

      0

      参考:

      [1]vscode安装:https://mp.weixin.qq.com/s/FvqSUrJFFXSVxFpZ6Q2-jg

      [2]vscode上安装Coninue的相关插件:

      https://mp.weixin.qq.com/s/cD-BHkCWQxfeedL3eboaBA

      [3]open-r1项目:https://mp.weixin.qq.com/s/BDDUe1RyIVutucUVA9Yuzg,https://github.com/huggingface/open-r1]

      本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328283.html

      如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

      相关文章

      51c嵌入式~单片机~合集7~※

      我自己的原文哦~ https://blog.51cto.com/whaosoft/13692314 一、芯片工作的心脏--晶振 在振荡器中采用一个特殊的元件——石英晶体,它可以产生频率高度稳定的交流信号,这种采用石英晶体的振荡器称为晶体振荡器,简称晶振。 制作方法 …

      英菲克(INPHIC)A9无线蓝牙鼠标 链接电脑的方式

      英菲克(INPHIC)A9鼠标链接至电脑时,要长按住“模式切换MODE”按钮5秒左右的时间,此时模式指示灯变成蓝色,并且闪烁。 这时使用电脑的蓝牙设置中,“添加设备”,会出现BT4.0 Mouse提示&#xff0…

      lua表table和JSON字符串互转

      --print("local ssxc{\n"..string.gsub(str,":","").."\n}") Utils {} ---------------------------------------------------------------------------------- -- Lua-Table 与 string 转换 local function value2string(value, isA…

      【每日一个知识点】分布式数据湖与实时计算

      在现代数据架构中,分布式数据湖(Distributed Data Lake) 结合 实时计算(Real-time Computing) 已成为大数据处理的核心模式。数据湖用于存储海量的结构化和非结构化数据,而实时计算则确保数据能够被迅速处理…

      c语言数据结构--------拓扑排序和逆拓扑排序(Kahn算法和DFS算法实现)

      #include <stdio.h> #include <string.h> #include <stdbool.h> #include <stdlib.h>//使用卡恩算法(Kahn)和深度优先算法(DFS)实现//拓扑排序和逆拓扑排序//拓扑排序和逆拓扑排序顶点顺序相反//图&#xff0c;邻接矩阵存储 #define MaxVertexNum 100 …

      谷粒微服务高级篇学习笔记整理---nginx搭建正反向代理

      正向与反向代理 **正向代理:**客户端向代理服务器发请求并指定目标服务器,代理向目标转交请求并将获得的内容转给客户端。 反向代理:用户直接访问反向代理服务器就可以获得目标服务器的资源。反向代理服务器统一了访问入口。 给首页配置反向代理 修改windows的hosts文件配…

      2.pycharm保姆级安装教程

      一、pycharm安装 1.官网上下载好好软&#xff0c;双击打开 2.下一步 3.修改路径地址 (默认也可以) 4.打勾 5.安装 不用重启电脑 二、添加解释器 1.双击软件&#xff0c;打开 2.projects – new project 3.指定项目名字&#xff0c;项目保存地址&#xff0c;解释器 4.右击 – …

      【SQL】取消sql某一列的唯一值key值的方法

      在插入数据到sql时&#xff0c;遇到了这个问题&#xff1a; Duplicate entry ‘XXX’ for key 起因是&#xff1a; 我之前设计表的时候&#xff0c;手动给product_title 这个列加了一个key&#xff0c; key 是这个字段的唯一键约束&#xff0c;就不能重复在这一列存入重复的数…

      数据库--SQL

      SQL&#xff1a;Structured Query Language&#xff0c;结构化查询语言 SQL是用于管理关系型数据库并对其中的数据进行一系列操作&#xff08;包括数据插入、查询、修改删除&#xff09;的一种语言 分类&#xff1a;数据定义语言DDL、数据操纵语言DML、数据控制语言DCL、事务处…

      SQL语句(一)—— DDL

      目录 一、SQL 基础知识 &#xff08;一&#xff09;SQL 通用语法 &#xff08;二&#xff09;SQL 分类 二、DDL —— 数据库操作 1、查询所有数据库 2、查询当前数据库 3、创建数据库 4、删除数据库 5、切换数据库 三、DDL —— 表操作 &#xff08;一&#xff09;查…

      Husky目标跟踪

      1.0设备清单 幻影峡谷、适配器 摄像头及数据线、显卡欺骗器 外接屏幕、键盘鼠标 Husky底盘、便携显示屏、键盘鼠标 移动电源 1.1连线 插排——移动电源幻影峡谷——适配器——插排摄像头——幻影峡谷&#xff08;摄像头固定在机械臂前方的底盘上&#xff09;键盘鼠标显示器…

      Python----机器学习(线性回归:自求导的方法实现)

      一、线性回归方程 目标&#xff1a; 线性回归的目标是找到最佳的系数来使模型与观察到的数据尽可能拟合。 应用&#xff1a; 预测&#xff1a;给定自变量的值&#xff0c;预测因变量的值。 回归分析&#xff1a;确定自变量对因变量的影响程度 线性回归是统计学和机器学习中最简…

      Springcache+xxljob实现定时刷新缓存

      目录 SpringCache详解 SpringCache概述 核心原理 接口抽象与多态 AOP动态代理 核心注解以及使用 公共属性 cacheNames KeyGenerator&#xff1a;key生成器 key condition&#xff1a;缓存的条件&#xff0c;对入参进行判断 注解 xxl-job详解 SpringcacheRedis实现…

      vue2拖拉拽做个模拟公式工具

      1. 成图 2. 介绍 就是简单拖拉拽来做个规则运算器&#xff0c;具体运算规则、校验规则自己加。 3. 代码 HTML代码 <template><div class"red-cont"><div class"red-top"><divclass"red-top-left"><div class&quo…

      Windows查重工具,强烈推荐大家收藏!

      我大家在用电脑的时候&#xff0c;是不是发现用得越久&#xff0c;电脑里的软件和文件就越多&#xff1f; 今天我给大家带来的这两款重复文件查找神器&#xff0c;简直就是电脑里的“清洁小能手”&#xff0c;能帮你把那些重复的文件和文件夹找出来。 Easy DupLicate Finder 重…

      使用python完成手写数字识别

      入门图像识别的第一个案例,看到好多小伙伴分享,也把自己当初的思路捋捋,写成一篇博客,作为记录和分享,也欢迎各位交流讨论。 实现思路 数据集:MNIST(包含60,000个训练样本和10,000个测试样本) 深度学习框架:Keras(基于TensorFlow) 模型架构:卷积神经网络(CNN) 实…

      OpenLayers:如何控制Overlay的层级?

      我最近在使用Overlay的时候遇到了一个问题&#xff0c;我向地图中添加了两种不同的Overlay&#xff08;下图中的蓝色标牌和粉色标牌&#xff09;&#xff0c;我希望粉色标牌可以显示在最上层&#xff0c;可偏偏蓝色标牌却将其遮挡住了。于是我对Overlay的层级开始起了兴趣&…

      《Golang高性能网络编程:构建低延迟服务器应用》

      在本文中&#xff0c;我们将深入探讨Golang高性能网络编程&#xff0c;帮助您构建低延迟服务器应用。我们将介绍Golang的网络编程特性、优化技巧和实际案例&#xff0c;让您更好地理解和应用Golang在网络编程领域的优势。 高性能网络编程简介 什么是Golang高性能网络编程 高性能…

      数据结构C语言练习(设计循环队列)

      一、循环队列简介 循环队列是一种线性数据结构&#xff0c;基于 FIFO&#xff08;先进先出&#xff09;原则&#xff0c;将队尾连接到队首形成循环。其核心优势是能复用队列之前用过的空间&#xff0c;避免普通队列 “假溢出” 问题。实现时&#xff0c;通常申请 k1 大小的数组…

      vscode代码片段的设置与使用

      在 Visual Studio Code (VS Code) 中&#xff0c;可以通过自定义**代码片段&#xff08;Snippets&#xff09;**快速插入常用代码模板。以下是详细设置步骤&#xff1a; 步骤 1&#xff1a;打开代码片段设置 按下快捷键 Ctrl Shift P&#xff08;Windows/Linux&#xff09;或…