- 动手学习RAG: 向量模型
- 动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习
- 动手学习RAG:rerank模型微调实践 bge-reranker-v2-m3
- 动手学习RAG:迟交互模型colbert微调实践 bge-m3
- 动手学习RAG: 大模型向量模型微调 intfloat/e5-mistral-7b-instruct
- 动手学习RAG:大模型重排模型 bge-reranker-v2-gemma微调
1. 环境准备
pip install transformers
pip install open-retrievals
- 注意安装时是
pip install open-retrievals
,但调用时只需要import retrievals
- 欢迎关注最新的更新 https://github.com/LongxingTan/open-retrievals
2. 使用大模型做重排
from retrievals import LLMRanker
model_name = 'BAAI/bge-reranker-v2-gemma'
model = LLMRanker.from_pretrained(
model_name,
causal_lm=True,
use_fp16=True,
)
score = model.compute_score(['query', 'passage'])
print(score)
scores = model.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores)
3. 微调
MODEL_NAME='BAAI/bge-reranker-v2-gemma'
TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"
torchrun --nproc_per_node 1 \
-m retrievals.pipelines.rerank \
--output_dir ${OUTPUT_DIR} \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--model_type llm \
--causal_lm True \
--use_lora True \
--data_name_or_path $TRAIN_DATA \
--task_prompt "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." \
--query_instruction "A: " \
--document_instruction 'B: ' \
--positive_key positive \
--negative_key negative \
--learning_rate 2e-4 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--max_len 256 \
--train_group_size 4 \
--logging_steps 10 \
--save_steps 20000 \
--save_total_limit 1 \
--bf16
4. 评测
在C-MTEB中进行评测。微调前保留10%的数据集作为测试集验证
微调前的指标:
微调后的指标:
{
"dataset_revision": null,
"mteb_dataset_name": "CustomReranking",
"mteb_version": "1.1.1",
"test": {
"evaluation_time": 77.35,
"map": 0.7057362287508586,
"mrr": 0.8166538440773136
}
}
微调后map从0.637上升至0.706,mrr从0.734上升至0.816