基于trl复现DeepSeek-R1的GRPO训练过程

news2025/4/21 10:53:09

1. 引入

huggingface开发了强化学习训练Transformer的库trl(参考3),借助这个trl,可以用来做GRPO的强化学习训练。魔搭ModelScope社区的文章(参考2)给出了基于Qwen基座模型Qwen2.5-0.5B-Instruct,复现DeepSeek-R1训练过程的描述,参考1还给出了源码。
本文就对这个代码进行复现,并记录重点步骤与过程,以期更好的理解GRPO训练过程、数据格式、训练前后的不同效果。

2. 关键步骤

本文仅讲解关键步骤,完整步骤/代码可看参考2与参考1。

2.1 数据处理

(1)原始数据
数据集选用的是OpenAI的gsm8k(参考4),数据是多组question和answer组成的,原始数据格式如下所示:
在这里插入图片描述
可以将answer看过有两部分组成:推理过程,标准答案(数字)。两部分用####隔开。

(2)处理方法:加上prompt

# 系统提示词:规定回复格式
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 最终每条数据格式
d = { # type: ignore
	'prompt': [
		{'role': 'system', 'content': SYSTEM_PROMPT},
		{'role': 'user', 'content': x['question']}
	],
	'answer': extract_answer(x['answer'])# 标准答案(数字)
}

最终处理后数据格式入上面代码中的d所示,处理完后,一共7473组数据。如下是1条处理完成后的数据示例:

{
    "prompt": [
        {
            "role": "system",
            "content": "\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n"
        },
        {
            "role": "user",
            "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
        }
    ],
    "answer": "72"
}

可从prompt中看到,最终模型需要回复的内容中有两个字段,一个是推理过程(<reasoning>字段),另一个是最终答案(<answer>字段)。

2.2 自定义Reward函数

强化学习中比较重要的是reward函数,这里定义了5个reward函数:

(1)correctness_reward_func
如果模型输出答案中的<answer>字段(数字)等于标准答案,则得2分,否则0分。

(2)int_reward_func
如果模型输出答案中的<answer>字段是纯数字类型,则得0.5分,否则得0分。

(3)strict_format_reward_func
如果模型输出答案中的,既有<reasoning>字段也有<answer>字段,则得0.5分,否则0分。
strict表示这里做正则匹配要求是更严格的,必须以<reasoning>开头,以<answer>结束。回复内容收尾必须是这两个字段。具体正则为:pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"。

(4)soft_format_reward_func
soft表示这里的格式要求匹配正则不是那么严格。具体正则为pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"

(5)xmlcount_reward_func
根据如下逻辑对<reasoning>字段<answer>字段做计数后打分,这也是一种对输出格式的打分。

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1: # 只出现1次<reasoning>
        count += 0.125
    if text.count("\n</reasoning>\n") == 1: # 只出现1次</reasoning>
        count += 0.125
    if text.count("\n<answer>\n") == 1: # 只出现1次<answer>
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1: # 只出现1次</answer>
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

2.3 GRPO训练

训练时,定义好超参数,按如下配置即可训练

trainer = GRPOTrainer(
    model=model, # 模型基座
    processing_class=tokenizer, # 分词器
    reward_funcs=[ # reward函数
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args, # 超参数
    train_dataset=dataset, # 数据
)
trainer.train()

如下为训练中间过程的部分输出(能看出用1张A800训练一次耗时1:41:31):

-------------------- Question:
Billy wants to watch something fun on YouTube but doesn't know what to watch.  He has the website generate 15 suggestions but, after watching each in one, he doesn't like any of them.  Billy's very picky so he does this a total of 5 times before he finally finds a video he thinks is worth watching.  He then picks the 5th show suggested on the final suggestion list.  What number of videos does Billy watch?
Answer:
65
Response:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Extracted:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-------------------- Question:
Xander read 20% of his 500-page book in one hour.  The next night he read another 20% of the book.  On the third night, he read 30% of his book.  How many pages does he have left to read?
Answer:
150
Response:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Extracted:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 4.918502993675464e-06, 'rewards/xmlcount_reward_func': 0.0, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0, 'rewards/correctness_reward_func': 0.0, 'reward': 0.0, 'reward_std': 0.0, 'completion_length': 200.0, 'kl': nan, 'epoch': 0.17}
 17%|███████████████████                                                                                           | 324/1868 [17:38<1:22:29,  3.21s/it]-------------------- Question:
Jordan gave Danielle two dozen roses and a box of chocolates as a birthday day gift.  Later that day, after Jordan left, Danielle traded the box of chocolates for another dozen roses.  Overnight, half of the roses wilted, and  Danielle decided to throw the wilted flowers away.  On the second day, another half of the remaining flowers wilted, and she threw the wilted ones away.  How many unwilted flowers remained?
Answer:
9
Response:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Extracted:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
...
...
-------------------- Question:
It's payday but Jebb has to pay 10% for the tax. If his pay is $650, how much is his take-home pay?
Answer:
585
Response:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Extracted:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-------------------- Question:
In a car racing competition, Skye drove a 6-kilometer track. For the first 3 kilometers, his speed was 150 kilometers per hour. For the next 2 kilometers, his speed was 50 kilometers per hour more. For the remaining 1 kilometer, his speed was twice as fast as his speed on the first 3 kilometers. What is Skye's average speed for the entire race?
Answer:
180
Response:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Extracted:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0, 'rewards/xmlcount_reward_func': 0.0, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0, 'rewards/correctness_reward_func': 0.0, 'reward': 0.0, 'reward_std': 0.0, 'completion_length': 200.0, 'kl': nan, 'epoch': 1.0}
{'train_runtime': 6091.2875, 'train_samples_per_second': 1.227, 'train_steps_per_second': 0.307, 'train_loss': 0.0, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1868/1868 [1:41:31<00:00,  3.26s/it]

2.4 基本验证

对GRPO训练前后的大模型输入一个问题:

prompt = "Xiao Ming bought 4 apples, ate 1, and gave 1 to his sister. How many apples were left?"

(1)GRPO训练前模型的输出:

To determine how many kilometers Natalia rode in total, we need to calculate her distances for each day and then sum them up.

1. **Monday**: Natalia rode 40 kilometers.
2. **Tuesday**: Natalia rode 50 kilometers.
3. **Wednesday**: She rode 50% fewer kilometers than on Tuesday. This means she rode \(50 - (0.5 \times 50) = 50 - 25 = 25\) kilometers.
4. **Thursday**: She rode as many kilometers as the sum of the kilometers from Monday and Wednesday. So, she rode \(40 + 25 = 65\) kilometers.

Now, let's add up the distances from all four days:

- Monday: 40 kilometers
- Tuesday: 50 kilometers
- Wednesday: 25 kilometers
- Thursday: 65 kilometers

Total distance = 40 + 50 + 25 + 65 = 180 kilometers

Therefore, the total number of kilometers Natalia rode is \(\boxed{180}\).

(2)GRPO训练后模型的输出:

<reasoning> 
Initially, Xiao Ming had 4 apples. After eating 1 apple, he was left with 4 - 1 = 3 apples. Then, he gave 1 apple to his sister, leaving him with 3 - 1 = 2 apples. 
</reasoning>
<answer> 
2 
</answer>

从输出内容,输出格式,最终结果可以看到GRPO训练前后模型的不同表现。

3. 总结

本文讲解了借助trl对Qwen2.5-0.5B-Instruct做GRPO的强化学习训练的过程,包括数据、数据处理、reward函数定义、训练前后的模型输出差异。

4. 参考:

  1. https://modelscope.cn/notebook/share/ipynb/c4d8363a/Qwen-GRPO.ipynb
  2. https://mp.weixin.qq.com/s/EkFRLMwHMdLvyra-ql-1QQ
  3. https://github.com/huggingface/trl
  4. https://modelscope.cn/datasets/modelscope/gsm8k/dataPeview

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2308618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常用的 pip 命令

pip 是 Python 的包管理工具&#xff0c;可用于安装、卸载、更新和管理 Python 包。以下是一些常用的 pip 命令&#xff1a; 1. 安装包 安装最新版本的包 pip install package_namepackage_name 是你要安装的 Python 包的名称&#xff0c;例如 pip install requests 可以安装…

基于C#的CANoe CLR Adapter开发指南

一、引言 CANoe 是一款广泛应用于汽车电子开发和测试的工具&#xff0c;它支持多种编程接口&#xff0c;方便开发者进行自定义扩展。CANoe CLR Adapter 允许我们使用 C# 语言与 CANoe 进行交互&#xff0c;充分利用 C# 的强大功能和丰富的类库。本文将详细介绍如何基于 C# 进行…

eMMC安全简介

1. 引言 术语“信息安全”涵盖多种不同的设计特性。一般而言&#xff0c; 信息安全是指通过实践防止信息遭受未经授权的访问、使用、披露、中断、篡改、检查、记录或销毁。 信息安全的三大核心目标为 机密性&#xff08;Confidentiality&#xff09;、完整性&#xff08;Integr…

从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(六) 导航栏 和 个人信息设置

1.导航栏&#xff08;navbar&#xff09; 在components下面 创建NavBar.jsx import { MessageSquare,Settings,User,LogOut} from "lucide-react" import {Link} from "react-router-dom" import { useAuthStore } from "../store/useAuthStore&qu…

数据库MySQL,在终端输入后,提示不是内部命令等

【解决问题】mysql提示不是内部或外部命令&#xff0c;也不是可运行的程序 一般这种问题是因为没有在系统变量里面添加MySQL的可执行路径 以下是添加可执行路径的方法&#xff1a; 第一步&#xff1a;winR输入services.msc 然后找到MySQL&#xff0c;右击属性并复制MySQL的可执…

C语言生成二维码

1. 效果 2. 需要的代码&#xff08;QRCode&#xff09; qrcode.cqrcode.h 代码 3. 代码 #include <stdio.h> #include "qrcode.h"int main() {//拓展编码SetConsoleOutputCP(437);QRCode qrcode;uint8_t qrcodeBytes[qrcode_getBufferSize(3)];qrcode_initT…

[Web 安全] PHP 反序列化漏洞 —— POP 链构造思路

关注这个专栏的其他相关笔记&#xff1a;[Web 安全] 反序列化漏洞 - 学习笔记-CSDN博客 0x01&#xff1a;什么是 POP 链&#xff1f; POP 链&#xff08;Payload On Purpose Chain&#xff09;是一种利用 PHP 中的魔法方法进行多次跳转以获取敏感数据的技术。它通常出现在 CTF…

深度探索推理新境界:DeepSeek-R1如何用“自学”让AI更聪明?

今天我们要聊从1月初火到现在的AI模型——DeepSeek-R1。它就像一个“自学成材的学霸”&#xff0c;不用老师手把手教&#xff0c;就能在数学、编程、逻辑推理等领域大显身手&#xff01;仔细阅读了深度求索发表的R1论文&#xff0c;发现它不仅揭秘了它的成长秘籍&#xff0c;还…

2025春新生培训数据结构(树,图)

教学目标&#xff1a; 1&#xff0c;清楚什么是树和图&#xff0c;了解基本概念&#xff0c;并且理解其应用场景 2&#xff0c;掌握一种建图&#xff08;树&#xff09;方法 3&#xff0c;掌握图的dfs和树的前中后序遍历 例题与习题 2025NENU新生培训&#xff08;树&#…

keil主题(vscode风格)

#修改global.prop文件&#xff0c;重新打开keil即可 # Keil uVision Global Properties File # This file is used to customize the appearance of the editor# Editor Font editor.font.nameConsolas editor.font.size10 editor.font.style0# Editor Colors editor.backgro…

使用Hydra进行AI项目的动态配置管理

引言:机器学习中的超参数调优挑战 在机器学习领域,超参数调优是决定模型性能的关键环节。不同的模型架构,如神经网络中的层数、节点数,决策树中的最大深度、最小样本分割数等;以及各种训练相关的超参数,像学习率、优化器类型、批量大小等,其取值的选择对最终模型的效果…

深入了解 K-Means 聚类算法:原理与应用

引言 在数据科学和机器学习的世界中&#xff0c;聚类是一项非常重要的技术&#xff0c;它帮助我们根据数据的相似性将数据划分为不同的组或簇。聚类算法在许多领域中得到了广泛的应用&#xff0c;如图像处理、市场细分、基因研究等。K-Means 聚类算法作为最常见的无监督学习算…

永磁同步电机无速度算法--反电动势观测器

一、原理介绍 在众多无位置传感器控制方法中&#xff0c;低通滤波反电势观测器结构简单&#xff0c;参数整定容易&#xff0c;易于编程实现。但是该方法估计出的反电势会产生相位滞后&#xff0c;需要在估计永磁同步电机转子位置时进行了相位补偿。 二、仿真模型 在MATLAB/si…

【Linux】命令行参数 | 环境变量(四)

目录 前言&#xff1a; 一、命令行参数&#xff1a; 1.main函数参数 2.为什么有它&#xff1f; 二、环境变量&#xff1a; 1.main函数第三个参数 2.查看shell本身环境变量 3.PATH环境变量 4.修改PATH环境变量配置文件 5.HOME环境变量 6.SHELL环境变量 7.PWD环境变…

java高级(IO流多线程)

file 递归 字符集 编码 乱码gbk&#xff0c;a我m&#xff0c;utf-8 缓冲流 冒泡排序 //冒泡排序 public static void bubbleSort(int[] arr) {int n arr.length;for (int i 0; i < n - 1; i) { // 外层循环控制排序轮数for (int j 0; j < n -i - 1; j) { // 内层循环…

深度剖析数据分析职业成长阶梯

一、数据分析岗位剖析 目前&#xff0c;数据分析领域主要有以下几类岗位&#xff1a;业务数据分析师、商业数据分析师、数据运营、数据产品经理、数据工程师、数据科学家等&#xff0c;按照工作侧重点不同&#xff0c;本文将上述岗位分为偏业务和偏技术两大类&#xff0c;并对…

Web3.py 入门笔记

Web3.py 学习笔记 &#x1f4da; 1. Web3.py 简介 &#x1f31f; Web3.py 是一个 Python 库&#xff0c;用于与以太坊区块链进行交互。它就像是连接 Python 程序和以太坊网络的桥梁。 官方文档 1.1 主要功能 查询区块链数据&#xff08;余额、交易等&#xff09;发送交易与…

NFC拉起微信小程序申请URL scheme 汇总

NFC拉起微信小程序&#xff0c;需要在微信小程序开发里边申请 URL scheme &#xff0c;审核通过后才可以使用NFC标签碰一碰拉起微信小程序 有不少人被难住了&#xff0c;从微信小程序开发社区汇总了以下信息&#xff0c;供大家参考 第一&#xff0c;NFC标签打开小程序 https://…

《Python实战进阶》No 8:部署 Flask/Django 应用到云平台(以Aliyun为例)

第8集&#xff1a;部署 Flask/Django 应用到云平台&#xff08;以Aliyun为例&#xff09; 2025年3月1日更新 增加了 Ubuntu服务器安装Python详细教程链接。 引言 在现代 Web 开发中&#xff0c;开发一个功能强大的应用只是第一步。为了让用户能够访问你的应用&#xff0c;你需…

【JAVA SE基础】抽象类和接口

目录 一、前言 二、抽象类 2.1 抽象类的概念 2.2 抽象类语法 2.3 抽象类特性 2.4 抽象类的作用 三、接口 3.1 什么是接口 3.2 语法规则 3.3 接口使用 3.4 接口特性 3.5 实现多接口 3.6 接口间的继承 四、Object类 4.1 获取对象信息&#xff08; toString() &…