0. 引言
《Toolformer: Language Models Can Teach Themselves to Use Tools》 论文主要探讨了语言模型(LMs)在解决新任务时的能力和局限性,并提出了一个名为 Toolformer 的新方法。该方法通过简单 API 接口将外部工具与 LMs 相结合,实现了 LMs 在自监督学习中的自我训练。实验结果表明,Toolformer 在保留了核心的语言建模能力的基础上,能显著提高零样本下的下游任务表现,且在许多情况下甚至可以与更大规模的模型相媲美。
1. 数据处理
数据处理希望达到以下的结果:
其中,颜色框出的部分是 API 调用:工具+参数+答案。那么这些数据是如何生成的呢?论文中实际上是通过 few-shot 的方式,给出一些例子让其他大模型根据例子去生成满足这些格式的数据样本。
其中,工具和对应的 Prompts:
- 问题回答
Your task is to add calls to a Question
Answering API to a piece of text.
The questions should help you get
information required to complete the
text. You can call the API by writing
"[QA(question)]" where "question" is the
question you want to ask. Here are some
examples of API calls:
Input: Joe Biden was born in Scranton,
Pennsylvania.
Output: Joe Biden was born in [QA("Where
was Joe Biden born?")] Scranton,
[QA("In which state is Scranton?")]
Pennsylvania.
Input: Coca-Cola, or Coke, is a
carbonated soft drink manufactured by
the Coca-Cola Company.
Output: Coca-Cola, or [QA("What other
name is Coca-Cola known by?")] Coke, is
a carbonated soft drink manufactured by
[QA("Who manufactures Coca-Cola?")] the
Coca-Cola Company.
Input: x
Output:
- 计算器
Your task is to add calls to a
Calculator API to a piece of text.
The calls should help you get
information required to complete the
text. You can call the API by writing
"[Calculator(expression)]" where
"expression" is the expression to be
computed. Here are some examples of API
calls:
Input: The number in the next term is 18
+ 12 x 3 = 54.
Output: The number in the next term is
18 + 12 x 3 = [Calculator(18 + 12 * 3)]
54.
Input: The population is 658,893 people.
This is 11.4% of the national average of
5,763,868 people.
Output: The population is 658,893 people.
This is 11.4% of the national average of
[Calculator(658,893 / 11.4%)] 5,763,868
people.
Input: A total of 252 qualifying matches
were played, and 723 goals were scored
(an average of 2.87 per match). This is
three times less than the 2169 goals
last year.
Output: A total of 252 qualifying
matches were played, and 723 goals were
scored (an average of [Calculator(723
/ 252)] 2.87 per match). This is twenty
goals more than the [Calculator(723 -
20)] 703 goals last year.
Input: I went to Paris in 1994 and
stayed there until 2011, so in total,
it was 17 years.
Output: I went to Paris in 1994 and
stayed there until 2011, so in total, it
was [Calculator(2011 - 1994)] 17 years.
Input: From this, we have 4 * 30 minutes
= 120 minutes.
Output: From this, we have 4 * 30
minutes = [Calculator(4 * 30)] 120
minutes.
Input: x
Output:
- 维基百科搜索
Your task is to complete a given piece
of text. You can use a Wikipedia Search
API to look up information. You can do
so by writing "[WikiSearch(term)]" where
"term" is the search term you want to
look up. Here are some examples of API
calls:
Input: The colors on the flag of Ghana
have the following meanings: red is for
the blood of martyrs, green for forests,
and gold for mineral wealth.
Output: The colors on the flag of Ghana
have the following meanings: red is for
[WikiSearch("Ghana flag red meaning")]
the blood of martyrs, green for forests,
and gold for mineral wealth.
Input: But what are the risks during
production of nanomaterials? Some nanomaterials may give rise to various
kinds of lung damage.
Output: But what are the risks
during production of nanomaterials?
[WikiSearch("nanomaterial production
risks")] Some nanomaterials may give
rise to various kinds of lung damage.
Input: Metformin is the first-line drug
for patients with type 2 diabetes and
obesity.
Output: Metformin is the first-line drug
for [WikiSearch("Metformin first-line
drug")] patients with type 2 diabetes
and obesity.
Input: x
Output:
- 机器翻译
Your task is to complete a given piece
of text by using a Machine Translation
API.
You can do so by writing "[MT(text)]"
where text is the text to be translated
into English.
Here are some examples:
Input: He has published one book: O
homem suprimido (“The Supressed Man”)
Output: He has published one book: O
homem suprimido [MT(O homem suprimido)]
(“The Supressed Man”)
Input: In Morris de Jonge’s Jeschuah,
der klassische jüdische Mann, there is a
description of a Jewish writer
Output: In Morris de Jonge’s Jeschuah,
der klassische jüdische Mann [MT(der
klassische jüdische Mann)], there is a
description of a Jewish writer
Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新
区 设 计 a plane of reference Gaochun is
one of seven districts of the provincial
capital Nanjing
Output: [MT(南京高淳县住房和城乡建设局 城市新
区 设 计)] a plane of reference Gaochun is
one of seven districts of the provincial
capital Nanjing
Input: x
Output:
- 日历
Your task is to add calls to a Calendar
API to a piece of text. The API calls
should help you get information required
to complete the text. You can call the
API by writing "[Calendar()]" Here are
some examples of API calls:
Input: Today is the first Friday of the
year.
Output: Today is the first [Calendar()]
Friday of the year.
Input: The president of the United
States is Joe Biden.
Output: The president of the United
States is [Calendar()] Joe Biden.
Input: The current day of the week is
Wednesday.
Output: The current day of the week is
[Calendar()] Wednesday.
Input: The number of days from now until
Christmas is 30.
Output: The number of days from now
until Christmas is [Calendar()] 30.
Input: The store is never open on the
weekend, so today it is closed.
Output: The store is never open on the
weekend, so today [Calendar()] it is
closed.
Input: x
Output:
2. 整体流程
API 调用表示为:
c
=
(
a
c
,
i
c
)
c = (a_c, i_c)
c=(ac,ic)
其中:
(1) c c c 表示 API 调用;
(2) a c a_c ac 是 API 的名字;
(3) i c i_c ic 是对应的输入。
根据有没有 API 输出,可以分为:
e
(
c
)
=
<
A
P
I
>
a
c
(
i
c
)
<
/
A
P
I
>
e(c) = < API > a_c(i_c) < / API >
e(c)=<API>ac(ic)</API>
e
(
c
,
r
)
=
<
A
P
I
>
a
c
(
i
c
)
→
r
<
/
A
P
I
>
e(c,r) = < API > a_c(i_c) → r < / API >
e(c,r)=<API>ac(ic)→r</API>
其中,“<API>”,“</API>” 和 “→”是特殊的 token,用于标识中间内容是需要调用第三方工具的,实际中分别对应 “[”,、“]” 和 “->”。
2.1 API调用采样
根据前一章节的数据处理,让大模型自动生成多个 API 调用,如下:
对于模型
M
M
M,
z
n
+
1
z_{n+1}
zn+1作为
z
1
,
.
.
.
,
z
n
z_1,...,z_n
z1,...,zn后一个序列的概率:
p
M
(
z
n
+
1
∣
z
1
,
.
.
.
,
z
n
)
p_M(z_{n+1}|z_1,...,z_n)
pM(zn+1∣z1,...,zn)
那么,对于每一个
i
∈
1
,
.
.
.
,
n
i\in{1,...,n}
i∈1,...,n,
i
i
i 插入 API 调用的概率:
p
i
=
p
M
(
<
A
P
I
>
∣
P
(
x
)
,
x
1
:
i
−
1
)
p_i = p_M(< API > | P(\mathbf{x}),x_{1:i-1})
pi=pM(<API>∣P(x),x1:i−1)
当设置一个概率阈值
τ
\tau
τ 后和采样上限
k
k
k 后,则可以获取插入的位置:
I
=
{
i
∣
p
i
>
τ
s
}
I= \{ i|p_i > \tau_s \}
I={i∣pi>τs}
如果满足的插入位置超过
k
k
k 之后,只取概率最大的前
k
k
k 个。
2.2 执行 API 调用
下一步执行所有 API 调用来获得相应的结果。可能涉及到调用另一个神经网络处理、执行 Python 脚本或使用检索系统在大型语料库上执行搜索。
2.3 过滤 API 调用结果
首先定义权重交叉熵 loss:
L
i
(
z
)
=
−
∑
j
=
i
n
ω
j
−
i
⋅
l
o
g
p
M
(
x
j
∣
z
,
x
1
:
j
−
1
)
L_i(\mathbf{z})=-\sum_{j=i}^n\omega_{j-i} \cdot logp_M(x_j|\mathbf{z},x_{1:j-1})
Li(z)=−j=i∑nωj−i⋅logpM(xj∣z,x1:j−1)
现考虑两种 loss:
L
i
+
=
L
i
(
e
(
c
i
,
r
i
)
)
L_i^+=L_i(e(c_i,r_i))
Li+=Li(e(ci,ri))
L
i
−
=
m
i
n
(
L
i
(
ϵ
)
,
L
i
(
e
(
c
i
,
ϵ
)
)
)
L_i^-=min(L_i(\epsilon),L_i(e(c_i,\epsilon)))
Li−=min(Li(ϵ),Li(e(ci,ϵ)))
其中:
(1) ϵ \epsilon ϵ 表示空序列;
(2) L i ( ϵ ) L_i(\epsilon) Li(ϵ) 表示没有插入 API 调用;
(3) L i ( e ( c i , ϵ ) ) L_i(e(c_i,\epsilon)) Li(e(ci,ϵ)) 表示 API 调用输出为空;
(4) L i ( e ( c i , r i ) ) L_i(e(c_i,r_i)) Li(e(ci,ri)) 表示存在API 调用输出。
给出一个过滤的阈值
τ
f
\tau_f
τf,当满足:
L
i
−
−
L
i
+
≥
τ
f
L_i^- - L_i^+ \ge \tau_f
Li−−Li+≥τf
则保留当前的 API 调用,也就是说只有当加入 API 调用后有返回值并且 loss 相比之前降低到到一定的程度,才认为这个 API 调用是有效果的。
例如,下表列出在一定的
τ
f
\tau_f
τf 阈值下的 API 调用是否保留:
2.4 模型微调
过滤完 API 调用之后,构建出新的序列:
x
∗
=
x
1
:
i
−
1
,
e
(
c
i
,
r
i
)
,
x
i
:
n
\mathbf{x}^*=x_{1:i-1},e(c_i,r_i),x_{i:n}
x∗=x1:i−1,e(ci,ri),xi:n
采用这种数据集并使用标准的语言模型去微调。
2.5 推理
进行微调后的模型生成文本时,当生成 “→” 的 token 时,表明它接下来期望有一个API 调用的响应。这时中断解码过程,调用适当的 API 来返回响应。在插入响应和 token 后继续解码过程。
3. 三方工具
调用的三方工具上面介绍过,包含:
- 问题回答
- 计算器
- 维基百科搜索
- 机器翻译
- 日历
工具输入输出示例:
4. 实验结论
4.2 实验设置
数据: 使用 CCNet 的一个子集作为模型训练数据集;
微调: 使用128的批量大小和1×10^-5的学习率对模型进行微调;
对比模型:
- GPT-J:一个没有任何微调的常规 GPT-J 模型。
- GPT-J+CC:在没有 API 调用的 CCNet 子集上微调 GPT-J
- Toolformer:在具有 API 调用数据集上微调 GPT-J
- Toolformer(disabled)禁用 API 调用的 Toolformer 模型。
4.3 下游任务
评估模型在多种下游任务上的表现,包括 LAMA、数学数据集、问答数据集、多语言问答和时态数据集。使用零样本设置,即模型在没有特定任务的上下文示例的情况下接受任务指令。
Toolformer 在这些任务上的表现显著优于 GPT-J 模型,并且在允许 API 调用时性能得到进一步提升,并且在语言建模能力上没有明显损失。
5. 参考
[1] https://arxiv.org/pdf/2302.04761
欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;
欢迎关注知乎/CSDN:SmallerFL
也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤