用通俗易懂的方式讲解:对 embedding 模型进行微调,我的大模型召回效果提升了太多了

news2024/9/29 23:30:54

QA对话目前是大语言模型的一大应用场景,在QA对话中,由于大语言模型信息的滞后性以及不包含业务知识的特点,我们经常需要外挂知识库来协助大模型解决一些问题。

在外挂知识库的过程中,embedding模型的召回效果直接影响到大模型的回答效果,因此,在许多场景下,我们都需要微调我们的embedding模型来提高我们的召回效果。

码字不易,喜欢记得收藏、点赞、关注,如果你希望技术交流&答疑,见文末

下面,我就基于llama-index对BAAI/bge-base-zh-v1.5模型进行微调,关于该模型的介绍,可以参考https://huggingface.co/BAAI/bge-base-zh-v1.5。

平台介绍

对embedding模型进行微调的过程中需要使用GPU加速训练,我这里就使用了Google colab提供的免费T4GPU进行微调测试。

如果大家没办法使用这个,可以使用国内一些公司的GPU云平台,租便宜的GPU就行,微调这个模型所耗费的GPU资源不多。

以下所有训练代码皆是在Jupter-notebook上编写并执行的。

依赖安装

安装一些依赖库,有些依赖需要制定版本,否则存在不兼容的问题。

!pip install langchain==0.0.300 llmx==0.0.15a0 openai==0.28.1 llama_index==0.8.23.post1 pypdf sentence-transformers

训练样本准备

我们当前的使用场景是QA问答场景,因此训练数据的格式最好也是问答的格式。我这里由于没有现成的问答样本(人工整理比较耗时),因此我就摘取了《明朝那些事儿》这个小说里面的部分章节,然后让GPT-3.5针对文章内容进行提问,从而形成问答对。代码如下

import json
import openai
import os

from llama_index import SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import MetadataMode
from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
    ServiceContext,
    Response
)

def load_corpus(docs, for_training=False, verbose=False):
    parser = SimpleNodeParser.from_defaults()
    if for_training:
        nodes = parser.get_nodes_from_documents(docs[:5], show_progress=verbose)
    else:
        nodes = parser.get_nodes_from_documents(docs[6:], show_progress=verbose)

    if verbose:
        print(f'Parsed {len(nodes)} nodes')

    return nodes

SEC_FILE = ['embedding_test.txt'] 

print(f"Loading files {SEC_FILE}")

reader = SimpleDirectoryReader(input_files=SEC_FILE)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')

docs_nodes = load_corpus(docs, for_training=True, verbose=True)

len(docs_nodes)

train_nodes = docs_nodes[:75]  
print(f'Loaded {len(train_nodes)} train docs')
val_nodes = docs_nodes[76:] 
print(f'Loaded {len(val_nodes)} val docs')

构造训练集和测试集

使用GPT3.5基于小说内容生成对应的问题,最后生成train_dataset.json作为训练集,val_dataset.json作为验证集。

from llama_index.finetuning import (
    generate_qa_embedding_pairs,
    EmbeddingQAFinetuneDataset,
)
from llama_index.llms import OpenAI

os.environ["OPENAI_API_KEY"] = "sk-************"
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_base = "https://************"

prompt="""下方是上下文信息。

---------------------
{context_str}
---------------------

根据提供的上下文信息和没有先验知识的原则,仅基于以下查询生成问题。

你是一名教师/教授。你的任务是为即将到来的测验/考试设置{num_questions_per_chunk}个问题。这些问题应在文档中多样化,且仅限于所提供的上下文信息。
"""

train_dataset = generate_qa_embedding_pairs(train_nodes, qa_generate_prompt_tmpl=prompt)
val_dataset = generate_qa_embedding_pairs(val_nodes, qa_generate_prompt_tmpl=prompt)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

微调Embedding模型

这里的微调都是使用的默认参数,在实际微调过程中,可根据实际情况进行调整。

from llama_index.finetuning import SentenceTransformersFinetuneEngine
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")
finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="BAAI/bge-base-zh-v1.5",
    model_output_path="test_model",
    val_dataset=val_dataset,
)
finetune_engine.finetune() 
embed_model = finetune_engine.get_finetuned_model()
embed_model

评估微调后的模型

在评估阶段,我们对比了微调前、后的BAAI/bge-base-zh-v1.5模型以及OPENAI的ada002的Embedding模型的召回效果,代码如下:

from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd
def evaluate(
    dataset,
    embed_model,
    top_k=5,
    verbose=False,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    service_context = ServiceContext.from_defaults(embed_model=embed_model)
    nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
    index = VectorStoreIndex(nodes, service_context=service_context, show_progress=True)
    retriever = index.as_retriever(similarity_top_k=top_k)

    eval_results = []
    for query_id, query in tqdm(queries.items()):
        retrieved_nodes = retriever.retrieve(query)
        retrieved_ids = [node.node.node_id for node in retrieved_nodes]
        expected_id = relevant_docs[query_id][0]
        is_hit = expected_id in retrieved_ids  

        eval_result = {
            "is_hit": is_hit,
            "retrieved": retrieved_ids,
            "expected": expected_id,
            "query": query_id,
        }
        eval_results.append(eval_result)
    return eval_results

注意,在执行下面的代码前,需要先在当前项目的目录下创建results文件夹,否则会导致程序执行失败。

from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformer

def evaluate_st(
    dataset,
    model_id,
    name,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs, name=name)
    model = SentenceTransformer(model_id)
    return evaluator(model, output_path="results/")

OPENAI-ada002

ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)
df_ada = pd.DataFrame(ada_val_results)
hit_rate_ada = df_ada['is_hit'].mean()
hit_rate_ada

ada002模型的最终评测结果为0.9285714285714286

原始BAAI/bge-base-zh-v1.5

bge = "local:BAAI/bge-base-zh-v1.5"
bge_val_results = evaluate(val_dataset, bge)
df_bge = pd.DataFrame(bge_val_results)
hit_rate_bge = df_bge['is_hit'].mean()
hit_rate_bge

原始的bge-base-zh-v1.5模型的评测结果为0.7663744588744589

微调后的BAAI/bge-base-zh-v1.5

finetuned = "local:test_model"
val_results_finetuned = evaluate(val_dataset, finetuned)
df_finetuned = pd.DataFrame(val_results_finetuned)
hit_rate_finetuned = df_finetuned['is_hit'].mean()
hit_rate_finetuned

微调后模型的最终评测结果为0.975。即微调后,我们的embedding模型在当前数据集的召回效果由0.766上升到0.975注意,得分并不是越高越好,需考虑是否过拟合,可以在其他数据集上再评测下。

以上,即是一次简单的微调过程。感谢技术的发展和开源大佬们的贡献,使得人工智能的应用门槛越来越低。

技术交流

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

本文完整代码、相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:来自CSDN + 技术交流

在这里插入图片描述

通俗易懂讲解大模型系列

  • 用通俗易懂的方式讲解:大模型 RAG 在 LangChain 中的应用实战

  • 用通俗易懂的方式讲解:一文讲清大模型 RAG 技术全流程

  • 用通俗易懂的方式讲解:如何提升大模型 Agent 的能力?

  • 用通俗易懂的方式讲解:使用 Mistral-7B 和 Langchain 搭建基于PDF文件的聊天机器人

  • 用通俗易懂的方式讲解:ChatGPT 开放的多模态的DALL-E 3功能,好玩到停不下来!

  • 用通俗易懂的方式讲解:结合检索和重排序模型,改善大模型 RAG 效果明显

  • 用通俗易懂的方式讲解:基于扩散模型(Diffusion),文生图 AnyText 的效果太棒了

  • 用通俗易懂的方式讲解:在 CPU 服务器上部署 ChatGLM3-6B 模型

  • 用通俗易懂的方式讲解:ChatGLM3-6B 功能原理解析

  • 用通俗易懂的方式讲解:使用 LangChain 和大模型生成海报文案

  • 用通俗易懂的方式讲解:一个强大的 LLM 微调工具 LLaMA Factory

  • 用通俗易懂的方式讲解:ChatGLM3-6B 部署指南

  • 用通俗易懂的方式讲解:LangChain Agent 原理解析

  • 用通俗易懂的方式讲解:HugggingFace 推理 API、推理端点和推理空间使用详解

  • 用通俗易懂的方式讲解:使用 LangChain 封装自定义的 LLM,太棒了

  • 用通俗易懂的方式讲解:使用 FastChat 部署 LLM 的体验太爽了

  • 用通俗易懂的方式讲解:基于 Langchain 和 ChatChat 部署本地知识库问答系统

  • 用通俗易懂的方式讲解:使用 Docker 部署大模型的训练环境

  • 用通俗易懂的方式讲解:在 Ubuntu 22 上安装 CUDA、Nvidia 显卡驱动、PyTorch等大模型基础环境

  • 用通俗易懂的方式讲解:Llama2 部署讲解及试用方式

  • 用通俗易懂的方式讲解:LangChain 知识库检索常见问题及解决方案

  • 用通俗易懂的方式讲解:基于 LangChain 和 ChatGLM2 打造自有知识库问答系统

  • 用通俗易懂的方式讲解:代码大模型盘点及优劣分析

  • 用通俗易懂的方式讲解:Prompt 提示词在开发中的使用

  • 用通俗易懂的方式讲解:万字长文带你入门大模型

参考资料

1.https://github.com/wenqiglantz/nvidia-sec-finetuning/tree/main/embedding-finetuning

2.https://colab.research.google.com/github/wenqiglantz/nvidia-sec-finetuning/blob/main/embedding-finetuning/finetune_embedding_nvidia_sec.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1378960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Sqoop将数据从Hadoop导出到关系型数据库

当将数据从Hadoop导出到关系型数据库时,Apache Sqoop是一个非常有用的工具。Sqoop可以轻松地将大数据存储中的数据导出到常见的关系型数据库,如MySQL、Oracle、SQL Server等。本文将深入介绍如何使用Sqoop进行数据导出,并提供详细的示例代码&…

Android Studio 实现网易新闻App (简单方便易懂)

🍅文章末尾有获取完整项目源码方式🍅 目录 前言 一、任务介绍 1.1 背景 1.2目的和意义 二、 实现介绍 视频演示 2.1 启动页实现 2.2 注册页面实现 2.3 登陆页面实现 2.4 首页实现 2.5 详情页面实现 三、获取源码 前言 随着移动互联网的持续发…

力扣120. 三角形最小路径和(Java 动态规划)

Problem: 120. 三角形最小路径和 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 Problem:64. 最小路径和 本题目可以看作是在上述题目的基础上改编而来,具体的思路: 1.记录一个int类型的大小的 n 乘 n n乘n n乘n的数组(其中 n n n为…

第九讲 单片机驱动彩色液晶屏 控制RA8889软件:显存操作

单片机驱动TFT彩色液晶屏系列讲座 目录 第一讲 单片机最小系统STM32F103C6T6通过RA8889驱动彩色液晶屏播放视频 第二讲 单片机最小系统STM32F103C6T6控制RA8889驱动彩色液晶屏硬件框架 第三讲 单片机驱动彩色液晶屏 控制RA8889软件:如何初始化 第四讲 单片机驱动彩色液晶屏 控…

日志审计系统Agent项目创建——读取日志文件(Linux版本)

紧接着上一篇的分享,继续做日志文件的读取,点击连接即可日志文件初始化https://blog.csdn.net/wjl990316fddwjl/article/details/135553238 1、将指针移动到文件末尾 //文件移动到结尾fseek(fp, 0, SEEK_END); 2、定义当前指针的位置 lastPosition ft…

人工智能:我的学习之旅与认知探索(第1版)

🌟🌌 欢迎来到知识与创意的殿堂 — 远见阁小民的世界!🚀 🌟🧭 在这里,我们一起探索技术的奥秘,一起在知识的海洋中遨游。 🌟🧭 在这里,每个错误都…

2024年第1周,第一期技术动态

大家好,才是真的好。 今天周五,我们继续介绍与Domino相关产品新闻,以及互联网或其他IT行业动态等。 一、Notes/Domino V9和V10技术支持结束和假消息 今年2024年6月1号,HCL将结束IBM Notes/Domino 9.0.x和10.0.x产品的技术支持声…

【发票识别】支持pdf、ofd、图片格式(orc、信息提取)的发票

背景 为了能够满足识别各种发票的功能,特地开发了当前发票识别的功能,当前的功能支持pdf、ofd、图片格式的发票识别,使用到的技术包括文本提取匹配、ocr识别和信息提取等相关的技术,用到机器学习和深度学习的相关技术。 体验 体…

强化学习应用(三):基于Q-learning的无人机物流路径规划研究(提供Python代码)

一、Q-learning简介 Q-learning是一种强化学习算法,用于解决基于马尔可夫决策过程(MDP)的问题。它通过学习一个价值函数来指导智能体在环境中做出决策,以最大化累积奖励。 Q-learning算法的核心思想是通过不断更新一个称为Q值的…

类图作业

类图作业 一. 简答题(共5题,100分) (简答题) 在对类名、属性 /方法名时,通常会遵循什么样的规则?请举例说明。 正确答案: 对于类名通常采用 CamelCase格式(大写字母开头、混合大小写&#xff0…

Spark---RDD持久化

文章目录 1.RDD持久化1.1 RDD Cache 缓存1.2 RDD CheckPoint 检查点1.3 缓存和检查点区别 1.RDD持久化 在Spark中,持久化是将RDD存储在内存中,以便在多次计算之间重复使用。这可以显著减少不必要的计算,提高Spark应用程序的性能。 val line…

MATLAB - 四旋翼飞行器动力学方程

系列文章目录 前言 本例演示了如何使用 Symbolic Math Toolbox™(符号数学工具箱)推导四旋翼飞行器的连续时间非线性模型。具体来说,本例讨论了 getQuadrotorDynamicsAndJacobian 脚本,该脚本可生成四旋翼状态函数及其雅各布函数…

Hive基础知识(十):Hive导入数据的五种方式

1. 向表中装载数据(Load) 1)语法 hive> load data [local] inpath 数据的 path[overwrite] into table student [partition (partcol1val1,…)]; (1)load data:表示加载数据 (2)local:表示…

蓝桥杯练习题(五)

📑前言 本文主要是【算法】——蓝桥杯练习题(五)的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 …

UE4工程升级UE5教程及注意事项

原文链接:https://mp.weixin.qq.com/s/vSVu0VsNub0J62Nz7vM6cA虚幻引擎5迁移指南 | 虚幻引擎5.3文档 (unrealengine.com) 官方教程应该是从英文直接翻译过来的,过多词汇没修改,本篇重新整理修改一下,供各位参考。 本教程介绍&…

基于JAVA的数据可视化的智慧河南大屏 开源项目

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 数据模块 A4.2 数据模块 B4.3 数据模块 C4.4 数据模块 D4.5 数据模块 E 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的数据可视化的智慧河南大屏,包含了GDP、…

分裂联邦学习论文-混合联邦分裂学习GAN驱动的预测性多目标优化

论文标题:《Predictive GAN-Powered Multi-Objective Optimization for Hybrid Federated Split Learning》 期刊:IEEE Transactions on Communications, 2023 一、论文介绍 背景:联邦学习作为一种多设备协同训练的边缘智能算法&#xff0…

IDEA—初始化配置

注:以下红框圈的部分,均为已设置好的 外观与行为 编辑器 高级设置 按两次 shift 弹出提示问题解决

OpenCV-19图像的仿射变换

放射变换是图像旋转,缩放,平移的总称,具体的做法是通过一个矩阵和原图片坐标进行计算,得到新的坐标,完成变换,所以关键就是这个矩阵。 一、仿射变换之图像平移 使用API------warpAffine(src &…

Nightingale 夜莺监控系统 - 监控篇(2)

Author:rab 官方文档:https://flashcat.cloud/docs/content/flashcat-monitor/categraf/3-configuration/ 目录 前言一、Categraf 配置文件二、Input 插件配置文件2.1 插件说明2.2 通用配置2.2.1 配置采集频率 interval2.2.2 配置采集实例 instances2.2…