动手学习RAG: 大模型向量模型微调 intfloat/e5-mistral-7b-instruct

news2024/9/19 10:45:06
  • 动手学习RAG: 向量模型
  • 动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习
  • 动手学习RAG:rerank模型微调实践 bge-reranker-v2-m3
  • 动手学习RAG:迟交互模型colbert微调实践 bge-m3
  • 动手学习RAG: 大模型向量模型微调 intfloat/e5-mistral-7b-instruct
  • 动手学习RAG:大模型重排模型 bge-reranker-v2-gemma微调

在这里插入图片描述

1. 环境准备

pip install transformers
pip install open-retrievals
  • 注意安装时是pip install open-retrievals,但调用时只需要import retrievals
  • 欢迎关注最新的更新 https://github.com/LongxingTan/open-retrievals

2. 使用Mistral作为向量模型

这里直接将query_instruction和document_instruction写进了text里

from retrievals import AutoModelForEmbedding

model_name = 'intfloat/e5-mistral-7b-instruct'
model = AutoModelForEmbedding.from_pretrained(
            model_name,
            pooling_method='last',
            use_fp16=True,
        )

texts = [
'Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: how much protein should a female eat', 
'Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: summit define', 
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", 
'Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.'
]

embeds = model.encode(texts, normalize_embeddings=True)
print(embeds)

scores = (embeds[:2] @ embeds[2:].T) * 100
print(scores.tolist())

请添加图片描述

  • 也可以把prompt写在函数中
from retrievals import AutoModelForEmbedding

model_name = 'intfloat/e5-mistral-7b-instruct'
model = AutoModelForEmbedding.from_pretrained(
            model_name,
            pooling_method='last',
            use_fp16=True,
            query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ',
            document_instruction='',
        )


query_texts = ['how much protein should a female eat', 'summit define']
document_texts = ["As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", 'Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.']

query_embeds = model.encode(query_texts, normalize_embeddings=True, is_query=True)
print(query_embeds)

doc_embeds = model.encode(document_texts, normalize_embeddings=True, is_query=False)
print(doc_embeds)

scores = (query_embeds @ doc_embeds.T) * 100
print(scores.tolist())

3. LoRA微调E5-mistral向量模型

数据还是按照惯例采用t2-ranking

MODEL_NAME="intfloat/e5-mistral-7b-instruct"
TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"


torchrun --nproc_per_node 1 \
  -m retrievals.pipelines.embed \
  --output_dir $OUTPUT_DIR \
  --overwrite_output_dir \
  --model_name_or_path $MODEL_NAME \
  --pooling_method last \
  --do_train \
  --data_name_or_path $TRAIN_DATA \
  --positive_key positive \
  --negative_key negative \
  --use_lora True \
  --query_instruction 'Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ' \
  --document_instruction '' \
  --learning_rate 1e-5 \
  --bf16 \
  --num_train_epochs 3 \
  --per_device_train_batch_size 2 \
  --gradient_accumulation_steps 16 \
  --dataloader_drop_last True \
  --query_max_length 64 \
  --document_max_length 256 \
  --train_group_size 2 \
  --logging_strategy steps \
  --logging_steps 100 \
  --temperature 0.02 \
  --use_inbatch_negative false \
  --save_total_limit 1

请添加图片描述

由于trainer中可以使用多种方式使用多GPU,因此retrievals也都支持。

# torchrun --nnodes 1 --nproc-per-node 4
# deepspeed --include localhost:0,1,2,3
# CUDA_VISIBLE_DEVICES=1,2,3 python
# accelerate launch --config_file conf_ds.yaml \

accelerate launch \
    --config_file conf_llm.yaml \
    llm_finetune_for_embed.py \
    --model_name_or_path mistralai/Mistral-7B-v0.1 \
    --train_data  \
    --output_dir output \

4. 评测

微调前性能 c-mteb t2-ranking score
请添加图片描述

微调后性能

请添加图片描述
微调后,map从0.651上升到0.699,mrr从0.758上升到0.808

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2146115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript高级——内存溢出和内存泄漏

1、闭包的缺点与解决方法 (1)缺点:函数执行完后,函数内的局部变量没有释放,占用内存时间会变长。 容易造成内存泄漏。 (2)解决:能不用闭包就不用。 及时释放。 2、内存溢出 ① 一…

Linux进阶 查看系统进程

操作系统中进程的生命周期是: 创建进程,(服务启动或软件的启动)进行运行状态进程等待状态进行唤醒进程结束一般主要关注是进行中间的三种状态,三种状态之间装换关系如下: 1、就绪状态:表示进程已经做好了运行的准备状态,只要获得内存空间,就可以立即执行。 2、阻塞状态:…

Maya---机械模型制作

材质效果(4)_哔哩哔哩_bilibili 三角面 四边面 多边面 *游戏允许出现三角面和四边面 游戏中一般是低模(几千个面) 动漫及影视是高模 机械由单独零件组合而成,需独立制作 低面模型到高面模型 卡线是为了将模型保…

电脑怎么设置开机密码?3个方法迅速搞定!

电脑已经成为了我们日常办公与学习的重要工具,其中保存有很多重要且需保密的资料,为电脑设置开机密码则是保护资料安全的第一步。那么,电脑怎么设置开机密码呢?今天,小编就为大家介绍3个设置电脑开机密码的方法&#x…

深度学习对抗海洋赤潮危机!浙大GIS实验室提出ChloroFormer模型,可提前预警海洋藻类爆发

2014 年 8 月,美国俄亥俄州托莱多市超 50 万名居民突然收到市政府的一则紧急通知——不得擅自饮用自来水! 水是人类生存的基本供给,此通告关系重大,发出后也引起了不小的恐慌。究其原因,其实是美国伊利湖爆发了大规模…

油烟机制造5G智能工厂物联数字孪生平台,推进制造业数字化转型

油烟机制造5G智能工厂物联数字孪生平台,是智能制造与信息技术的深度融合产物。数字孪生工业互联平台通过部署在工厂各个环节的传感器和设备,实时采集、分析和处理生产过程中的海量数据,构建出高度逼真的数字孪生模型。这一模型不仅能够真实反…

基于树莓派ubuntu20.04的ros-noetic小车

目录 一、小车的架构 1.1 总体的概述 1.2 驱动系统 1.3 控制系统 二、驱动系统开发 2.1 PC端Ubuntu20.04安装 2.2 树莓派Ubuntu20.04安装 2.3 PC端虚拟机设置静态IP 2.4 树莓派设置静态IP 2.5 树莓派启动ssh进行远程开发 2.5 arduino ide 开发环境搭建 2.5.1 PC…

深入探索Docker核心原理:从Libcontainer到runC的演化与实现

随着容器技术的发展,Docker从早期的Libcontainer逐步演化到runC,推动了容器运行时的标准化进程。Libcontainer是Docker容器的核心管理工具,而runC则在此基础上发展成为符合OCI(Open Container Initiative)标准的轻量级…

Vue常用PC端和移动端组件库、Element UI的基本使用(完整引入和按需引入)

目录 1. Vue常用PC端和移动端组件库2. Element UI的基本使用2.1 完整引入2.2 按需引入 1. Vue常用PC端和移动端组件库 提供常用的布局、按钮、输入框、下拉框等UI布局,以组件的形式提供。使用这些组件,结构、样式、交互就都有了 移动端常用UI组件库 Van…

windows10 修改默认输入法

右键桌面,选择个性化 左侧搜索 语言 选择编辑语言和键盘选项 点击键盘 默认替代输入法 选择你想要设置的。重启电脑。如下图

C语言18--头文件

头文件的作用 通常,一个常规的C语言程序会包含多个源码文件(.c),当某些公共资源需要在各个源码文件中使用时,为了避免多次编写相同的代码,一般的做法是将这些大家都需要用到的公共资源放入头文件&#xff…

光学超表面在成像和传感中的应用

光学超表面已成为解决笨重光学元件所带来的限制,极具前景的解决方案。与传统的折射传播技术相比,它们提供了一种紧凑、高效的光操纵方法,可对相位、偏振和发射进行先进的控制。本文概述了光学超表面、它们在成像和传感技术中的各种应用以及这…

Broadcast:Android中实现组件与进程间通信

目录 一,Broadcast和BroadcastReceiver 1,简介 2,广播使用 二,静态注册和动态注册 三,无序广播和有序广播 1,有序广播的使用 2,有序广播的截断 3,有序广播的信息传递 四&am…

力扣(LeetCode)每日一题 1184. 公交站间的距离

题目链接https://leetcode.cn/problems/distance-between-bus-stops/description/?envTypedaily-question&envId2024-09-16 环形公交路线上有 n 个站,按次序从 0 到 n - 1 进行编号。我们已知每一对相邻公交站之间的距离,distance[i] 表示编号为 i …

Python燃烧废气排放推断算法模型

🎯要点 宏观能耗场景模型参数化输入数据,分析可视化输出结果,使用场景时间序列数据模型及定量和定性指标使用线图和箱线图、饼图、散点图、堆积条形图、桑基图等可视化模型输出结果根据气体排放过程得出其时间序列关系,使用推断模…

nginx基础篇(一)

文章目录 学习链接概图一、Nginx简介1.1 背景介绍名词解释 1.2 常见服务器对比IISTomcatApacheLighttpd其他的服务器 1.3 Nginx的优点(1)速度更快、并发更高(2)配置简单,扩展性强(3)高可靠性(4)热部署(5)成本低、BSD许可证 1.4 Nginx的功能特性及常用功能基本HTTP服…

GlusterFS 分布式文件系统

一、GlusterFS 概述 1.1 什么是GlusterFS GlusterFS 是一个开源的分布式文件系统,它可以将多个存储服务器结合在一起,创建一个大的存储池,供客户端使用。它不需要单独的元数据服务器,这样可以提高系统的性能和可靠性。由于没有…

python毕业设计基于django+vue医院社区医疗挂号预约综合管理系统7918h-pycharm-flask

目录 技术栈和环境说明预期达到的目标具体实现截图系统设计Python技术介绍django框架介绍flask框架介绍解决的思路性能/安全/负载方面可行性分析论证python-flask核心代码部分展示python-django核心代码部分展示操作可行性技术路线感恩大学老师和同学详细视频演示源码获取 技术…

【Finetune】(二)、transformers之Prompt-Tuning微调

文章目录 0、prompt-tuning基本原理1、实战1.1、导包1.2、加载数据1.3、数据预处理1.4、创建模型1.5、Prompt Tuning*1.5.1、配置文件1.5.2、创建模型 1.6、配置训练参数1.7、创建训练器1.8、模型训练1.9、推理:加载预训练好的模型 0、prompt-tuning基本原理 prompt…

【机器学习】任务五:葡萄酒和鸢尾花数据集分类任务

目录 1.实验基础知识 1.1 集成学习 (1)随机森林 (2)梯度提升决策树(GBDT) (3)XGBoost (4)LightGBM 1.2 参数优化 (1)网格搜索…