论文题目:STEP-DPO: STEP-WISE PREFERENCE OPTIMIZATION FOR LONG-CHAIN REASONING OF LLMS
翻译为中文就是:“LLMs长链推理的逐步偏好优化”
论文由港中文贾佳亚团队推出,基于推理步骤的大模型优化策略,能够像老师教学生一样优化大模型。
以Qwen2-72B-Instruct模型作为基础模型进行微调优化后,其数学成绩超越了GPT-4、Gemini1.5-Pro、Claude3-Opus等闭源模型。
论文链接:https://arxiv.org/pdf/2406.18629
代码仓库:https://github.com/dvlab-research/Step-DPO
1. 介绍
大语言模型(LLMs)在数学推理上具有重大挑战,这是由于数学需要精确的推理链。然而,直接偏好优化(DPO)对长链数学推理的益处有限,因为采用DPO的模型难以识别错误答案中的详细错误。
所以作者提出了Step-DPO方法,它将整个答案划分多个步骤作答(Step1, Step2, Step3, ...),大大提高的模型的精度。
在MATH数据集上,在Qwen2-7B-Instruct上准确率从53.0% 提升到58.6%,GSM8K数据集,准确率从85.5%提升到87.9% 。使用 Qwen2-72B-Instruct模型,在MATH和GSM8K上分别取得 70.8% 和 94.0%的准确率。
1.1 像教育学生一样训练大模型
数学推理被认为是大语言模型(LLMs)中一种关键的长链推理能力。由于需要广泛的思维链(CoT),这项任务尤其具有挑战性,其中可能包括许多推理步骤,这些步骤中的任何错误都可能导致最终得不到正确答案。
(1)首先,最常用的方法就是监督微调(SFT),使用各种数据增强对齐来微调模型。然而,当SFT数据达到一定数量时,模型经常出现幻觉,性能也随之趋于饱和。一个潜在的原因是,随着首选输出的概率增加,不理想输出的概率也会增加。这种现象使得模型在长链推理中更容易出错。
(2)最近,直接偏好优化(DPO)(Rafailov et al., 2024)被提出用于使用偏好对数据进行对齐(每个偏好对都包含一个输入提示、偏好输出及非偏好输出),因其简单性而广受欢迎。尽管DPO在Chat聊天任务中很有效,但它对长链(long-chain)数学任务效果不明显。如下图2所示。
(3)于是作者提出了Step-DPO,基于推理步骤的直接偏好优化。Step-DPO 逐步检查每个步骤的答案是否正确,这使得模型能够轻松定位错误Step,以进行有效的优化,显著增强了长链推理。
2. STEP-DPO 公式
2.1 DPO
我们先看到DPO的优化目标函数:
其中, 是输入提示 , 分别表示正确的回答、错误的回答, 是偏好对数据集。 表示 sigmoid 函数, 与 分别表示当前要优化的微调模型 以及训练过程中保存不变的参照模型, 是一个超参数用来控制距离。
2.2 Step-DPO
我们再看到Step-DPO,它不再像DPO从整体对比答案,而是将每个推理步骤视为一个基本单元,对比单个推理步骤,更精细地提升模型的推理能力。目标优化公式:
回答 可以分解为多个步骤 , 表示输入提示。Step-DPO 优化目标就是最大化正确的下一个推理步骤 的概率,最小化错误步骤 的概率,如图3所示。
3. 分布式数据构建
Step-DPO 的训练数据集是怎样的呢?每个数据样本中应该包含下面4项:
1)prompt ;
2)初始推理步骤 ;
3)首选推理步骤 ;
4)不需要(错误)的推理步骤
如下图所示:
(1)错误收集
首先,我们收集数学问题问答的数据集 ,x 是数学问题, 是真实答案。
然后,使用初始(参照)模型 来得到每个数学问题 x 的答案。
在进行模型推理之前,添加思维链(CoT)前缀作为提示,比如:“Let‘s think step by step. Step 1:”,以确保模型的推理结果被结构化为多个推理步骤。
模型推理完成之后可得到每个数学问题x的推理结果y,然后选择与真实答案 不一致的那些结果,汇总得到数据集 :
(2)错误步骤定位
假设每个错误的推理结果都被明确地表示为 推理步骤序列 ,随后需要人工或利用GPT-4验证每个推理步骤的正确性,直到找到第一个错误步骤 ,选择 作为错误的推理步骤 。这样得到一个包含错误步骤的数据集 :
(3)步骤修正
为了获得 中每个样本的相应正确推理步骤,需要通过用 提示x 和前面的正确推理步骤 通过模型 来采样多个输出 ,该过程被表述为:
随后,保留那些最终答案与实际情况相匹配的输出。我们选择 中的第一个推理步骤作为 ,从而得到最终的数据集D:
数据样本示例如 Figure 5 所示。
4. 实验结果
Step-DPO 可以在SFT模型或现有的开源 Instruct 模型上进行微调,仅通过 10K 数据以及数百个训练步数,即可去得大幅度数学能力提升。
其中 Qwen2-72B-Instruct + Step-DPO 取得了 70.8% 和 94.0% 准确率在 MATH 和 GSM8K 数据集上。
在难度较高的包含数学竞赛题 Odyssey-MATH 榜单上也有显著提升。
突出了 Step-DPO 强大泛化能力,模型更加鲁棒,减少幻觉的产生。
如下三个例子:
1. 假设h(x)=f-1(x),如果h(2)=10,h(10)=1,h(1)=2,求f(f(10))
2. t的平方根大于2且小于3.5,满足这一条件的整数t有多少个?
下面比较难的数学竞赛题也能做对
3. 在所有非增函数f:{1,2,…,10}→{1,2,…,10}中,有些函数有固定点,另一些没有,这两种函数的数量相差多少?
参考:
https://github.com/dvlab-research/Step-DPO
贾佳亚团队新作:10k数据让大模型数学能力超GPT-4