Sentence-BERT实现文本匹配【回归目标函数】

news2025/1/11 2:47:40

引言

上篇文章我们通过Sentence-Bert提出的分类目标函数来训练句子嵌入模型,本文同样基于Sentence-Bert的架构,但改用回归目标函数。

架构

image-20210923000654664

如上图,计算两个句嵌入 u \pmb u u v \pmb v v​之间的余弦相似度,然后可以使用均方误差(mean-squared-error)作为目标函数。
L = ∣ ∣ y − cosine_sim ( u , v ) ∣ ∣ 2 \mathcal L = ||y - \text{cosine\_sim}(\pmb u,\pmb v)||_2 L=∣∣ycosine_sim(u,v)2
这里的 y y y是真实标签。

回归目标函数的预测不再是整数标签1或0了,而可以为数值。比如对于给定的句子对,可以计算相似度得分。此时推理流程与训练完全相同。

实现

实现采用类似Huggingface的形式,每个文件夹下面有一种模型。分为modelingargumentstrainer等不同的文件。不同的架构放置在不同的文件夹内。

modeling.py:

from dataclasses import dataclass

import torch
from torch import Tensor, nn

from transformers.file_utils import ModelOutput

from transformers import (
    AutoModel,
    AutoTokenizer,
)

import numpy as np
from tqdm.autonotebook import trange
from typing import Optional


@dataclass
class BiOutput(ModelOutput):
    loss: Optional[Tensor] = None
    scores: Optional[Tensor] = None


class SentenceBert(nn.Module):
    def __init__(
        self,
        model_name: str,
        trust_remote_code: bool = True,
        max_length: int = None,
        num_classes: int = 2,
        pooling_mode: str = "mean",
        normalize_embeddings: bool = False,
    ) -> None:
        super().__init__()
        self.model_name = model_name
        self.normalize_embeddings = normalize_embeddings

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=trust_remote_code
        )
        self.model = AutoModel.from_pretrained(
            model_name, trust_remote_code=trust_remote_code
        ).to(self.device)

        self.max_length = max_length
        self.pooling_mode = pooling_mode

        self.loss_fct = nn.MSELoss()

    def sentence_embedding(self, last_hidden_state, attention_mask):
        if self.pooling_mode == "mean":
            attention_mask = attention_mask.unsqueeze(-1).float()
            return torch.sum(last_hidden_state * attention_mask, dim=1) / torch.clamp(
                attention_mask.sum(1), min=1e-9
            )
        else:
            # cls
            return last_hidden_state[:, 0]

    def encode(
        self,
        sentences: str | list[str],
        batch_size: int = 64,
        convert_to_tensor: bool = True,
        show_progress_bar: bool = False,
    ):
        if isinstance(sentences, str):
            sentences = [sentences]

        all_embeddings = []

        for start_index in trange(
            0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
        ):
            batch = sentences[start_index : start_index + batch_size]

            features = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                return_tensors="pt",
                return_attention_mask=True,
                max_length=self.max_length,
            ).to(self.device)

            out_features = self.model(**features, return_dict=True)
            embeddings = self.sentence_embedding(
                out_features.last_hidden_state, features["attention_mask"]
            )
            if not self.training:
                embeddings = embeddings.detach()

            if self.normalize_embeddings:
                embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

            if not convert_to_tensor:
                embeddings = embeddings.cpu()

            all_embeddings.extend(embeddings)

        if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
        else:
            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

        return all_embeddings

    def compute_loss(self, scores, labels):
        labels = torch.tensor(labels).float().to(self.device)
        return self.loss_fct(scores, labels.view(-1))

    def forward(self, source, target, labels) -> BiOutput:
        """
        Args:
            source :
            target :
        """
        source_embed = self.encode(source)
        target_embed = self.encode(target)

        scores = torch.cosine_similarity(source_embed, target_embed)

        loss = self.compute_loss(scores, labels)
        return BiOutput(loss, scores)

    def save_pretrained(self, output_dir: str):
        state_dict = self.model.state_dict()
        state_dict = type(state_dict)(
            {k: v.clone().cpu().contiguous() for k, v in state_dict.items()}
        )
        self.model.save_pretrained(output_dir, state_dict=state_dict)

整个模型的实现放到modeling.py文件中。

arguments.py:

from dataclasses import dataclass, field
from typing import Optional

import os


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model"
        }
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )


@dataclass
class DataArguments:
    train_data_path: str = field(
        default=None, metadata={"help": "Path to train corpus"}
    )
    eval_data_path: str = field(default=None, metadata={"help": "Path to eval corpus"})
    max_length: int = field(
        default=512,
        metadata={
            "help": "The maximum total input sequence length after tokenization for input text."
        },
    )

    def __post_init__(self):
        if not os.path.exists(self.train_data_path):
            raise FileNotFoundError(
                f"cannot find file: {self.train_data_path}, please set a true path"
            )
        
        if not os.path.exists(self.eval_data_path):
            raise FileNotFoundError(
                f"cannot find file: {self.eval_data_path}, please set a true path"
            )

定义了模型和数据相关参数。

dataset.py:

from torch.utils.data import Dataset
from datasets import Dataset as dt
import pandas as pd

from utils import build_dataframe_from_csv


class PairDataset(Dataset):
    def __init__(self, data_path: str) -> None:

        df = build_dataframe_from_csv(data_path)
        self.dataset = dt.from_pandas(df, split="train")

        self.total_len = len(self.dataset)

    def __len__(self):
        return self.total_len

    def __getitem__(self, index) -> dict[str, str]:
        query1 = self.dataset[index]["query1"]
        query2 = self.dataset[index]["query2"]
        label = self.dataset[index]["label"]
        return {"query1": query1, "query2": query2, "label": label}


class PairCollator:
    def __call__(self, features) -> dict[str, list[str]]:
        queries1 = []
        queries2 = []
        labels = []

        for feature in features:
            queries1.append(feature["query1"])
            queries2.append(feature["query2"])
            labels.append(feature["label"])

        return {"source": queries1, "target": queries2, "labels": labels}

数据集类考虑了LCQMC数据集的格式,即成对的语句和一个数值标签。类似:

Hello.	Hi.	1
Nice to see you.	Nice	0

trainer.py:

import torch
from transformers.trainer import Trainer

from typing import Optional
import os
import logging

from modeling import SentenceBert

TRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)


class BiTrainer(Trainer):

    def compute_loss(self, model: SentenceBert, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        self.model.save_pretrained(output_dir)

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

继承🤗 Transformers的Trainer类,重写了compute_loss_save方法。

这样我们就可以利用🤗 Transformers来训练我们的模型了。

utils.py:

import torch
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from typing import Tuple


def build_dataframe_from_csv(dataset_csv: str) -> pd.DataFrame:
    df = pd.read_csv(
        dataset_csv,
        sep="\t",
        header=None,
        names=["query1", "query2", "label"],
    )

    return df


def compute_spearmanr(x, y):
    return spearmanr(x, y).correlation


def compute_pearsonr(x, y):
    return pearsonr(x, y)[0]


def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
    """Copied from https://github.com/UKPLab/sentence-transformers/tree/master"""
    assert len(scores) == len(labels)
    rows = list(zip(scores, labels))

    rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
    print(rows)

    max_acc = 0
    best_threshold = -1
    # positive examples number so far
    positive_so_far = 0
    # remain negative examples
    remaining_negatives = sum(labels == 0)

    for i in range(len(rows) - 1):
        score, label = rows[i]
        if label == 1:
            positive_so_far += 1
        else:
            remaining_negatives -= 1

        acc = (positive_so_far + remaining_negatives) / len(labels)
        if acc > max_acc:
            max_acc = acc
            best_threshold = (rows[i][0] + rows[i + 1][0]) / 2

    return max_acc, best_threshold


def metrics(y: torch.Tensor, y_pred: torch.Tensor) -> Tuple[float, float, float, float]:
    TP = ((y_pred == 1) & (y == 1)).sum().float()  # True Positive
    TN = ((y_pred == 0) & (y == 0)).sum().float()  # True Negative
    FN = ((y_pred == 0) & (y == 1)).sum().float()  # False Negatvie
    FP = ((y_pred == 1) & (y == 0)).sum().float()  # False Positive
    p = TP / (TP + FP).clamp(min=1e-8)  # Precision
    r = TP / (TP + FN).clamp(min=1e-8)  # Recall
    F1 = 2 * r * p / (r + p).clamp(min=1e-8)  # F1 score
    acc = (TP + TN) / (TP + TN + FP + FN).clamp(min=1e-8)  # Accurary
    return acc, p, r, F1


def compute_metrics(predicts, labels):
    return metrics(labels, predicts)

定义了一些帮助函数,从sentence-transformers库中拷贝了寻找最佳准确率阈值的实现find_best_acc_and_threshold

除了准确率,还计算了句嵌入的余弦相似度与真实标签之间的斯皮尔曼等级相关系数指标。

最后定义训练和测试脚本。

train.py:

from transformers import set_seed, HfArgumentParser, TrainingArguments

import logging
from pathlib import Path

from datetime import datetime

from modeling import SentenceBert
from trainer import BiTrainer
from arguments import DataArguments, ModelArguments
from dataset import PairCollator, PairDataset

logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)


def main():
    parser = HfArgumentParser((TrainingArguments, DataArguments, ModelArguments))
    training_args, data_args, model_args = parser.parse_args_into_dataclasses()
	# 根据当前时间生成输出目录
    output_dir = f"{training_args.output_dir}/{model_args.model_name_or_path.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    training_args.output_dir = output_dir

    logger.info(f"Training parameters {training_args}")
    logger.info(f"Data parameters {data_args}")
    logger.info(f"Model parameters {model_args}")
	# 设置随机种子
    set_seed(training_args.seed)
	# 加载预训练模型
    model = SentenceBert(
        model_args.model_name_or_path,
        trust_remote_code=True,
        max_length=data_args.max_length,
    )
	
    tokenizer = model.tokenizer
	# 构建训练和测试集
    train_dataset = PairDataset(data_args.train_data_path)
    eval_dataset = PairDataset(data_args.eval_data_path)
	# 传入参数
    trainer = BiTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=PairCollator(),
        tokenizer=tokenizer,
    )
    Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
	# 开始训练
    trainer.train()
    trainer.save_model()


if __name__ == "__main__":
    main()

训练

基于train.py定义了train.sh传入相关参数:

timestamp=$(date +%Y%m%d%H%M)
logfile="train_${timestamp}.log"

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=3 nohup python train.py \
    --model_name_or_path=hfl/chinese-macbert-large \
    --output_dir=output \
    --train_data_path=data/train.txt \
    --eval_data_path=data/dev.txt \
    --num_train_epochs=3 \
    --save_total_limit=5 \
    --learning_rate=2e-5 \
    --weight_decay=0.01 \
    --warmup_ratio=0.01 \
    --bf16=True \
    --eval_strategy=epoch \
    --save_strategy=epoch \
    --per_device_train_batch_size=64 \
    --report_to="none" \
    --remove_unused_columns=False \
    --max_length=128 \
    > "$logfile" 2>&1 &


以上参数根据个人环境修改,这里使用的是哈工大的chinese-macbert-large预训练模型。

注意:

  • --remove_unused_columns是必须的。
  • 通过bf16=True可以加速训练同时不影响效果。
  • 其他参数可以自己调整。
100%|██████████| 18655/18655 [1:17:23<00:00,  4.44it/s]
100%|██████████| 18655/18655 [1:17:23<00:00,  4.02it/s]
09/02/2024 21:02:41 - INFO - trainer - Saving model checkpoint to output/hfl-chinese-macbert-large-2024-09-02_19-45-12
{'eval_loss': 0.09294428676366806, 'eval_runtime': 56.1261, 'eval_samples_per_second': 156.825, 'eval_steps_per_second': 19.617, 'epoch': 5.0}
{'train_runtime': 4643.261, 'train_samples_per_second': 257.11, 'train_steps_per_second': 4.018, 'train_loss': 0.049199433276877584, 'epoch': 5.0}

这里训练了5轮,我们拿最后保存的模型output/hfl-chinese-macbert-large-2024-09-02_19-45-12进行测试。

参数忘改了,为了便于比较,实际上下面的结果是以3轮的训练结果验证的。

测试

test.py: 测试脚本见后文的完整代码。

test.sh:

# change CUDA_VISIBLE_DEVICES
CUDA_VISIBLE_DEVICES=0 python test.py \
    --model_name_or_path=output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193 \
    --test_data_path=data/test.txt

输出:

TestArguments(model_name_or_path='output/hfl-chinese-macbert-large-2024-09-02_19-45-12/checkpoint-11193', test_data_path='data/test.txt', max_length=64, batch_size=128)
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.77it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:11<00:00,  8.89it/s]
max_acc: 0.8832, best_threshold: 0.794167
spearman corr: 0.7795 |  pearson_corr corr: 0.7668 | compute time: 22.25s
accuracy=0.883 precision=0.876 recal=0.893 f1 score=0.8843

测试集上的准确率达到88.3%,这种以回归目标函数进行训练的效果没有分类的好。

完整代码

完整代码: →点此←

本文代码是和某次commit相关的,Master分支上的代码随时可能会被优化。

参考

  1. [论文笔记]Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101103.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何通过住宅代理优化SERP表现:提升SEO排名的实用指南

引言 什么是SERP&#xff1f;包含哪些内容&#xff1f; 为什么SERP对SEO至关重要&#xff1f; 如何优化SERP表现&#xff1f; 总结 引言 在当今竞争激烈的数字营销环境中&#xff0c;搜索引擎优化&#xff08;SEO&#xff09;已成为企业提升在线可见性和吸引有机流量的关键…

matlab2024a/2023/2022/2020/matlab2019 如何plot画局部放大图(已解决)

matlab 2024&#xff1b;matlab 2023&#xff1b;matlab 2022&#xff1b;matlab 2021&#xff1b;matlab 2020&#xff1b;matlab 2019 matlab 2017一下的 使用magnify.m 进行局部放大图操作是没有问题的。 新版本 采用magnify.m 很难操作。 为什么要局部放大 局部方…

【王树森】Few-Shot Learning (3/3):Pretraining + Fine Tuning(个人向笔记)

Preliminary Few-Shot Learning 很简单&#xff0c;但是却能达到比较高的准确度&#xff0c;相反一些复杂的模型反而不能达到很高的准确率 1. Cosine Similarity 余弦相似度可以衡量两个向量的相似度 假设两个向量的长度都是1&#xff1a;那么它们余弦相似度的计算方法如下…

HarmonyOS开发实战( Beta5版)线程间通信场景最佳实践

简介 在应用开发中&#xff0c;经常会需要处理一些耗时的任务&#xff0c;如果全部放在主线程中执行就会导致阻塞&#xff0c;从而引起卡顿或者掉帧现象&#xff0c;降低用户体验&#xff0c;此时就可以将这些耗时操作放到子线程中处理。通常情况下&#xff0c;子线程可以独立…

bcftools报错|The sequence “chr1“ not defined in the header: chr1.recode.vcf

1、报错信息 The sequence "chr1" not defined in the header: chr1.recode.vcf (Quick workaround: index the file.) 所使用的命令&#xff0c;目的是想合并所提取的特定染色体。 bcftools concat -O v / -o varscan.indel_merged.vcf chr1.recode.vcf chr2.reco…

超好用的图纸加密软件排行榜 | 2024图纸加密软件的七款最优选择!

数字化设计日益普及的今天&#xff0c;图纸作为设计与工程的核心载体&#xff0c;其安全性成为了企业和设计师们最为关注的焦点之一。 面对日益复杂的数据泄露风险&#xff0c;如何有效地保护图纸文件的安全呢&#xff1f; 下面&#xff0c;我们就来探讨一下2024图纸加密软件的…

Python的10个文件对比与合并高效策略

文末赠免费精品编程资料~~ 在日常编程或数据分析工作中&#xff0c;经常需要处理多个文件的对比与合并任务。Python因其强大的文件处理能力和丰富的库支持&#xff0c;成为了处理这类任务的理想选择。下面&#xff0c;我们将逐步探索10种高效的文件对比与合并策略&#xff0c;…

OpenGL/GLUT实践:粒子系统,并添加纹理、动态模糊、边界碰撞(电子科技大学信软图形与动画Ⅱ实验)

源码见GitHub&#xff1a;A-UESTCer-s-Code 文章目录 1 运行效果2 实验过程2.1 基本粒子系统2.1.1 定义粒子结构2.1.2 创建粒子并初始化2.1.2.1 创建粒子2.1.2.2 初始化 2.1.3 粒子状态更新与绘制2.1.3.1 绘制2.1.3.2 更新 2.1.4 实现效果 2.2 添加纹理2.2.1 纹理添加2.2.2 渲染…

PostgreSQL + PostGIS:空间数据存储及管理解决方案

在数据库领域&#xff0c;PostgreSQL 已成为最强大、最通用的选项之一。它管理大量数据的能力、对 SQL 标准的遵守以及可扩展的架构使其受到学术界和工业界的喜爱。然而&#xff0c;真正让 PostgreSQL 脱颖而出的原因之一是它与PostGIS的集成&#xff0c;这是一个允许您有效处理…

第七课,条件表达式与初识分支判断

一&#xff0c;什么是判断 判断&#xff0c;就是在做某件事前&#xff0c;先问问满不满足条件。 进行逻辑判断&#xff0c;是生活中常见的行为。 “今天出门你要带伞吗&#xff1f;” “那得看天气怎么样&#xff0c;如果下雨或者太阳太大就带伞&#xff0c;否则就不带。”…

内存卡乱码问题解析恢复方案

一、内存卡乱码现象探析 在数字化时代&#xff0c;内存卡作为便携式数据存储设备&#xff0c;广泛应用于手机、相机、行车记录仪等多种电子设备中。然而&#xff0c;不少用户在使用过程中会遇到内存卡乱码的问题&#xff0c;即原本有序存储的文件突然变得无法识别&#xff0c;…

【前端面试】设计循环双端队列javascript

题目 https://leetcode.cn/problems/design-circular-deque/description/ 存储循环队列的向量空间是循环的,用通俗的话来讲,就是我们在做next或者prev操作时,不会发生溢出 取模、或者直接判断是否为0/size返回一个值。 数组实现 用函数来实现一个类,定义容量、头尾指针…

青远生态为云南林业规划院定制开发的自然保护地规划智能编制系统顺利通过验收

8月30日&#xff0c;青远生态为云南省林业调查规划院开发的自然保护地规划智能编制系统顺利通过验收。该系统具有智能推荐规划内容、自动生成投资估算表、智能编制规划报告等功能&#xff0c;集合了拉丁名填充、表格制作等丰富实用的工具&#xff0c;显著提升了规划工作的效率和…

电力系统有滤波器还需要装电抗器吗

在电力系统中&#xff0c;滤波器和电抗器各有不同的功能&#xff0c;尽管它们都能改善电力质量。是否需要同时安装滤波器和电抗器&#xff0c;取决于系统的具体需求和现状。以下是一些考虑因素&#xff1a; 1、滤波器的功能&#xff1a; 谐波滤波&#xff1a;滤波器主要用于抑…

基于vue框架的超市会员管理系统设计与实现xeb8c(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;会员,商品分类,商品信息,订单信息,积分等级,礼品信息,礼品兑换 开题报告内容 基于Vue框架的超市会员管理系统设计与实现开题报告 一、研究背景与意义 随着消费者对个性化服务和优惠活动需求的增加&#xff0c;超市会员管理成为提升顾…

Docker安装及验证,小白必备

Docker安装 本教程以centos系统为例 1、Docker安装前准备工作 切换国内源 cp -a /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak #备份设置为华为云的yum wget -O /etc/yum.repos.d/CentOS-Base.repo https://repo.huaweicloud.com/repository…

专用于理解游戏场景的开源大模型-VideoGameBunny

大模型在游戏开发领域扮演了重要角色&#xff0c;从AI机器人生成到场景搭建覆盖各个领域。但在游戏场景理解、图像识别、内容描述方面很差。 为了解决这些难题&#xff0c;加拿大阿尔伯塔的研究人员专门开源了一款针对游戏领域的大模型VideoGameBunny&#xff08;以下简称“VG…

7-8月月报 | Apache SeaTunnel社区进展一览

各位热爱 Apache SeaTunnel 的小伙伴们&#xff0c;社区 7-8 月份月报来啦&#xff01;这两个月项目有了哪些进展&#xff1f;又有谁登上了我们社区的贡献者榜单呢&#xff1f;快来一睹为快吧。 Merge Stars 感谢以下小伙伴上两个月为 Apache SeaTunnel 项目和社区发展所做的…

非时序检查(Non-Sequential Check)

单元或宏&#xff08;macro&#xff09;的库文件可以将时序弧指定为非时序&#xff08;non-sequential&#xff09;检查&#xff0c;例如两个数据引脚之间的时序弧。非时序检查是指两个引脚之间的检查&#xff0c;两者都不是时钟。一个引脚是约束引脚&#xff0c;其作用类似于数…

WPF在MVVM架构下使用DataGrid并实现行删除

一、效果演示 二、Model创建 //User&#xff1a;用于绑定DataGrid控件的数据 private ObservableCollection<User> _users new ObservableCollection<User>();public ObservableCollection<User> Users{get { return _users; }set { _users value; }}//Sel…