bert 相似度任务训练简单版本,faiss 寻找相似 topk

news2024/12/22 22:09:23

目录

任务

代码

train.py

predit.py

faiss 最相似的 topk


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset


# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5


# 准备数据集
class MyDataset(Dataset):
    def __init__(self, data_file_paths):
        self.texts = []
        self.labels = []
        # 分词器用默认的
        self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
        # 自己实现对数据集的解析
        with open(data_file_paths, 'r', encoding='utf-8') as f:
            for line in f:
                text1, text2, label = line.split('\t')
                self.texts.append((text1, text2))
                self.labels.append(int(label))

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text1, text2 = self.texts[idx]
        label = self.labels[idx]
        encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        return encoded_text, label


# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')

# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0

for epoch in range(epoch):
    trained_data = 0
    for batch in train_loader:
        inputs, labels = batch
        # 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉
        inputs['input_ids'] = inputs['input_ids'].squeeze(1)
        inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)
        # 因为要用GPU,将数据传输到gpu上
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss, logits = outputs[:2]
        loss.backward()
        optimizer.step()
        trained_data += len(labels)
        trained_process = float(trained_data) / len(train_dataset)
        batch_after_last_save += 1
        total_batch += 1
        # 每训练 auto_save_batch 个 batch,保存一次模型
        if batch_after_last_save >= auto_save_batch:
            batch_after_last_save = 0
            model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
            print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))
        print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))
    total_epoch += 1
    model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
    print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePrediction


tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')

with torch.no_grad():
    with open('../data/test.txt', 'r', encoding='utf8') as f:
        lines = f.readlines()
        correct = 0
        for i, line in enumerate(lines):
            text1, text2, label = line.split('\t')
            encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
            outputs = model(**encoded_text)
            res = torch.argmax(outputs.logits, dim=1).item()
            print(text1, text2, label, res)
            if str(res) == label.strip('\n'):
                correct += 1
            print(f'{i + 1}/{len(lines)}')
        print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

faiss 最相似的 topk

使用 faiss 寻找 topk 相似的,从结果上看最相似的基本都还是找到排到较为靠前的位置

import torch
import faiss
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel


# 假设有一个数据集df,其中包含'index'列和'text'列
df = pd.read_csv('../data/DataAnalyst.csv', encoding='gbk')  # 根据实际情况加载数据集
df = df.dropna().drop_duplicates().reset_index()
df['index'] = df.index
df = df[['index', '公司所在商区']]  # 保留所需列
df['公司所在商区'] = df['公司所在商区'].map(lambda row: ','.join(eval(row)))

# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

# 加载微调好的模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertModel.from_pretrained('../output/cn_equal_model_3_171.pth')
model.eval()


# 将数据集转化为模型所需的格式并计算所有样本的向量表示
def encode_texts(df):
    text_vectors = []
    for index, row in df.iterrows():
        text = row['公司所在商区']
        inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        with torch.no_grad():
            embeddings = model(**inputs.to(device))['last_hidden_state'][:, 0]
            text_vectors.append(embeddings.cpu().numpy())
        print(f'{index + 1}/{len(df)}')
    return np.vstack(text_vectors)

# 加载数据集并计算所有样本的向量
print('enbedding all data...')
all_embeddings = encode_texts(df)

# 初始化Faiss索引
print('init faiss all embedding...')
index = faiss.IndexFlatIP(all_embeddings.shape[1])  # 使用内积空间,适用于余弦相似度
index.add(all_embeddings)
print('init faiss all embedding finish~~~')


# 定义查找最相似样本的函数
def find_top_k_similar(query_text, k=100):
    print('当前 query_text embedding.')
    query_embedding = encode_single_text(query_text)
    print('begin to search topk....')
    D, I = index.search(query_embedding, k)  # 返回距离和索引
    top_k_indices = df.iloc[I[0]].index.tolist()  # 将索引转换为原始数据集的索引
    return top_k_indices


# 编码单个文本的函数
def encode_single_text(text):
    inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    with torch.no_grad():
        embedding = model(**inputs.to(device))['last_hidden_state'][:, 0].cpu().numpy()
    print('当前 query_text embedding finish!')
    return embedding


# 示例:找一个query_text的top10相似样本
query_text = "左家庄,国展,西坝河"
top10_indices = find_top_k_similar(query_text)
# 获取与查询文本最相似的前10条原始文本
top10_texts = [df.loc[index, '公司所在商区'] for index in top10_indices]

print(f"与'{query_text}'最相似的前100条样本及其文本:")
for i, (idx, text) in enumerate(zip(top10_indices, top10_texts)):
    print(f"{i+1}. 索引:{idx},文本:{text}")

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
链接:https://pan.baidu.com/s/1CTntG1Z6AIhiPt6i8Ad97Q 
提取码:x6sz 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1490076.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在idea中用模板骨架初始创建maven管理的web项目时没有src有关的目录的解决方案

一.问题如下 二.解决方法 首先关闭当前项目,接着修改全局设置,重新创建项目 在VM Options中添加"-DarchetypeCataloginternal",点击ok保存 点击创建,如果创建成功没报错且有src,就ok了。 当然如果出现以下…

【C++】十大排序算法之 插入排序 希尔排序

本次介绍内容参考自:十大经典排序算法(C实现) - fengMisaka - 博客园 (cnblogs.com) 排序算法是《数据结构与算法》中最基本的算法之一。 十种常见排序算法可以分为两大类: 比较类排序:通过比较来决定元素间的相对次序…

大厂报价查询系统性能优化之道!

0 前言 机票查询系统,日均亿级流量,要求高吞吐,低延迟架构设计。提升缓存的效率以及实时计算模块长尾延迟,成为制约机票查询系统性能关键。本文介绍机票查询系统在缓存和实时计算两个领域的架构提升。 1 机票搜索服务概述 1.1 …

C++的类与对象(二)

目录 结构体内存对其规则 相关面试题 this指针 相关面试题 结构体内存对其规则 1、第一个成员在与结构体偏移量为0的地址处 2、其它成员变量要对齐到某个数字(对齐数)的整数倍的地址处 对齐数 编译器默认对齐数与该成员大小的较小值(v…

学习记录12-单片机代码几种常见命名规则

良好的编程习惯,决定了今后代码的质量。 有很多人平时不注意自己的代码规范,函数和变量命命随心所欲,造成一个星期就不认识自己的代码,于是今天就来分享一点关于软件代码常见的几种命名规则。 匈牙利命名法 匈牙利命名法广泛应用…

RBAC实战

一、权限控制概述 1.1、访问控制目的 在实际的组织中,为了完成组织的业务工作,需要在组织内部设置不同的职位,职位既表示一种业务分工,又表示一种责任与权利。根据业务分工的需要,职位被划分给不同群体,各…

C++:Vector的模拟实现

创作不易,感谢三连 !! 一,前言 在学习string类的时候,我们可能会发现遍历的话下标访问特别香,比迭代器用的舒服,但是下标其实只能是支持连续的空间,他的使用是非常具有局限性的&am…

迷不迷糊?前后端、三层架构和MVC傻傻分不清

现在的项目都讲究前后端分离,那到底什么是前后端,前后端和以前的MVC以及三层架构啥关系呢?今天就这个问题展开一下,方面后面的学习,因为前面讲的jsp、servlet和javabean根据实例,基本上有一个框架的理解了&…

基于STC12C5A60S2系列1T 8051单片机的TM1638键盘数码管模块的按键扫描、数码管显示按键值、显示按键LED应用

基于STC12C5A60S2系列1T 8051单片机的TM1638键盘数码管模块的按键扫描、数码管显示按键值、显示按键LED应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍TM1638键盘…

【如何在Docker中,修改已经挂载的卷(Volume)】

曾梦想执剑走天涯,我是程序猿【AK】 提示:添加投票!!! 目录 简述概要知识图谱 简述概要 如何在Docker中,修改已经挂载的卷(Volume) 知识图谱 在Docker中,修改已经挂载…

消息队列-kafka-消息发送流程(源码跟踪)

官方网址 源码:https://kafka.apache.org/downloads 快速开始:https://kafka.apache.org/documentation/#gettingStarted springcloud整合 发送消息流程 主线程:主线程只负责组织消息,如果是同步发送会阻塞,如果是异…

安装Proxmox VE虚拟机平台

PVE是专业的虚拟机平台,可以利用它安装操作系统,如:Win、Linux、Mac、群晖等。 1. 下载镜像 访问PVE官网,下载最新的PVE镜像。 https://www.proxmox.com/en/downloads 2. 下载balenaEtcher balenaEtcher用于将镜像文件&#…

【Vue3】3-6 : 仿ElementPlus框架的el-button按钮组件实

文章目录 前言 本节内容实现需求完整代码如下: 前言 上节,我们学习了 slot插槽,组件内容的分发处理 本节内容 本小节利用前面学习的组件通信知识,来完成一个仿Element Plus框架的el-button按钮组件实现。 仿造的地址:uhttps://…

docker pull 拉取失败,设置docker国内镜像

遇到的问题 最近在拉取nginx时,显示如下错误:Error response from daemon: Get “https://registry-1.docker.io/v2/”: net/http: request canceled (Client.Timeout exceeded while awaiting headers)。 这个的问题是拉取镜像超时,通过检索…

基于Golang客户端实现Nacos服务注册发现和配置管理

基于Golang客户端实现Nacos服务注册发现和配置管理 背景 最近需要把Golang实现的一个web项目集成到基于Spring Cloud Alibaba的微服务体系中,走Spring Cloud Gateway网关路由实现统一的鉴权入口。 软件版本 组件名称组件版本Nacos2.2.0Go1.21.0Ginv1.9.1Nacos-s…

项目部署发布

目录 上传数据库 修改代码中的数据源配置 修改配置文件中的日志级别和日志目录 打包程序 ​编辑​编辑 上传程序 查看进程是否在运行 以及端口 云服务器开放端口(项目所需要的端口) 上传数据库 通过xshell控制服务器 创建目录 mkdir bit_forum 然后进入该目录 查看路…

【AI+CAD】(一)ezdxf 解析DXF文件

DXF文件格式理解 DXF文件格式是矢量图形文件格式,其详细说明了如何表示不同的图形元素。 DXF是一个矢量图形文件,它捕获CAD图形的所有元素,例如文本,线条和形状。更重要的是,DXF是用于在CAD应用程序之间传输数据的图形…

Java日志框架的纷争演进与传奇故事

在Java的世界里,日志记录是每一个应用不可或缺的部分。它帮助开发者了解应用的运行状态、调试问题、监控性能等。而在这背后,是一系列日志框架的发展与演进。今天,就让我们一起回顾这些日志框架的历史,探寻它们背后的故事。 1. Lo…

分布式数据库中全局自增序列的实现

自增序列广泛使用于数据库的开发和设计中,用于生产唯一主键、日志流水号等唯一ID的场景。传统数据库中使用Sequence和自增列的方式实现自增序列的功能,在分布式数据库中兼容Oracle和MySQL等传统数据库语法,也是基于Sequence和自增列的方式实现…

使用Visual Studio 2022 创建lib和dll并使用

概述:对于一个经常写javaWeb的人来说,使用Visual Studio似乎没什么必要,但是对于使用ffi的人来说,使用c或c编译器,似乎是必不可少的,下面我将讲述如何用Visual Studio 2022 来创建lib和dll,并使用。 静态库…