NLP: SBERT介绍及sentence-transformers库的使用

news2024/11/23 7:57:31

1. Sentence-BERT

  Sentence-BERT(简写SBERT)模型是BERT模型最有趣的变体之一,通过扩展预训练的BERT模型来获得固定长度的句子特征,主要用于句子对分类、计算两个句子之间的相似度任务。

1.1 计算句子特征

  SBERT模型同样是将句子标记送入预训练的BERT模型来获取句子特征的,但这里并不使用 R [ C L S ] R_{[CLS]} R[CLS]作为最终的句子特征。在SBERT中,通过汇聚所有标记的特征来计算整个句子的特征。具体的汇聚方法有两种:平均汇聚和最大汇聚。

  • 平均汇聚:使用平均汇聚来获取句子特征。这种方法得到的句子的特征将包含所有词语(Token)的意义。
  • 最大汇聚:使用最大汇聚来获取句子特征。这种方法得到的句子的特征将仅包含重要词语(Token)的意义。
    在这里插入图片描述

1.2 SBERT架构

  SBERT模型使用二元组网络架构来执行以一对句子作为输入的任务,并使用三元组网络架构来实现三元组损失函数。

1.2.1 使用二元组网络架构的SBERT模型

  SBERT通过二元组网络(两个共享同样权重的相同网络)架构对执行句子对任务的预训练的BERT模型进行微调。句子对任务具体包括以下两种:

  • 句子对分类任务: 判断句子对是否相似。相似则返回1,不相似则返回0。其SBERT模型架构为:
    在这里插入图片描述
  • 句子对回归任务:计算两个给定句子之间的语义相似度。其对应的SBERT架构为:在这里插入图片描述
1.2.2 使用三元组网络架构的SBERT模型

  三元组网络架构的SBERT模型的任务计算出一个特征,使锚定句和正向句之间的相似度高,锚定句和负向句之间的相似度低。其架构如下:
在这里插入图片描述

2. 计算文本相似度

2.1 bi-encoder VS cross-encoder

  bi-encoder和cross-encoder是语义匹配、文本相似度、信息检索场景下下常用的两种模型架构。这两者都基于深度学习模型(如BERT等)进行编码和比较文本之间的相似度,但它们在计算方式、效率和适用场景上有显著的区别。

2.1.1 bi-encoder

  bi-encoder是一种独立编码方式,即输入的两个文本会被分别编码为独立的向量,然后通过计算这两个向量的相似度来判断文本之间的关系。使用bi-encoder方式计算文本相似度的案例如下:

from sentence_transformers import SentenceTransformer
#加载预训练的sentence transformer模型
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences=["这个商品挺好用的","这个商品一点也不好用"]
embeddings=model.encode(sentences)
similarity=model.similarity(embeddings[0],embeddings[1])
print(similarity) #0.5868
2.1.2 cross-encoder

  cross-encoder是一种联合编码方式,即将两个文本拼接在一起作为模型的输入,模型会通过对两个文本的联合表示来直接输出一个相似度分数。这种方式可以更好地捕捉两个文本之间的复杂交互信息,因此在诸如问答匹配、精确文本相似度计算等需要细粒度判断的任务上表现更好。具体使用方式如下:

from sentence_transformers.cross_encoder import CrossEncoder
model=CrossEncoder("cross-encoder/stsb-distilroberta-base")
query="这个产品挺好用的"
corpus=["这个产品很好",
        "这个产品的设计有很大问题",
        "这个产品不好用"]
ranks=model.rank(query,corpus)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")

3 微调SBERT

  接下来我们使用STSB数据集对SBERT模型进行微调。具体代码如下

from datasets import load_dataset
from sentence_transformers import losses
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer,
)
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator,SimilarityFunction
from datasets import load_dataset

train = load_dataset("sentence-transformers/stsb",split='train')
dev = load_dataset("sentence-transformers/stsb",split='validation')
test= load_dataset("sentence-transformers/stsb",split='test')

model=SentenceTransformer('FacebookAI/xlm-roberta-base')

loss=losses.CoSENTLoss(model=model)

args=SentenceTransformerTrainingArguments(output_dir='models/model1',
                                          num_train_epochs=1,
                                          per_device_train_batch_size=16,
                                          per_device_eval_batch_size=16,
                                          warmup_ratio=0.1,
                                          eval_strategy='steps',
                                          eval_steps=100,
                                          save_strategy='steps',
                                          save_total_limit=2,
                                          bf16=False,)

dev_evaluator=EmbeddingSimilarityEvaluator(
    sentences1=dev['sentence1'],
    sentences2=dev['sentence2'],
    scores=dev['score'],
    main_similarity=SimilarityFunction.COSINE,
    name='dev-evaluator')

dev_evaluator(model)

trainer=SentenceTransformerTrainer(model=model,
                                   args=args,
                                   train_dataset=train,
                                   eval_dataset=dev,
                                   loss=loss,
                                   evaluator=dev_evaluator)   
trainer.train()                        

test_evaluator=EmbeddingSimilarityEvaluator(
    sentences1=test['sentence1'],
    sentences2=test['sentence2'],
    scores=test['score'],
    main_similarity=SimilarityFunction.COSINE,
    name='test-evaluator')
test_evaluator(model)
model.save_pretrained('models/model1')

关于上述代码,需要说明以下几点:

  • 训练和评估SBERT的数据类型必须是datasets.Datasetdatasets.DatasetDict
  • 数据集的格式必须和损失函数、评估器相匹配。如果损失函数需要标签字段,那么数据集必须有“label”或“score”字段;其他名称非“label”或“score”的字段将自动归属于Inputs字段。所以在进行后续步骤时,必须将数据集中的无法标签删除,同时要保证数据集中的字段顺序与对应损失函数中要求的顺序一致。
  • 需要根据具体的任务以及数据集的形式选择合适的损失函数,没有哪种损失函数可以解决所有的问题。SBERT提供的损失函数列表如下:
    https://www.sbert.net/docs/sentence_transformer/loss_overview.html
  • 微调后的模型可以和其他预训练的模型一样使用,比如计算文本相似度,这里不再赘述。

参考资料

  1. BERT基础教程: Transformer大模型实战
  2. https://baijiahao.baidu.com/s?id=1801193891938395467
  3. https://www.sbert.net

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2201001.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OmniH2O——通用灵巧且可全身远程操作并学习的人形机器人(其前身H2O是HumanPlus的重要参考)

前言 由于我司一直在针对各个工厂、公司、客户特定的业务场景,做解决方案或定制开发,所以针对每一个场景,我们都会反复考虑用什么样的机器人做定制开发 于此,便不可避免的追踪国内外最前沿的机器人技术进展,本来准备…

数据库管理-第249期 23ai:全球分布式数据库-请求路由与查询过程(20241008)

数据库管理249期 2024-10-08 数据库管理-第249期 23ai:全球分布式数据库-请求路由与查询过程(20241008)1 客户端应用请求路由1.1 分片键1.2 Oracle连接驱动 2 查询过程和查询协调器2.1 指定一致性级别2.2 高可用与性能 总结 数据库管理-第249…

拍立淘API接口以图搜商品列表功能实现技术分享item_search_img|返回商品列表商品id商品价格url

开发背景 在电商平台的快速发展中,用户对于商品搜索的效率和准确性提出了越来越高的要求。传统的基于关键词的搜索方式,虽然在一定程度上满足了用户的需求,但在面对复杂的商品信息和多样化的用户搜索意图时,仍存在诸多局限性。为…

PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测

目录 I. 前言II. 数据集说明III. 模型3.1 GCN3.2 GraphSAGE3.3 GAT IV. 训练与测试V. 实验结果 I. 前言 前面已经写了很多关于时间序列预测的文章: 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)PyTorch搭建LSTM实现时间序列…

IO相关,标准输入输出及错误提示

一、IO简介 1.1 IO的过程 操作系统的概念:向下统筹控制硬件,向上为用户提供接口。 操作系统的组成 内核 外壳(shell) linux的五大功能:进程管理、内存管理、文件管理、设备管理、网络管理。 最早接触的IO&#xf…

01背包,CF 1974E - Money Buys Happiness

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 1974E - Money Buys Happiness 二、解题报告 1、思路分析 问我们能够到达…

docker简述

1.安装dockers,配置docker软件仓库 安装,可能需要开代理,这里我提前使用了下好的包安装 启动docker systemctl enable --now docker查看是否安装成功 2.简单命令 拉取镜像,也可以提前下载使用以下命令上传 docker load -i imag…

深度学习笔记(持续更新)

注:本文所有深度学习内容都是基于PyTorch,PyTorch作为一个开源的深度学习框架,具有可以动态计算图、拥有简洁易用的API、支持GPU加速等特点,在计算机视觉、自然语言处理、强化学习等方面有广泛应用。 使用matplotlib绘图&#xff…

Linux 常用命令详解,线上问题排查必备

comm 比较文件行 comm 是 Linux 系统下的用于比较两个已排序文件的命令行工具。主要用于找出文件之间的差异或相同之处,例如两个文件中相同的行、仅在第一个文件中的行以及仅在第二个文件中的行。 基本语法 comm [OPTION] FILE1 FILE2可选参数OPTION如下&#xf…

图像分类-demo(Lenet),tensorflow和Alexnet

目录 demo(Lenet) 代码实现基本步骤: TensorFlow 一、核心概念 二、主要特点 三、简单实现 参数: 模型编译 模型训练 模型评估 Alexnet model.py train.py predict.py demo(Lenet) PyTorch提供了一个名为“torchvision”的附加库,其中包含…

芯课堂 | FatFs文件系统的移植及应用指南

1、FatFs文件系统简介 FatFs是用于小型嵌入式系统的通用FAT/exFAT文件系统模块。FatFs模块是按照ANSI C(C89)编写的,与磁盘控制层完全分离。因此,它独立于平台和存储设备,具有良好的硬件平台独立性。它可以集成到资源有…

这个问题做项目的时给某些客户普及过,这里再给你普及一下

有些因素不是地理概念,没错!但与地理有关!可以通过地理位置将他们链接起来,再结合其它业务数据,完成数据分析!例如百度地图会将:餐饮、文化、交通、住宿、甚至价格、天气与位置关联分析&#xf…

S7---基本介绍

目录 高通S7和S7 Pro Gen 1声音平台 音频性能的新层次 高通XPAN技术 卓越的听力增强 高通第四代ANC 特征 QualcommS7 Pro Gen 1附加功能 QualcommS7 Pro Gen 1框图 高通S7和S7 Pro Gen 1声音平台 声音被重新想象。QualcommS7声音平台旨在开启一个新的高级音频性能级别。…

Unity转Unreal5之从入门到精通 Spline(样条曲线)组件的使用

前言 Spline 组件 能编辑 样条曲线,定义一条路径,路径上的点可以通过距离起点的长度获取,因此可以实现 物体沿路径连续移动 的效果或者 物体沿路径分布 的效果。 今天我们就来实现一个简单的Spline样条曲线的Demo 实现一个沿路径运动的功能 1.新建一个基于 Actor 的蓝图…

JavaSE——集合1:Collection接口(Iterator和增强for遍历集合)

目录 一、集合框架体系(重要) 二、集合引入 (一)集合的理解与好处 三、Collection接口 (一)Collection接口实现类的特点 (二)Collection接口常用方法 (三)Collection接口遍历元素的方式(Iterator和增强for) 1.使用Iterator(迭代器) 1.1Iterator(迭代器)介绍 1.2Itera…

使用cv::FileStorage对yaml文件进行读写

问题描述:记录使用cv::FileStorage对yaml文件进行读写 参考官网:OpenCV: cv::FileStorage Class Reference WRITE:根据文件路径写文件,如果文件不存在会新建,文件存在则变空白 FileStorage fs(filepath, FileStorag…

新增数据集 SDK、“关系抽取”文本标注、优化模型监控和管理|ModelWhale 版本更新

ModelWhale 带来了新一轮的版本更新,期待为大家带来更优质的使用体验。 本次更新中,ModelWhale 主要进行了以下功能迭代: 数据管理:新增 mw_python_sdk 支持通过查看、下载、制作、更新数据集 文本标注:新增“关系抽取…

【DFDT】DFDT: An End-to-End DeepFake Detection Framework Using Vision Transformer

文章目录 DFDT: An End-to-End DeepFake Detection Framework Using Vision Transformerkey points贡献方法补丁提取和嵌入基于注意力的补丁选择多流transformer块多尺度分类器实验DFDT: An End-to-End DeepFake Detection Framework Using Vision Transformer 会议/期刊:App…

Apache Linkis + OceanBase:如何提升数据分析效率

计算中间件 Apache Linkis 构建了一个计算中间件层,以实现上层应用程序和底层数据引擎之间的连接、治理和编排。目前,已经支持通过数据源的功能,实现用户通过Linkis 对接并使用 OceanBase数据库。 本文详细阐述了在 Apache Linkis v1.3.2中&a…

【虚拟化】内核级虚拟化技术KVM介绍,全/半虚拟化的区别,使用libvirt搭建虚拟化平台(go/java/c++)

【虚拟化】内核级虚拟化技术KVM介绍,全/半虚拟化的区别,使用libvirt搭建虚拟化平台(go/java/c) 文章目录 1、虚拟化技术分类与架构(KVM,Xen),全/半虚拟化的区别2、libvirt介绍3、使用…