LangChain之关于RetrievalQA input_variables 的定义与使用

news2024/10/7 0:26:50

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩,也就是RAG(Retrieval Augmented Generation)。
这部分的内容其实在LangChain的官网已经给出了流程图。在这里插入图片描述
我这里就直接偷懒了,准备对Webui的项目进行复刻练习,那么接下来就是照着葫芦画瓢就行。
那么我卡在了Retrieve这一步。先放有疑惑地方的代码:

if web_content:
            prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
                                如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
                                已知网络检索内容:{web_content}""" + """
                                已知内容:
                                {context}
                                问题:
                                {question}"""
        else:
            prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
                如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。

                已知内容:
                {context}

                问题:
                {question}"""
        prompt = PromptTemplate(template=prompt_template,
                                input_variables=["context", "question"])
        ......

        knowledge_chain = RetrievalQA.from_llm(
            llm=self.llm,
            retriever=vector_store.as_retriever(
                search_kwargs={"k": self.top_k}),
            prompt=prompt)
        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
            input_variables=["page_content"], template="{page_content}")

        knowledge_chain.return_source_documents = True

        result = knowledge_chain({"query": query})
        return result

我对prompt_templateknowledge_chain.combine_documents_chain.document_prompt result = knowledge_chain({"query": query})这三个地方的input_key不明白为啥一定要这样设置。虽然我也看了LangChain的API文档。但是我并未得到详细的答案,那么只能一行行看源码是到底怎么设置的了。

注意:由于LangChain是一层层封装的,那么result = knowledge_chain({"query": query})可以认为是最外层,那么我们先看最外层。

result = knowledge_chain({“query”: query})

其实这部分是直接与用户的输入问题做对接的,我们只需要定位到RetrievalQA这个类就可以了,下面是RetrievalQA这个类的实现:

class RetrievalQA(BaseRetrievalQA):
    """Chain for question-answering against an index.
    Example:
        .. code-block:: python
            from langchain.llms import OpenAI
            from langchain.chains import RetrievalQA
            from langchain.vectorstores import FAISS
            from langchain.schema.vectorstore import VectorStoreRetriever
            retriever = VectorStoreRetriever(vectorstore=FAISS(...))
            retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)

    """

    retriever: BaseRetriever = Field(exclude=True)

    def _get_docs(
        self,
        question: str,
        *,
        run_manager: CallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return self.retriever.get_relevant_documents(
            question, callbacks=run_manager.get_child()
        )

    async def _aget_docs(
        self,
        question: str,
        *,
        run_manager: AsyncCallbackManagerForChainRun,
    ) -> List[Document]:
        """Get docs."""
        return await self.retriever.aget_relevant_documents(
            question, callbacks=run_manager.get_child()
        )

    @property
    def _chain_type(self) -> str:
        """Return the chain type."""
        return "retrieval_qa"

可以看到其继承了BaseRetrievalQA这个父类,同时对_get_docs这个抽象方法进行了实现。

这里要扩展的说一下,_get_docs这个方法就是利用向量相似性,在vector Base中选择与embedding之后的query最近似的Document结果。然后作为RetrievalQA的上下文。具体只需要看BaseRetrievalQA这个方法的_call和就可以了。
接下来我们只需要看BaseRetrievalQA这个类的属性就可以了。

class BaseRetrievalQA(Chain):
    """Base class for question-answering chains."""

    combine_documents_chain: BaseCombineDocumentsChain
    """Chain to use to combine the documents."""
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    return_source_documents: bool = False
    """Return the source documents or not."""
    ……
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs[self.input_key]
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._get_docs).parameters
        )
        if accepts_run_manager:
            docs = self._get_docs(question, run_manager=_run_manager)
        else:
            docs = self._get_docs(question)  # type: ignore[call-arg]
        answer = self.combine_documents_chain.run(
            input_documents=docs, question=question, callbacks=_run_manager.get_child()
        )

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}

可以看到其有input_key这个属性,默认值是"query"。到这里我们就可以看到result = knowledge_chain({"query": query})是调用的BaseRetrievalQA_call,这里的question = inputs[self.input_key]就是其体现。

knowledge_chain.combine_documents_chain.document_prompt

这个地方一开始我很奇怪,为什么会重新定义呢?
我们可以先定位到,combine_documents_chain这个参数的位置,其是StuffDocumentsChain的方法。

@classmethod
def from_llm(
    cls,
    llm: BaseLanguageModel,
    prompt: Optional[PromptTemplate] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> BaseRetrievalQA:
    """Initialize from LLM."""
    _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
    llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)
    document_prompt = PromptTemplate(
        input_variables=["page_content"], template="Context:\n{page_content}"
    )
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name="context",
        document_prompt=document_prompt,
        callbacks=callbacks,
    )

    return cls(
        combine_documents_chain=combine_documents_chain,
        callbacks=callbacks,
        **kwargs,
    )

可以看到原始的document_prompt中PromptTemplate的template是“Context:\n{page_content}”。因为这个项目是针对中文的,所以需要将英文的Context去掉。

扩展

  1. 这里PromptTemplate(input_variables=[“page_content”], template=“Context:\n{page_content}”)的input_variablestemplate为什么要这样定义呢?其实是根据Document这个数据对象来定义使用的,我们可以看到其数据格式为:Document(page_content=‘……’, metadata={‘source’: ‘……’, ‘row’: ……})
    那么input_variables的输入就是Document的page_content。
  2. StuffDocumentsChain中有一个参数是document_variable_name。那么这个类是这样定义的This chain takes a list of documents and first combines them into a single string. It does this by formatting each document into a string with the document_prompt and then joining them together with document_separator. It then adds that new string to the inputs with the variable name set by document_variable_name. Those inputs are then passed to the llm_chain. 这个document_variable_name简单来说就是在document_prompt中的占位符,用于在Chain中的使用。
    因此我们上文prompt_template变量中的“已知内容: {context}”,用的就是context这个变量。因此在prompt_template中换成其他的占位符都不能正常使用这个Chain。

prompt_template

在上面的拓展中其实已经对prompt_template做了部分的讲解,那么这个字符串还剩下“问题:{question}”这个地方没有说通
还是回归源码:

return cls(
        combine_documents_chain=combine_documents_chain,
        callbacks=callbacks,
        **kwargs,
    )

我们可以在from_llm函数中看到其返回值是到了_call,那么剩下的我们来看这个函数:


......
uestion = inputs[self.input_key]
accepts_run_manager = (
    "run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
    docs = self._get_docs(question, run_manager=_run_manager)
else:
    docs = self._get_docs(question)  # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
    input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
......

这里是在run这个函数中传入了一个字典值,这个字典值有三个参数。

注意:

  1. 这三个参数就是kwargs,也就是_validate_inputs的参数input;
  2. 此时已经是在Chain这个基本类了)
def run(
        self,
        *args: Any,
        callbacks: Callbacks = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
       """Convenience method for executing chain.

        The main difference between this method and `Chain.__call__` is that this
        method expects inputs to be passed directly in as positional arguments or
        keyword arguments, whereas `Chain.__call__` expects a single input dictionary
        with all the inputs"""

接下来调用__call__:

def __call__(
        self,
        inputs: Union[Dict[str, Any], Any],
        return_only_outputs: bool = False,
        callbacks: Callbacks = None,
        *,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        run_name: Optional[str] = None,
        include_run_info: bool = False,
    ) -> Dict[str, Any]:
        """Execute the chain.

        Args:
            inputs: Dictionary of inputs, or single input if chain expects
                only one param. Should contain all inputs specified in
                `Chain.input_keys` except for inputs that will be set by the chain's
                memory.
            return_only_outputs: Whether to return only outputs in the
                response. If True, only new keys generated by this chain will be
                returned. If False, both input keys and new keys generated by this
                chain will be returned. Defaults to False.
            callbacks: Callbacks to use for this chain run. These will be called in
                addition to callbacks passed to the chain during construction, but only
                these runtime callbacks will propagate to calls to other objects.
            tags: List of string tags to pass to all callbacks. These will be passed in
                addition to tags passed to the chain during construction, but only
                these runtime tags will propagate to calls to other objects.
            metadata: Optional metadata associated with the chain. Defaults to None
            include_run_info: Whether to include run info in the response. Defaults
                to False.

        Returns:
            A dict of named outputs. Should contain all outputs specified in
                `Chain.output_keys`.
        """
        inputs = self.prep_inputs(inputs)
        ......

这里的prep_inputs会调用_validate_inputs函数

def _validate_inputs(self,inputs: Dict[str, Any]) -> None:
    """Check that all inputs are present."""
    missing_keys = set(self.input_keys).difference(inputs)
    if missing_keys:
        raise ValueError(f"Missing some input keys: {missing_keys}")

这里的input_keys通过调试,看到的就是有多个输入,分别是"input_documents"和"question"
这里的"input_documents"是来自于BaseCombineDocumentsChain

class BaseCombineDocumentsChain(Chain, ABC):
    """Base interface for chains combining documents.

    Subclasses of this chain deal with combining documents in a variety of
    ways. This base class exists to add some uniformity in the interface these types
    of chains should expose. Namely, they expect an input key related to the documents
    to use (default `input_documents`), and then also expose a method to calculate
    the length of a prompt from documents (useful for outside callers to use to
    determine whether it's safe to pass a list of documents into this chain or whether
    that will longer than the context length).
    """

    input_key: str = "input_documents"  #: :meta private:
    output_key: str = "output_text"  #: :meta private:

那为什么有两个呢,“question”来自于哪里?
StuffDocumentsChain继承BaseCombineDocumentsChain,其input_key是这样定义的:

  @property
  def input_keys(self) -> List[str]:
      extra_keys = [
          k for k in self.llm_chain.input_keys if k != self.document_variable_name
      ]
      return super().input_keys + extra_keys

原来是重写了input_keys函数,其是对llm_chain的input_keys进行遍历。

那么llm_chain的input_keys是用其prompt的input_variables。(这里的input_variables是PromptTemplate中的[“context”, “question”])

	@property
	def input_keys(self) -> List[str]:
	   """Will be whatever keys the prompt expects.
	   :meta private:
	   """
	   return self.prompt.input_variables

至此,我们StuffDocumentsChain的input_keys有两个变量了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1183270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Cloud - 手写 Gateway 源码,实现自定义局部 FilterFactory

目录 一、FilterFactory 分析 1.1、前置知识 1.2、分析源码 1.2.1、整体分析 1.2.2、源码分析 1.3、手写源码 1.3.1、基础框架 1.3.2、实现自定义局部过滤器 1.3.3、加参数的自定义局部过滤器器 一、FilterFactory 分析 1.1、前置知识 前面的学习我们知道&#xff0c…

云服务器搭建flink集群

文章目录 1.集群配置2.修改集群配置3. 访问Web UI4. 提交作业方式5.Yarn部署模式配置5.1 会话模式部署(Session Mode)5.2 单作业模式(Per-job Mode)5.3 应用模式部署(推荐)5.3.1 上传HDFS提交(推荐) 5.4 历…

SpringCloudAlibaba——Sentinel

Sentinel也就是我们之前的Hystrix,而且比Hystrix功能更加的强大。Sentinel是分布式系统的流量防卫兵,以流量为切入点,从流量控制、流量路由、熔断降级等多个维度保护服务的稳定性。 Sentinel采用的是懒加载,这个接口被访问一次&a…

爬取Elastic Stack采集的Nginx内容

以下是一个简单的Go语言爬虫程序,用于爬取Elastic Stack采集的Nginx内容。请注意,这只是一个基本的示例,实际使用时可能需要根据具体情况进行修改和扩展。 package mainimport ("fmt""net/http""io/ioutil" )…

高效接口重试机制的实现

实现一个高效的接口重试机制对于保证系统的稳定性和可靠性至关重要。在面对网络不稳定、服务端故障或者高负载的情况下,接口重试机制能够确保请求的成功执行,同时也需要保证在重试过程中不会造成额外的负担或不必要的延迟。本文将为您介绍高效接口重试机…

工业相机基本知识理解:靶面尺寸、像元尺寸、分辨率

1、靶面尺寸:由Sensor对角线长度表示,单位英寸,这里的1英寸16mm 2、像元尺寸:单个感光元件的大小,一般都是正方形,边长单位um 3、分辨率: Sensor长边像元数 Sensor短边像元数,俗称像…

220v插座led指示灯维修

由于220v是交流电,有反向电压的情况,而led反向通电的时候电阻无穷大,所以分压也无穷大,220v一导通就击穿,即使加了很大的电阻也没用,串联电阻只能作用于二极管正向的时候。 目前有两种方案: 方…

EM@解三角形@正弦定理@余弦定理

文章目录 abstract解三角形基本原理不唯一性 正弦定理直角三角形中的情形推广锐角三角形钝角情形 小结:正弦定理 余弦定理直角三角形中的情形非直角情形小结:余弦定理公式的角余弦形式 abstract 解直角三角形问题正弦定理和余弦定理的推导 对于非直角情形,都是直角情形的推广同…

Springboot项目的多数据源配置

spring boot项目配置多个数据源很常见! 话不多说,上代码。 首先先在system账号下创建了一个用户test1,并授予权限 create user test1 identified by 123456; grant connect,resource to test1; 接下来登录test1用户,创建一个表student …

使用表单登录方法模拟登录通信人家园,要求发送登录请求后打印出来的用户名下的用户组类别

目标网站:https://www.txrjy.com/forum.php 一、进入网页,右键“检查” 二、输入用户名和密码,点击“登录”,点击“Network”,上划加载项找到蓝色框中的内容 三、点击第一个加载项,找到URL 四、相关代码: …

数据结构-单链表-力扣题

移除链表元素 题目链接:力扣(LeetCode) 思路:和前面学的单链表的中间删除数据一样,使要被删除节点的前一个节点指向下要被删除节点的下一个节点,然后把要被删除的节点free掉。 具体实现过程:先…

15 Linux 按键

一、Linux 按键驱动原理 其实案件驱动和 LED 驱动很相似,只不过区别在于,一个是读取GPIO高低电平,一个是从GPIO输出高低电平。 在驱动程序中使用一个整形变量来表示按键值,应用程序通过 read 函数来读取按键值,判断按键…

【Qt之绘制兔纸】

效果 代码 class drawRabbit: public QWidget { public:drawRabbit(QWidget *parent nullptr) : QWidget(parent) {}private:void paintEvent(QPaintEvent *event) {QPainter painter(this);painter.setRenderHint(QPainter::Antialiasing, true);// 绘制兔子的耳朵painter.s…

【代码随想录】算法训练营 第十五天 第六章 二叉树 Part 2

102. 二叉树的层序遍历 层序遍历,就是一层一层地遍历二叉树,最常见的就是从上到下,从左到右来遍历,遍历的方法依然有两种,第一种是借助队列,第二种则是递归,都算是很简单、很容易理解的方法&am…

新登录接口独立版变现宝升级版知识付费小程序-多领域素材资源知识变现营销系统

源码简介: 资源入口 点击进入 源码亲测无bug,含前后端源码,非线传,修复最新登录接口 梦想贩卖机升级版,变现宝吸取了资源变现类产品的很多优点,摒弃了那些无关紧要的东西,使本产品在运营和变现…

VMware部署CentOS7

一、创建虚拟机 1、点击新建虚拟机 2、选择自定义 下一步 3、点击下一步 4、选择稍后安装操作系统 5、选择linux 下一步 6、选择要安装的centos 版本 这里选择centos7 7、自定义虚拟机名称 设置虚拟机运行空间 8、配置处理器,使用默认 1个处理器 1核 9、修改虚拟机…

用友U8 Cloud 反序列化RCE漏洞复现

0x01 产品简介 用友U8 Cloud是用友推出的新一代云ERP,主要聚焦成长型、创新型企业,提供企业级云ERP整体解决方案。 0x02 漏洞概述 用友U8 Cloud存在多处(FileManageServlet和LoginVideoServlet)反序列化漏洞,系统未将…

Vue组件的存放目录问题

注意: .vue文件 本质无区别 1.组件分类 .vue文件分为2类,都是 .vue文件(本质无区别) 页面组件 (配置路由规则时使用的组件) 复用组件(多个组件中都使用到的组件) 2.存放目录 分…

bootstrap3简单玩法

Bootstrap v3 Bootstrap v3 是一个流行的前端框架,它提供了一系列的模板、组件和工具,可以帮助开发者快速地构建响应式的网站和应用程序。 以下是 Bootstrap v3 的一些常见应用: 响应式布局:Bootstrap v3 提供了一个易于使用的网…

Failed to connect to github.com port 443:connection timed out

解决办法: 步骤1: 在这里插入图片描述 步骤2: -步骤3 :在git终端中执行如下命令: git config --global http.proxy http:ip:port git config --global https.proxy http:ip:port git config --global http.proxy htt…