Langchain-Chatchat项目:5.1-ChatGLM3-6B工具调用

news2024/11/25 0:52:24

  在语义、数学、推理、代码、知识等不同角度的数据集上测评显示,ChatGLM3-6B-Base 具有在10B以下的基础模型中最强的性能。ChatGLM3-6B采用了全新设计的Prompt格式,除正常的多轮对话外。同时原生支持工具调用(Function Call)、代码执行(Code Interpreter)和Agent任务等复杂场景。本文主要通过天气查询例子介绍了在tool_registry.py中注册新的工具来增强模型能力。

  可以直接调用LangChain自带的工具(比如,ArXiv),也可以调用自定义的工具。LangChain自带的部分工具[2],如下所示:


一.自定义天气查询工具
1.Weather类
  可以参考Tool/Weather.py以及Tool/Weather.yaml文件,继承BaseTool类,重载_run()方法,如下所示:

class Weather(BaseTool):  # 天气查询工具
    name = "weather"
    description = "Use for searching weather at a specific location"

    def __init__(self):
        super().__init__()

    def get_weather(self, location):
        api_key = os.environ["SENIVERSE_KEY"]
        url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            weather = {
                "temperature": data["results"][0]["now"]["temperature"],
                "description": data["results"][0]["now"]["text"],
            }
            return weather
        else:
            raise Exception(
                f"Failed to retrieve weather: {response.status_code}")

    def _run(self, para: str) -> str:
        return self.get_weather(para)

2.weather.yaml文件
  weather.yaml文件内容,如下所示:

name: weather
description: Search the current weather of a city
parameters:
  type: object
  properties:
    city:
      type: string
      description: City name
  required:
    - city

二.自定义天气查询工具调用
  自定义天气查询工具调用,在main.py中导入Weather工具。如下所示:

run_tool([Weather()], llm, [
    "今天北京天气怎么样?",
    "What's the weather like in Shanghai today",
])

  其中,run_tool()函数实现如下所示:

def run_tool(tools, llm, prompt_chain: List[str]):
    loaded_tolls = []  # 用于存储加载的工具
    for tool in tools:  # 逐个加载工具
        if isinstance(tool, str):
            loaded_tolls.append(load_tools([tool], llm=llm)[0])  # load_tools返回的是一个列表
        else:
            loaded_tolls.append(tool)  # 如果是自定义的工具,直接添加到列表中
    agent = initialize_agent(  # 初始化agent
        loaded_tolls, llm,
        agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,  # agent类型:使用结构化聊天的agent
        verbose=True,
        handle_parsing_errors=True
    )
    for prompt in prompt_chain:  # 逐个输入prompt
        agent.run(prompt)

1.load_tools()函数
  根据工具名字加载相应的工具,如下所示:

def load_tools(
    tool_names: List[str],
    llm: Optional[BaseLanguageModel] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> List[BaseTool]:

2.initialize_agent()函数
  根据工具列表和LLM加载一个agent executor,如下所示:

def initialize_agent(
    tools: Sequence[BaseTool],
    llm: BaseLanguageModel,
    agent: Optional[AgentType] = None,
    callback_manager: Optional[BaseCallbackManager] = None,
    agent_path: Optional[str] = None,
    agent_kwargs: Optional[dict] = None,
    *,
    tags: Optional[Sequence[str]] = None,
    **kwargs: Any,
) -> AgentExecutor:

  其中,agent默认为AgentType.ZERO_SHOT_REACT_DESCRIPTION。本文中使用为AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,一种为聊天模型优化的zero-shot react agent,该agent能够调用具有多个输入的工具。
3.run()函数
  执行链的便捷方法,这个方法与Chain.__call__之间的主要区别在于,这个方法期望将输入直接作为位置参数或关键字参数传递,而Chain.__call__期望一个包含所有输入的单一输入字典。如下所示:

def run(
    self,
    *args: Any,
    callbacks: Callbacks = None,
    tags: Optional[List[str]] = None,
    metadata: Optional[Dict[str, Any]] = None,
    **kwargs: Any,
) -> Any:


4.结果分析
  结果输出,如下所示:

> Entering new AgentExecutor chain...
======
======

Action: 
``
{"action": "weather", "action_input": "北京"}
``
Observation: {'temperature': '20', 'description': '晴'}
Thought:======
======

Action: 
``
{"action": "Final Answer", "action_input": "根据查询结果,北京今天的天气是晴,气温为20℃。"}
``

> Finished chain.


> Entering new AgentExecutor chain...
======
======

Action: 
``
{"action": "weather", "action_input": "Shanghai"}
``
Observation: {'temperature': '20', 'description': '晴'}
Thought:======
======

Action: 
``
{"action": "Final Answer", "action_input": "根据最新的天气数据,今天上海的天气情况是晴朗的,气温为20℃。"}
``

> Finished chain.

  刚开始的时候没有找到识别实体city的地方,后面调试ChatGLM3/langchain_demo/ChatGLM3.py->_call()时发现了一个巨长的prompt,这不就是zero-prompt(AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION)吗?顺便吐槽下LangChain的代码真的不好调试。



三.注册工具增强LLM能力
1.注册工具
  可以通过在tool_registry.py中注册新的工具来增强模型的能力。只需要使用@register_tool装饰函数即可完成注册。对于工具声明,函数名称即为工具的名称,函数docstring即为工具的说明;对于工具的参数,使用Annotated[typ: type, description: str, required: bool]标注参数的类型、描述和是否必须。将get_weather()函数进行注册,如下所示:

@register_tool
def get_weather(  # 工具函数
        city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
    """
    Get the current weather for `city_name`
    """

    if not isinstance(city_name, str):  # 参数类型检查
        raise TypeError("City name must be a string")

    key_selection = {  # 选择的键
        "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
    }
    import requests
    try:
        resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
        resp.raise_for_status()
        resp = resp.json()
        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
    except:
        import traceback
        ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()

    return str(ret)

  具体工具注册实现方式@register_tool装饰函数,如下所示:

def register_tool(func: callable):  # 注册工具
    tool_name = func.__name__  # 工具名
    tool_description = inspect.getdoc(func).strip()  # 工具描述
    python_params = inspect.signature(func).parameters  # 工具参数
    tool_params = []  # 工具参数描述
    for name, param in python_params.items():  # 遍历参数
        annotation = param.annotation  # 参数注解
        if annotation is inspect.Parameter.empty:
            raise TypeError(f"Parameter `{name}` missing type annotation")  # 参数缺少注解
        if get_origin(annotation) != Annotated:  # 参数注解不是Annotated
            raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")  # 参数注解必须是Annotated

        typ, (description, required) = annotation.__origin__, annotation.__metadata__  # 参数类型, 参数描述, 是否必须
        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__  # 参数类型名
        if not isinstance(description, str):  # 参数描述必须是字符串
            raise TypeError(f"Description for `{name}` must be a string")
        if not isinstance(required, bool):  # 是否必须必须是布尔值
            raise TypeError(f"Required for `{name}` must be a bool")

        tool_params.append({  # 添加参数描述
            "name": name,
            "description": description,
            "type": typ,
            "required": required
        })
    tool_def = {  # 工具定义
        "name": tool_name,
        "description": tool_description,
        "params": tool_params
    }

    print("[registered tool] " + pformat(tool_def))  # 打印工具定义
    _TOOL_HOOKS[tool_name] = func  # 注册工具
    _TOOL_DESCRIPTIONS[tool_name] = tool_def  # 添加工具定义

    return func

2.调用工具
  参考文件ChatGLM3/tool_using/openai_api_demo.py,如下所示:

def main():
    messages = [  # 对话信息
        system_info,
        {
            "role": "user",
            "content": "帮我查询北京的天气怎么样",
        }
    ]
    response = openai.ChatCompletion.create(  # 调用OpenAI API
        model="chatglm3",
        messages=messages,
        temperature=0,
        return_function_call=True
    )
    function_call = json.loads(response.choices[0].message.content)  # 获取函数调用信息
    logger.info(f"Function Call Response: {function_call}")  # 打印函数调用信息

    tool_response = dispatch_tool(function_call["name"], function_call["parameters"])  # 调用函数
    logger.info(f"Tool Call Response: {tool_response}")  # 打印函数调用结果

    messages = response.choices[0].history  # 获取历史对话信息
    messages.append(
        {
            "role": "observation",
            "content": tool_response,  # 调用函数返回结果
        }
    )

    response = openai.ChatCompletion.create(  # 调用OpenAI API
        model="chatglm3",
        messages=messages,
        temperature=0,
    )
    logger.info(response.choices[0].message.content)  # 打印对话结果

参考文献:
[1]https://github.com/THUDM/ChatGLM3/tree/main
[2]https://python.langchain.com/docs/integrations/tools

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1178241.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何用思维导图拆书

思维导图是一种非常有用的工具,可以被广泛应用于不同领域的人群。 阅读是我们获取知识和提升自己的重要途径,而做好读书笔记有助于加深对书中内容的理解和记忆。其中,使用思维导图作为读书笔记的工具,不仅能够帮助我们更好地整理…

lombok依赖介绍(帮助我们消除冗长代码,如get,set方法)

前言 lombok 是一个 Java 工具库,通过注解的方式,简化 Java 开发。要想使用 lombok 中的注解,我们需要先引入依赖,推荐看idea必装插件EditStarters(快速引入依赖),lombok是⼀款在编译期⽣成代码…

k8s提交spark应用消费kafka数据写入elasticsearch7

一、k8s集群环境 k8s 1.23版本,三个节点,容器运行时使用docker。 spark版本时3.3.3 k8s部署单节点的zookeeper、kafka、elasticsearch7 二、spark源码 https://download.csdn.net/download/TT1024167802/88509398 命令行提交方式 /opt/module/spark…

技术分享 | 被测项目需求你理解到位了么?

需求分析是开始测试工作的第一步,产品会先产出一个需求文档,然后会组织需求宣讲,在需求宣讲中分析需求中是否存在问题,然后宣讲结束后,通过需求文档分析测试点并且预估排期。所以对于需求的理解非常重要。 需求文档 …

壹[1],QT自定义控件创建(QtDesigner)

1,环境 Qt 5.14.2 VS2022 原因:厌烦了控件提升的繁琐设置,且看不到界面预览显示。 2,QT制作自定义控件 2.1,New/其他项目/Qt4 设计师自定义控件 2.2,设置项目名称 2.3,设置 2.4,设…

YOLOv7改进策略:一种新颖的可扩张残差(DWR)注意力模块,增强多尺度感受野特征,助力小目标检测

💡💡💡本文全网首发独家改进:一种新颖的可扩张残差(DWR)注意力模块,加强不同尺度特征提取能力,创新十足,独家首发适合科研 推荐指数:五星 DWR | 亲测在多个数据集能够实现涨点,多尺度特性在小目标检测表现也十分出色。 💡💡💡Yolov5/Yolov7魔术师,独…

【Java 进阶篇】Cookie 使用详解

欢迎阅读本篇博客,我们将深入研究 Java 中的 Cookie,从入门到精通,包括 Cookie 的基本概念、原理、使用方法以及一些高级技巧。无论你是新手还是有经验的开发者,希望这篇博客对你有所帮助。 第一部分:Cookie 是什么&a…

网络原理---封装和分用

文章目录 什么是封装和分用?封装应用层传输层网络层数据链路层物理层 分用物理层数据链路层网络层传输层应用层 什么是封装和分用? 我们前面讲过协议会分层,每一层都有各自的功能。而在数据传输的过程中,得按照顺序把每一层协议都…

数仓分层能减少重复计算,为啥能减少?如何减少?这篇文章包懂!

很多时候,看一些数据领域的文章,说到为什么做数据仓库、数据仓库要分层,我们经常会看到一些结论:因为有ABCD…等等理由,比如降低开发成本、减少重复计算等等好处 然后,多数人就记住了ABCD。但是&#xff0…

VScode连接Xshell 并解决【过程试图写入的管道不存在】报错

一.下载vscode 国内镜像: https://vscode.cdn.azure.cn/stable/6c3e3dba23e8fadc360aed75ce363ba185c49794/VSCodeUserSetup-x64-1.81.1.exe二.打开vscode在扩展搜索SSH并安装 三.添加主机 按F1选择添加新的ssh主机 按格式输入后在左边会出现电视的图标 之后输入…

十一、K8S之持久化存储

持久化存储 一、概念 在K8S中,数据持久化可以让容器在重新调度、重启或者迁移时保留其数据,并且确保数据的可靠性和持久性。 持久化存储通常用于程序的状态数据、数据库文件、日志文件等需要在容器生命周期之外的数据,它可以通过各种存储解…

项目管理之如何监控项目健康状态

项目管理是一个复杂且关键的过程,涉及到多个关键因素,包括项目名称、项目管理委员会成员、项目经理、项目生命周期的各个阶段以及资源泳道等。如何有效地监控项目的健康状态是确保项目成功的重要环节。本文将详细介绍项目管理全景图及其在风险识别中的应…

【差旅游记】公乌素遇到的那些司机师傅

哈喽,大家好,我是雷工! 出差人出差在外,城际间靠各种公共交通工具,但到了目的地的城镇,最后一公里往往少不了打车,或出租车,或摩的三轮车。 不同于公共交通,像飞机火车高…

【C++类和对象中:解锁面向对象编程的奇妙世界】

【本节目标】 1. 类的6个默认成员函数 2. 构造函数 3. 析构函数 4. 拷贝构造函数 5. 赋值运算符重载 6. const成员函数 7. 取地址及const取地址操作符重载 1.类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗&#xf…

ConvNets 与 Vision Transformers:数学深入探讨

一、说明 我目睹了关于 Vision Transformer 的争论,讨论它们如何与 CNN 一样好或更好。我想知道我们是否也同样争论菠萝比西瓜好!或者马比海豚更好?其中许多讨论往往缺乏具体性,有时可能会歪曲上下文。 作为背景,在快速…

计算机基础知识44

overflow溢出属性 visible:默认值,内容不会被修剪,会呈现在元素框之外。hidden:内容会被修剪,并且其余内容是不可见的。scroll:内容会被修剪,但是浏览器会显示滚动条以便查看其余的内容。auto: …

GEE错误——XXX is not a function,如何解决这个问题?

错误: 这里的时错误原始的代码链接: https://code.earthengine.google.com/4bf0975a41e14d0c40e01925c6f3cf2a 这里主要的问题时这个单一影像不存在: ImageCollection (Error) ImageCollection.load: ImageCollection asset LANDSAT/LC08/C01/T1_SR/LC08_221077_201704…

self.register_buffer方法使用解析(pytorch)

self.register_buffer就是pytorch框架用来保存不更新参数的方法。 列子如下: self.register_buffer("position_emb", torch.randn((5, 3)))第一个参数position_emb传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数…

微信Wxid转换微信号

微信号在申请的时候,系统随机分配了一个微信原始ID,该ID号以wxid_开头,后面是随机的字符串 分配的原始ID是目前是不可以直接用来加好友的,需要转换成微信号才能加好友, 经过逆向分析通过PC端找到了该接口并且可以成功用…

Langchain知识点(下)

背景: 这部分给主要介绍Langchain的agent部分,前面已经章节已经介绍了思维和思路作为一种数据资产是这一次LLM数据化的核心。也介绍了各种的chain,那么既然有了chain可以把专家思路和专家思维固化并且可被方便的共享和利用;那为什…