bert 适合 embedding 的模型

news2024/11/19 5:29:41

目录

背景

embedding

求最相似的 topk

结果查看


背景

想要求两个文本的相似度,就单纯相似度,不要语义相似度,直接使用 bert 先 embedding 然后找出相似的文本,效果都不太好,试过 bert-base-chinese,bert-wwm,robert-wwm 这些,都有一个问题,那就是明明不相似的文本却在结果中变成了相似,真正相似的有没有,

例如:手机壳迷你版,与这条数据相似的应该都是跟手机壳有关的才合理,但结果不太好,明明不相关的,余弦相似度都能有有 0.9 以上的,所以问题出在 embedding 上,找了适合做 embedding 的模型,再去计算相似效果好了很多,合理很多。

之前写了一篇 bert+np.memap+faiss文本相似度匹配 topN-CSDN博客 是把流程打通,现在是找适合文本相似的来操作。

模型:

bge-small-zh-v1.5

bge-large-zh-v1.5

embedding

数据弄的几条测试数据,方便看那些相似

我用 bge-large-zh-v1.5 来操作,embedding 代码,为了知道 embedding 进度,加了进度条功能,同时打印了当前使用 embedding 的 bert 模型输出为度,这很重要,会影响求相似的 topk

import numpy as np
import pandas as pd
import time
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
import torch


class TextEmbedder():
    def __init__(self, model_name="./bge-large-zh-v1.5"):
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自己电脑跑不起来 gpu
        self.device = torch.device("cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    # 没加进度条的
    # def embed_sentences(self, sentences):
    #     encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    #     with torch.no_grad():
    #         model_output = self.model(**encoded_input)
    #         sentence_embeddings = model_output[0][:, 0]
    #     sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    #
    #     return sentence_embeddings
    
    # 加进度条
    def embed_sentences(self, sentences):
        embedded_sentences = []

        for sentence in tqdm(sentences):
            encoded_input = self.tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')
            with torch.no_grad():
                model_output = self.model(**encoded_input)
                sentence_embedding = model_output[0][:, 0]
            sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2)

            embedded_sentences.append(sentence_embedding.cpu().numpy())

        print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1])
        return np.array(embedded_sentences)

    def save_embeddings_to_memmap(self, sentences, output_file, dtype=np.float32):
        embeddings = self.embed_sentences(sentences)
        shape = embeddings.shape
        embeddings_memmap = np.memmap(output_file, dtype=dtype, mode='w+', shape=shape)
        embeddings_memmap[:] = embeddings[:]
        del embeddings_memmap  # 关闭并确保数据已写入磁盘


def read_data():
    data = pd.read_excel('新建 XLSX 工作表.xlsx')
    return data['addr'].to_list()


def main():
    # text_data = ["这是第一个句子", "这是第二个句子", "这是第三个句子"]
    text_data = read_data()

    embedder = TextEmbedder()

    # 设置输出文件路径
    output_filepath = 'sentence_embeddings.npy'

    # 将文本数据向量化并保存到内存映射文件
    embedder.save_embeddings_to_memmap(text_data, output_filepath)


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    print(end - start)

求最相似的 topk

使用 faiss 索引需要设置 bert 模型的维度,所以我们前面打印出来了,要不然会报错,像这样的:

ValueError: cannot reshape array of size 10240 into shape (768)

所以  print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1]) 的值换上去,我这里打印的 1024

index = faiss.IndexFlatL2(1024)  # 假设BERT输出维度是768

# 确保embeddings_memmap是二维数组,如有需要转换
if len(embeddings_memmap.shape) == 1:
    embeddings_memmap = embeddings_memmap.reshape(-1, 1024)

完整代码 

import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm


def search_top4_similarities(index_path, data, topk=4):
    embeddings_memmap = np.memmap(index_path, dtype=np.float32, mode='r')

    index = faiss.IndexFlatL2(768)  # 假设BERT输出维度是768

    # 确保embeddings_memmap是二维数组,如有需要转换
    if len(embeddings_memmap.shape) == 1:
        embeddings_memmap = embeddings_memmap.reshape(-1, 768)

    index.add(embeddings_memmap)

    results = []
    for i, text_emb in enumerate(tqdm(embeddings_memmap)):
        D, I = index.search(np.expand_dims(text_emb, axis=0), topk)  # 查找前topk个最近邻

        # 获取对应的 nature_df_img_id 的索引
        top_k_indices = I[0][:topk]  #
        # 根据索引提取 nature_df_img_id
        top_k_ids = [data.iloc[index]['index'] for index in top_k_indices]

        # 计算余弦相似度并构建字典
        cosine_similarities = [cosine_similarity(text_emb, embeddings_memmap[index]) for index in top_k_indices]
        top_similarity = dict(zip(top_k_ids, cosine_similarities))

        results.append((data['index'].to_list()[i], top_similarity))

    return results


# 使用余弦相似度公式,这里假设 cosine_similarity 是一个计算两个向量之间余弦相似度的函数
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


def main_search():
    data = pd.read_excel('新建 XLSX 工作表.xlsx')
    data['index'] = data.index
    similarities = search_top4_similarities('sentence_embeddings.npy', data)

    # 输出结果
    similar_df = pd.DataFrame(similarities, columns=['id', 'top'])
    similar_df.to_csv('similarities.csv', index=False)

# 执行搜索并保存结果
main_search()

结果查看

看一看到余弦数值还是比较合理的,没有那种明明不相关但余弦值是 0.9 的情况了,这两个模型还是可以的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1547879.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

图书馆培训英文

introduction move1 move2 move3 要么定性,要么定量 implications 指导意义 Corpus语料库 collocation 激烈的竞争 competiton ADJ 批改网www.pigai.org acdademic phrasebank 在这里插入图片描述 好文章不是写出来的,是改出来的。

Spring事务-两种开启事务管理的方式:基于注解的声明式事务管理、基于编程式的事务管理

Spring事务-两种开启事务管理的方式 1、前期准备2、基于注解的声明式事务管理3、基于编程式的事务管理4、声明式事务失效的情况 例子:假设有一个银行转账的业务,其中涉及到从一个账户转钱到另一个账户。在这个业务中,我们需要保证要么两个账户…

FastAPI+React全栈开发08 安装MongoDB

Chapter02 Setting Up the Document Store with MongoDB 08 Installing MongoDB and friends FastAPIReact全栈开发08 安装MongoDB The MongoDB ecosystem is composed of different pieces of software, and I remember that when I was starting to play with it, there w…

I.MX6ULL_Linux_驱动篇(57)linux Regmap API驱动

我们在前面学习 I2C 和 SPI 驱动的时候,针对 I2C 和 SPI 设备寄存器的操作都是通过相关的 API 函数进行操作的。这样 Linux 内核中就会充斥着大量的重复、冗余代码,但是这些本质上都是对寄存器的操作,所以为了方便内核开发人员统一访问 I2C/S…

解决错误LibreSSL SSL_connect: SSL_ERROR_SYSCALL in connection to

react native pod第三方包或者git clone的时候遇到 OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443两种解决方案 方法一 修改计算机网络配置 由于使用 IPv6 的原因,可能会导致这一问题的出现 系统在解析hostname时使用了ipv6 可以配…

Linux:Jenkins全自动持续集成持续部署(4)

在上一章部署好了之后,还需要点击一下才能进行部署,本章的效果是:当gitlab上的代码发生了变化后,我们不需要做任何事情不需要去点击构建按钮,Jenkins直接自动检测变化,然后自动去集成部署Linux:…

【echart】数据可视化

什么是数据可视化? 数据可视化主要目的:借助于图形化手段,清晰有效地传达与沟通信息。 数据可视化可以把数据从冰冷的数字转换成图形,揭示蕴含在数据中的规律和道理。 如何绘制? echarts 图表的绘制,大体分为三步:…

进入消息传递的魔法之门:ActiveMQ原理与使用详解

嗨,亲爱的童鞋们!欢迎来到这个充满魔法的世界,今天我们将一同揭开消息中间件ActiveMQ的神秘面纱。如果你是一个对编程稍有兴趣,但又对消息中间件一知半解的小白,不要害怕,我将用最简单、最友好的语言为你呈…

mysql安装及操作

一、Mysql 1.1 MySQL数据库介绍 1.1.1 什么是数据库DB? DB的全称是database,即数据库的意思。数据库实际上就是一个文件集合,是一个存储数据的仓库,数据库是按照特定的格式把数据存储起来,用户可以对存储的数据进行…

秋招打卡算法题第一天

一年多没有刷过算法题了,已经不打算找计算机类工作了,但是思来想去,还是继续找吧。 1. 字符串最后一个单词的长度 public static void main(String[] args) {Scanner in new Scanner(System.in);while(in.hasNextInt()){String itemin.nextL…

2024/3/26 C++作业

定义一个矩形类(Rectangle),包含私有成员:长(length)、宽(width), 定义成员函数: 设置长度:void set_l(int l) 设置宽度:void set_w(int w) 获取长度:int…

Wireshark 抓包

启动时选择一个有信号的网卡双击打开,或者在 捕获选择里打开选择网卡。 然后输出下面的规则就可以抓到报文了。 最上面的三条是建立连接时的三次握手, 下面是发送数据hello 对应两条数据 最下面的4条是断时的4次挥手

【真题解析】题目 3151: 蓝桥杯2023年第十四届省赛真题-飞机降落【C++ DFS 超详解注释版本】

爆搜冥想 暴力枚举每一辆飞机对于每一个飞机都只存在两种情况,可以降落和不可以降落如果可以降落,计算降落后最早可以降落的时间pre,作为下一次递归的传参如果不可以降落,枚举下一辆飞机 注意这辆的降落有盘旋这种量子叠加态&…

mysql刨根问底

索引:排好序的数据结构 二叉树: 红黑树 hash表: b-tree: 叶子相同深度,叶节点指针空,索引元素不重复,从左到右递增排序 节点带data btree: 非叶子节点只存储索引,可…

HTTPS 从懵懵懂懂到认知清晰、从深度理解到落地实操

Https 在现代互联网应用中,网上诈骗、垃圾邮件、数据泄露的现象时有发生。为了数据安全,我们都会选择采用https技术。甚至iOS开发调用接口的时候,必须是https接口,才能调用。现在有部分浏览器也开始强制要求网站必须使用https&am…

Python(Socket) +Unreal(HTTP)

Python(Socket) Unreal(HTTP) python(Socket):UE:Post请求并发送本机IP 上班咯,好久没记笔记了。。。 局域网 UE的apk,请求Python的Socket 跑起Socket ,UE发 …

长安链共识算法切换:动态调整,灵活可变

#功能发布 长安链3.0正式版发布了多个重点功能,包括共识算法切换、支持java智能合约引擎、支持后量子密码、web3生态兼容等。我们接下来为大家详细介绍新功能的设计、应用与规划。 随着长安链应用愈加成熟与广泛,一些在生产中很实用的需求浮出水面。长安…

RabbitMQ 《简单消息》

package com.xzp.rabbitmq.simple; import com.rabbitmq.client.Channel; import com.rabbitmq.client.Connection; import com.xzp.rabbitmq.util.ConnectionUtil; /** * "Hello World!" * 简单消息 * 消息发送者 - R - 发送消息(生产者) …

nandgame中的汇编语言(Assembler Language)

配置一个汇编器,将符号指令转换为二进制机器码。汇编器指令有三个部分:目标、计算和(可选的)跳转条件。目标是操作的输出写入的寄存器。计算是ALU操作。请参阅ALU级别的位模式。跳转条件是将触发跳转的条件。请参阅条件级别以获取…

初识云原生、虚拟化、DevOps

文章目录 K8S虚拟化DevOpsdevops平台搭建工具大数据架构 K8S master 主节点,控制平台,Master节点负责核心的调度、管理和运维,不需要很高性能,不跑任务,通常一个就行了,也可以开多个主节点来提高集群可用度…