使用向量检索和rerank 在RAG数据集上实验评估hit_rate和mrr

news2024/11/19 20:23:44

文章目录

    • 背景
    • 简介
    • 代码实现
      • 自定义检索器
      • 向量检索实验
      • 向量检索和rerank 实验
    • 代码开源

背景

在前面部分 大模型生成RAG评估数据集并计算hit_rate 和 mrr 介绍了使用大模型生成RAG评估数据集与评估;

在 上文 使用到了BM25 关键词检索器。接下来,想利用向量检索器测试一下在RAG评估数据集上的 hit_rate 和 mrr;

简介

使用 向量检索 和 rerank 在给定RAG评估数据集上的实验计算 hit_rate 和 mrr;

对比了使用 rerank 和 不使用 rerank的实验结果;

步骤:

  1. 基于RAG评估数据集,构建nodes节点;
  2. 构建 CustomRetriever 自定义的检索器,在检索器中实现 向量检索和 rerank;
  3. 实验评估;

代码实现

from typing import List

from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.evaluation import RetrieverEvaluator
from llama_index.core.indices.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.settings import Settings
from llama_index.legacy.embeddings import HuggingFaceEmbedding
# from llama_index.legacy.schema import NodeWithScore, QueryBundle
from llama_index.core.schema import NodeWithScore, QueryBundle, QueryType, Node
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

利用数据集中的数据,构建nodes
pg_eval_dataset.json的下载地址: https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/files

qa_dataset = EmbeddingQAFinetuneDataset.from_json("pg_eval_dataset.json")

nodes = []
for key, value in qa_dataset.corpus.items():
    nodes.append(Node(id_=key, text=value))

m3e 向量编码模型
若想使用其他的编码模型,直接进行修改即可,modelscope和huggingface的编码模型都行;

from modelscope import snapshot_download
model_dir = snapshot_download('AI-ModelScope/m3e-base')
Settings.embed_model = HuggingFaceEmbedding(model_dir)
Settings.llm = None

由于huggingface被墙了,笔者使用的是 modelscope平台,model_dir 为编码模型在本地的绝对路径

自定义检索器

tok_k: 表示召回的节点数量,可自定义设置;

top_k = 10

定义向量检索器,还实现了rerank;

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both Vector search and Knowledge Graph search"""

    def __init__(self, vector_retriever: VectorIndexRetriever, reranker=None) -> None:
        """Init params."""

        super().__init__()
        self._vector_retriever = vector_retriever
        self.reranker = reranker

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        # print(query_bundle, isinstance(QueryBundle))

        retrieved_nodes = self._vector_retriever.retrieve(query_bundle)

        if self.reranker != 'None':
            retrieved_nodes = self.reranker.postprocess_nodes(retrieved_nodes, query_bundle)
        else:
            retrieved_nodes = retrieved_nodes[:top_k]

        return retrieved_nodes
    
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Asynchronously retrieve nodes given query.

        Implemented by the user.

        """
        return self._retrieve(query_bundle)

    async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
        if isinstance(str_or_query_bundle, str):
            str_or_query_bundle = QueryBundle(str_or_query_bundle)
        return await self._aretrieve(str_or_query_bundle)

eval_results包含每个query的 hit_rate 和 mrr,display_results 计算平均;

import pandas as pd
def display_results(eval_results):
    """
    	计算平均 hit_rate 和 mrr
    """

    metric_dicts = []
    for eval_result in eval_results:
        metric_dict = eval_result.metric_vals_dict
        metric_dicts.append(metric_dict)

    full_df = pd.DataFrame(metric_dicts)

    hit_rate = full_df["hit_rate"].mean()
    mrr = full_df["mrr"].mean()

    metric_df = pd.DataFrame(
        {"hit_rate": [hit_rate], "mrr": [mrr]}
    )
    return metric_df

向量检索实验

index = VectorStoreIndex(nodes)
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
retriever_evaluator = RetrieverEvaluator.from_metric_names(
    ["mrr", "hit_rate"], retriever=vector_retriever
)
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
display_results(eval_results)

在这里插入图片描述

向量检索和rerank 实验

bge_reranker_base = SentenceTransformerRerank(
    model=snapshot_download("Xorbits/bge-reranker-base"),
    top_n=top_k)

retriever = CustomRetriever(
    vector_retriever=vector_retriever,
    reranker=bge_reranker_base)

retriever_evaluator = RetrieverEvaluator.from_metric_names(
    ["mrr", "hit_rate"], retriever=retriever
)
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
display_results(eval_results)

在这里插入图片描述
若想使用其他的rerank模型,更换Xorbits/bge-reranker-base

若使用modelscope平台的rerank模型,直接修改模型名即可;
若使用huggingface 平台的rerank模型,自行修改代码;

上述对比了,在向量检索下,对比了添加rerank和不添加rerank的实验结果;
如上图所示,相比只有向量检索的实验,加了rerank mrr 反而还下降了,这是一个比较反常的实验结果;

这个并不能说明rerank没有用,笔者在其他的RAG数据集测试时,rerank确实能提升mrr;本例子这里的情况大家忽略即可。
在本实验这里仅仅是给读者展示如何使用rerank;这也说明了rerank模型,也并不都能提升所有的mrr;

代码开源

本项目的完整代码,已发布到modelscope平台上;
点击下述链接查看代码:
https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/file/view/master/vector_rerank_eval.ipynb?status=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1573015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Filter

概念:Filter表示过滤器,是JavaWeb三大组件(Servlet,Filter,Listener)之一 过滤器可以把对资源的请求拦截下来,从而实现一些特殊的功能。 过滤器一般完成一些通用的操作,比如:权限控…

元宇宙虚拟空间的场景渲染(五)

前言 该文章主要讲元宇宙虚拟空间的场景渲染,基本核心技术点,不多说,直接引入正题。 场景渲染 下面第二个图中的代码是一个循环渲染逻辑,首先getDelta 获取2次时间的时间间隔,requestAnimationFrame请求我们的一个动…

mbti,INTJ型人格的心理问题分析

什么是INTJ型人格? INTJ来自mbti职业性格测试,16种人格类型之一,INTJ分别代表内向,直觉,理智,独立,而INTJ型人格是一种以冷静和理性著称的人格,这种人格的人总给人一种比较理智&…

【【萌新的学习之Numpy数组的使用】】

萌新的学习之Numpy数组的使用 先记录一下之前的关于函数的设计 通过创造类的形式 复习完毕之后介绍numpy数组的使用 #整数型数组遇到除法 (即便是除以整数) 不同维度的数组之间 从外形上的本质区别 一维数组用1层中括号 二维数组用2层中括号 三维数…

golang设计模式图解——命令模式

设计模式 GoF提出的设计模式有23个,包括: (1)创建型(Creational)模式:如何创建对象; (2)结构型(Structural )模式:如何实现类或对象的组合; (3&a…

docker 部署 dujiaoka 独角数卡自动售货系统 支持 X86 和 ARM 架构

前言 很早就想部署一套自己的发卡自动售货系统,研究了很久发现独角数卡相对更加成熟好用,可是折腾技术三年多最怕的就是php和Laravel之类的语言和框架,各种权限,守护之类配置麻烦,加上如果跑在docker里更加头疼&#…

Android 的网络加载

发起网络请求的过程 当用户在应用程序中输入网址或关键字时,应用程序会发起网络请求。这个过程大致如下: 应用程序将请求发送到服务器,服务器返回响应数据。应用程序接收到响应数据后,将其转换为应用程序可识别的数据格式。应用…

Flutter 解决NestedScrollView与TabBar双列表滚动位置同步问题

文章目录 前言一、需要实现的效果如下二、flutter实现代码如下:总结 前言 最近写flutter项目,遇到NestedScrollView与TabBar双列表滚动位置同步问题,下面是解决方案,希望帮助到大家。 一、需要实现的效果如下 1、UI图&#xff1…

程序汪5万接的公交车板打车小程序,开发周期40天(发布版

本项目来自程序汪背后的私活小团队,开发了一个打车小程序,给粉丝分享一下解决方案,本项目前端工作量比较大,希望给想接私活的朋友一些经验参考 视频版本 在 B站【我是程序汪】 目录 一、项目构成 二、开发人员 三、项目背景 四…

Redis数据库的入门学习

关系型数据库和非关系型数据库的区别: 简介 Redis数据库和MySql数据库的区别:Redis数据库是基于内存的key-value结构的数据库。本质上是内存存储。 而MySql数据库是通过数据文件的方式存在磁盘当中,本质上是磁盘存储。且MySql当中是通过二维…

2-3多交换机静态流表控制原理与实现

实现目标环境下的静态流表设置: 1 单个ovs上实现多个主机hosts之间的通信 2多ovs上多主机之间的通信 1 单个ovs上实现多个主机hosts之间的通信 使用函数定义的方式创建一个如下的拓扑,并使用静态流表 from mininet.net import Mininet from mininet.n…

ENSP防火墙,解决不兼容及报错等问题,windows命令行修改网卡配置,配置cloud及防火墙连接,web连接防火墙

解决不兼容和报错等问题 原因1:VirtualBox版本太低(5.1.x)或太高(6.x.x)和eNSP不兼容 卸载virtualbox,下载virtualbox 5.2.28,安装稳定版本的virtualbox 删除原有程序:c:\用户\***\.…

Qt报错:C1083 无法打开包括文件: No such file or directory

我用的是VS2019 添加了一个继承自QTextEdit 的新类QMsgTextEdit, 就出现了这样的报错: 我双击ui_TalkWindow.h, 打开这个文件后, 发现: 我就试着打开qmsgtextedit.h,发现: 于是,我就在当前ui_TalkWindow.h文件的目…

数据可视化-地图可视化-Python

师从黑马程序员 基础地图使用 基础地图演示 视觉映射器 具体颜色对应的代码可以在http://www.ab173.com/中查询RGB颜色查询对照表 from pyecharts.charts import Map from pyecharts.options import VisualMapOpts#准备地图对象 mapMap() #准备数据 data[("北京",…

数据结构学习——栈和队列

1.栈 1.1栈的概念及结构 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端 称为栈顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 …

YOLOv5实战记录05 Pyside6可视化界面

个人打卡,慎看。 指路大佬:【手把手带你实战YOLOv5-入门篇】YOLOv5 Pyside6可视化界面_哔哩哔哩_bilibili 零、虚拟环境迁移路径后pip报错解决 yolov5-master文件夹我换位置后,无法pip install了。解决如下: activate.bat中修改…

关系型数据库与非关系型数据库、Redis数据库

相比于其他的内存/缓存数据库,redis可以方便的实现持久化的功能(保存至磁盘中) 一、关系数据库与非关系型数据库 1.1 关系型数据库 一个结构化的数据库,创建在关系模型基础上一般面向于记录 SQL语句 (标准数据查询语言) 就是一种…

11-pyspark的RDD的变换与动作算子总结

目录 前言 变换算子动作算子 前言 一般来说,RDD包括两个操作算子: 变换(Transformations):变换算子的特点是懒执行,变换操作并不会立刻执行,而是需要等到有动作(Actions)…

大语言模型落地的关键技术:RAG

1、什么是RAG? RAG 是检索增强生成(Retrieval-Augmented Generation)的简称,是当前最火热的大语言模型应用落地的关键技术,主要用于提高语言模型的效果和准确性。它结合了两种主要的NLP方法:检索&#xff…

【智能排班系统】AOP实现操作日志自动记录

文章目录 操作日志介绍自动保存操作日志基本实现思路定义注解枚举业务类型枚举操作人员类型枚举 AOP具体实现方法上添加注解 日志增删改查日志表sql实体类ServiceControllerVo 操作日志介绍 操作日志是对系统或应用程序中所有用户操作、系统事件、后台任务等进行详细记录的文本…