人工智能咨询培训老师叶梓 转载标明出处
尽管大模型在数学、科学或编程任务上表现出优异的推理精细化能力,但它们在没有外部反馈的情况下,很难识别何时何地需要精细化。为了解决这一问题,来自Meta的FAIR团队和佐治亚理工学院的研究者们提出了一种新的方法——全球和局部精细化(Global and Local Refinements),旨在提升LLMs在没有外部反馈的情况下的自我精细化能力。
该方法通过分解问题——决定何时精细化、在哪里精细化以及如何精细化——来利用奖励模型(RMs)。具体来说研究者们首先使用结果基础奖励模型(Outcome-based Reward Models, ORMs)来判断最终答案的正确性,以决定何时需要精细化。然后,提出了一种新的步骤精细化ORM(Stepwise ORMs, SORMs),专门用来评估中间步骤的正确性,从而更准确地检测出不正确的推理步骤。
方法
改善大模型(LLMs)在推理任务中的自我精细化能力过程被分解为三个主要阶段:
-
判断草稿的正确性:研究者们首先使用ORM来预测草稿D是否正确。ORM被训练来预测草稿最终答案的正确性概率,这有助于确定何时需要对答案进行精细化。
-
定位错误的起始步骤:接下来,研究者们需要确定从哪个步骤开始精细化,即识别出第一个错误的步骤。为此,他们提出使用(S)ORM来定位错误,这使得精细化任务变得更加具体和有针对性。
-
修正初始草稿:最后研究者们需要学习如何修正初始草稿中的错误。这需要精细化模型来决定如何修正错误并继续进行。
为了在步骤级别识别错误,研究者们利用ORM的中间预测。ORM在这一步Si的预测可以被解释为策略π的价值函数,它预测具有前缀Pi的解决方案最终答案的正确性概率。这个预测依赖于数据生成策略π,因此ORM的预测可以被看作是在预测π的价值函数。
ORM预测的概率可以写成,这里Pi=(S1,...,Si)是所有先前步骤的前缀,而is_correct(A)是采样自π的完整解决方案A具有正确最终答案的事件。这可以被理解为预测策略π在状态S(由问题Q和中间步骤Sj组成)的平均回报。
尽管ORM可以近似作为价值函数,但它在识别错误步骤时可能会过于悲观,因为它预期数据生成策略π会失败。例如,如果π在涉及除法的问题上总是失败,ORM会在学生采取第一步之前就对这个问题赋予低成功率。
为了解决这个问题,研究者们提出了学习步骤精细化ORM(SORM)。SORM旨在直接近似推理任务的最优价值函数V*,这个函数对应于能够从任何逻辑上有效的中间状态Sj成功解决推理任务的最优策略。SORM通过拒绝采样学生策略π*来近似最优策略,这涉及到在模型生成的解决方案中对每个步骤Si进行采样,并检查最终答案的正确性。
研究者们设计了一个三步训练管道,用于合成精细化数据,并对ORM和SORM进行训练,如图1所示。首先,研究者们通过RL微调基础模型以产生一个强大的学生策略π。然后,他们通过在训练数据上采样π来生成ORM/SORM训练数据。最后,他们通过将不正确的rollout与正确的rollout配对来生成全局和局部的精细化数据。例如,SORM训练样本包括步骤前缀(S1, ..., Si),前缀的二元正确性标签li,以及从Pi验证正确性的验证rollout集合T1, ..., TK。
第1步:微调学生模型
首先,研究者们使用专家迭代(Expert Iteration, EI)方法微调模型,以产生基础检查点。这一步骤涉及对每个问题采样学生模型K=96次,并过滤掉最终答案不正确的样本。然后对剩余样本进行去重,构建新的微调数据集R1。接着,将R1与任何可用的监督微调(Supervised Fine-Tuning, SFT)数据结合,生成D1,并再次微调预训练模型。这个过程一直重复,直到连续微调的maj@1得分收敛。微调数据集Di是第i步生成的rollouts与之前生成的训练数据的联合(D0为空集或SFT)。
第2步:训练ORM/SORM
ORM训练数据是通过采样RL微调后的学生策略πK次来生成的。每个中间步骤Si被标记为正确或不正确,这取决于最终答案的正确性。
对于SORM,研究者们通过在模型生成的解决方案中采样近似最优策略π*,并检查最终答案的正确性来生成训练数据。这一过程使用拒绝采样方法,即从步骤Si开始,采样学生策略π进行K次rollout,生成验证轨迹T1, ..., TK,并根据正确的最终答案标记li。如果从Si开始可以找到正确的最终答案,则将Si标记为正面样本。实践中,每步采样K=8次rollout,每次生成最多300个token。否则,将Si标记为负面样本。然后,SORM的训练方式与ORM相同,预测解决方案中每个步骤后的正确标签。
为了提高通过拒绝采样近似最优策略的准确性,研究者们采取了几个后处理步骤:
-
如果步骤Si有正面标签li,则将之前所有步骤的标签都设置为正面。
-
强制执行验证rollout的一致性约束,确保每个中间结果在后续步骤中被使用。
-
平衡训练数据中每个前缀长度的正面和负面标签数量,避免模型倾向于在解决方案的开始预测正面标签,在结束时预测负面标签。
第3步:训练精细化模型
为了训练局部精细化模型,需要一个形式为(Q, AD, AR, E)的数据集,其中Q是问题,AD是初始草稿,E标记了AD中第一个错误的位置,指示在哪里进行精细化,AR是带有正确最终答案的精细化结果。在实践中,E通过在草稿中错误步骤Si前添加"[BAD]"标记来传达给局部精细化模型。在测试时,需要一个模型预测p(E|Q, AD)来定位草稿中的错误。由于SORM被明确训练来预测AD中每个步骤的正确性,因此可以通过在所有步骤上推断SORM并返回第一个预测正确性低于阈值T的步骤的索引来生成E。
图2展示了数学文字问题上的局部和全局精细化示例。左侧示例中,局部精细化在处理学生难以执行的分数除法时表现不佳。尽管所有先前步骤都是有效的,但局部精细化模型被迫再次尝试困难的操作或完全选择错误的操作。相比之下,全局精细化模型可能尝试用全新的方法解决问题。右侧示例中,模型接近最终答案,只是在最后做了一个简单的错误,局部精细化能够修正这个简单错误,而全局精细化则需要从头开始。
全局精细化数据集的构建与ORM训练数据集类似,通过将不正确的rollout与正确的rollout配对来完成。为了保持与局部精细化相似的格式,在不正确的rollout的开头放置一个[BAD]标记。然后将局部和全局精细化数据集合并,训练一个模型,使其能够进行全局和局部精细化。
评估
研究者们构建了一个测试集,用于评估ORM和SORM以及精细化模型的性能。这个测试集是通过在任务τ的测试问题Q上贪婪采样学生模型来生成的,得到的是问题和初始草稿的组合,即(Q, AD),其中Q是问题,AD是初步解答。对于两个基准测试,这种测试集都被称为(Q, D)测试集。
为了生成中间步骤的标签,研究者们采用了与SORM训练数据相同的生成过程。他们通过比较ORM和SORM的预测与这些真实标签,来评估这些模型在测试集上的表现。在评估全局精细化性能时,研究者们贪婪地推断每个(Q, AD)样本上的精细化模型,并比较得到的全局精细化AGR与真实情况。而在评估局部精细化模型时,他们首先使用ORM或SORM标注每个(Q, AD)对中第一个错误的地点,形成一个(Q, AD, E)三元组,然后利用这个三元组来贪婪采样局部精细化模型。
为了获得最佳结果,研究者们提出对每个草稿AD同时采样一个全局精细化AGR和一个局部精细化ALR,并使用ORM作为重新排名器来选择最佳解决方案。这种策略基于一个观察:全局和局部精细化各自解决了学生最初失败的问题的一个互补的、部分不重叠的子集。因此,将两种精细化与草稿结合起来,显著扩展了研究者们能够解决的问题集合。另外使用ORM对精细化进行重新排名,可以与学生模型π生成草稿的“三选一”基线进行更清晰的比较。图3提供了评估流程的图解。
基于过程的局部精细化依赖于在解决方案的轨迹中定位推理错误,这种方法的缺点是它对学生模型进行精细化的能力是不可知的。作为替代,研究者们考虑了基于价值的精细化,它依赖于反馈来识别解决方案中模型最有可能成功的步骤。
研究者们在GSM8K和SVAMP数学文字问题基准上评估了他们的精细化流程。他们微调了Llama-2 7B和13B模型,以产生包括ORM、SORM和精细化模型在内的所有下游模型。每个模型尺寸的评估都是独立的,不使用不同尺寸模型的数据或反馈。通过贪婪采样得到的maj@1模型得分用于评估模型性能。
SORM在评估中间答案方面优于ORM:在GSM8K上,SORM将ORM的中间步骤准确率提高了最多8%,从73%提高到81%(见表1)。这证实了ORM在估计中间步骤正确性方面做得相当不错,但仍有改进空间,特别是在像GSM8K这样困难的任务上对较小模型进行评估时。这种标签准确率的差异也转化为精细化最终准确度的差异,这对于ORM/SORM可靠地识别错误位置至关重要。
表1显示了7B/13B ORM和SORM在测试集标签上的步骤级准确率。测试集的正样本标签占比为45%-55%,SORM在更难的GSM8K基准上比ORM具有更好的步骤级准确率,但在SVAMP上的步骤级准确率相当。
表2显示了7B/13B ORM和SORM在测试集标签上的最终答案准确率。ORM在预测最终答案正确性方面比SORM更准确。
研究者们使用ORM作为重排器来决定何时接受精细化AR。在执行局部精细化时,他们还可以使用ORM和SORM来识别AD中的第一个错误位置。对于ORM,通过标记第一个步骤Si,使得ORM(Si) ≤ T = 0.5,其中T是阈值超参数。SORM以类似的方式识别第一个错误。
在两个基准测试GSM8K和SVAMP上,全局和局部精细化模型都在努力了解何时进行精细化。全局和局部精细化在整体模型准确率上都显示出很少的改进。GSM8K 7B全局精细化甚至降低了整体准确率,其他模型的改进最多为1%。局部精细化更有可能提高整体准确率,这可能是由于存在“[BAD]”标记,指示第一个错误的位置(和存在)。
表3显示了在不正确的模型答案上的精细化准确率。使用SORM进行局部精细化可以修复模型之前失败的高达41%的问题。
图5显示了GSM8K和SVAMP上的精细化准确率。所有精细化模型都在努力识别不需要精细化的正确草稿。仅在精细化不正确草稿时,才看到显著的改进。
全局和局部精细化解决部分不相交的、互补的问题集:为了更好地理解全局和局部精细化的比较,研究者们检查了它们正确解决的问题的重叠部分。表3的最后两行显示,当结合使用时,全局和局部精细化可以修复来自13B学生13B的41%不正确的GSM8K草稿。单独使用全局精细化和使用SORM的局部精细化只能解决28%的问题。然而,当对同一问题采取两种类型的精细化中的最佳方案时,我们在所有基准和模型尺寸的组合上显著提高了性能。这表明局部精细化能够解决大量全局精细化无法解决的问题,反之亦然。
研究者们可以使用ORM作为重排器,以选择全局和局部精细化之间的适当选项。此外,他们还可以考虑将初始草稿作为第三种可能的选项,以决定是否要进行精细化。图6显示了对每个问题的草稿、全局和局部精细化进行重排的结果。由于我们实际上是在进行三次采样,我们包括了EI学生的最佳三个(Bo3)样本作为基线。我们还报告了如果我们有一个完美的重排器,始终能够选择正确解决方案的总体准确率。
重排草稿+精细化比草稿准确率平均提高了8%。在GSM8K上,使用ORM对精细化进行重排比Bo3基线提高了最多9%,使用完美的重排器可以提高多达13%。在SVAMP上,重排的Bo3是一个更具竞争力的基线,本身比草稿准确率有了很大的提高。当使用神谕重排器时,可以看到更大的改进,13B精细化器在GSM8K上比Bo3提高了11%。
实验表明,结合使用全局和局部精细化,并将ORM作为重新排序器,显著优于单独使用任一策略,或者最佳的三个样本基线。
论文链接:https://arxiv.org/abs/2402.10963