LlamaIndex 中提供了一个 RouterOutputAgentWorkflow 功能,可以集成多个 QueryTool,根据用户的输入判断使用那个 QueryEngine,在做查询的时候,可以从不同的数据源进行查询,例如确定的数据从数据库查询,如果是语义查询可以从向量数据库进行查询。本文将实现两个搜索引擎,根据不同 Query 使用不同 QueryEngine。
安装 MySQL 依赖
pip install mysql-connector-python
搜索引擎
定义搜索引擎,初始两个数据源
- 使用 MySQL 作为数据库的数据源
- 使用 VectorIndex 作为语义搜索数据源
from pathlib import Path
from llama_index.core.tools import QueryEngineTool
from llama_index.core import VectorStoreIndex
import llm
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import Settings
from llama_index.core import SQLDatabase
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select
Settings.llm = llm.get_ollama("mistral-nemo")
Settings.embed_model = llm.get_ollama_embbeding()
engine = create_engine(
'mysql+mysqlconnector://root:123456@localhost:13306/db_llama',
echo=True
)
def init_db():
# 初始化数据库
metadata_obj = MetaData()
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer, ),
Column("state", String(16), nullable=False),
)
metadata_obj.create_all(engine)
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert
rows = [
{"city_name": "New York City", "population": 8336000, "state": "New York"},
{"city_name": "Los Angeles", "population": 3822000, "state": "California"},
{"city_name": "Chicago", "population": 2665000, "state": "Illinois"},
{"city_name": "Houston", "population": 2303000, "state": "Texas"},
{"city_name": "Miami", "population": 449514, "state": "Florida"},
{"city_name": "Seattle", "population": 749256, "state": "Washington"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
from llama_index.core.query_engine import NLSQLTableQueryEngine
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=["city_stats"]
)
def get_doc_index()-> VectorStoreIndex:
'''
解析 words
'''
# 创建 OllamaEmbedding 实例,用于指定嵌入模型和服务的基本 URL
ollama_embedding = llm.get_ollama_embbeding()
# 读取 "./data" 目录中的数据并加载为文档对象
documents = SimpleDirectoryReader(input_files=[Path(__file__).parent / "data" / "LA.pdf"]).load_data()
# 从文档中创建 VectorStoreIndex,并使用 OllamaEmbedding 作为嵌入模型
vector_index = VectorStoreIndex.from_documents(documents, embed_model=ollama_embedding,
transformations=[SentenceSplitter(chunk_size=1000, chunk_overlap=20)],)
vector_index.set_index_id("vector_index") # 设置索引 ID
vector_index.storage_context.persist("./storage") # 将索引持久化到 "./storage"
return vector_index
llama_index_query_engine = get_doc_index().as_query_engine()
sql_tool = QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
description=(
"Useful for translating a natural language query into a SQL query over"
" a table containing: city_stats, containing the population/state of"
" each city located in the USA."
),
name="sql_tool"
)
llama_cloud_tool = QueryEngineTool.from_defaults(
query_engine=llama_index_query_engine,
description=(
f"Useful for answering semantic questions about certain cities in the US."
),
name="llama_cloud_tool"
)
创建工作流
下图中显示了工作流的节点,绿色背景节点是工作流的动作,例如大模型返回 ToolEvent,ToolEvent 节点执行并返回结果。
工作流定义代码:
from typing import Dict, List, Any, Optional
from llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (
Workflow,
Event,
StartEvent,
StopEvent,
step,
Context
)
from llama_index.core.base.response.schema import Response
from llama_index.core.tools import FunctionTool
from llama_index.utils.workflow import draw_all_possible_flows
from llm import get_ollama
from docs import enable_trace
enable_trace()
class InputEvent(Event):
"""Input event."""
class GatherToolsEvent(Event):
"""Gather Tools Event"""
tool_calls: Any
class ToolCallEvent(Event):
"""Tool Call event"""
tool_call: ToolSelection
class ToolCallEventResult(Event):
"""Tool call event result."""
msg: ChatMessage
class RouterOutputAgentWorkflow(Workflow):
"""Custom router output agent workflow."""
def __init__(self,
tools: List[BaseTool],
timeout: Optional[float] = 10.0,
disable_validation: bool = False,
verbose: bool = False,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
):
"""Constructor."""
super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)
self.tools: List[BaseTool] = tools
self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}
self.llm: LLM = llm
self.chat_history: List[ChatMessage] = chat_history or []
def reset(self) -> None:
"""Resets Chat History"""
self.chat_history = []
@step()
async def prepare_chat(self, ev: StartEvent) -> InputEvent:
message = ev.get("message")
if message is None:
raise ValueError("'message' field is required.")
# add msg to chat history
chat_history = self.chat_history
chat_history.append(ChatMessage(role="user", content=message))
return InputEvent()
@step()
async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:
"""Appends msg to chat history, then gets tool calls."""
# Put msg into LLM with tools included
chat_res = await self.llm.achat_with_tools(
self.tools,
chat_history=self.chat_history,
verbose=self._verbose,
allow_parallel_tool_calls=True
)
tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)
ai_message = chat_res.message
self.chat_history.append(ai_message)
if self._verbose:
print(f"Chat message: {ai_message.content}")
# no tool calls, return chat message.
if not tool_calls:
return StopEvent(result=ai_message.content)
return GatherToolsEvent(tool_calls=tool_calls)
@step(pass_context=True)
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
"""Dispatches calls."""
tool_calls = ev.tool_calls
await ctx.set("num_tool_calls", len(tool_calls))
# trigger tool call events
for tool_call in tool_calls:
ctx.send_event(ToolCallEvent(tool_call=tool_call))
return None
@step()
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
"""Calls tool."""
tool_call = ev.tool_call
# get tool ID and function call
id_ = tool_call.tool_id
if self._verbose:
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")
# call function and put result into a chat message
tool = self.tools_dict[tool_call.tool_name]
output = await tool.acall(**tool_call.tool_kwargs)
msg = ChatMessage(
name=tool_call.tool_name,
content=str(output),
role="tool",
additional_kwargs={
"tool_call_id": id_,
"name": tool_call.tool_name
}
)
return ToolCallEventResult(msg=msg)
@step(pass_context=True)
async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:
"""Gathers tool calls."""
# wait for all tool call events to finish.
tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))
if not tool_events:
return None
for tool_event in tool_events:
# append tool call chat messages to history
self.chat_history.append(tool_event.msg)
# # after all tool calls finish, pass input event back, restart agent loop
return InputEvent()
from muti_agent import sql_tool, llama_cloud_tool
wf = RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120, llm=get_ollama("mistral-nemo"))
async def main():
result = await wf.run(message="Which city has the highest population?")
print("RSULT ===============", result)
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
import gradio as gr
async def random_response(message, history):
wf.reset()
result = await wf.run(message=message)
print("RSULT ===============", result)
return result
demo = gr.ChatInterface(random_response, clear_btn=None, title="Qwen2")
demo.launch()
输入问题是 “What are five popular travel spots in Los Angeles?”,自动路由到 VectorIndex 进行查询。
输入问题为 “which city has the most population” 时,调用数据库进行搜索。
总结
LlamaIndex 中搜索引擎自动路由,根据用户的输入型自动选择所需的搜索引擎,这里有一个需要注意的点,模型需要支持 Function Call。如果 Ollama 本地模型进行推理,不是所有的本地模型都支持Function Call,Llama3.1 和 mistral-nemo 是支持 Function Call 的,可以使用。