提升AI模型的准确性与可靠性
©作者|Ninja Geek
来源|神州问学
介绍
检索增强生成(RAG)彻底改变了使用大语言模型和利用外部知识库的方式。它允许模型从文档存储的相关索引数据中获取信息用以增强其生成的内容,使其更加准确和信息丰富。然而,RAG并非完全无缺。它有时会检索出不相关或不正确的信息,这就导致了不准确或带有明显误导性的生成内容。这就是 CRAG (修正型检索增强生成)发挥作用的地方。
CRAG 是一种强大的技术,它通过结合反馈机制来改进检索过程,从而增强了 RAG 的鲁棒性。它确保生成中使用的信息既相关又准确,从而产生更可靠、更值得信赖的生成内容。
了解 CRAG
CRAG 的工作原理是对检索到的文档进行额外的内容审查。它采用一种称作”评估器“的模型来评估每个文档与给定查询的相关性。如果评估器认为文档不相关,CRAG 会启动回退机制(通常利用网络搜索)来搜索其他来源的信息。
CRAG 的优势
1. 提高准确性:通过确保检索信息的相关性,CRAG 显著降低了产生不正确或误导性内容的可能性。
2. 增强可靠性:CRAG 通过验证生成中使用的信息来提供安全网,使模型更加值得信赖和可靠。
3. 扩展知识库:回退机制允许模型利用更广泛的信息源,从而有可能客服原始文档存储的限制。
了解检索评估器架构
我从论文《Corrective Retrieval Augmented》中截取了如下的示意图,在该示意图中描绘了如何构建一个检索评估器来评估检索到的文档与输入的问题的相关性。对置信度进行估计,以此为基础可以触发不同的知识检索操作,如”正确“、”错误“或”模棱两可“。
来源:《Corrective Retrieval Augmented Generation》
该方法被命名为“校正增强生成”(Corrective Retrieval-Augmented Generation,CRAG),旨在自我校正检索结果并增强文档在生成过程中的利用率。
引入了一个轻量级的检索评估器,用于评估给定查询下检索到的文档的整体质量。
该评估器是“检索增强生成”(Retrieval-Augmented Generation,RAG)的关键组成部分,通过审查和评估检索到的文档的相关性和可靠性,帮助生成有价值的信息。
检索评估器量化了信心水平,从而基于评估结果触发不同的知识检索操作,如“正确”、“错误”或“模棱两可”。
对于“错误”和“模棱两可”的情况,可以通过大规模网络搜索,以解决静态和有限语料库的局限性,旨在提供更广泛和多样化的信息集。
最后,在整个检索和利用过程中实施了“分解再重构”算法。
该算法有助于消除检索文档中对RAG无益的冗余内容,优化信息提取过程,并在减少非必要元素的同时最大限度地包含关键信息。
检索评估器逻辑算法
检索评估器涉及到一种算法,该算法确保检索到的信息的细化,优化关键信息的提取并最大限度地减少非必要信息被检索到,从而提高检索到的数据的利用率。
通过下面的算法伪码我们来了解检索评估器的核心思想:
来源:《Corrective Retrieval Augmented Generation》
使用 LlamaIndex 实现CRAG
LlamaIndex 提供了一个便利的框架来实现 CRAG。这是一个简化的示例,目的是让您了解这个过程:
步骤一:安装 OpenAI 的 Python 库并填写 API 秘钥
%pip install llama-index llama-index-indices-managed-llama-cloud llama-index-tools-tavily-research
import nest_asyncio
import os
nest_asyncio.apply()
os.environ["OPENAI_API_KEY"] = "<YOUR_OPENAI_API_KEY>"
步骤二:设计工作流程
from typing import List, Optional, Any
from llama_index.core.schema import NodeWithScore
from llama_index.core.workflow import (
Event,
)
from llama_index.core.workflow import (
StartEvent,
StopEvent,
step,
Workflow,
Context,
)
from llama_index.core import SummaryIndex
from llama_index.core.schema import Document
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from llama_index.llms.openai import OpenAI
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
from llama_index.tools.tavily_research import TavilyToolSpec
class RetrieveEvent(Event):
"""检索事件(获取被检索的节点)"""
retrieved_nodes: List[NodeWithScore]
class RelevanceEvalEvent(Event):
"""相关性评估事件(获取相关性评估结果)"""
relevant_results: List[str]
class TextExtractEvent(Event):
"""文本提取事件 提取相关文本并进行拼接"""
relevant_text: str
class QueryEvent(Event):
"""查询事件 对给定的相关文本和搜索文本进行查询。"""
relevant_text: str
search_text: str
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
template="""作为评分员,您的任务是评估根据用户问题检索到的文档的相关性。
检索到的文档:
-------------------
{context_str}
用户问题:
--------------
{query_str}
评估标准:
- 考虑文档是否包含与用户问题相关的关键词或主题。
- 评估不应过于严格;主要目的是识别并过滤掉明显不相关的检索结果。
决策:
- 赋予二元评分以指示文档的相关性。
- 如果文档与问题相关,请使用“是”;如果不相关,请使用“否”。
请在下方提供您的二元评分(“yes”或“no”)以指示文档与用户问题的相关性。"""
)
DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
template="""您的任务是优化查询,以确保其在检索相关搜索结果时具有较高的有效性。\n
分析给定的输入以把握核心语义意图或含义。\n
原始查询:
\n ------- \n
{query_str}
\n ------- \n
您的目标是重新表述或改进此查询,以提高其搜索性能。确保修订后的查询简洁明了,并直接符合预期的搜索目标。\n
只需回复优化后的查询内容:"""
)
class CorrectiveRAGWorkflow(Workflow):
"""CRAG 工作流"""
def __init__(
self,
index,
tavily_ai_apikey: str,
llm: Optional[LLM] = None,
**kwargs: Any
) -> None:
"""初始化参数"""
super().__init__(**kwargs)
self.index = index
self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)
self.llm = llm or OpenAI(model="gpt-4o")
@step
async def retrieve(self, ctx: Context, ev: StartEvent) -> Optional[RetrieveEvent]:
"""检索与查询相关的节点"""
query_str = ev.get("query_str")
retriever_kwargs = ev.get("retriever_kwargs", {})
if query_str is None:
return None
retriever: BaseRetriever = self.index.as_retriever(**retriever_kwargs)
result = retriever.retrieve(query_str)
await ctx.set("retrieved_nodes", result)
await ctx.set("query_str", query_str)
return RetrieveEvent(retrieved_nodes=result)
@step
async def eval_relevance(
self, ctx: Context, ev: RetrieveEvent
) -> RelevanceEvalEvent:
"""评估检索到的文档与查询的相关性"""
retrieved_nodes = ev.retrieved_nodes
query_str = await ctx.get("query_str")
relevancy_results = []
for node in retrieved_nodes:
prompt = DEFAULT_RELEVANCY_PROMPT_TEMPLATE.format(context_str=node.text, query_str=query_str)
relevancy = self.llm.complete(prompt)
relevancy_results.append(relevancy.text.lower().strip())
await ctx.set("relevancy_results", relevancy_results)
return RelevanceEvalEvent(relevant_results=relevancy_results)
@step
async def extract_relevant_texts(
self, ctx: Context, ev: RelevanceEvalEvent
) -> TextExtractEvent:
"""从检索到的文档中提取相关文本"""
retrieved_nodes = await ctx.get("retrieved_nodes")
relevancy_results = ev.relevant_results
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
result = "\n".join(relevant_texts)
return TextExtractEvent(relevant_text=result)
@step
async def transform_query_pipeline(
self, ctx: Context, ev: TextExtractEvent
) -> QueryEvent:
"""使用Tavily API搜索转换后的查询"""
relevant_text = ev.relevant_text
relevancy_results = await ctx.get("relevancy_results")
query_str = await ctx.get("query_str")
# 如果发现任何文档不相关,请转换查询字符串以获得更好的搜索结果。
if "no" in relevancy_results:
prompt = DEFAULT_TRANSFORM_QUERY_TEMPLATE.format(query_str=query_str)
result = self.llm.complete(prompt)
transformed_query_str = result.text
# 使用转换后的查询字符串进行搜索并收集结果。
search_results = self.tavily_tool.search(
transformed_query_str, max_results=5
)
search_text = "\n".join([result.text for result in search_results])
else:
search_text = ""
return QueryEvent(relevant_text=relevant_text, search_text=search_text)
@step
async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
"""获取包含相关文本的结果"""
relevant_text = ev.relevant_text
search_text = ev.search_text
query_str = await ctx.get("query_str")
documents = [Document(text=relevant_text + "\n" + search_text)]
index = SummaryIndex.from_documents(documents)
query_engine = index.as_query_engine()
result = query_engine.query(query_str)
return StopEvent(result=result)
步骤三:创建基于 LlamaCloud 的索引
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
index = LlamaCloudIndex(
name="<索引名称>",
project_name="<项目名称>",
api_key="<LlamaCloud_API_KEY>",
organization_id="<组织ID>",
)
# 设置工作流程集成
workflow = CorrectiveRAGWorkflow(index=index, tavily_ai_apikey="<TAVILY_API_KEY>", verbose=True, timeout=60)
# 可视化工作流
from llama_index.utils.workflow import draw_all_possible_flows
draw_all_possible_flows(CorrectiveRAGWorkflow, filename="crag_workflow.html")
步骤四:执行一个示例查询
from IPython.display import display, Markdown
result = await workflow.run(query_str="How was Llama2 pretrained?") # 这是在所提供的论文中提到的内容。
display(Markdown(str(result)))
# 输出内容
Running step retrieve
Step retrieve produced event RetrieveEvent
Running step eval_relevance
Step eval_relevance produced event RelevanceEvalEvent
Running step extract_relevant_texts
Step extract_relevant_texts produced event TextExtractEvent
Running step transform_query_pipeline
Step transform_query_pipeline produced event QueryEvent
Running step query_result
Step query_result produced event StopEvent
Llama 2 was pretrained using self-supervised learning on 2 trillion tokens of data from publicly available online sources.
result = await workflow.run(query_str="Where does the airline flight UA 1 fly?") # this info is not in the paper
display(Markdown(str(result)))
# 输出内容
Running step retrieve
Step retrieve produced event RetrieveEvent
Running step eval_relevance
Step eval_relevance produced event RelevanceEvalEvent
Running step extract_relevant_texts
Step extract_relevant_texts produced event TextExtractEvent
Running step transform_query_pipeline
Step transform_query_pipeline produced event QueryEvent
Running step query_result
Step query_result produced event StopEvent
The airline flight UA 1 flies from San Francisco, California (SFO) to Singapore (SIN).
结论
CRAG 是一种非常有价值的技术,可提高检索增强生成的可靠性和准确性。通过结合反馈机制和后备策略,CRAG 可确保语言模型根据相关且准确的信息来生成回答。LlamaIndex 提供了一个用户友好的平台来实施 CRAG,使得开发者能够构建更强大、更值得信赖的 AI 应用。
参考资源:
https://github.com/run-llama/ll