LlamaIndex 使用 RouterOutputAgentWorkflow

news2024/9/21 12:23:12

LlamaIndex 中提供了一个 RouterOutputAgentWorkflow 功能,可以集成多个 QueryTool,根据用户的输入判断使用那个 QueryEngine,在做查询的时候,可以从不同的数据源进行查询,例如确定的数据从数据库查询,如果是语义查询可以从向量数据库进行查询。本文将实现两个搜索引擎,根据不同 Query 使用不同 QueryEngine。

安装 MySQL 依赖

pip install mysql-connector-python  

搜索引擎

定义搜索引擎,初始两个数据源

  • 使用 MySQL 作为数据库的数据源
  • 使用 VectorIndex 作为语义搜索数据源
from pathlib import Path
from llama_index.core.tools import QueryEngineTool
from llama_index.core import VectorStoreIndex
import llm
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import Settings
from llama_index.core import SQLDatabase

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select
Settings.llm = llm.get_ollama("mistral-nemo")
Settings.embed_model = llm.get_ollama_embbeding()

engine = create_engine(
    'mysql+mysqlconnector://root:123456@localhost:13306/db_llama', 
    echo=True  
)

def init_db():
    # 初始化数据库
    metadata_obj = MetaData()

    table_name = "city_stats"
    city_stats_table = Table(
        table_name,
        metadata_obj,
        Column("city_name", String(16), primary_key=True),
        Column("population", Integer, ),
        Column("state", String(16), nullable=False),
    )

    metadata_obj.create_all(engine)

    sql_database = SQLDatabase(engine, include_tables=["city_stats"])
    from sqlalchemy import insert
    rows = [
        {"city_name": "New York City", "population": 8336000, "state": "New York"},
        {"city_name": "Los Angeles", "population": 3822000, "state": "California"},
        {"city_name": "Chicago", "population": 2665000, "state": "Illinois"},
        {"city_name": "Houston", "population": 2303000, "state": "Texas"},
        {"city_name": "Miami", "population": 449514, "state": "Florida"},
        {"city_name": "Seattle", "population": 749256, "state": "Washington"},
    ]
    for row in rows:
        stmt = insert(city_stats_table).values(**row)
        with engine.begin() as connection:
            cursor = connection.execute(stmt)

from llama_index.core.query_engine import NLSQLTableQueryEngine

sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_stats"]
)

    
def get_doc_index()-> VectorStoreIndex:
    '''
    解析 words
    '''
    # 创建 OllamaEmbedding 实例,用于指定嵌入模型和服务的基本 URL
    ollama_embedding = llm.get_ollama_embbeding()

    # 读取 "./data" 目录中的数据并加载为文档对象
    documents = SimpleDirectoryReader(input_files=[Path(__file__).parent / "data" / "LA.pdf"]).load_data()


    # 从文档中创建 VectorStoreIndex,并使用 OllamaEmbedding 作为嵌入模型
    vector_index = VectorStoreIndex.from_documents(documents, embed_model=ollama_embedding, 
                                                   transformations=[SentenceSplitter(chunk_size=1000, chunk_overlap=20)],)
    vector_index.set_index_id("vector_index")  # 设置索引 ID
    vector_index.storage_context.persist("./storage")  # 将索引持久化到 "./storage"
    return vector_index

llama_index_query_engine = get_doc_index().as_query_engine()


sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    description=(
        "Useful for translating a natural language query into a SQL query over"
        " a table containing: city_stats, containing the population/state of"
        " each city located in the USA."
    ),
    name="sql_tool"
)

llama_cloud_tool = QueryEngineTool.from_defaults(
    query_engine=llama_index_query_engine,
    description=(
        f"Useful for answering semantic questions about certain cities in the US."
    ),
    name="llama_cloud_tool"
)


创建工作流

下图中显示了工作流的节点,绿色背景节点是工作流的动作,例如大模型返回 ToolEvent,ToolEvent 节点执行并返回结果。
在这里插入图片描述
工作流定义代码:

from typing import Dict, List, Any, Optional

from llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (
    Workflow,
    Event,
    StartEvent,
    StopEvent,
    step,
    Context
)
from llama_index.core.base.response.schema import Response
from llama_index.core.tools import FunctionTool
from llama_index.utils.workflow import draw_all_possible_flows
from llm import get_ollama

from docs import enable_trace

enable_trace()

class InputEvent(Event):
    """Input event."""

class GatherToolsEvent(Event):
    """Gather Tools Event"""

    tool_calls: Any

class ToolCallEvent(Event):
    """Tool Call event"""

    tool_call: ToolSelection

class ToolCallEventResult(Event):
    """Tool call event result."""

    msg: ChatMessage

class RouterOutputAgentWorkflow(Workflow):
    """Custom router output agent workflow."""

    def __init__(self,
        tools: List[BaseTool],
        timeout: Optional[float] = 10.0,
        disable_validation: bool = False,
        verbose: bool = False,
        llm: Optional[LLM] = None,
        chat_history: Optional[List[ChatMessage]] = None,
    ):
        """Constructor."""

        super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)

        self.tools: List[BaseTool] = tools
        self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}
        self.llm: LLM = llm
        self.chat_history: List[ChatMessage] = chat_history or []
    

    def reset(self) -> None:
        """Resets Chat History"""

        self.chat_history = []

    @step()
    async def prepare_chat(self, ev: StartEvent) -> InputEvent:
        message = ev.get("message")
        if message is None:
            raise ValueError("'message' field is required.")
        
        # add msg to chat history
        chat_history = self.chat_history
        chat_history.append(ChatMessage(role="user", content=message))
        return InputEvent()

    @step()
    async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:
        """Appends msg to chat history, then gets tool calls."""

        # Put msg into LLM with tools included
        chat_res = await self.llm.achat_with_tools(
            self.tools,
            chat_history=self.chat_history,
            verbose=self._verbose,
            allow_parallel_tool_calls=True
        )
        tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)
        
        ai_message = chat_res.message
        self.chat_history.append(ai_message)
        if self._verbose:
            print(f"Chat message: {ai_message.content}")

        # no tool calls, return chat message.
        if not tool_calls:
            return StopEvent(result=ai_message.content)

        return GatherToolsEvent(tool_calls=tool_calls)

    @step(pass_context=True)
    async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
        """Dispatches calls."""

        tool_calls = ev.tool_calls
        await ctx.set("num_tool_calls", len(tool_calls))

        # trigger tool call events
        for tool_call in tool_calls:
            ctx.send_event(ToolCallEvent(tool_call=tool_call))
        
        return None
    
    @step()
    async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
        """Calls tool."""

        tool_call = ev.tool_call

        # get tool ID and function call
        id_ = tool_call.tool_id

        if self._verbose:
            print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")

        # call function and put result into a chat message
        tool = self.tools_dict[tool_call.tool_name]
        output = await tool.acall(**tool_call.tool_kwargs)
        msg = ChatMessage(
            name=tool_call.tool_name,
            content=str(output),
            role="tool",
            additional_kwargs={
                "tool_call_id": id_,
                "name": tool_call.tool_name
            }
        )

        return ToolCallEventResult(msg=msg)
    
    @step(pass_context=True)
    async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:
        """Gathers tool calls."""
        # wait for all tool call events to finish.
        tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))
        if not tool_events:
            return None
        
        for tool_event in tool_events:
            # append tool call chat messages to history
            self.chat_history.append(tool_event.msg)
        
        # # after all tool calls finish, pass input event back, restart agent loop
        return InputEvent()

from muti_agent import sql_tool, llama_cloud_tool
wf = RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120, llm=get_ollama("mistral-nemo"))

async def main():
    result = await wf.run(message="Which city has the highest population?")
    print("RSULT ===============", result)


# if __name__ == "__main__":
#     import asyncio

#     asyncio.run(main())


import gradio as gr

async def random_response(message, history):
    wf.reset()
    result = await wf.run(message=message)
    print("RSULT ===============", result)
    return result

demo = gr.ChatInterface(random_response, clear_btn=None, title="Qwen2")


demo.launch()

输入问题是 “What are five popular travel spots in Los Angeles?”,自动路由到 VectorIndex 进行查询。
在这里插入图片描述
输入问题为 “which city has the most population” 时,调用数据库进行搜索。
在这里插入图片描述

总结

LlamaIndex 中搜索引擎自动路由,根据用户的输入型自动选择所需的搜索引擎,这里有一个需要注意的点,模型需要支持 Function Call。如果 Ollama 本地模型进行推理,不是所有的本地模型都支持Function Call,Llama3.1 和 mistral-nemo 是支持 Function Call 的,可以使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108369.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年装电脑,就认准这几个型号,能避坑!

前言 小伙伴是否都会觉得,自己又不懂电脑,跑电脑城去装机又怕被坑。这时候只能找熟人给装机,至少……熟人应该不会坑自己吧?! 这不,小白电脑技术的抖音评论区上就有这么一条评论: 这哥们找一熟…

最新HTML5中的视频和音频讲解

第6章 HTML5中的视频和音频 H5新增video,audio,播放视频和音频,统称为多媒体元素。 6.1 多媒体元素基本属性 video用于电影文件和其他视频流的播放。 audio用于音乐文件和其他音频流的播放。 video的属性 src:文件路径,本地或者网络上。…

Android经典实战之SurfaceView原理和实践

本文首发于公众号“AntDream”,欢迎微信搜索“AntDream”或扫描文章底部二维码关注,和我一起每天进步一点点 SurfaceView 是一个非常强大但也相对复杂的 UI 组件,特别适用于对性能要求较高的绘制任务,如视频播放、游戏等。 1. Su…

Java 方法的定义

目录 1.Java的方法类似于其他语言的函数,是一段用来完成特定功能的代码片段。 2.方法包含一个方法头和方法体,下面是一个方法的所有部分: (1)修饰符:可选。告诉编译器如何调用该方法,定义了该…

Java笔试面试题AI答之JDBC(2)

文章目录 7. 列出Java应该遵循的JDBC最佳实践?8. Statement与PreparedStatement的区别,什么是SQL注入,如何防止SQL注入Statement与PreparedStatement的区别什么是SQL注入如何防止SQL注入 9. JDBC如何连接数据库?1. 加载JDBC驱动程序2. 建立数…

Python复杂网络社区检测:并行谱聚类算法设计与多种算法应用实战研究

原文链接:https://tecdat.cn/?p37574 分析师:Leiyun Liao 在当今的网络科学领域,复杂网络中的社区检测成为了一个至关重要的研究课题。随着信息技术的飞速发展,各种大规模网络不断涌现,如社交网络、生物网络等。准确地…

chapter12-异常(Exception)——(作业)——day15

目录 457-异常课后作业 458-异常课后作业2 457-异常课后作业 package chapter12.exception.homework;/*** author LuHan* version 1.0*/ public class Homework01 {public static void main(String[] args) {try {if(args.length!2){throw new ArrayIndexOutOfBoundsException…

立创商城9.9免邮活动开始啦!

从9月2日起,立创商城推出免邮活动,每月在领券中心>精选专区领取免邮券,即可享受满9.9元使用免邮券服务。 未注册的用户,可扫描下方二维码注册哦~

2024高教社杯数学建模国赛ABCDE题选题建议+初步分析

提示&#xff1a;DS C君认为的难度&#xff1a;C<B<A&#xff0c;开放度&#xff1a;A<C<B 。 D、E题推荐选E题&#xff0c;后续会直接更新E论文和思路&#xff0c;不在这里进行选题分析&#xff0c;以下为A、B、C题选题建议及初步分析 A题&#xff1a;“板凳龙”…

AI技术的新篇章:GPT Next、Gemini 2、GPT-6 和千代理人探索虚拟世界

在AI技术飞速发展的今天&#xff0c;许多令人兴奋的突破正逐渐进入公众视野。最近的新闻显示&#xff0c;诸如OpenAI的GPT Next、Google的Gemini 2.0、GPT-6以及模拟虚拟世界中的1000个AI代理人等前沿项目&#xff0c;标志着人工智能领域即将进入一个全新阶段。本文将深入探讨这…

多线程的简单了解——多客户端链接

在前面的学习中发现我们的聊天室功能只能有一个客户端接入服务端中&#xff0c;第二个客户端想要接入服务端中必须要等待第一个客户端输入结束才能接入。 这很明显不符合实际应用的开发&#xff0c;现在我们就来学习Java中一个重要的知识&#xff0c;多线程来解决这个问题。我们…

内存管理篇-22 高端内存和低端内存的分界线

这节课讲的主是为了区分低端内存和高端内存的是如何区分的&#xff1f;内核空间的划分是可以配置的。为了查看现象&#xff0c;通过qemu设置物理内存为不同情况。 结论&#xff1a;线性映射区的大小&#xff0c;和page_offset(内核起始地址0x80000000还是0xc0000000)和物理内存…

oracle startup失败,ORA-01078: failure in processing system parameters

SQL> startup ORA-01078: failure in processing system parameters LRM-00109: could not open parameter file /data/oracle/product/11.2.0/db_1/dbs/initorc1.ora 出错的原因可能是&#xff1a;文件名字不正确&#xff0c;文件权限不对&#xff0c;文件不存在&#x…

铁打的程序员轻易“不哭”-我的大模型创业近2年来的感悟

楔子 2022年11月&#xff0c;GPT-3发布那一刻&#xff0c;我被AI的强大能力所震撼&#xff0c;意识到“超级个体”时代的来临。自那时起&#xff0c;我开始全心投入创业&#xff0c;经历了许多苦乐交织的时光。 2023年6月&#xff0c;我尝试将AI应用于智能营销导购&#xff0…

143.布隆过滤器原理以及Go使用示例

文章目录 1. 是什么&#xff1f;2. 干什么&#xff1f;3. 为什么&#xff1f;4. 有什么问题&#xff1f;5. Go使用布隆过滤器单机版(Golang)分布式版(Java) 1. 是什么&#xff1f; 它是一个二进制bit数组&#xff0c;初始为 0 采用位存储数据结构&#xff0c;节省存储空间 1…

学学vue-1

vue 0 安装 装node.js&#xff0c;以及cnpm&#xff08;npm超时或者被屏蔽&#xff0c;安装cnpm国内镜像&#xff09; 查看安装版本&#xff08;是否安装成功&#xff09; node -v 安装成功之后也会安装npm npm -v cnpm镜像 npm install -g cnpm --registryhttp://registry.np…

spring如何解决bean的循环依赖

通过三级缓存解决循环依赖问题。 其中一级缓存用于存储完整的bean&#xff1b;二级缓存用于存储已经完成aop动态代理的bean&#xff0c;防止重复创建动态代理&#xff1b;三级缓存存储未实现aop动态代理和为实现依赖注入的bean。getBean()时先从一级缓存取&#xff0c;没有取二…

s3c2440---PWM使用之蜂鸣器驱动移植

一、蜂鸣器驱动介绍 1.1.什么是蜂鸣器 蜂鸣器是一种简单的声响发生器&#xff0c;常用于电子产品中作为警示或提醒作用。其基本原理是通过交替改变直流电的电压方向来产生声音&#xff0c;一般使用交替电流产生声音会比较稳定。 1.2.蜂鸣器的类别 1.有源蜂鸣器 1&…

2024 数学建模高教社杯 国赛(A题)| “板凳龙”舞龙队 | 建模秘籍文章代码思路大全

铛铛&#xff01;小秘籍来咯&#xff01; 小秘籍团队独辟蹊径&#xff0c;运用等距螺线&#xff0c;多目标规划等强大工具&#xff0c;构建了这一题的详细解答哦&#xff01; 为大家量身打造创新解决方案。小秘籍团队&#xff0c;始终引领着建模问题求解的风潮。 抓紧小秘籍&am…

嵌入式S3C2440:控制LED灯

发光二极管接口&#xff08;左端&#xff09;应为低电平 以LED1为例 LED1的接口为GPB5 void led_init(void) {//配置GPB5功能为输出GPBCON & ~(0x3 << 10);GPBCON | (0x1 << 10); //使GPB5输出高电平(关灯)GPBDAT | (1 << 5); }void led_on(void) {GPB…