bert 相似度任务训练,简单版本

news2024/9/22 5:42:02

目录

任务

代码

train.py

predit.py

数据


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset


# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5


# 准备数据集
class MyDataset(Dataset):
    def __init__(self, data_file_paths):
        self.texts = []
        self.labels = []
        # 分词器用默认的
        self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
        # 自己实现对数据集的解析
        with open(data_file_paths, 'r', encoding='utf-8') as f:
            for line in f:
                text1, text2, label = line.split('\t')
                self.texts.append((text1, text2))
                self.labels.append(int(label))

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text1, text2 = self.texts[idx]
        label = self.labels[idx]
        encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        return encoded_text, label


# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')

# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0

for epoch in range(epoch):
    trained_data = 0
    for batch in train_loader:
        inputs, labels = batch
        # 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉
        inputs['input_ids'] = inputs['input_ids'].squeeze(1)
        inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)
        # 因为要用GPU,将数据传输到gpu上
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss, logits = outputs[:2]
        loss.backward()
        optimizer.step()
        trained_data += len(labels)
        trained_process = float(trained_data) / len(train_dataset)
        batch_after_last_save += 1
        total_batch += 1
        # 每训练 auto_save_batch 个 batch,保存一次模型
        if batch_after_last_save >= auto_save_batch:
            batch_after_last_save = 0
            model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
            print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))
        print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))
    total_epoch += 1
    model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
    print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePrediction


tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')

with torch.no_grad():
    with open('../data/test.txt', 'r', encoding='utf8') as f:
        lines = f.readlines()
        correct = 0
        for i, line in enumerate(lines):
            text1, text2, label = line.split('\t')
            encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
            outputs = model(**encoded_text)
            res = torch.argmax(outputs.logits, dim=1).item()
            print(text1, text2, label, res)
            if str(res) == label.strip('\n'):
                correct += 1
            print(f'{i + 1}/{len(lines)}')
        print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1481112.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第三百七十六回

文章目录 1 .概念介绍2. 实现方法3. 示例代码 我们在上一章回中介绍了在页面之间共传递数据相关的内容,本章回中将介绍如何拦截路由.闲话休提,让我们一起Talk Flutter吧。 1 .概念介绍 本章回中介绍的路由拦截是指在路由运行过程中,对路由做…

01tire算法

01tire算法 #include<bits/stdc.h> using namespace std; #define maxn 210000 int a[maxn], ch[maxn][2], val[maxn], n, ans, tot; void insert(int x) {int now 0;for (int j 31; j > 0; j -- ){int pos ((x >> i) & 1);if (!ch[now][pos])ch[now][po…

elasticsearch7.17 terms聚合性能提升90%+

背景 ES7 相比于 ES6 有多个层面的优化&#xff0c;对于开源的ES而言&#xff0c;升级是必经之路。 ES的使用场景非常多&#xff0c;在升级过程中可能会遇到非预期的结果&#xff1b; 比如之前文章提到的典型案例&#xff1a;ES7.17版本terms查询性能问题 ES7.17版本terms查…

服务端向客户端推送数据的实现方案

在日常的开发中&#xff0c;我们经常能碰见服务端需要主动推送给客户端数据的业务场景&#xff0c;比如数据大屏的实时数据&#xff0c;比如消息中心的未读消息&#xff0c;比如聊天功能等等。 本文主要介绍SSE的使用场景和如何使用SSE。 服务端向客户端推送数据的实现方案有哪…

maven项目导入mysql依赖

最近在B站跟着狂神学习Mybatis&#xff0c;学到P2就卡住了&#xff0c;搭建的maven项目一直无法导入依赖&#xff0c;在网上查找了很多相关的解决方法&#xff0c;project structure不知道点进去多少回&#xff0c;始终无法解决&#xff0c;后来把responsity文件夹删除重置了一…

【代码随想录python笔记整理】第十六课 · 出现频率最高的字母

前言:本笔记仅仅只是对内容的整理和自行消化,并不是完整内容,如有侵权,联系立删。 一、哈希表初步 在之前的学习中,我们使用数组、字符串、链表等等,假如需要找到某个节点,则都要从头开始,逐一比较,直到找到为止。为了能够直接通过要查找的记录找到其存储位置,我们选…

RV1126芯片概述

RV1126芯片概述 前言1 主要特性2 详细参数 前言 1 主要特性 四核 ARM Cortex-A7 and RISC-V MCU250ms快速开机2.0Tops NPU14M ISP with 3帧 HDR支持3个摄像头同时输入4K H.264/H.265 视频编码和解码 2 详细参数

【王道数据结构】【chapter8排序】【P371t5】

编写一个算法&#xff0c;在基于单链表表示的待排序关键字序列上进行简单选择排序 #include <iostream> #include <time.h> #include <stdlib.h> typedef struct node{int data;struct node *next; }node,*pnode;pnode buynode(int x) {pnode tmp(pnode) mal…

2024腾讯云服务器8888元代金券领取、主机价格表新鲜出炉!

腾讯云优惠活动2024新春采购节活动上线&#xff0c;云服务器价格已经出来了&#xff0c;云服务器61元一年起&#xff0c;配置和价格基本上和上个月没什么变化&#xff0c;但是新增了8888元代金券和会员续费优惠&#xff0c;腾讯云百科txybk.com整理腾讯云最新优惠活动云服务器配…

遥感影像处理(ENVI+ChatGPT+python+ GEE)处理高光谱及多光谱遥感数据

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境&#xff0c;是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型&#xff0c;在理解和生成人类语言方面表现出了非凡的能力。本文重点介绍ChatGPT在遥感中的应用&#xff0c;人工智能…

【vue3】命令式组件封装,message封装示例;(函数式组件?)

仅做代码示例&#xff1b;当然改进的地方还是不少的&#xff0c;仅作为该类组件封装方式的初步启发&#xff1b; 理想大成肯定是想要像 饿了么 这些组件库一样。 有的人叫这函数式组件&#xff0c;有的人叫这命令式组件&#xff0c;我个人还是偏向于命令式组件的称呼。因为以vu…

【JVM】JVM相关机制

1. JVM内存区域划分 1.1 内存区域划分简介 内存区域划分&#xff1a;实际上JVM也是一个进程&#xff0c;进程运行时需要向操作系统申请一些系统资源&#xff08;内存就是典型的资源&#xff09;&#xff0c;这些内存空间就支撑着后续Java程序的运行&#xff0c;而这些内存又会…

php 支持mssqlserver

系统不支持:sqlsrv 需要一下几个环节 1.准备检测php版本 查看 VC 版本 查看操作系统位数&#xff1a;X86(32位) 和X64 2.下载php的sqlserver库 extensionphp_sqlsrv_74_nts_x64.dll extensionphp_pdo_sqlsrv_74_nts_x64.dll extensionphp_sqlsrv_74_nts_x64 extensionphp_…

用HTML5的<canvas>元素实现刮刮乐游戏

用HTML5的<canvas>元素实现刮刮乐 用HTML5的<canvas>元素实现刮刮乐&#xff0c;要求&#xff1a;将上面的“图层”的图像可用鼠标刮去&#xff0c;露出下面的“图层”的图像。 示例从简单到复杂。 简单示例 准备两张图像&#xff0c;我这里上面的图像top_imag…

【Spring】spring中怎么解决循环依赖的问题

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a;Spring ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 解决步骤 考虑 结语 我的其他博客 前言 在软件开发中&#xff0c;依赖注入是一种常见的设计模式&#xff0c;它可以帮助我们管…

探讨javascript的程序性能

如果阅读有疑问的话&#xff0c;欢迎评论或私信&#xff01;&#xff01; 本人会很热心的阐述自己的想法&#xff01;谢谢&#xff01;&#xff01;&#xff01; 文章目录 Web WorkerWorker之间通讯Worker销毁 Web Worker 当我们需要处理一些比较耗时的任务时&#xff0c;我们…

杭电OJ 2045 不容易系列之(3)—— LELE的RPG难题 C++

思路&#xff1a;我先模拟了一下1&#xff0c;2&#xff0c;3的情况&#xff0c;对应的是3 6 6&#xff0c;模拟到4的时候就有感觉了&#xff0c;1是不受到任何制约的&#xff0c;2到n-1是收到了前面一个的制约&#xff0c;n受到了n-1与1的制约&#xff0c;那么就可以去判断4 …

七通道NPN 达林顿管GC2003,专为符合标准 TTL 而制造,最高工作电压 50V,耐压 80V

GC2003 内部集成了 7 个 NPN 达林顿晶体管&#xff0c;连接的阵列&#xff0c;非常适合逻辑接口电平数字电路&#xff08;例 如 TTL&#xff0c;CMOS 或PMOS 上/NMOS&#xff09;和较高的电流/电压&#xff0c;如电灯电磁阀&#xff0c;继电器&#xff0c;打印机或其他类似的负…

构造pop链

反序列化视频笔记 第一步&#xff1a;找到目标触发echo调用$flag 第二步&#xff1a;触发_invoke函数调用appeng函数$varflag.php&#xff08;把对象当成函数&#xff09; 第三步&#xff1a;给$p赋值为对象&#xff0c;即function成为对象Modifier却被当成函数调用&#xff…

csv大数值不显示E科学计算法的解决方案

背景&#xff1a; 从其他系统获取到一个商品mid的大的数值的csv文件&#xff0c;然后使用excel打开的时候有各种问题&#xff0c;本文记录下怎么正确的展示这个大数值的csv文件 正确展示数值精度&#xff1a; 数值展示错误 正确展示的方法&#xff1a; 1使用文本编辑器比如…