【SentenceTransformer系列】计算句子嵌入的概念(01/10)

news2024/11/24 19:00:32

一、说明

        要分清词嵌入和句子嵌入的区别。

        句子嵌入是指将句子或文档表示为固定长度的向量的过程,使得向量能够捕获句子的语义和上下文信息。它是自然语言处理 (NLP) 和机器学习中的常见任务,因为它可以帮助对句子之间的关系和相似性进行建模,并执行各种下游任务,例如情感分析、文本分类和机器翻译。句子嵌入有多种技术,例如 Word2Vec、GloVe 和 BERT。

二、计算句子嵌入

计算句子嵌入的基本函数如下所示:

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.', 
    'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

注意:即使我们谈论句子嵌入,您也可以将其用于较短的短语以及具有多个句子的较长文本。有关段落嵌入的更多说明,请参阅输入序列长度部分。

首先,我们加载一个句子转换器模型:

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('model_name_or_path')

您可以指定预先训练的模型,也可以在光盘上传递路径以从该文件夹加载句子转换器模型。

如果可用,模型将在 GPU 上自动执行。您可以像这样为模型指定设备:

model = SentenceTransformer('model_name_or_path', device='cuda')

使用设备任何 pytorch 设备(如 CPU、cuda、cuda:0等)

对一组句子/文本进行编码的相关方法是。在下文中,您可以找到此方法接受的参数。一些相关参数是batch_size(取决于您的 GPU,不同的批大小是最佳的)以及convert_to_numpy(返回 numpy 矩阵)和convert_to_tensor(返回 pytorch 张量)。model.encode()

.classsentence_transformers。句子转换器(model_name_or_path: 可选[str] = 无, 模块: 可选[可迭代[torch.nn.modules.module.Module]] = 无,设备:可选[str] = 无,cache_folder : 可选[str] = 无, use_auth_token: 可选[联合[布尔值, str] ] = 无)

加载或创建一个句子转换器模型,该模型可用于将句子/文本映射到嵌入。

参数

  • model_name_or_path – 如果它是光盘上的文件路径,则从该路径加载模型。如果它不是路径,它首先尝试下载预先训练的 SentenceTransformer 模型。如果失败,则尝试使用该名称从Huggingface模型存储库构建模型。

  • 模块 – 此参数可用于从头开始创建自定义句子转换器模型。

  • 设备 – 应该用于计算的设备(如“cuda”/“CPU”)。如果为“无”,则检查是否可以使用 GPU。

  • cache_folder – 存储模型的路径

  • use_auth_token – HuggingFace 身份验证令牌以下载私有模型。

初始化由两个 nn 共享的内部模块状态。模块和脚本模块。

编码(句子: Union[str, List[str]], batch_size: int = 32, show_progress_bar: 可选[bool] = 无, output_value: str = 'sentence_ 嵌入'convert_to_numpy:布尔 = 真,convert_to_tensor:布尔 = 假,设备: 可选[str] = 无,normalize_embeddings:bool = 假) → 联合[列表 [火炬。Tensor], numpy.ndarray, torch.张量]

计算句子嵌入

参数

  • 句子 – 要嵌入的句子

  • batch_size – 用于计算的批大小

  • show_progress_bar – 编码句子时输出进度条

  • output_value – 默认sentence_embedding,用于获取句子嵌入。可以设置为 token_embeddings 以获取词片标记嵌入。设置为“无”,以获取所有输出值

  • convert_to_numpy – 如果为 true,则输出是 numpy 向量的列表。否则,它是一个 pytorch 张量列表。

  • convert_to_tensor – 如果为 true,您将获得一个大张量作为返回。覆盖convert_to_numpy中的任何设置

  • device – 用于计算的火炬设备

  • normalize_embeddings – 如果设置为 true,则返回的向量的长度为 1。在这种情况下,可以使用更快的点积 (util.dot_score) 而不是余弦相似性。

返回

默认情况下,返回张量列表。如果convert_to_tensor,则返回堆叠张量。如果convert_to_numpy,则返回 numpy 矩阵。

三、输入序列长度

        变压器模型如BERT / RoBERTa / DistilBERT等,运行时和内存要求随输入长度呈二次增长。这将变压器限制为一定长度的输入。BERT&Co.的一个共同值是512个单词,相当于大约300-400个单词(英语)。比这更长的文本被截断为前 x 个单词片段。

        默认情况下,提供的方法使用限制为 128 个字段,较长的输入将被截断。您可以像这样获取和设置最大序列长度:

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')

print("Max Sequence Length:", model.max_seq_length)

#Change the length to 200
model.max_seq_length = 200

print("Max Sequence Length:", model.max_seq_length)

注意:您不能增加的长度高于相应变压器型号支持的最大长度。另请注意,如果模型是在短文本上训练的,则长文本的表示可能不是那么好。

四、存储和加载嵌入

        最简单的方法是使用 pickle 将预先计算的嵌入存储在光盘上并从光盘加载。如果您需要对大量句子进行编码,这尤其有用。

from sentence_transformers import SentenceTransformer
import pickle
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.',
    'The quick brown fox jumps over the lazy dog.']
embeddings = model.encode(sentences)
#Store sentences & embeddings on disc
with open('embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
#Load sentences & embeddings from disc
with open('embeddings.pkl', "rb") as fIn:
    stored_data = pickle.load(fIn)
    stored_sentences = stored_data['sentences']
    stored_embeddings = stored_data['embeddings']



五、进程/多GPU编码

        您可以使用多个 GPU(或 CPU 计算机上的多个进程)对输入文本进行编码。有关示例,请参阅:computing_embeddings_mutli_gpu.py。

        相关方法是 ,它启动多个用于编码的进程。start_multi_process_pool()

句子变形金刚。start_multi_process_pool(target_devices:可选[列表[str]] = 无)

        启动多进程以使用多个独立进程处理编码。 如果要在多个 GPU 上进行编码,建议使用此方法。建议 为每个 GPU 仅启动一个进程。此方法与encode_multi_process一起使用

        参数

target_devices – PyTorch 目标设备,例如 cuda:0、cuda:1...如果为“无”,则将使用所有可用的 CUDA 设备

返回

返回包含目标进程、输入队列和输出队列的字典。

六、使用转换器嵌入句子

        我们的大多数预训练模型都基于 Huggingface.co/Transformers,并且也托管在Huggingface的模型存储库中。可以在不安装句子转换器的情况下使用我们的句子嵌入模型:

from transformers import AutoTokenizer, AutoModel
import torch


#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask



#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence',
             'Sentences are passed as a list of string.',
             'The quick brown fox jumps over the lazy dog.']

#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

七、后记

        您可以在此处找到可用的型号:sentence-transformers (Sentence Transformers)

        在上面的示例中,我们在AutoModel(将加载BERT模型)之上添加了平均池化。我们也有具有最大池化的模型以及我们使用 CLS 令牌的模型。如何正确应用此池化,

请查看句子转换器/bert-base-nli-max-tokens 和 /sentence-transformers/bert-base-nli-cls-token。

参考资料地址:利用 Embedding 层学习词嵌入 · python深度学习 · 看云 (kancloud.cn)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/891122.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

接口自动化测试(添加课程接口调试,调试合同上传接口,合同列表查询接口,批量执行)

1、我们把信息截取一下 1.1 添加一个新的请求 1.2 对整个请求进行保存,Ctrl S 2、这一次我们添加的是课程添加接口,以后一个接口完成,之后Ctrl S 就能够保存 2.1 选择方法 2.2 设置请求头,参数数据后期我们通过配置设置就行 3、…

Lua 位和字节

一、位运算 从 Lua 5.3 版本开始&#xff0c;提供了针对数值类型的一组标准位运算符&#xff0c;与算数运算符不同的是&#xff0c;运算符只能用于整型数。 运算符描述&按位与|按位或&#xff5e;按位异或>>逻辑右移<<逻辑左移&#xff5e;&#xff08;一元运…

安全学习DAY17_信息打点-语言框架组件识别

信息打点-WEB打点-语言框架&开发组件 文章目录 信息打点-WEB打点-语言框架&开发组件本节涉及链接&工具本节知识&思维导图基础概念介绍框架&#xff1a;组件&#xff1a;Web架构 对应Web测试手法后端&#xff1a;前端组件&#xff1a;java居多&#xff0c;框架&…

RP2040开发板自制树莓派逻辑分析仪

目录 前言 1 准备工作和前提条件 1.1 Raspberry Pi Pico RP2040板子一个 1.2 Firmware-LogicAnalyzer-5.0.0.0-PICO.uf2固件 1.3 LogicAnalyzer-5.0.0.0-win-x64软件 2 操作指南 2.1 按住Raspberry Pi Pico开发板的BOOTSEL按键&#xff0c;再接上USB接口到电脑 2.2 刷入…

产品帮助中心怎么做?这两点不能忽略,让用户自助解决问题!

对于大部分线上产品&#xff0c;因为其功能和系统的复杂性&#xff0c;使得新手客户入门学习非常复杂&#xff0c;为了快速响应并且解决问题&#xff0c;一套系统完整的产品帮助中心必不可少&#xff01; 产品帮助中心 因此&#xff0c;对于很多产品开发者来说&#xff0c;借助…

pg简单使用

1.创建服务器 2.创建数据库 3.修改默认连接数据库 工具都是链接到这里 4.数据库代码工具

ByteBuffer 使用

ByteBuffer 使用 1 java.nio包中的类定义的缓冲区类型2 缓冲区常用属性2.1缓冲区的容量(capacity)2.2 缓冲区的位置(position)2.3 缓冲区的限制(limit)2.4 缓冲区的标记(mark)2.5 剩余容量 remaining/hasRemaining 3 缓冲区常用方法3.1 创建缓冲区3.1.1 allocate方法3.1.2 wrap…

交叉编译之wiringPi库,【全志H616,orangepi-zero2】

文章目录 书接上回wiringPi全志库下载建立软链接软连接软连接创建 硬链接硬链接创建 测试树莓派运行servo文件 结束 书接上回 上回已经完整的安装了全志的gcc交叉编译工具 https://blog.csdn.net/qq_52749711/article/details/132306764 wiringPi全志库下载 下载链接 先搞到…

Jmeter+ant+jenkins实现持续集成

jmeterantjenkins持续集成 一、下载并配置jmeter 首先下载jmeter工具&#xff0c;并配置好环境变量&#xff1b;参考&#xff1a;https://www.cnblogs.com/YouJeffrey/p/16029894.html jmeter默认保存的是.jtl格式的文件&#xff0c;要设置一下bin/jmeter.properties,文件内容…

中国电信物联网收入33亿元,用户达到4.73亿户!

近日&#xff0c;中国电信发布2023中期业绩&#xff0c;物联网迎来强劲增长&#xff0c;物联网收入33亿元&#xff0c;同比增长75.7%&#xff0c;物联网用户4.73亿户&#xff0c;同比增长31.5%。天翼物联自主研发的AIoT物联网平台&#xff0c;升级为云原生3AZ架构&#xff0c;提…

在线课堂录播直播管理系统SpringBoot+Vue

在线课堂录播直播管理系统SpringBootVue 文章目录 在线课堂录播直播管理系统SpringBootVue共三个端&#xff1a;后端、后台管理系统、前端&#xff0c;如要学习看评论区&#xff08;全部源码、文档、数据库&#xff09;。内置功能一、前端二、后台管理三、后端--代码全有。四、…

k8s 认证和权限控制

k8s 的认证机制是啥&#xff1f; 说到 k8s 的认证机制&#xff0c;其实之前咋那么也有提到过 ServiceAccouont &#xff0c;以及相应的 token &#xff0c;证书 crt&#xff0c;和基于 HTTP 的认证等等 k8s 会使用如上几种方式来获取客户端身份信息&#xff0c;不限于上面几种…

【数据结构OJ题】链表分割

原题链接&#xff1a;https://www.nowcoder.com/practice/0e27e0b064de4eacac178676ef9c9d70?tpId8&&tqId11004&rp2&ru/activity/oj&qru/ta/cracking-the-coding-interview/question-ranking 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2…

Java面向对象——封装以及this关键字

封 装 封装是面向对象编程&#xff08;OOP&#xff09;的三大特性之一&#xff0c;它将数据和操作数据的方法组合在一个单元内部&#xff0c;并对外部隐藏其具体实现细节。在Java中&#xff0c;封装是通过类的访问控制修饰符&#xff08;如 private、protected、public&#x…

Android Drawable转BitmapDrawable再提取Bitmap,Kotlin

Android Drawable转BitmapDrawable再提取Bitmap&#xff0c;Kotlin <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"…

C++ 结构体的对齐

C 结构体的对齐 flyfish 文章目录 C 结构体的对齐一 非对齐方式二 对齐方式示例1示例2 三 对齐到指定字节数 boundary 一 非对齐方式 也就是按照1字节对齐 #pragma pack(1) typedef unsigned char BYTE; typedef struct message {BYTE a[4];BYTE b[2];BYTE *c;BYTE d[4];} M…

阿里云ECS服务器企业级和共享型介绍_企业级常见问题解答FAQ

阿里云企业级服务器是什么&#xff1f;企业级和共享型有什么区别&#xff1f;企业级服务器具有独享且稳定的计算、存储、网络资源&#xff0c;如ECS计算型c6、通用型g8等都是企业级实例&#xff0c;阿里云百科分享什么是企业级云服务器、企业级实例的优势、企业级和共享型云服务…

如何收缩wsl2虚拟磁盘

简介 WSL2使用虚拟化层为它带来更高的性能和兼容性。但是&#xff0c;WSL2 的少数缺点之一是它使用虚拟磁盘 &#xff08;VHDX&#xff09; 来存储文件系统。这意味着您的虚拟磁盘占用了 100GB&#xff0c;但 WSL2 只需要 15GB... 所以要寻找一种缩小 WSL2 虚拟磁盘的方法&…

​Redis概述

目录 Redis - 概述 使用场景 如何安装 Window 下安装 Linux 下安装 docker直接进行安装 下载Redis镜像 Redis启动检查常用命令 Redis - 概述 redis是一款高性能的开源NOSQL系列的非关系型数据库,Redis是用C语言开发的一个开源的高键值对(key value)数据库,官方提供测试…

Leetcode每日一题:1444. 切披萨的方案数(2023.8.17 C++)

目录 1444. 切披萨的方案数 题目描述&#xff1a; 实现代码与解析&#xff1a; 二维后缀和 动态规划 原理思路&#xff1a; 1444. 切披萨的方案数 题目描述&#xff1a; 给你一个 rows x cols 大小的矩形披萨和一个整数 k &#xff0c;矩形包含两种字符&#xff1a; A …