在工具调用过程中,如何将中间状态返回到的stream,可以使用from langchain_core.callbacks import dispatch_custom_event方法实现。示例如下
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.tools import tool
@tool
def query_and_summary(question: str):
"""
根据用户提问,查询相关文档,并对命中的文档片段进行总结
"""
# 模拟结构化数据
docs = {"《听懂掌声》": "听懂掌声", "《成功学》": "自信&成功"}
dispatch_custom_event(name="docs", data=docs)
# 模拟流式数据
for data in summary():
dispatch_custom_event(name="summary", data=data)
def summary():
summary = "“听懂掌声”这个词语是来自于一个培训视频,在视频里面做培训的人叫做枭哥"
for data in summary:
yield data
- graph定义
这里我们定义了一个模拟查询文档并返回总结的工具,graph调用时需通过astream_events方法调用。当执行dispatch_custom_event函数时,stream会收到一个on_custom_event事件,事件的name是我们在函数中传入的name
from typing_extensions import Annotated, TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import tool
from langgraph.graph.message import AnyMessage, add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, InjectedState, tools_condition
from langgraph.checkpoint.memory import MemorySaver
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
class Assistant:
def __init__(self, runnable: Runnable):
self.runnable = runnable
def __call__(self, state: State, config: RunnableConfig):
result = self.runnable.invoke(state)
return {"messages": result}
assistant_system_prompt = "你是成功学助手,收到用户提问请调用query_and_summary工具"
assistant_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
assistant_system_prompt,
),
("placeholder", "{messages}"),
],
)
assistant_tool = [query_and_summary]
assistant_runnable = assistant_prompt | llm.bind_tools(assistant_tool)
assistant = Assistant(assistant_runnable)
# node定义
builder = StateGraph(State)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(assistant_tool))
# edge定义
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", END)
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
- 调用
import uuid
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}
async for chunk in graph.astream_events({"messages": ("user", "什么是成功学")}, config=config, version="v2"):
if chunk["event"] == "on_custom_event" and chunk["name"] == "docs":
print(chunk["data"])
if chunk["event"] == "on_custom_event" and chunk["name"] == "summary":
print("recv:" + chunk["data"])
{'《听懂掌声》': '听懂掌声', '《成功学》': '自信&成功'}
recv:“
recv:听
recv:懂
recv:掌
recv:声
recv:”
recv:这
recv:个
recv:词
recv:语
recv:是
...