动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习

news2025/1/16 16:42:57
  • 动手学习RAG: 向量模型
  • 动手学习RAG:迟交互模型colbert微调实践 bge-m3

1. 环境准备

pip install open-retrievals

2. 使用M3E模型

from retrievals import AutoModelForEmbedding

embedder = AutoModelForEmbedding.from_pretrained('moka-ai/m3e-base', pooling_method='mean')
embedder

请添加图片描述

sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem',
    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one'
]

embeddings = embedder.encode(sentences)

for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

请添加图片描述

3. deepspeed 微调m3e模型

数据仍然采用之前介绍的t2-ranking数据集

  • deepspeed配置保存为 ds_zero2_no_offload.json. 不过虽然设置了zero2,这里我只用了一张卡.
    • 关于deepspeed的分布式设置,可参考Tranformer分布式特辑
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 100,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1e-10
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 1e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 1e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

这里稍微修改了open-retrievals这里的代码,主要是修改了导入为包的导入,而不是相对引用。保存文件为embed.py

"""Embedding fine tune pipeline"""

import logging
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

from retrievals import (
    EncodeCollator,
    EncodeDataset,
    PairCollator,
    RetrievalTrainDataset,
    TripletCollator,
)
from retrievals.losses import AutoLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.models.embedding_auto import AutoModelForEmbedding
from retrievals.trainer import RetrievalTrainer

# os.environ["WANDB_LOG_MODEL"] = "false"
logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})
    lora_path: Optional[str] = field(default=None, metadata={'help': "Lora adapter save path"})


@dataclass
class DataArguments:
    data_name_or_path: str = field(default=None, metadata={"help": "Path to train data"})
    train_group_size: int = field(default=2)
    unfold_each_positive: bool = field(default=False)
    query_max_length: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    document_max_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
    document_instruction: str = field(default=None, metadata={"help": "instruction for document"})
    query_key: str = field(default=None)
    positive_key: str = field(default='positive')
    negative_key: str = field(default='negative')
    is_query: bool = field(default=False)
    encoding_save_file: str = field(default='embed.pkl')

    def __post_init__(self):
        # self.data_name_or_path = 'json'
        self.dataset_split = 'train'
        self.dataset_language = 'default'

        if self.data_name_or_path is not None:
            if not os.path.isfile(self.data_name_or_path) and not os.path.isdir(self.data_name_or_path):
                info = self.data_name_or_path.split('/')
                self.dataset_split = info[-1] if len(info) == 3 else 'train'
                self.data_name_or_path = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
                self.dataset_language = 'default'
                if ':' in self.data_name_or_path:
                    self.data_name_or_path, self.dataset_language = self.data_name_or_path.split(':')


@dataclass
class RetrieverTrainingArguments(TrainingArguments):
    train_type: str = field(default='pairwise', metadata={'help': "train type of point, pair, or list"})
    negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    temperature: Optional[float] = field(default=0.02)
    fix_position_embedding: bool = field(
        default=False, metadata={"help": "Freeze the parameters of position embeddings"}
    )
    pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
    normalized: bool = field(default=True)
    loss_fn: str = field(default='infonce')
    use_inbatch_negative: bool = field(default=True, metadata={"help": "use documents in the same batch as negatives"})
    remove_unused_columns: bool = field(default=False)
    use_lora: bool = field(default=False)
    use_bnb_config: bool = field(default=False)
    do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
    report_to: Optional[List[str]] = field(
        default="none", metadata={"help": "The list of integrations to report the results and logs to."}
    )


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, RetrieverTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("Model parameters %s", model_args)
    logger.info("Data parameters %s", data_args)

    set_seed(training_args.seed)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    if training_args.use_bnb_config:
        from transformers import BitsAndBytesConfig

        logger.info('Use quantization bnb config')
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    else:
        quantization_config = None

    if training_args.do_train:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
        )

        loss_fn = AutoLoss(
            loss_name=training_args.loss_fn,
            loss_kwargs={
                'use_inbatch_negative': training_args.use_inbatch_negative,
                'temperature': training_args.temperature,
            },
        )

        model = model.set_train_type(
            "pairwise",
            loss_fn=loss_fn,
        )

        train_dataset = RetrievalTrainDataset(
            args=data_args,
            tokenizer=tokenizer,
            positive_key=data_args.positive_key,
            negative_key=data_args.negative_key,
        )
        logger.info(f"Total training examples: {len(train_dataset)}")

        trainer = RetrievalTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=TripletCollator(
                tokenizer,
                query_max_length=data_args.query_max_length,
                document_max_length=data_args.document_max_length,
                positive_key=data_args.positive_key,
                negative_key=data_args.negative_key,
            ),
        )

        Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

        trainer.train()
        # trainer.save_model(training_args.output_dir)
        model.save_pretrained(training_args.output_dir)

        if trainer.is_world_process_zero():
            tokenizer.save_pretrained(training_args.output_dir)

    if training_args.do_encode:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
            lora_path=model_args.lora_path,
        )

        max_length = data_args.query_max_length if data_args.is_query else data_args.document_max_length
        logger.info(f'Encoding will be saved in {training_args.output_dir}')

        encode_dataset = EncodeDataset(args=data_args, tokenizer=tokenizer, max_length=max_length, text_key='text')
        logger.info(f"Number of train samples: {len(encode_dataset)}, max_length: {max_length}")

        encode_loader = DataLoader(
            encode_dataset,
            batch_size=training_args.per_device_eval_batch_size,
            collate_fn=EncodeCollator(tokenizer, max_length=max_length, padding='max_length'),
            shuffle=False,
            drop_last=False,
            num_workers=training_args.dataloader_num_workers,
        )

        embeddings = model.encode(encode_loader, show_progress_bar=True, convert_to_numpy=True)
        lookup_indices = list(range(len(encode_dataset)))

        with open(os.path.join(training_args.output_dir, data_args.encoding_save_file), 'wb') as f:
            pickle.dump((embeddings, lookup_indices), f)


if __name__ == "__main__":
    main()

  • 最终调用文件 shell run.sh
MODEL_NAME="moka-ai/m3e-base"

TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"


# loss_fn: infonce, simcse

deepspeed -m --include localhost:0 embed.py \
  --deepspeed ds_zero2_no_offload.json \
  --output_dir $OUTPUT_DIR \
  --overwrite_output_dir \
  --model_name_or_path $MODEL_NAME \
  --do_train \
  --data_name_or_path $TRAIN_DATA \
  --positive_key positive \
  --negative_key negative \
  --pooling_method mean \
  --loss_fn infonce \
  --use_lora False \
  --query_instruction "" \
  --document_instruction "" \
  --learning_rate 3e-5 \
  --fp16 \
  --num_train_epochs 5 \
  --per_device_train_batch_size 32 \
  --dataloader_drop_last True \
  --query_max_length 64 \
  --document_max_length 256 \
  --train_group_size 4 \
  --logging_steps 100 \
  --temperature 0.02 \
  --save_total_limit 1 \
  --use_inbatch_negative false

请添加图片描述

4. 测试

微调前性能 c-mteb t2-ranking score

请添加图片描述

微调后性能

请添加图片描述

采用infoNCE损失函数,没有加in-batch negative,而关注的是困难负样本,经过微调map从0.654提升至0.692,mrr从0.754提升至0.805

对比一下非deepspeed而是直接torchrun的微调

  • map略低,mrr略高。猜测是因为deepspeed中设置的一些auto会和直接跑并不完全一样
    请添加图片描述

欢迎关注最新的更新https://github.com/LongxingTan/open-retrievals

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Windows】获取进程缓解策略设置情况

目录 一、前言 二、主要概念 三、实现步骤 四、总结 原文出处链接:[https://blog.csdn.net/qq_59075481/article/details/142234952] 一、前言 在现代操作系统中,进程缓解策略(Process Mitigation Policy)提供了一种防御机制…

谷歌创始人谢尔盖·布林回归一线:承认错失先机,每天都在写代码

在科技界,有些名字永远闪耀着创新的光芒,谢尔盖布林就是其中之一。作为谷歌的联合创始人,布林在经历了一段时间的隐退后,宣布重返一线,投身于人工智能(AI)技术的研发。本周,在洛杉矶…

F12抓包10:UI自动化 - Elements(元素)定位页面元素

​课程大纲 1、前端基础 1.1 元素 元素是构成HTML文档的基本组成部分之一,定义了文档的结构和内容,比如段落、标题、链接等。 元素大致分为3种:基本结构、自闭合元素(self-closing element)、嵌套元素。 1、基本结构&…

Docker 常用命令(未完待续...)

Docker 常用命令(未完待续...) 命令的完整名称和别名帮助登录和搜索命令 (Login and Search Commands)其他管理命令 (Other Management Commands)镜像命令 (Image Commands)容器命令 (Container Commands)docker run 从镜像创建并运行一个新的容器docker…

Midjourney中秋特典-12张图附魔咒

第一张 魔咒 A Mid-Autumn Festival poster, a round bright moon, a Chinese-style pavilion with a scene of a reunion from Dream of the Red Chamber, a new Chinese style --ar 3:4 --v 6.1第二张 魔咒 The bright full moon hung in the night sky,clear in outline a…

【疑难杂症2024-005】docker-compose中设置容器的ip为固定ip后,服务无法启动

本文由Markdown语法编辑器编辑完成。 1.背景: 我们的产品是通过docker image的方式发布,并且编排在docker-compose.yml中发布。在同一个docker-compose.yml中的服务,相互之间,可以通过对方的服务名和端口,来直接访问…

动态规划算法---04.斐波那契数列模型_解码方法_C++

题目链接:91. 解码方法 - 力扣(LeetCode)https://leetcode.cn/problems/decode-ways/description/ 一、题目解析 题目: 题目大意:从题目中我们可以知道,解码就是在字符串s中由‘1’到‘26’的字符可以转化…

echarts饼图让部分数据显示在图外,部分显示在图内

echarts饼图让部分数据显示在图外,部分显示在图内 var dataList [{ value: 10, name: 商户 },{ value: 20, name: 充电桩 },{ value: 30, name: 业主 } ] option {series: [{type: pie,radius: 70%,data: dataList,labelLine: {show: true,position: outside,len…

JavaSE:4、流程控制

1、代码块与作用域 变量的使用范围,仅限于其定义时所处的代码块,也就是他的作用域。 目前所说的变量均为局部变量 public class Main {public static void main(String [] argv){int a10;{int b10;System.out.println(a);System.out.println(b);}Sys…

计算机网络八股总结

这里写目录标题 网络模型划分(五层和七层)及每一层的功能五层网络模型七层网络模型(OSI模型) 三次握手和四次挥手具体过程及原因三次握手四次挥手 TCP/IP协议组成UDP协议与TCP/IP协议的区别Http协议相关知识网络地址,子…

前端——标签二(超链接)

标签二 超链接标签:a 超链接,实现页面间的跳转和数据传输 a标签的属性 href:跳转路径(url)必须具备,表示点击后会跳转到哪个页面 target:页面打开方式。默认是 _self 如果是 _blank则用新的…

[基于 Vue CLI 5 + Vue 3 + Ant Design Vue 4 搭建项目] 02 配置 nodejs 淘宝镜像仓库

文章目录 为什么要配置淘宝镜像仓库呢如何查看镜像仓库如何配置镜像仓库 为什么要配置淘宝镜像仓库呢 主要是因为默认的镜像仓库是国外的,当我们使用 npm 安装依赖的时候会很慢或者失败,我们配置国内的镜像仓库这样就可以加速我们安装镜像的过程&#x…

突破瓶颈:Java并发编程的最佳实践与技巧,你了解了吗?

文章目录 1 什么是 Executor 和 ExecutorService ?这两个接口有什么区别?2 java.util.concurrent 标准库中 ExecutorService 的可用实现是什么 ?3 什么是 Java 内存模型( JMM )?描述下其目的和基本思想4 JM…

Dubbo精要

1、为什么需要 Dubbo? 分布式系统中的服务调用和协调问题:在分布式系统中,服务之间的相互依赖会导致复杂的通信和协调问题。Dubbo提供了高效的服务调用和自动注册、发现等功能,使得构建分布式应用程序更加容易。服务治理和服务调…

6 递归——509. 斐波那契数 ★

6 递归 509. 斐波那契数 斐波那契数列从0和1开始,后面的每一项数字都是前面两项数字的和。F(0) = 0,F(1) = 1,当n > 1时,F(n) = F(n − 1) + F(n − 2)。给定n,请计算 F(n)。 示例 1: 输入:n = 2 输出:1 解释:F(2) = F(1) + F(0) = 1 + 0 = 1 示例 2: 输入:n …

[000-01-008].第05节:OpenFeign特性-重试机制

我的后端学习大纲 SpringCloud学习大纲 1.1.重试机制的默认值: 1.重试机制默认是关闭的,给了默认值 1.2.测试重试机制的默认值: 1.3.开启Retryer功能: 1.修改配置文件YML的配置: 2.新增配置类: packa…

CentOs 入门必备基础知识详细讲解

​ 大家好,我是程序员小羊! 前言: CentOS(Community ENTerprise Operating System)是一个基于 Red Hat Enterprise Linux (RHEL) 源代码的开源操作系统,主要用于服务器和企业环境。下面是一个详细的入门知识…

JDBC数据库连接技术

JDBC数据库连接技术 基础篇 一、引言 1.1 数据的存储 我们在开发Java程序时,数据都是存储在内存中,属于临时存储,当程序停止或重启时,内存中的数据就丢失了!我们为了解决数据的长期存储问题,有如下解决方…

【Prompt Enhancer】如何优化prompt的内容

背景 在使用LLM的时候,提示词的好坏对模型的输出质量影响很大,提示词又是一个复杂工程,要写出优秀的提示词,需要丰富的经验。正因如此,各类Agent平台都会有自己的提示词增强功能,帮助用户编写提示词。 最…

Linux驱动.之platform平台总线驱动框架(二),正点原子

第五十四章 platform设备驱动实验 我们在前面几章编写的设备驱动都非常的简单,都是对IO进行最简单的读写操作。像I2C、SPI、LCD等这些复杂外设的驱动就不能这么去写了,Linux系统要考虑到驱动的可重用性,因此提出了驱动的分离与分层这样的软件…