Gorilla
Gorilla出自2023年5月的论文《Gorilla: Large Language Model Connected with Massive APIs》,针对LLM无法准确地生成API调用时的参数,构建API使用数据集后基于Llama微调了一个模型。
数据集构建
API数据集APIBench的构建过程如下:
-
先从HuggingFace、TensorFlow Hub、Torch Hub过滤得到1645个模型API调用,将这些模型的API接口转换成为有如下字段的json对象:{domain, framework, functionality, api_name, api_call, api_arguments, environment_requirements, example_code, performance, and description.}。
-
按照Self-Instruct思路用GPT-4来构建合成指令数据。为三个模型hub各创建6个Instruction-API样例,总共人工构建18个样例。先让GPT-4在不使用任何API名称或相关提示的情况下生成API使用指令。接下来从6个Instruction-API样例中采样3个例子、与API文档作为上下文让LLM针对指令生成api(这一步可用开源模型完成)。 对于数据集中的1645个API,每一个都生成10对Instruction-API数据。
Gorilla训练
将构建的数据集转换成user-agent chat-style conversation格式,即每一条数据集都是一轮用户与agent的对话。接着对LLaMA-7B按标准的指令微调得到Gorilla模型。在训练时,会分为带检索和不带检索器两种,带检索器的训练会在prompt中额外加入检索到的API信息:“Use this API documentation for reference: <retrieved_API_doc_JSON>”。作者认为带检索器的训练有三点好处:a) 使LLM适应在推理时API文档的改变;b)提高in-context learning的性能;c)减少幻觉错误。但是作者的试验表明带检索器的训练并不总是有助于提高模型表现,有时候甚至会有损于模型性能,主要原因在于检索器给出的内容并不总是有用的,所以作者建议如果有好的检索器可用,则使用带检索器的微调,否则不带检索器的微调是更好的选择。
API校验
论文中的API数据集是基于模型hub构建的,所以对于一个请求有多个正确答案,比如对于一个图片分类问题,有很多模型都可以回答此问题。Gorilla作者用构建的数据集来评估API的功能相等性,用AST树匹配策略来判断LLM调用数据集的哪个API。如果API不属于数据集中任何API的子树,说明LLM出现了幻觉。
Gorilla openfunction and openfunction2
在2023年6月OpenAI宣布支持function call功能之后,Gorilla团队陆续开源了其openfunction 和openfunction v2模型,在其两篇博客(1, 2)中介绍了其模型是如何训练的。
gorilla-openfunctions-v0
模型是基于7B的LLaMA-v2-chat模型训练的,gorilla-openfunctions-v1
是基于7B的 LLaMA-v2预训练模型训练的。 gorilla-openfunctions-v2
是基于6.91B的Deepseek-Coder-7B-Instruct-v1.5
训练的。
在gorilla-openfunction-v2的huggingface仓库列出的三个模型的不同区别。
下图示意了function call的各种功能(来自blog)
gorilla-openfunctions-v0
和gorilla-openfunctions-v1
的数据集和在本地部署时的prompt是一样的。其数据集一共14189个instruction-API对,API文档来源有三个:Python包、RapidAPI、云厂商的命令行工具。
def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
if len(functions) == 0:
return f"USER: <<question>> {user_query}\nASSISTANT: "
functions_string = json.dumps(functions)
return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
gorilla-openfunctions-v2
收集了65283条 question-function-answer对,数据来源有:Python包 (19,353), Java 仓库 (16,586), Javascript 仓库 (4,245), public-API (6,009), 云厂商的命令行工具 (19,090) 。并进行了如下图片示意的数据增强。
gorilla-openfunctions-v2
在本地部署时prompt组织方式如下。
def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
if len(functions) == 0:
return f"{system}\n### Instruction: <<question>> {user_query}\n### Response: "
functions_string = json.dumps(functions)
return f"{system}\n### Instruction: <<function>>{functions_string}\n<<question>>{user_query}\n### Response: "
Berkeley Function-Calling Leaderboard
gorilla团队提供了一个榜单来评估LLM的call function能力,更多介绍可参见其blog。
参考资料
- Gorilla: arxiv, github, website
- Gorilla openfunction模型: gorilla-openfunction-v0 huggingface, gorilla-openfunction-v1 huggingface, gorilla-openfunction-v2 huggingface
- Berkeley Function-Calling Leaderboard, Berkeley Function-Calling Leaderboard 介绍blog