Contents
- Introduction
- Method
- Experiments
- References
Introduction
- CoT 推理可以有效提升 LLM 推理能力,但 few-shot prompting 无法发挥 CoT 的全部潜力,训练能够生成中间推理步骤 (i.e., rationale) 的 LLM 又需要大量人工标注 rationale,为此作者提出 STaR (Self-Taught Reasoner),可以仅利用少量含有 rationale 的标注数据和大量不含 rationale 的标注数据,迭代式地生成大量含有 rationale 的数据集并基于此训练能够生成 rationale 的 LLM,有效提升 LLM 的复杂推理能力
Method
- Rationale Generation Bootstrapping (STaR Without Rationalization). 给定预训练 LLM
M
M
M 和 small prompt set
P
=
{
(
x
i
p
,
r
i
p
,
y
i
p
)
}
i
=
1
P
\mathcal{P}=\{(x_{i}^{p},r_{i}^{p},y_{i}^{p})\}_{i=1}^{P}
P={(xip,rip,yip)}i=1P (e.g.
P
=
10
P = 10
P=10),其中
x
x
x 为问题,
r
r
r 为中间推理步骤,
y
y
y 为问题回答,可以利用 few-shot prompting 为一个更大的数据集
D
=
{
(
x
i
,
y
i
)
}
i
=
1
D
\mathcal D=\{(x_i,y_i)\}_{i=1}^D
D={(xi,yi)}i=1D 生成中间推理步骤
r
^
i
\hat r_i
r^i 和答案
y
^
i
\hat y_i
y^i,这样就得到了含有中间推理步骤的大规模数据集。此外,作者只保留其中
y
^
i
=
y
i
\hat y_i=y_i
y^i=yi 的样本,因为这些样本对应的中间推理步骤质量总体来说会更高一些,由此得到 filtered dataset,在此数据集上微调
M
M
M 得到可以直接生成中间推理步骤的 LLM. 上述步骤为 1 个循环,STaR 会重复上述循环多次,每次都用上一轮循环中得到的最新的生成中间推理步骤的 LLM
M
n
−
1
M_{n-1}
Mn−1 为
D
\mathcal D
D 生成中间推理步骤得到 filtered dataset,然后在该数据集上基于预训练 LLM
M
M
M 重新训练得到新的生成中间推理步骤的 LLM
M
n
M_n
Mn;上述优化过程可以被近似看作 policy gradient,其中
J
(
M
,
X
,
Y
)
J(M,X,Y)
J(M,X,Y) 为 total expected reward across the dataset
- Rationalization. 上述步骤还有一个缺点,就是如果
D
\mathcal D
D 中某些难样本始终无法生成正确答案,那么这些样本将永远无法加入 filtered dataset,无法被有效学习;为此,作者给生成错误答案的样本 prompt 中加入提示正确答案的 hint 来引导模型生成中间推理步骤和最终答案
- STaR.
Experiments
- Symbolic Reasoning: Results on Arithmetic.
- Natural Language Reasoning: Commonsense Question Answering.
- Mathematical Reasoning in Language: Grade School Math.
References
- Zelikman, Eric, et al. “Star: Bootstrapping reasoning with reasoning.” Advances in Neural Information Processing Systems 35 (2022): 15476-15488.