动手学习RAG:迟交互模型colbert微调实践 bge-m3

news2024/11/25 18:43:20
  • 动手学习RAG: 向量模型
  • 动手学习RAG: BGE向量模型微调实践]()
  • 动手学习RAG: BCEmbedding 向量模型 微调实践]()
  • BCE ranking 微调实践]()
  • GTE向量与排序模型 微调实践]()
  • 模型微调中的模型序列长度]()
  • 相似度与温度系数

本文我们来进行ColBERT模型的实践,按惯例,还是以open-retrievals中的代码为蓝本。在RAG兴起之后,ColBERT也获得了更多的关注。ColBERT整体结构和双塔特别相似,但迟交互式也就意味着比起一般ranking模型,交互来的更晚一些。
请添加图片描述

准备环境

pip install transformers
pip install open-retrievals

准备数据

还是采用C-MTEB/T2Reranking数据。

  • 每个样本有query, positive, negative。其中query和positive构成正样本对,query和negative构成负样本对
    请添加图片描述

使用

由于ColBERT作为迟交互式模型,既可以像向量模型一样生成向量,也可以计算相似度。BAAI/bge-m3中的colbert模型是基于XLMRoberta训练而来,因此使用ColBERT可以直接从bge-m3中加载预训练权重。

import transformers
from retrievals import ColBERT
model_name_or_path: str =  'BAAI/bge-m3' 
model = ColBERT.from_pretrained(
    model_name_or_path,
    colbert_dim=1024,    
    use_fp16=True,
    loss_fn=ColbertLoss(use_inbatch_negative=True),
)

model

请添加图片描述

  • 生成向量的方法
sentences_1 = ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."]
sentences_2 = ['A dog is chasing car.', 'A man is playing a guitar.']

output_1 = model.encode(sentences_1, normalize_embeddings=True)
print(output_1.shape, output_1)

output_2 = model.encode(sentences_2, normalize_embeddings=True)
print(output_2.shape, output_2)

请添加图片描述

  • 计算句子对 相似度的方法
sentences = [
    ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],
    ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", 'A man is playing a guitar.'],
]

scores_list = model.compute_score(sentences)
print(scores_list)

请添加图片描述

微调

尝试了两种方法来做,一种是调包自己写代码,一种是采用open-retrievals中的代码写shell脚本。这里我们采用第一种,另外一种方法可参考文章最后番外中的微调

import transformers
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import AutoModelForRanking, RerankCollator, RerankTrainDataset, RerankTrainer, ColBERT, RetrievalTrainDataset, ColBertCollator
from retrievals.losses import ColbertLoss
transformers.logging.set_verbosity_error()


model_name_or_path: str = 'BAAI/bge-m3'

learning_rate: float = 1e-5
batch_size: int = 2
epochs: int = 1
output_dir: str = './checkpoints'

train_dataset = RetrievalTrainDataset(
    'C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
)


tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)

data_collator = ColBertCollator(
    tokenizer,
    query_max_length=64,
    document_max_length=128,
    positive_key='positive',
    negative_key='negative',
)
model = ColBERT.from_pretrained(
    model_name_or_path,
    colbert_dim=1024,
    loss_fn=ColbertLoss(use_inbatch_negative=False),
)

optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_args = TrainingArguments(
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    output_dir = './checkpoints',
    remove_unused_columns=False,
    gradient_accumulation_steps=8,
    logging_steps=100,

)
trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

model.save_pretrained(output_dir)

训练过程中会加载BAAI/bge-m3模型权重
请添加图片描述
损失函数下降
请添加图片描述

{'loss': 7.4858, 'grad_norm': 30.484981536865234, 'learning_rate': 4.076305220883534e-06, 'epoch': 0.6024096385542169}
{'loss': 1.18, 'grad_norm': 28.68316650390625, 'learning_rate': 3.072289156626506e-06, 'epoch': 1.2048192771084336}
{'loss': 1.1399, 'grad_norm': 14.203865051269531, 'learning_rate': 2.068273092369478e-06, 'epoch': 1.8072289156626506}
{'loss': 1.1261, 'grad_norm': 24.30337905883789, 'learning_rate': 1.0642570281124499e-06, 'epoch': 2.4096385542168672}
{'train_runtime': 471.8191, 'train_samples_per_second': 33.827, 'train_steps_per_second': 1.055, 'train_loss': 2.4146631079984, 'epoch': 3.0}

评测

在C-MTEB中进行评测。微调前保留10%的数据集作为测试集验证

from datasets import load_dataset

dataset = load_dataset("C-MTEB/T2Reranking", split="dev")
ds = dataset.train_test_split(test_size=0.1, seed=42)

ds_train = ds["train"].filter(
    lambda x: len(x["positive"]) > 0 and len(x["negative"]) > 0
)

ds_train.to_json("t2_ranking.jsonl", force_ascii=False)

微调前的指标:
请添加图片描述
微调后的指标:
请添加图片描述

{
  "dataset_revision": null,
  "mteb_dataset_name": "CustomReranking",
  "mteb_version": "1.1.1",
  "test": {
    "evaluation_time": 221.45,
    "map": 0.6950128151840831,
    "mrr": 0.8193114944390455
  }
}

番外:从语言模型直接训练ColBERT

之前的例子里是从BAAI/bge-m3继续微调,这里再跑一个从hfl/chinese-roberta-wwm-ext训练一个ColBERT模型

  • 注意,从头跑需要设置更大的学习率与更多的epochs
MODEL_NAME='hfl/chinese-roberta-wwm-ext'
TRAIN_DATA="/root/kaggle101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kaggle101/src/open-retrievals/t2/ft_out"

cd /root/open-retrievals/src

torchrun --nproc_per_node 1 \
  --module retrievals.pipelines.rerank \
  --output_dir $OUTPUT_DIR \
  --overwrite_output_dir \
  --model_name_or_path $MODEL_NAME \
  --tokenizer_name $MODEL_NAME \
  --model_type colbert \
  --do_train \
  --data_name_or_path $TRAIN_DATA \
  --positive_key positive \
  --negative_key negative \
  --learning_rate 5e-5 \
  --bf16 \
  --num_train_epochs 5 \
  --per_device_train_batch_size 32 \
  --dataloader_drop_last True \
  --query_max_length 128 \
  --max_length 256 \
  --train_group_size 4 \
  --unfold_each_positive false \
  --save_total_limit 1 \
  --logging_steps 100 \
  --use_inbatch_negative False

微调后指标

{
  "dataset_revision": null,
  "mteb_dataset_name": "CustomReranking",
  "mteb_version": "1.1.1",
  "test": {
    "evaluation_time": 75.38,
    "map": 0.6865308507184888,
    "mrr": 0.8039965986394558
  }
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2128999.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5G毫米波阵列天线仿真——CDF计算(手动AC远场)

之前写过两个关于阵列天线获取CDF的方法,一个通过Realized Gain,一个通过Power Flow, 三个案例中都是3D中直接波束扫描,并没有展示场路结合的情况。这期我们用Power Flow的方法,手动合并AC任务的波束计算CDF。 还是用…

[Power save]wifi省电模式

管理帧 beacon DTIM AP的beacon中携带TIM(Traffic indication Map)字段,里面包含DTIM Count,DTIM Period,Bitmap Control和Part Virt Bmap字段 DTIM Period:AP缓存数据的能力,处于PS状态下的…

B2B销售:成功所需的工具

谈到B2B销售,拥有合适的工具可以带来巨大的差异。合适的工具可以提高效率和效能,简化操作,节省成本并提供竞争优势。 探索优化B2B销售栈的重要组成部分时,我们可以发现,正确的技术能让您的业务在未来取得成功。 电子…

前端 + 接口请求实现 vue 动态路由

前端 接口请求实现 vue 动态路由 在 Vue 应用中,通过前端结合后端接口请求来实现动态路由是一种常见且有效的权限控制方案。这种方法允许前端根据用户的角色和权限,动态生成和加载路由,而不是在应用启动时就固定所有的路由配置。 实现原理…

C语言-综合案例:通讯录

传送门:C语言-第九章-加餐:文件位置指示器与二进制读写 目录 第一节:思路整理 第二节:代码编写 2-1.通讯录初始化 2-2.功能选择 2-3.增加 和 扩容 2-4.查看 2-5.查找 2-6.删除 2-7.修改 2-8.退出 第三节:测试 下期…

基于SpringBoot+Vue的超市外卖管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

OceanBase 企业版OMS 4.2.3的使用

OceanBase 企业版OMS 4.2.3的使用 一、界面说明 1.1 概览 1.2 数据迁移 1.3 数据同步 1.4 数据源管理 1.5 运维监控 1.6 系统管理 二、功能说明 注意: 在数据迁移与数据同步的功能中,如果涉及到增量操作: 1.需要使用sys租户的用…

828华为云征文 | 华为云Flexusx与Docker技术融合,打造个性化WizNote服务

前言 华为云Flexusx携手Docker技术,创新融合打造高效个性化WizNote服务。Flexusx的柔性算力与Docker的容器化优势相结合,实现资源灵活配置与性能优化,助力企业轻松构建稳定、高效的云端笔记平台。828华为云企业上云节特惠来袭,Fle…

【无标题】Efinity 0基础进行流水灯项目撰写(FPGA)

文章目录 前言一、定义概念 缩写1. 二、性质1.2. 三、使用步骤编译常见错误1. 没加分号2. end 写多了 编译成功的标志总结参考文献 前言 数电课设 使用 FPGAIDE 使用 Efinity 一、定义概念 缩写 1. 二、性质 1. 2. 三、使用步骤 python代码块matlab代码块c代码块编译…

你真的了解Canvas吗--解密二【ZRender篇】

书接上文你真的了解Canvas吗--解密一【ZRender篇】 目录 入口 挖掘 继承 _init step-1:取所有key值 ​​​​​​​ step-2:定义构造函数BezierCurveShape …

PMP--一模--解题--1-10

文章目录 14.敏捷--方法--替代敏捷方法--看板1、 [单选] 根据项目的特点,项目经理建议选择一种敏捷方法,该方法限制团队成员在任何给定时间执行的任务数。此方法还允许团队提高工作过程中问题和瓶颈的可见性。项目经理建议采用以下哪种方法? …

金属铬厂商分析:前十强厂商占有大约64.0%的市场份额

金属铬是一种灰色、有光泽、硬而脆的过渡金属。铬是不锈钢的主要添加剂,可增加耐腐蚀性。 据QYResearch调研团队最新报告“全球金属铬市场报告2024-2030”显示,预计2030年全球金属铬市场规模将达到11.8亿美元,未来几年年复合增长率CAGR为6.5%…

【数据结构与算法】受限线性表 --- 栈

【数据结构与算法】受限线性表 — 栈 文章目录 【数据结构与算法】受限线性表 --- 栈前言一、栈的基本概念二、栈的顺序存储三、栈的分文件编写四、栈的链式存储五、栈的应用案例-就近分配六、 中缀表达式转后缀表达式以及基于后缀表达式运算总结 前言 本篇文章就栈的基本概念…

医院HISPACS存储备份 要求全周期保存

内蒙古赤峰医院 HIS PACS数据备份 ,质量好可以满足15-30年存储的空间需求,用着放心

NIDS——suricata(三)

一、监控ICMP流量 1、ICMP流量特征 四大特征分别为:消息类型(Type)、代码(Code)、校验和(Checksum)、数据字段(Data Field)。这里我们使用 type消息类型。 ICMP 消息的类…

Cookie、Web Storage介绍

概述 Cookie、LocalStorage、SessionStorage、IndexDB这些作为浏览器的存储入口,也是经典的八股文了,本文再次冷饭热吃来介绍这些API,主要是因为在其他文章中看到了一些个人感觉有用的小知识点,所以在这记录一下,以便…

招加盟商视频怎么拍效果好?

一定没有人比我更适合分享这篇文章了,我自己曾经就是做宣传片的,而且还有一家酸奶品牌做全国招商。 我来分享下加盟招商视频怎么拍效果好? 一、定脚步 说个题外话,以前我做传媒公司的时候,很多客户找我做宣传片&…

Ubuntu系统Docker部署数据库管理工具DbGate并实现远程查询数据

文章目录 前言1. 安装Docker2. 使用Docker拉取DbGate镜像3. 创建并启动DbGate容器4. 本地连接测试5. 公网远程访问本地DbGate容器5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定公网地址远程访问 前言 本文主要介绍如何在Linux Ubuntu系统中使用Docker部署DbGate数…

Windows安装字体的几种方式

1.选中你要安装的字体,然后右键,点击“为所有用户安装”,这样字体就会安装在C:\Windows\Fonts;如果你点了“安装”,那么就会安装在C:\Users\你电脑用户的名字\AppData\Local\Microsoft\Windows\Fonts。 像电脑自带的字…