LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)

news2025/2/15 14:45:33

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145640762


GPRO,即 Group Relative Policy Optimization分组相对的策略优化,是 PPO(Proximal Policy Optimization, 近端策略优化) 的优化版本,省略优化 评论家模型(Critic Model),用于估计价值(Value Function Model),降低模型训练的资源消耗。

GRPO 目标的工作原理如下:

  1. 为查询生成一组响应。
  2. 根据预定义的标准(例如准确性、格式),计算每个响应的奖励。
  3. 比较组内的反应以计算他们的相对优势。
  4. 更新策略以支持具有更高优势的响应,剪裁(clip)确保的稳定性。
  5. 规范更新以防止模型偏离基线太远。

GRPO 有效的原因:

  • 无需评论:GRPO 依靠群体比较,避免对于单独评估者的需求,从而降低了计算成本。
  • 稳定学习:剪裁(clip) 和 KL 正则化确保模型稳步改进,不会出现剧烈波动。
  • 高效训练:通过关注相对性能,GRPO 非常适合推理等绝对评分困难的任务。

在 DeepSeekMath (2024.4) 中,使用 GPRO 代替 PPO。

GRPO

回顾一下 PPO 模型的公式与框架,PPO 是先训练 奖励模型(RM),通过强化学习策略,将奖励模型的能力,学习到大语言模型中,同时,注意模型的输出符合之前的预期,不要偏离过远(KL Divergence)。即:

  • RM(Reward Model, 奖励模型) m a x r ϕ { E ( x , y w i n , y l o s s ) ∼ D [ l o g   σ ( r ϕ ( x , y w i n ) − r ϕ ( x , y l o s s ) ) ] } \underset{r_{\phi}}{max} \{ {E_{(x,y_{win},y_{loss}) \sim D}}[log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss}))] \} rϕmax{E(x,ywin,yloss)D[log σ(rϕ(x,ywin)rϕ(x,yloss))]}
  • PPO(Proximal Policy Optimization, 近端策略优化) m a x π θ { E x ∼ D , y ∼ π θ ( y ∣ x ) [ r ϕ ( x , y ) ] − β D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] } \underset{\pi_{\theta}}{max} \{ E_{x \sim D,y \sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)] - \beta D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] \} πθmax{ExD,yπθ(yx)[rϕ(x,y)]βDKL[πθ(yx)∣∣πref(yx)]}
  • KL 散度(KL Divergence)

D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] = π r e f ( y ∣ x )   l o g π r e f ( y ∣ x ) π θ ( y ∣ x ) = π r e f ( y ∣ x ) ( l o g π r e f ( y ∣ x ) − l o g π θ ( y ∣ x ) ) \begin{align} D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] &= \pi_{ref}(y|x) \ log{\frac{\pi_{ref}(y|x)}{\pi_{\theta}(y|x)}} \\ &= \pi_{ref}(y|x) (log \pi_{ref}(y|x) - log\pi_{\theta}(y|x)) \end{align} DKL[πθ(yx)∣∣πref(yx)]=πref(yx) logπθ(yx)πref(yx)=πref(yx)(logπref(yx)logπθ(yx))

其中,Actor 和 Critic 损失函数如下:
a [ i , j ] = r e t u r n s [ i , j ] − v a l u e s [ i , j ] L o s s a c t o r = − 1 M N ∑ i = 1 M ∑ j = 1 N a [ i , j ] × e x p ( l o g _ p r o b [ i , j ] − o l d _ l o g _ p r o b [ i , j ] ) L o s s c r i t i c = 1 2 M N ∑ i = 1 M ∑ j = 1 N ( v a l u e s [ i , j ] − r e t u r n s [ i , j ] ) 2 L o s s = L o s s a c t o r + 0.1 ∗ L o s s c r i t i c \begin{align} a[i,j] &= returns[i,j] - values[i,j] \\ Loss_{actor} &= -\frac{1}{MN} \sum_{i=1}^{M} \sum_{j=1}^{N} a[i,j] \times exp(log\_prob[i,j]-old\_log\_prob[i,j]) \\ Loss_{critic} &= \frac{1}{2MN} \sum_{i=1}^{M} \sum_{j=1}^{N} (values[i,j] - returns[i,j])^{2} \\ Loss & = Loss_{actor} + 0.1*Loss_{critic} \end{align} a[i,j]LossactorLosscriticLoss=returns[i,j]values[i,j]=MN1i=1Mj=1Na[i,j]×exp(log_prob[i,j]old_log_prob[i,j])=2MN1i=1Mj=1N(values[i,j]returns[i,j])2=Lossactor+0.1Losscritic
PPO 的奖励(Reward 计算),一般而言,超参数 β = 0.1 \beta=0.1 β=0.1
r t = r ψ ( q , o ≤ t ) − β l o g ( π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) ) r_{t} = r_{\psi}(q,o_{\leq t}) - \beta log(\frac{\pi_{\theta}(o_{t}|q,o_{<t})}{\pi_{ref}(o_{t}|q,o_{<t})}) rt=rψ(q,ot)βlog(πref(otq,o<t)πθ(otq,o<t))
在 PPO 中使用的价值函数(Critic Model),通常与策略模型(Policy Model)大小相当,带来内存和计算负担。在强化学习训练中,价值函数作为基线,以减少优势函数计算中的方差。然而,在 大语言模型(LLM) 的场景中,只有最后一个 Token 被奖励模型赋予奖励分数,使训练一个在每个标记处都准确的价值函数,变得复杂。GRPO 无需像 PPO 那样,使用额外的近似价值函数,而是使用同一问题产生的多个采样输出的平均奖励,作为基线。

GRPO 使用基于 组相对(Group Relative) 的优势计算方式,与奖励模型比较特性一致,因为奖励模型通常是在同一问题上不同输出之间的比较数据集上进行训练的。同时,GRPO 没有在 奖励(Reward) 中加入 KL 惩罚,而是直接将训练策略与参考策略之间的KL散度添加到损失函数中,从而避免了在计算优势时增加复杂性。

GPRO 的公式, Q Q Q 表示 Query,即输入的问题,采样出问题 q q q,推理大模型,输出 G G G 个输出 o i o_{i} oi
J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] L o s s = 1 G ∑ i = 1 G ( m i n ( ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) ) A i , c l i p ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i − β D K L ( π θ ∣ ∣ π r e f ) ) D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − l o g π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − 1 A i = r i − m e a n ( { r 1 , r 2 , … , r G } ) s t d ( { r 1 , r 2 , … , r G } ) \begin{align} J_{GRPO}(\theta) &= \mathbb{E}[q \sim P(Q), \{{o_{i}}\}_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)] \\ Loss &= \frac{1}{G}\sum_{i=1}^{G}(min((\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)})A_{i}, clip(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)},1-\epsilon,1+\epsilon)A_{i}-\beta \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref})) \\ \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref}) &= \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - log\frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - 1 \\ A_{i} &= \frac{r_{i}-mean(\{r_{1},r_{2},\ldots,r_{G}\})}{std(\{r_{1},r_{2},\ldots,r_{G}\})} \end{align} JGRPO(θ)LossDKL(πθ∣∣πref)Ai=E[qP(Q),{oi}i=1Gπθold(Oq)]=G1i=1G(min((πθold(oiq)πθ(oiq))Ai,clip(πθold(oiq)πθ(oiq),1ϵ,1+ϵ)AiβDKL(πθ∣∣πref))=πθ(oiq)πref(oiq)logπθ(oiq)πref(oiq)1=std({r1,r2,,rG})rimean({r1,r2,,rG})
GRPO 的 KL 散度,使用蒙特卡洛(Monte-Carlo) 近似计算 KL散度(Kullback-Leibler Divergence),结果始终为正数。

参考源码,TRL - GRPO:

# Advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# KL 散度
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# 期望
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# 联合 loss
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# mask loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

GRPO 的 训练源码:trl/trainer/grpo_trainer.py

GRPO

PPO 的伪码流程:

policy_model = load_model()
ref_model = policy_model.copy()  # 不更新
critic_model = load_reward_model(only_last=False)
reward_model = critic_mode.copy()    # 不更新

for i in steps:
    # 1. 采样阶段
    prompts = sample_prompt()
    # old_log_probs[i][j](from policy_model), old_values[i][j](from critic_model)
    responses, old_log_probs, old_values = respond(policy_model, critic_model, prompts)
    
    # 2. 反馈阶段
    scores = reward_model(prompts, responses)
    # ref_log_probs[i][j](from ref_model)
    ref_log_probs = analyze_responses(ref_model, prompts, responses)  # ref logps
    # rewards[i][j] = scores[i] - (old_log_probs[i][j] - ref_log_prob[i][j])
    rewards = reward_func(scores, old_log_probs, ref_log_probs) # 奖励计算
    # advantages[i][j] = rewards[i][j] - old_values[i][j] 
    advantages = advantage_func(rewards, old_values)  # 奖励(r)-价值(v)=优势(a)
    
    # 3. 学习阶段
    for j in ppo_epochs:  # 多次更新学习,逐渐靠近奖励
        log_probs = analyze_responses(policy_model, prompts, responses)
        values = analyze_responses(critic_model, prompts, responses)
        
        # 更新 actor(policy) 模型,学习更新的差异,advantages[i][j]越大,强化动作
        actor_loss = actor_loss_func(advantages, old_log_probs, log_probs)  
        critic_loss = critic_loss_func(rewards, values)  # 更新 critic 模型
        
        loss = actor_loss + 0.1 * critic_loss   # 更新
        train(loss, policy_model.parameters(), critic_model.parameters())   # 参数

参考 知乎 - 图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读

KL 散度的实现,如下:

import torch
import torch.nn.functional as F
# 假设我们有两个概率分布 P 和 Q
P = torch.tensor([0.1, 0.2, 0.7])   # 参考的、真实的
Q = torch.tensor([0.2, 0.3, 0.5])   # 模型生成的
# 计算 Q 的对数概率
log_Q = torch.log(Q)
# 使用 PyTorch 的 kl_div 函数计算 KL 散度
kl_divergence = F.kl_div(log_Q, P, reduction='sum')  # 注意先Q后P
print(f"KL Div (PyTorch): {kl_divergence}")
log_P = torch.log(P)
kl_elementwise = P * (log_P - log_Q)
# 对所有元素求和,得到 KL 散度
kl_divergence = torch.sum(kl_elementwise)

参考:

  • 知乎 - GRPO: Group Relative Policy Optimization
  • GitHub - GRPO Trainer
  • Medium - The Math Behind DeepSeek: A Deep Dive into Group Relative Policy Optimization (GRPO)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2298713.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于springboot 以及vue前后端分离架构的求职招聘系统设计与实现

基于springboot 以及vue前后端分离架构的求职招聘系统设计与实现 随着互联网技术的飞速发展&#xff0c;求职招聘行业也在不断发生变革。传统的求职招聘方式往往存在着信息不对称、效率低下、交易成本高等问题&#xff0c;导致企业的招聘成本增加&#xff0c;求职者的体验下降…

Spring Boot整合协同过滤算法,实现个性化推荐

1. 引言 在这篇文章中&#xff0c;我们将展示如何使用 Spring Boot 框架与 协同过滤算法 相结合来构建一个简单的推荐系统。推荐系统广泛应用于电商、电影推荐、社交平台等领域。协同过滤算法通过分析用户行为&#xff0c;找出相似的用户或者物品&#xff0c;从而实现个性化推荐…

自己部署 DeepSeek 助力 Vue 开发:打造丝滑的时间线(Timeline )

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 自己…

光谱相机在天文学领域的应用

天体成分分析 恒星成分研究&#xff1a;恒星的光谱包含了其大气中各种元素的吸收和发射线特征。通过光谱相机精确测量这些谱线&#xff0c;天文学家能确定恒星大气中氢、氦、碳、氮、氧等元素的含量。如对太阳的光谱分析发现&#xff0c;太阳大气中氢元素占比约 71%&#xff0…

深度卷积神经网络实战海洋动物图像识别

本文采用深度卷积神经网络作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。YOLOv11以其高效的特征提取能力&#xff0c;在多个图像分类任务中展现出卓越性能。本研究针对5种海洋动物数据集进行训练和优化&#xff0c;该数据集包含丰富的海…

MySQL-mysql zip安装包配置教程

网上的教程有很多&#xff0c;基本上大同小异。但是安装软件有时就可能因为一个细节安装失败。我也是综合了很多个教程才安装好的&#xff0c;所以本教程可能也不是普遍适合的。 安装环境&#xff1a;win11 1、下载zip安装包&#xff1a; MySQL8.0 For Windows zip包下载地址…

ECP在Successfactors中paylisp越南语乱码问题

导读 pyalisp:ECP中显示工资单有两种方式&#xff0c;一种是PE51&#xff0c;一种是hrform&#xff0c;PE51就是划线的那种&#xff0c; 海外使用的比较多&#xff0c;国内基本没人使用&#xff0c;hrform就是pdf&#xff0c;可以编辑pdf&#xff0c;这个国内相对使用的人 比…

PDF另存为图片的一个方法

说明 有时需要把PDF的每一页另存为图片。用Devexpress可以很方便的完成这个功能。 窗体上放置一个PdfViewer。 然后循环每一页 for (int i 1; i < pdfViewer1.PageCount; i) 调用 chg_pdf_to_bmp函数获得图片并保存 chg_pdf_to_bmp中调用了PdfViewer的CreateBitmap函数…

本地部署DeepSeek集成VSCode创建自己的AI助手

文章目录 安装Ollama和CodeGPT安装Ollama安装CodeGPT 下载并配置DeepSeek模型下载聊天模型&#xff08;deepseek-r1:1.5b&#xff09;下载自动补全模型&#xff08;deepseek-coder:1.3b&#xff09; 使用DeepSeek进行编程辅助配置CodeGPT使用DeepSeek模型开始使用AI助手 ✍️相…

无人机雨季应急救灾技术详解

无人机在雨季应急救灾中发挥着至关重要的作用&#xff0c;其凭借机动灵活、反应迅速、高效安全等特点&#xff0c;为救灾工作提供了强有力的技术支撑。以下是对无人机雨季应急救灾技术的详细解析&#xff1a; 一、无人机在雨季应急救灾中的应用场景 1. 灾情侦查与监测 无人机…

DeepSeek本地化部署【window下安装】【linux下安装】

一、window 本地安装指导 1.1、下载window安装包 https://ollama.com/download/OllamaSetup.exe 1.2、点击下载好的安装包进行安装 检测安装是否成功&#xff1a; C:\Users\admin>ollama -v ollama version is 0.5.7有上面的输出&#xff0c;则证明已经安装成功。 配置…

Ae:常见的光照控件和材质控件

在 After Effects中&#xff0c;几种模拟效果都有类似的光照控件和材质控件&#xff0c;比如&#xff0c;焦散、卡片动画、碎片等。 光照控件和材质控件允许用户模拟不同光源、阴影和高光效果&#xff0c;控制表面反射特性&#xff0c;从而实现真实的光照和反射模拟。适用于材质…

【鸿蒙开发】第三十章 应用稳定性-检测、分析、优化、运维汇总

目录​​​​​​​ 1 概述 2 使用Asan检测内存错误 2.1 背景 2.2 原理概述 2.3 使用约束 2.4 配置参数 2.4.1 在app.json5中配置环境变量 2.4.2 在Run/Debug Configurations中配置环境变量 2.5 Asan使能 方式一 方式二 运行ASan 2.6 ASan异常检测类型 heap-buf…

Linux软件编程:IO编程

IO&#xff08;linux输入输出&#xff09; 1. IO概念&#xff1a; I&#xff1a;输入 O&#xff1a;输出 Linux 一切皆是文件 屏幕 -> /dev/tty 磁盘 -> /dev/sda 网卡 键盘 -> /dev/event 鼠标-> /dev/mice 都是一个文件 2. IO操作的对象&#xff1a; 文件 3. 文…

javaEE2

maven 搭建 前后端交互 HTML servlet 后台和数据库交互 servlet jdbc 未来 servlet-->springmvc jdbc-->mybatis-->mybatisplus/jpa javaee-->spring-->springboot SERVLET tomcat ~Apache 服务 Apache(音译为阿帕奇)是世界上使用排名第一的Web服务器…

2025最新深度学习pytorch完整配置:conda/jupyter/vscode

从今天开始&#xff0c;开始一个新的专栏&#xff0c;更新深度学习相关的内容&#xff0c;从入门到精通&#xff0c;首先的首先是关于环境的配置指南&#xff1a;工欲善其事必先利其器&#xff01; PyTorch 是由 Facebook&#xff08;现 Meta&#xff09;开发的 开源深度学习框…

华为小艺助手接入DeepSeek,升级鸿蒙HarmonyOS NEXT即可体验

小艺助手接入DeepSeek的背景与意义 随着人工智能技术的不断发展&#xff0c;大模型成为推动智能交互升级的关键力量。DeepSeek在自然语言处理等领域具有出色的表现&#xff0c;其模型在语言理解、生成等方面展现出强大的能力。华为小艺助手接入DeepSeek&#xff0c;旨在借助其先…

Git 查看修改记录 二

Git 查看修改记录 二 续接 Git 查看一个文件的修改记录 一 一、修改 A.txt 修改 A.txt number6执行命令 git add . git commit -a -m "修改 number6" # git commit -a -m "修改 number6" 执行 输出如下 # $ git commit -a -m "修改 number6"…

【STM32】增量型旋钮编码器

1.增量型旋钮编码器原理 该编码器有A&#xff0c;B两相&#xff0c;当顺时针旋转时 B相会提前A相90度&#xff08;匀速转的时候&#xff09;&#xff0c;也就是A相上升沿时&#xff0c;B相对应高电平&#xff0c;计数器会1&#xff0c;A相下降沿时&#xff0c;B相为低电平时&…

电动汽车电池监测平台系统设计(论文+源码+图纸)

1总体设计 本次基于单片机的电池监测平台系统设计&#xff0c;其整个系统架构如图2.1所示&#xff0c;其采用STC89C52单片机作为控制器&#xff0c;结合ACS712电流传感器、TLC1543模数转换器、LCD液晶、DS18B20温度传感器构成整个系统&#xff0c;在功能上可以实现电压、电流、…