微调您的Embedding模型以最大限度地提高RAG管道中的相关性检索

news2024/11/16 18:11:09

英文原文地址:https://betterprogramming.pub/fine-tuning-your-embedding-model-to-maximize-relevance-retrieval-in-rag-pipeline-2ea3fa231149

微调您的Embedding模型以最大限度地提高RAG管道中的相关性检索

微调嵌入前后的 NVIDIA SEC 10-K 文件分析

2023 年 9 月 13 日

让我们继续之前的文章,使用 GPT-4 训练数据微调 GPT-3.5 RAG 管道。这一次,让我们深入微调 RAG(检索增强生成)管道的另一端——嵌入模型。

通过微调我们的嵌入模型,我们增强了系统检索最相关文档的能力,确保我们的 RAG 管道发挥最佳性能。

text-embedding-ada-002我们在 LlamaIndex 博客系列中的大多数 RAG 管道中一直使用 OpenAI 的嵌入模型。然而,OpenAI 不提供微调功能text-embedding-ada-002,因此让我们在本文中探索微调开源嵌入模型。

BAAI/bge-small-en

目前HuggingFace 的 MTEB(大规模文本嵌入基准)排行榜上当前排名第一的嵌入模型是bge-large-en;它是由北京人工智能研究院(BAAI)开发的。它是一个预训练的 Transformer 模型,可用于各种自然语言处理任务,例如文本分类、问答、文本生成等。该模型在海量文本和代码数据集上进行训练,并经过微调大规模文本嵌入基准(MTEB)。

在本文中,我们将使用 的bge-large-en兄弟姐妹之一,bge-small-en一个具有竞争性能的 384 维小型模型,非常适合在 Google Colab 中运行。

微调嵌入模型与微调 LLM

从上一篇关于微调的文章gpt-3.5-turbo中,我们对 LLM 微调所涉及的步骤有了深入的了解。与LLM微调相比,微调的实施bge-small-en有一些相似之处和不同之处。

相似点

  • 两种类型的微调都遵循相同的方法,即生成用于训练和评估的数据集,微调模型,最后评估基本模型和微调模型之间的性能。
  • 使用LLM自动生成训练和评估数据集。

不同点

  • 数据集内容在LLM微调和Embedding模型微调之间有所不同。用于LLM微调的数据集包含LLM生成的问题。在微调过程中,包括问题、答案、系统prompt等在内的一系列数据将以JSON行(jsonl)文件的形式传递给要进行微调的模型。

不同的是,用于Embedding模型微调的数据集包含以下三组:

  1. queriesnode_id映射和LLM生成的问题的集合。
  2. corpusnode_id映射和相应节点中的文本的集合。
  3. relevant_docs:查询的node_id和语料库 node_id之间的交叉引用映射的集合。给定一个查询,它告诉Embedding模型要查找哪个文本节点/语料库。
  • 由于我们使用开源Embedding模型bge-small-en ,微调的前提就是要先把它下载到您的本地环境。以Google Colab为例,经过微调的模型将被下载到笔记本的根目录中。
  • 评估方法在微调Embedding模型和微调LLM之间有所不同,我们可以使用Ragas框架来衡量精准度和答案相关性。然而,当使用Embedding模型微调时,我们无法测量答案的正确性,因为我们只能为我们的问题检索相关节点。相反,我们使用一个称为“命中率”的简单度量,这意味着对于每个(query, relevant_doc)对,我们用查询检索top-k文档,如果结果包含relevant_doc,则它被认为是“命中”的。该指标可用于专有Embeddings,如OpenAI的Embedding模型和开源Embedding模型。对于开源Embedding模型,我们还可以使用来自sentence_transformersInformationRetrievalEvaluator进行评估,因为它提供了一套更全面的指标。

微调Embedding模型似乎涉及到很多问题。幸运的是,LlamaIndex(我个人感觉LlamaIndex目前的发展可能会在RAG方面打败LangChain)在最近的0.8.21版本中引入以下关键类/函数,使得微调Embedding模型变得超级简单:

  • SentenceTransformersFinetuneEngine
  • generate_qa_embedding_pairs
  • EmbeddingQAFinetuneDataset

这些类和函数为我们抽象了底层的详细集成逻辑,使开发人员能够非常直观地调用它。

微调方法

为了可视化微调BAAI/big-small-en所涉及的主要任务,让我们看看下图:

img

如图中的数值所示,主要任务包括:

  1. 通过调用 EmbeddingQAFinetuneDataset函数generate_qa_embedding_pairs,自动生成评估和训练数据集的数据。
  2. 通过传入基本模型和训练数据集来构造SentenceTransformersFinetuneEngine,然后调用其finetune函数来训练基本模型。
  3. 创建经过微调的模型。
  4. 调用向量存储索引检索器检索相关节点并评估基本模型的命中率。
  5. 调用InformationRetrievalEvaluator来评估基本模型。
  6. 调用向量存储索引检索器检索相关节点并评估微调模型的命中率。
  7. 调用InformationRetrievalEvaluator来评估经过微调的模型。

基于LlamaIndex的微调Embeddings指南(文末有链接),我们将在我们的用例中微调bge-small-en模型。

实现细节

Step 1: 生成数据集

让我们使用LLM来自动生成训练和评估的数据集。

  • Load corpus

在我们的用例中NVIDIA的SEC 10-K文件(代码中和文末都有链接)是一个169页的PDF文档(你可以用你自己的中文PDF),所以我们需要在生成数据集时将文档分成两部分——一部分用于训练数据集,另一部分用于evalals数据集。

使用单独的数据集进行训练和评估被认为是一种很好的ML实践。可以调用load_corpus函数来收集训练数据集(前90页)或eval数据集(其余页面)的节点。下面是load_corpus的代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
!curl https://d18rn0p25nwr6d.cloudfront.net/CIK-0001045810/4e9abe7b-fdc7-4cd2-8487-dc3a99f30e98.pdf --output nvidia-sec-10k-2022.pdf

def load_corpus(docs, for_training=False, verbose=False):
    parser = SimpleNodeParser.from_defaults()
    if for_training:
        nodes = parser.get_nodes_from_documents(docs[:90], show_progress=verbose)
    else:
        nodes = parser.get_nodes_from_documents(docs[91:], show_progress=verbose)

    if verbose:
        print(f'Parsed {len(nodes)} nodes')

    return nodes

SEC_FILE = ['nvidia-sec-10k-2022.pdf']

print(f"Loading files {SEC_FILE}")

reader = SimpleDirectoryReader(input_files=SEC_FILE)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')

train_nodes = load_corpus(docs, for_training=True, verbose=True)
val_nodes = load_corpus(docs, for_training=False, verbose=True)

请记住,在LlamaIndex中,节点和页面并不完全匹配。对于一个169页的文档,结果显示它为训练数据集解析了97个节点,为evals数据集解析了91个节点。这两个数据集的节点数量足够接近。让我们继续。

img

  • 生成合成查询和数据集

现在,让我们生成训练和评估的数据集。请注意,我们这里没有传递LLM (gpt-3.5-turbo-0613),只有OpenAI API密钥。这是因为LlamaIndex的默认LLM是gpt-3.5-turbo-0613;如果没有定义LLM,只要提供OpenAI API密钥,则默认为它。

generate_qa_embedding_pairs是一个生成数据集的方便函数。基于上面load_corpus函数返回的节点,它为每个节点生成问题(默认为每个节点两个问题,可以自定义),然后用所有三组数据构建数据集:queriescorpusrelevant_docs(queriescorpus之间的映射对应的node_id)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from llama_index.finetuning import (
    generate_qa_embedding_pairs,
    EmbeddingQAFinetuneDataset,
)
from llama_index.llms import OpenAI

os.environ["OPENAI_API_KEY"] = "sk-############"
openai.api_key = os.environ["OPENAI_API_KEY"]

train_dataset = generate_qa_embedding_pairs(train_nodes)
val_dataset = generate_qa_embedding_pairs(val_nodes)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

下面是样本训练数据集的样子。注意queriescorpus在截图中是折叠的,因为每个都有超过100个数据对:

img

Step 2: 微调Embedding模型

SentenceTransformersFinetuneEngine就是为这个任务设计的。在底层,它执行多个子任务:

  • 通过构建SentenceTransformer加载预训练模型,传入BAAI/big-small-en模型id。
  • 定义数据加载器。它加载我们的训练数据集,将其解析为查询语料库relevant_docs。然后循环查询,将relevant_docs中的node_idcorpus中的文本节点进行映射,构造InputExample,其列表依次传递到创建DataLoader中.
  • 定义loss(损失函数)。它使用sentence_transformers multiplenegativerankingloss来训练检索设置的Embeddings。
  • 定义评估器。它设置了一个带有eval数据集的评估器来监控Embedding模型在训练期间的表现。
  • 运行训练。它插入上面定义的数据加载器、损失函数和评估器来运行训练。

LlamaIndex将微调Embedding模型的所有详细子任务封装在一个SentenceTransformersFinetuneEngine中,我们所需要做的就是调用它的finetune函数。下面,您可以看到展示LlamaIndex的代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
from llama_index.finetuning import SentenceTransformersFinetuneEngine

finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="BAAI/bge-small-en",
    model_output_path="test_model",
    val_dataset=val_dataset,
)

finetune_engine.finetune()

embed_model = finetune_engine.get_finetuned_model()

Step 3: 评估微调后的模型

如上所述,我们使用两种不同的评估方法:

  • 命中率:对每个query / relevant_doc对进行简单的top-k检索。如果搜索结果包含relevant_doc,那么它就是一个“命中”。这可以用于专有的Embeddings,例如OpenAI的Embedding模型和开源Embedding模型。请参阅下面代码片段中的evaluate函数。

  • InformationRetrievalEvaluator:一个更全面的用于评估开源Embeddings的度量套件。请参阅下面代码片段中的evaluate_st函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd

# function for hit rate evals
def evaluate(
    dataset,
    embed_model,
    top_k=5,
    verbose=False,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    service_context = ServiceContext.from_defaults(embed_model=embed_model)
    nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
    index = VectorStoreIndex(nodes, service_context=service_context, show_progress=True)
    retriever = index.as_retriever(similarity_top_k=top_k)

    eval_results = []
    for query_id, query in tqdm(queries.items()):
        retrieved_nodes = retriever.retrieve(query)
        retrieved_ids = [node.node.node_id for node in retrieved_nodes]
        expected_id = relevant_docs[query_id][0]
        is_hit = expected_id in retrieved_ids  # assume 1 relevant doc

        eval_result = {
            "is_hit": is_hit,
            "retrieved": retrieved_ids,
            "expected": expected_id,
            "query": query_id,
        }
        eval_results.append(eval_result)
    return eval_results


from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformer

def evaluate_st(
    dataset,
    model_id,
    name,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs, name=name)
    model = SentenceTransformer(model_id)
    return evaluator(model, output_path="results/")
  • 评测OpenAI

现在,让我们评估一下OpenAI的Embedding模型text-embedding-ada-002。代码如下:

1
2
3
4
5
6
ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)

df_ada = pd.DataFrame(ada_val_results)

hit_rate_ada = df_ada['is_hit'].mean()

结果:

img

  • 评测BAAI/bge-small-en
1
2
3
4
5
6
7
8
bge = "local:BAAI/bge-small-en"
bge_val_results = evaluate(val_dataset, bge)

df_bge = pd.DataFrame(bge_val_results)

hit_rate_bge = df_bge['is_hit'].mean()

evaluate_st(val_dataset, "BAAI/bge-small-en", name='bge')

结果:

img

  • 评估微调后的model
1
2
3
4
5
6
7
8
finetuned = "local:test_model"
val_results_finetuned = evaluate(val_dataset, finetuned)

df_finetuned = pd.DataFrame(val_results_finetuned)

hit_rate_finetuned = df_finetuned['is_hit'].mean()

evaluate_st(val_dataset, "test_model", name='finetuned')

查看结果:

img

  • Summary of results

把评测结果放在一起,让我们仔细看看。

命中率:我们的微调模型比其基本模型bge-small-en的性能提高了1.29%。与OpenAI的Embedding模型相比,我们的微调模型的性能仅低了4.85%。

img

InformationRetrievalEvaluator结果:经过微调的模型比其基本模型的性能提高了5.81%。与基本模型相比,微调模型对这30多个指标列中的每一个都有更好的数字。

img

总结

在本文中,我们探讨了微调RAG管道的Embedding模型所涉及的步骤。我们使用开源的sentence_transformers模型BAAI/big-small-en作为我们的基本Embedding模型,介绍了如何生成用于训练和评估的数据集,如何对其进行微调,以及如何评估基本模型和微调模型之间的性能差异。

评估结果表明,微调Embedding模型的性能比基本模型提高了1-6%,与OpenAI的Embedding模型相比,微调模型的性能损失仅为4.85%。这种性能提升可能因数据集的质量和数量而异。

我们还简要探讨了LlamaIndex的最新版本,该版本对任何Embedding模型的线性适配器进行了微调,从而提高了性能并避免了在RAG管道中重新嵌入文档。

引用

  • LlamaIndex的Finetune Embeddings指南:Finetune Embeddings - LlamaIndex 🦙 0.9.26
  • NVIDIA的SEC10-K文件的PDF:https://d18rn0p25nwr6d.cloudfront.net/CIK-0001045810/4e9abe7b-fdc7-4cd2-8487-dc3a99f30e98.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1385559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#灵活控制多线程的状态(开始暂停继续取消)

ManualResetEvent类 ManualResetEvent是一个同步基元,用于在多线程环境中协调线程的执行。它提供了两种状态:终止状态和非终止状态。 在终止状态下,ManualResetEvent允许线程继续执行。而在非终止状态下,ManualResetEvent会阻塞线…

智能助手的巅峰对决:ChatGPT对阵文心一言

在人工智能的世界里,ChatGPT与文心一言都是备受瞩目的明星产品。它们凭借先进的技术和强大的性能,吸引了大量用户的关注。但究竟哪一个在智能回复、语言准确性、知识库丰富度等方面更胜一筹呢?下面就让我们一探究竟。 首先来谈谈智能回复能力…

SwiftUI之深入解析高级布局的实战教程

一、自定义动画 首先实现一个圆形布局的视图容器 WheelLayout: struct ContentView: View {let colors: [Color] [.yellow, .orange, .red, .pink, .purple, .blue, .cyan, .green]var body: some View {WheelLayout(radius: 130.0, rotation: .zero) {ForEach(0.…

强化学习应用(三):基于Q-learning的物流配送路径规划研究(提供Python代码)

一、Q-learning算法简介 Q-learning是一种强化学习算法,用于解决基于马尔可夫决策过程(MDP)的问题。它通过学习一个值函数来指导智能体在环境中做出决策,以最大化累积奖励。 Q-learning算法的核心思想是使用一个Q值函数来估计每…

纳米量级晶圆表面微观检测技术

持续更新 背景:晶圆表面形状偏差分为:宏观几何误差,中间几何误差,微观几何误差,跟别用表面形状误差,表面波纹度,表面粗度来描述。 主要技术:微分剪切干涩显微技术,五步…

Dubbo分层设计之Transport层

前言 Dubbo 框架采用分层设计,最底下的 Serialize 层负责把对象序列化为字节序列,再经过 Transport 层网络传输到对端。一次 RPC 调用,在 Dubbo 看来其实就是一段请求报文和一段响应报文的传输过程。 理解Transport Transport 层即网络传输…

计算机毕业设计----SSH在线水果商城平台含管理系统

项目介绍 本项目分为前后台,分为普通用户与管理员两个角色,前台为普通用户登录,后台为管理员登录; 管理员角色包含以下功能: 管理员登录,修改密码,类别管理,水果管理,订单管理,网站论坛管理,网站公告管理等功能。 …

抖音小店2024年创业新趋势,新手找项目,不要再错过这次的机会了

大家好,我是电商花花。 现在的抖音小店完全是电商创业中的一个优秀代名词和最轻便的创业项目,更是以独特的直播达人带货的优势将店铺激发出来。 今天给大家介绍下抖音小店的运作方式,并分析互联网创业的机遇,并提供相关的再做点…

Unity中URP下 SimpleLit框架

文章目录 前言一、整体框架1、该Shader是用于低端设备的2、包含一个Properties3、只有一个SubShader4、如果SubShader错误,返回洋葱紫5、调用自定义ShaderGUI面板 二、SubShader中1、Tags2、Pass 三、我们看一下ForwardLit的Pass1、混合模式、深度写入、面皮剔除、透…

ZooKeeper 简介

1、概念介绍 ZooKeeper 是一个开放源码的分布式应用程序协调服务,为分布式应用提供一致性服务的软件,由雅虎创建,是 Google Chubby 的开源实现,是 Apache 的子项目,之前是 Hadoop 项目的一部分,使用 Java …

提高执行力,关键在于管理者做到这四个字

执行力,对于个人而言,它就是办事的效能;而对于领导来说,它是管理的能力。 老板命令员工去买复印纸,员工第一次买回了一沓复印纸,第二次买了三摞复印纸,却仍然没有得到老板的满意。员工之所以跑…

Halcon滤波器 laplace 算子

Halcon滤波器 laplace 算子 使用laplace 算子对图像进行二次求导,会在边缘产生零点,因此该算子常常与zero_crossing算子配合使用。求出这些零点,也就得到了图像的边缘。同时,由于laplace算子对孤立像素的响应要比对边缘或线的响应…

element upload 自定义上传 报错Cannot set properties of null (setting ‘status‘)

element upload 自定义上传 报错Cannot set properties of null (setting ‘status’) 问题展示 原因分析 自定义上传方式 fileList 显示一切正常&#xff0c;状态也是成功 文件url通过URL.createObjectURL(file.raw) 进行添加 以下为配置代码 <el-uploadclass"uplo…

【K12】Python写串联电阻问题的求解思路解析

问题源代码 方法&#xff1a;calculate_circuit_parameter 构造题目&#xff1a; 模板&#xff1a; 已知电阻R1为 10Ω&#xff0c;电阻R2为 5Ω&#xff0c;电压表示数为2.5V&#xff0c;求电源电压U&#xff1f; 给合上面题目&#xff0c;利用Python程序&#xff0c;可以任…

【ScienceAI Weekly】DeepMind拆分的AI药企达成30亿美元新协议;网传字节跳动在美招聘生物/化学/物理人才

AI for Science 的新成果、新动态、新视角—— 由 DeepMind 拆分的 AI 药企首次达成制药合作&#xff0c;价值 30 亿美元微软协助科研人员发现 3,200 万种新电池材料网传 TikTok 在美国各地招募计算生物学、量子化学、分子动力学和物理方面的人才科大讯飞拟分拆医疗业务在港交…

遥感卫星影像现拍,哪里想看拍哪里!

我们为大家分享了查看实时卫星影像的方法。 虽然这个网站的卫星影像10分钟一更新&#xff0c;让世界尽收眼底&#xff0c;但分辨率却非常有限。 如果项目中需要更高清的卫星影像&#xff0c;且对时效性又有较高的要求&#xff0c;那么可以考虑用卫星专门拍摄。 光学遥感卫星…

为什么有人说PMP是水证,它的含金量到底怎么样?

在我国大陆&#xff0c;有好多证书被商业化得太重了&#xff0c;甚至演变成了个人或一些公司摇钱的工具。所以有些证书受人吹捧它崛起的快&#xff0c;但是活不长&#xff0c;甚至“夭折”&#xff0c;比如以前微软系列的证书&#xff1b; 而PMP认证从国外引进大陆这么多年了&…

【昕宝爸爸小模块】守护线程、普通线程、两者之间的区别

➡️博客首页 https://blog.csdn.net/Java_Yangxiaoyuan 欢迎优秀的你&#x1f44d;点赞、&#x1f5c2;️收藏、加❤️关注哦。 本文章CSDN首发&#xff0c;欢迎转载&#xff0c;要注明出处哦&#xff01; 先感谢优秀的你能认真的看完本文&…

打造完美跨境商城源码,助你轻松进军国际市场

随着全球化的深入&#xff0c;跨境电商已成为各国企业拓展国际市场的重要途径之一。根据最新数据显示&#xff0c;跨境电商市场规模逐年扩大&#xff0c;预计未来几年将保持较高增长率。因此&#xff0c;拥有一套完善的跨境商城源码成为企业进军国际市场的关键。 跨境商城源码…

Java--ListUtil工具类,实现将一个大列表,拆分成指定长度的子列表

文章目录 前言实现代码执行结果 前言 在项目中有时会出现列表很大&#xff0c;无法一次性批量操作&#xff0c;我们需要将列表分成指定大小的几个子列表&#xff0c;一份一份进行操作&#xff0c;本文提供这样的工具类实现这个需求。 实现代码 以下为ListUtil工具类代码实现…