GRPO训练下的参考模型选择

news2025/4/12 7:18:59

一、普通全量微调模型

核心机制:模型克隆
  1. 深拷贝创建

    • 通过create_reference_model(model)对当前模型进行完全复制(包括所有层和参数)。
    • 示例代码:
      import copy
      def create_reference_model(model):
          ref_model = copy.deepcopy(model)
          ref_model.requires_grad_(False)  # 冻结参数
          ref_model.eval()                 # 评估模式
          return ref_model
      
    • 技术细节:深拷贝会递归复制所有子模块,确保参考模型与原始模型完全独立。
  2. 参数冻结与评估模式

    • requires_grad_(False):关闭梯度计算,防止反向传播影响参考模型。
    • eval():关闭Dropout和BatchNorm等训练专用层,保证输出稳定性。
  3. 内存占用分析

    • 原始模型参数量为N时,总内存占用≈2N。
    • 典型场景:7B参数的模型需要约14GB显存(假设FP32精度)。
  4. 同步机制(可选)

    • 启用sync_ref_model后,通过回调函数周期性将参考模型参数替换为当前模型:
      class SyncRefModelCallback:
          def on_step_end(self, args, state, control, **kwargs):
              with torch.no_grad():
                  for ref_param, model_param in zip(ref_model.parameters(), model.parameters()):
                      ref_param.copy_(model_param.detach())
      
    • 应用场景:允许参考模型跟随训练进度,实现动态策略约束。

二、PEFT微调模型

核心机制:动态适配器切换
  1. PEFT架构特性

    • 典型实现(如LoRA):在原始模型基础上添加低秩适配器矩阵。
    • 参数分布:基础模型参数冻结(占比≈95%),仅训练适配器(占比≈5%)。
  2. 禁用适配器原理

    • 上下文管理器disable_adapter()的工作流程:
      class LoraModel:
          def disable_adapter(self):
              original_forward = self.layer.forward
              self.layer.forward = self.original_forward  # 恢复原始前向传播
      
    • 技术效果:前向计算时绕过所有适配器层,等同于原始模型。
  3. 内存优化原理

    • 不需要存储额外模型实例,节省≈N显存。
    • 示例对比:7B模型PEFT微调时,显存占用从14GB降至≈7.5GB。
  4. 梯度计算隔离

    • 即使禁用适配器,反向传播时仍只会更新适配器参数。
    • 实现方式:通过PyTorch的torch.no_grad()上下文管理器:
      with model.disable_adapter():
          with torch.no_grad():  # 确保不计算参考模型梯度
              outputs = model(inputs)
      

三、DeepSpeed ZeRO-3模式

核心机制:权重重加载
  1. ZeRO-3分片原理

    • 参数分布:模型参数被划分到多个GPU,单个设备只保留部分参数。
    • 示例:8 GPU训练时,每个GPU存储约1/8的参数和优化器状态。
  2. 无法深拷贝的根本原因

    • 分片后的参数无法通过常规方式访问完整副本。
    • 尝试复制会引发错误:RuntimeError: Cannot access full parameter outside of forward/backward
  3. 重加载实现细节

    • 从磁盘或缓存重新初始化模型:
      model_id = "qwen/Qwen1.5-7B"
      ref_model = AutoModelForCausalLM.from_pretrained(
          model_id,
          device_map="auto",
          torch_dtype=torch.bfloat16
      )
      
    • 优化技巧:使用accelerate库的disk_offload功能减少内存压力。
  4. 分布式一致性保证

    • 通过DeepSpeed的broadcast_parameters()确保所有GPU加载相同初始权重。
    • 关键代码:
      deepspeed.utils.broadcast_parameters(ref_model.state_dict())
      

四、KL散度计算流程

无论采用何种参考模型机制,最终目标都是计算:
D K L ( π θ ∣ ∣ π r e f ) = E x ∼ π θ [ log ⁡ π θ ( x ) − log ⁡ π r e f ( x ) ] D_{KL}(\pi_{\theta} || \pi_{ref}) = \mathbb{E}_{x \sim \pi_{\theta}}[\log \pi_{\theta}(x) - \log \pi_{ref}(x)] DKL(πθ∣∣πref)=Exπθ[logπθ(x)logπref(x)]

  1. 计算步骤

    def compute_kl_divergence(model, ref_model, inputs):
        with torch.no_grad():
            ref_logits = ref_model(**inputs).logits
        current_logits = model(**inputs).logits
        
        kl = F.kl_div(
            F.log_softmax(current_logits, dim=-1),
            F.softmax(ref_logits.detach(), dim=-1),
            reduction='batchmean'
        )
        return kl
    
  2. 各机制下的实现差异

    • 普通微调:直接调用ref_model计算
    • PEFT:在disable_adapter()上下文中用同一模型计算
    • ZeRO-3:使用独立加载的ref_model计算

五、选型建议

微调类型适用场景显存开销计算效率
普通全量微调单卡/多卡非ZeRO环境
PEFT微调低显存设备(如消费级GPU)
DeepSpeed ZeRO-3超大模型训练(如>20B参数)最低较低

典型决策流程:

是否需要训练超大模型(>20B)?
├─ 是 → 采用DeepSpeed ZeRO-3
└─ 否 → 显存是否充足(如A100 80G)?
         ├─ 是 → 普通全量微调
         └─ 否 → 使用PEFT微调

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

英菲克(INPHIC)A9无线蓝牙鼠标 链接电脑的方式

英菲克(INPHIC)A9鼠标链接至电脑时,要长按住“模式切换MODE”按钮5秒左右的时间,此时模式指示灯变成蓝色,并且闪烁。 这时使用电脑的蓝牙设置中,“添加设备”,会出现BT4.0 Mouse提示&#xff0…

lua表table和JSON字符串互转

--print("local ssxc{\n"..string.gsub(str,":","").."\n}") Utils {} ---------------------------------------------------------------------------------- -- Lua-Table 与 string 转换 local function value2string(value, isA…

【每日一个知识点】分布式数据湖与实时计算

在现代数据架构中,分布式数据湖(Distributed Data Lake) 结合 实时计算(Real-time Computing) 已成为大数据处理的核心模式。数据湖用于存储海量的结构化和非结构化数据,而实时计算则确保数据能够被迅速处理…

c语言数据结构--------拓扑排序和逆拓扑排序(Kahn算法和DFS算法实现)

#include <stdio.h> #include <string.h> #include <stdbool.h> #include <stdlib.h>//使用卡恩算法(Kahn)和深度优先算法(DFS)实现//拓扑排序和逆拓扑排序//拓扑排序和逆拓扑排序顶点顺序相反//图&#xff0c;邻接矩阵存储 #define MaxVertexNum 100 …

谷粒微服务高级篇学习笔记整理---nginx搭建正反向代理

正向与反向代理 **正向代理:**客户端向代理服务器发请求并指定目标服务器,代理向目标转交请求并将获得的内容转给客户端。 反向代理:用户直接访问反向代理服务器就可以获得目标服务器的资源。反向代理服务器统一了访问入口。 给首页配置反向代理 修改windows的hosts文件配…

2.pycharm保姆级安装教程

一、pycharm安装 1.官网上下载好好软&#xff0c;双击打开 2.下一步 3.修改路径地址 (默认也可以) 4.打勾 5.安装 不用重启电脑 二、添加解释器 1.双击软件&#xff0c;打开 2.projects – new project 3.指定项目名字&#xff0c;项目保存地址&#xff0c;解释器 4.右击 – …

【SQL】取消sql某一列的唯一值key值的方法

在插入数据到sql时&#xff0c;遇到了这个问题&#xff1a; Duplicate entry ‘XXX’ for key 起因是&#xff1a; 我之前设计表的时候&#xff0c;手动给product_title 这个列加了一个key&#xff0c; key 是这个字段的唯一键约束&#xff0c;就不能重复在这一列存入重复的数…

数据库--SQL

SQL&#xff1a;Structured Query Language&#xff0c;结构化查询语言 SQL是用于管理关系型数据库并对其中的数据进行一系列操作&#xff08;包括数据插入、查询、修改删除&#xff09;的一种语言 分类&#xff1a;数据定义语言DDL、数据操纵语言DML、数据控制语言DCL、事务处…

SQL语句(一)—— DDL

目录 一、SQL 基础知识 &#xff08;一&#xff09;SQL 通用语法 &#xff08;二&#xff09;SQL 分类 二、DDL —— 数据库操作 1、查询所有数据库 2、查询当前数据库 3、创建数据库 4、删除数据库 5、切换数据库 三、DDL —— 表操作 &#xff08;一&#xff09;查…

Husky目标跟踪

1.0设备清单 幻影峡谷、适配器 摄像头及数据线、显卡欺骗器 外接屏幕、键盘鼠标 Husky底盘、便携显示屏、键盘鼠标 移动电源 1.1连线 插排——移动电源幻影峡谷——适配器——插排摄像头——幻影峡谷&#xff08;摄像头固定在机械臂前方的底盘上&#xff09;键盘鼠标显示器…

Python----机器学习(线性回归:自求导的方法实现)

一、线性回归方程 目标&#xff1a; 线性回归的目标是找到最佳的系数来使模型与观察到的数据尽可能拟合。 应用&#xff1a; 预测&#xff1a;给定自变量的值&#xff0c;预测因变量的值。 回归分析&#xff1a;确定自变量对因变量的影响程度 线性回归是统计学和机器学习中最简…

Springcache+xxljob实现定时刷新缓存

目录 SpringCache详解 SpringCache概述 核心原理 接口抽象与多态 AOP动态代理 核心注解以及使用 公共属性 cacheNames KeyGenerator&#xff1a;key生成器 key condition&#xff1a;缓存的条件&#xff0c;对入参进行判断 注解 xxl-job详解 SpringcacheRedis实现…

vue2拖拉拽做个模拟公式工具

1. 成图 2. 介绍 就是简单拖拉拽来做个规则运算器&#xff0c;具体运算规则、校验规则自己加。 3. 代码 HTML代码 <template><div class"red-cont"><div class"red-top"><divclass"red-top-left"><div class&quo…

Windows查重工具,强烈推荐大家收藏!

我大家在用电脑的时候&#xff0c;是不是发现用得越久&#xff0c;电脑里的软件和文件就越多&#xff1f; 今天我给大家带来的这两款重复文件查找神器&#xff0c;简直就是电脑里的“清洁小能手”&#xff0c;能帮你把那些重复的文件和文件夹找出来。 Easy DupLicate Finder 重…

使用python完成手写数字识别

入门图像识别的第一个案例,看到好多小伙伴分享,也把自己当初的思路捋捋,写成一篇博客,作为记录和分享,也欢迎各位交流讨论。 实现思路 数据集:MNIST(包含60,000个训练样本和10,000个测试样本) 深度学习框架:Keras(基于TensorFlow) 模型架构:卷积神经网络(CNN) 实…

OpenLayers:如何控制Overlay的层级?

我最近在使用Overlay的时候遇到了一个问题&#xff0c;我向地图中添加了两种不同的Overlay&#xff08;下图中的蓝色标牌和粉色标牌&#xff09;&#xff0c;我希望粉色标牌可以显示在最上层&#xff0c;可偏偏蓝色标牌却将其遮挡住了。于是我对Overlay的层级开始起了兴趣&…

《Golang高性能网络编程:构建低延迟服务器应用》

在本文中&#xff0c;我们将深入探讨Golang高性能网络编程&#xff0c;帮助您构建低延迟服务器应用。我们将介绍Golang的网络编程特性、优化技巧和实际案例&#xff0c;让您更好地理解和应用Golang在网络编程领域的优势。 高性能网络编程简介 什么是Golang高性能网络编程 高性能…

数据结构C语言练习(设计循环队列)

一、循环队列简介 循环队列是一种线性数据结构&#xff0c;基于 FIFO&#xff08;先进先出&#xff09;原则&#xff0c;将队尾连接到队首形成循环。其核心优势是能复用队列之前用过的空间&#xff0c;避免普通队列 “假溢出” 问题。实现时&#xff0c;通常申请 k1 大小的数组…

vscode代码片段的设置与使用

在 Visual Studio Code (VS Code) 中&#xff0c;可以通过自定义**代码片段&#xff08;Snippets&#xff09;**快速插入常用代码模板。以下是详细设置步骤&#xff1a; 步骤 1&#xff1a;打开代码片段设置 按下快捷键 Ctrl Shift P&#xff08;Windows/Linux&#xff09;或…

uniapp -- 列表垂直方向拖拽drag组件

背景 需要在小程序中实现拖拽排序功能,所以就用到了m-drag拖拽组件,在开发的过程中,发现该组件在特殊的场景下会有些问题,并对其进行了拓展。 效果 组件代码 <template><!-- 创建一个垂直滚动视图,类名为m-drag --><scroll