AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务
-
RAG(Retrieval-Augmented Generation,如下图所示)检索增强生成
,即大模型LLM在回答问题时,会先从大量的文档中检索出相关信息,然后基于这些检索出的信息进行回答或生成文本,从而可以提高回答的质量,而不是任由LLM来发挥。 -
随着大模型能力逐步强大、场景越来越丰富,从Text到sql或者从Chat到sql的方案也十分火热。
- 这个方案主要是利用大模型将自然语言转化为可以执行的Sql语句,进行数据分析,并根据结果实现报告生成或者可视化展示。
- 下面给出几个开源框架示例:
- Chat2db:https://github.com/chat2db/chat2db
- DB-GPT: https://github.com/eosphoros-ai/DB-GPT
- SQL Chat: https://github.com/sqlchat/sqlchat
- Vanna: https://github.com/vanna-ai/vanna
-
今天,我们了解下基于检索增强(RAG)的sql生成框架:Vanna。
-
GitHub:https://github.com/vanna-ai/vanna
-
文档:https://vanna.ai/docs/
-
1 Vanna概述
1.1 Vanna的工作原理
-
vanna是基于检索增强(RAG)的sql生成框架,具体的执行逻辑如下:
- 先用向量数据库将待查询数据库的建表语句、文档、常用SQL及其自然语言查询问题存储起来。
- 用户发起查询请求时,会先从向量数据库中检索出相关的建表语句、文档、SQL问答对放入到prompt里(
DDL和文档作为上下文、SQL问答对作为few-shot样例
) LLM根据prompt生成查询SQL并执行
,框架会进一步将查询结果使用plotly可视化出来或用LLM生成后续问题。如果用户反馈LLM生成的结果是正确的,可以将这一问答对存储到向量数据库,可以使得以后的生成结果更准确
。
-
Vanna 的工作过程可以概括为:
-
在用户的数据上
训练 RAG模型
,然后提出问题,这些问题将返回 SQL 查询,这些查询可以设置为在用户的数据库上自动运行。 -
这里的
"训练"是指:根据数据结构构建向量库
。- 用户可以使用 DDL 语句、文档或样例 SQL 查询对 Vanna 进行训练,让它掌握数据库的结构、业务术语和查询模式。
- Vanna 会将训练数据转化为向量嵌入,存储在向量数据库中,并建立元数据索引,以便于后续检索。
-
-
Vanna的优缺点:
- Vanna最大的优点就是:
用户可以选择在成功执行的查询上“自动训练”,或让界面提示用户对结果提供反馈,使未来的结果更加准确
。 - 缺点也很明显:
- 生成的 SQL 查询可能不完全准确,需要人工干预来修正。
- 复杂查询生成能力有限,这也是Text2SQL场景的挑战了,尤其是涉及到多表查询。
- Vanna最大的优点就是:
1.2 Vanna的快速上手
1.2.1 利用OPENAI提供的API_KEY以及ChromaDB向量数据库
"""
Vanna的源码安装:
github地址:https://github.com/vanna-ai/vanna
拉取vanna的源码,使用下面命令安装pyproject.toml中定义的依赖项:
pip install .
我这里使用国内的代理,使用代理网站提供的url和key,使用起来和原生的区别不大。
代理网站如下:
https://api.zetatechs.com/
"""
import os
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp
api_url = str(os.getenv('OPENAI_URL'))
api_key = str(os.getenv('OPENAI_API_KEY'))
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
"""
我这里由于使用的代理网站提供的url和api_key,需要在vanna.openai.openai_chat.py中修改:
修改前:
if "api_key" in config:
self.client = OpenAI(api_key=config["api_key"])
修改后:
if "api_key" in config:
self.client = OpenAI(api_key=config["api_key"], base_url=config['base_url'])
"""
vn = MyVanna(config={'api_key': api_key, 'model': 'gpt-3.5-turbo', 'base_url': api_url})
# 链接本地的Mysql数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)
# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(150) NOT NULL,
`cate_name` varchar(40) NOT NULL,
`brand_name` varchar(40) NOT NULL,
`price` decimal(10,3) NOT NULL DEFAULT '0.000',
`is_show` bit(1) NOT NULL DEFAULT b'1',
`is_saleoff` bit(1) NOT NULL DEFAULT b'0',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")
vn.train(
documentation="""
goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
"""
)
vn.train(question="华硕品牌的笔记本的平均价格是多少?"
, sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")
# 访问地址: http://localhost:8084
VannaFlaskApp(vn).run()
在TrainingData中就会出现代码中训练数据:
我们就可以利用自然语言进行查询了:
如果生成的SQL是正确的,我们就可以点击下面的按钮,将此条Question-SQL pair
添加到知识库中:
右上角的Open Debugger
中可以看到如下信息:
- 包括生成SQL的提示词,大模型的回复,以及提取到最终执行的SQL
1.2.2 利用fastchat部署本地模型
"""
# fastchat的安装
pip install "fschat[model_worker,webui]"
# 1、启动controller
python -m fastchat.serve.controller --host 0.0.0.0 --port 21001
# 2、启动worker(这里采用小模型,可以CPU运行)
python -m fastchat.serve.model_worker --model-names qwen2-1.5b --model-path D:\python\models\qwen\Qwen2-1.5B-Instruct
--host 0.0.0.0 --device cpu
# 3、启动openai_api_server
# 兼容OpenAI的RESTful API
python -m fastchat.serve.openai_api_server --controller-address http://127.0.0.1:21001 --host 0.0.0.0 --port 48000
# 4、测试OpenAI的RESTful API
```python
from openai import OpenAI
api_key = 'none'
api_url = 'http://127.0.0.1:48000/v1'
client = OpenAI(base_url=api_url, api_key=api_key)
completion = client.chat.completions.create(
model="qwen2-1.5b",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "你是谁?"}
]
)
print(completion)
```
"""
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp
api_url = 'http://127.0.0.1:48000/v1'
api_key = 'none'
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
vn = MyVanna(config={'api_key': api_key, 'model': 'qwen2-1.5b', 'base_url': api_url})
# 链接数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)
# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(150) NOT NULL,
`cate_name` varchar(40) NOT NULL,
`brand_name` varchar(40) NOT NULL,
`price` decimal(10,3) NOT NULL DEFAULT '0.000',
`is_show` bit(1) NOT NULL DEFAULT b'1',
`is_saleoff` bit(1) NOT NULL DEFAULT b'0',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")
vn.train(
documentation="""
goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
"""
)
vn.train(question="华硕品牌的笔记本的平均价格是多少?"
, sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")
# 访问地址: http://localhost:8084
VannaFlaskApp(vn).run()
2 Vanna源码分析
我们可以利用下面代码,查看源码:
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
api_url = 'http://127.0.0.1:48000/v1'
api_key = 'none'
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
vn = MyVanna(config={'api_key': api_key, 'model': 'qwen2-1.5b', 'base_url': api_url})
# 链接数据库
vn.connect_to_mysql(host='localhost', dbname='test', user='root', password='root', port=3306)
# 训练Vanna,构建知识库
vn.train(ddl="""
CREATE TABLE `goods` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(150) NOT NULL,
`cate_name` varchar(40) NOT NULL,
`brand_name` varchar(40) NOT NULL,
`price` decimal(10,3) NOT NULL DEFAULT '0.000',
`is_show` bit(1) NOT NULL DEFAULT b'1',
`is_saleoff` bit(1) NOT NULL DEFAULT b'0',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
""")
vn.train(
documentation="""
goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
"""
)
vn.train(question="华硕品牌的笔记本的平均价格是多少?"
, sql="SELECT AVG(price) AS avg_price FROM goods WHERE brand_name = '华硕' AND cate_name = '笔记本';")
# 用户提问
sql, df, fig = vn.ask("华硕品牌的笔记本的最低价格、最高价格分别是多少?")
print('=======================================')
print('the final sql = \n', sql)
print('the final df = \n', df)
Vanna的核心代码就是src.vanna.base.base.py
文件下的train函数和ask函数
2.1 train函数
- train函数就是根据documentation、sql以及ddl构建知识库
# src.vanna.base.base.py
def train(
self,
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
if question and not sql:
raise ValidationError("Please also provide a SQL query")
if documentation:
print("Adding documentation....")
return self.add_documentation(documentation)
if sql:
if question is None:
question = self.generate_question(sql)
print("Question generated with sql:", question, "\nAdding SQL...")
return self.add_question_sql(question=question, sql=sql)
if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
......
- 函数中add_documentation、add_ddl以及add_question_sql均为抽象函数,不同的向量数据库有不同的实现方式
- 比如,这里使用的chromadb的实现方式如下:
# src.vanna.chromadb.chromadb_vector.py
class ChromaDB_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
config = {}
......
# 创建三个集合,分别存储:文档、DDL、以及sql 三种经过embedding的知识库信息
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_function([data])
if len(embedding) == 1:
return embedding[0]
return embedding
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
)
id = deterministic_uuid(question_sql_json) + "-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=id,
)
return id
def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return id
def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation) + "-doc"
self.documentation_collection.add(
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id
......
2.2 ask函数
def ask(
self,
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
allow_llm_to_see_data: bool = False,
) -> Union[
Tuple[
Union[str, None],
Union[pd.DataFrame, None],
Union[plotly.graph_objs.Figure, None],
],
None,
]:
"""
**Example:**
```python
vn.ask("What are the top 10 customers by sales?")
```
Ask Vanna.AI a question and get the SQL query that answers it.
Args:
question (str): The question to ask.
提出的问题
print_results (bool): Whether to print the results of the SQL query.
是否打印SQL查询的结果, 默认为True
auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
是否自动使用问题和SQL查询对Vanna.AI进行训练, 默认为True
visualize (bool): Whether to generate plotly code and display the plotly figure.
是否生成plotly代码并显示plotly图表, 默认为True
Returns:
Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]:
The SQL query, the results of the SQL query, and the plotly figure.
包含SQL查询语句、SQL查询的结果(以pandas DataFrame形式)以及plotly图表对象的元组
"""
if question is None:
question = input("Enter a question: ")
try:
# 1、根据用户的question产生SQL
sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
except Exception as e:
print(e)
return None, None, None
......
try:
# 2、在相应的数据库中执行SQL语句,获取结果
# 这里支持不同的数据库的查询语句,如:sqlite、mysql、oracle、hive、clickhouse等
df = self.run_sql(sql)
if print_results:
try:
display = __import__(
"IPython.display", fromList=["display"]
).display
display(df)
except Exception as e:
print(df)
if len(df) > 0 and auto_train:
self.add_question_sql(question=question, sql=sql)
# 3、对查询的结果进行可视化
# Only generate plotly code if visualize is True
if visualize:
try:
plotly_code = self.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__(
"IPython.display", fromlist=["Image"]
).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
fig.show()
except Exception as e:
# Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None
else:
return sql, df, None
except Exception as e:
print("Couldn't run sql: ", e)
if print_results:
return None
else:
return sql, None, None
return sql, df, fig
- 用户发起查询请求时,会先从向量数据库中检索出相关的建表语句、文档、SQL问答对放入到prompt里(
DDL和文档作为上下文、SQL问答对作为few-shot样例
)
def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
"""
Example:
```python
vn.generate_sql("What are the top 10 customers by sales?")
```
Uses the LLM to generate a SQL query that answers a question. It runs the following methods:
该函数使用大语言模型(LLM)生成一个能够回答特定问题的SQL查询。它按顺序执行以下方法:
1、获取与输入问题相似的SQL查询
2、获取与问题相关的数据定义语言(DDL)语句
3、获取与问题相关的文档
4、生成用于提交给LLM的SQL查询prompt
5、将提示提交给LLM并获取生成的SQL查询
Args:
question (str): The question to generate a SQL query for.
allow_llm_to_see_data (bool):
Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
是否允许大型语言模型(LLM)查看数据,以便更好地理解数据结构并生成相应的SQL查询
Returns:
str: The SQL query that answers the question.
"""
if self.config is not None:
initial_prompt = self.config.get("initial_prompt", None)
else:
initial_prompt = None
# 1、获取与输入问题相似的SQL查询,默认最多返回10条数据
question_sql_list = self.get_similar_question_sql(question, **kwargs)
# 2、获取与问题相关的DDL语句,默认最多返回10条数据
ddl_list = self.get_related_ddl(question, **kwargs)
# 3、获取与问题相关的文档
doc_list = self.get_related_documentation(question, **kwargs)
# 4、生成用于提交给LLM的SQL查询prompt
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
self.log(title="SQL Prompt", message=prompt)
......
函数中get_similar_question_sql均为抽象函数,不同的向量数据库有不同的实现方式
比如,这里使用的chromadb的实现方式如下:
# src.vanna.chromadb.chromadb_vector.py
class ChromaDB_VectorStore(VannaBase):
......
def get_similar_question_sql(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.sql_collection.query(
query_texts=[question],
n_results=self.n_results_sql,# 默认为10
)
)
def get_related_ddl(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.ddl_collection.query(
query_texts=[question],
n_results=self.n_results_ddl,# 默认为10
)
)
def get_related_documentation(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.documentation_collection.query(
query_texts=[question],
n_results=self.n_results_documentation,# 默认为10
)
)
我们看下,最终的prompt:
[
[{'role': 'system', 'content': """
You are a SQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions.
===Tables
CREATE TABLE `goods` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(150) NOT NULL,
`cate_name` varchar(40) NOT NULL,
`brand_name` varchar(40) NOT NULL,
`price` decimal(10,3) NOT NULL DEFAULT '0.000',
`is_show` bit(1) NOT NULL DEFAULT b'1',
`is_saleoff` bit(1) NOT NULL DEFAULT b'0',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
===Additional Context
goods表中的字段cate_name为电脑类型,包括:笔记本、游戏本、超极本、平板电脑、台式机、服务器/工作站、笔记本配件
goods表中的字段brand_name为品牌名字,包括:华硕、联想、索尼、戴尔、苹果等
goods表中的字段name为电子产品具体型号,例如:ipad air 9.7英寸平板电脑
===Response Guidelines
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question.
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql
3. If the provided context is insufficient, please explain why it can't be generated.
4. Please use the most relevant table(s).
5. If the question has been asked and answered before, please repeat the answer exactly as it was given before.
6. Ensure that the output SQL is SQL-compliant and executable, and free of syntax errors.
"""
, {"role":"user","content":"华硕品牌的笔记本的平均价格是多少?"}
, {"role":"assistant","content":"SELECT AVG(price) AS avg_price\nFROM goods\nWHERE brand_name = '华硕' AND cate_name = '笔记本';"}
, {"role":"user","content":"华硕品牌的笔记本的最低价格、最高价格分别是多少?"}
]
LLM根据prompt生成查询SQL
def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
......
# 5、将提示提交给LLM并获取生成的SQL查询
# submit_prompt为抽象方法,这里实现方法在OpenAI_Chat中
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
if 'intermediate_sql' in llm_response:
......
# 6、提取最终的SQL
return self.extract_sql(llm_response)
- 然后会执行查询SQL(支持的数据库如下所示,可以参考源码,这里不再赘述),框架会进一步将查询结果使用plotly可视化出来或用LLM生成后续问题。