文本匹配SimCSE模型代码详解以及训练自己的中文数据集

news2024/11/25 13:18:06

前言

在上一篇博客文本匹配中的示例代码中使用到了一个SimCSE模型,用来提取短文本的特征,然后计算特征相似度,最终达到文本匹配的目的。但是该示例代码中的短文本是用的英文短句,其实SimCSE模型也可以用于中文短文本的特征提取,本篇博客就基于苏沐剑发表于科学空间的中文任务还是SOTA吗?我们给SimCSE补充了一些实验博客中使用到的代码,来记录一下代码梳理的笔记,并且使用自己的数据集在这篇代码上进行训练。另外,关于这个模型的原理细节等,可以参考别的博主写的内容,还有就是作者的论文,这些会附在最后的参考链接。

代码详解

数据导入部分

数据导入部分的代码主要有三个步骤,(1)从txt中读取文本数据,常规操作,这里没什么可说的;

datasets = {
    '%s-%s' % (task_name, f):
    load_data('%s%s/%s.%s.data' % (data_path, task_name, task_name, f))
    for f in ['train', 'valid', 'test']
}

(2)将读取到的文本句子转换成id向量,同样也是常规操作;

def convert_to_ids(data, tokenizer, maxlen=64):
    """转换文本数据为id形式
    """
    a_token_ids, b_token_ids, labels = [], [], []
    for d in tqdm(data):
        token_ids = tokenizer.encode(d[0], maxlen=maxlen)[0]
        a_token_ids.append(token_ids)
        token_ids = tokenizer.encode(d[1], maxlen=maxlen)[0]
        b_token_ids.append(token_ids)
        labels.append(d[2])
    a_token_ids = sequence_padding(a_token_ids)
    b_token_ids = sequence_padding(b_token_ids)
    return a_token_ids, b_token_ids, labels

(3)第三步则是写了一个class,使用了一个生成器,完成数据batch读取。这里需要注意的是,每个batch中,同一个文本数据,输入了两次,一个batch中的两个一样的文本输入,由于模型最后一层的加入了dropout,模型输出结果是有些许差别的,这样有差别的输出,则可以互为label,这也是SimCSE模型巧妙的地方。

class data_generator(DataGenerator):
    """训练语料生成器
    """
    def __iter__(self, random=False):
        batch_token_ids = []
        for is_end, token_ids in self.sample(random):
            batch_token_ids.append(token_ids) ##同一条文本输入两次
            batch_token_ids.append(token_ids) ##同一条文本输入两次
            if len(batch_token_ids) == self.batch_size * 2 or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = np.zeros_like(batch_token_ids)
                batch_labels = np.zeros_like(batch_token_ids[:, :1])
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids = []

模型定义部分

这个模型的定义其实很简单,就是用bert作为特征提取的基础模型,然后再bert模型输出的基础上加上一个dropout操作,就是代码中的pooling层,核心代码就是下面几行

bert = build_transformer_model(
            config_path,
            checkpoint_path,
            model=model,
            with_pool='linear',
            dropout_rate=dropout_rate
        )
outputs, count = [], 0
while True:
	try:
      	output = bert.get_layer(
                'Transformer-%d-FeedForward-Norm' % count
           ).output
        outputs.append(output)
        count += 1
    except:
        break
output = bert.output
# 最后的编码器
encoder = Model(bert.inputs, output) 

模型的损失函数

模型的损失函数是所有代码中最难理解的部分,虽然代码只有十几行,但是最需要花费时间去理解的。
在阐述这个SimCSE模型的损失函数代码之前,首先要搞清楚,这个模型是要解决什么问题,其目的主要是为了提取短文本的特征,使得相似的句子,提取出来的特征距离更近,不同语义的句子,特征距离越远,这样使得提取出来的文本特征更具有辨识度,和人脸识别原理很类似,这就是对比学习模型系列想要达到的目的。

在了解了对比学习的大致原理之后,再来看代码,下面是解释

idxs = K.arange(0, K.shape(y_pred)[0])

这行代码就是模型输出的一个维度(模型输入的batchsize),构建一个索引,比如,模型输入batchsize为6,那idxs则就是[0,1,2,3,4,5]

idxs_1 = idxs[None, :]

这就是给idxs增加一个维度,使其变成[[0,1,2,3,4,5]]

idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]

这行代码比较关键,目的是让idxs向量中数值是奇数的赋值为它的前一个数,数值为偶数的则赋值为它后一个索引值,这个一前一后的赋值,就是它相似度最大的索引值(排除自己)。这里需要解释一下的是,这里每个索引值背后代表的是SimCSE模型输出的一个个的提取到的文本特征向量,维度是1*738,和bert模型输出应该是一样的维度。而这里为什么要取一前一后的赋值索引,这因为数据导入时候,在每个batch里面同一条文本被相邻的导入了两次,那么这两个相邻的文本,经过SimCSE模型提取到的特征也是最为相似的,其相似度要接近1,而每个batch里面不相邻的模型输出,则应该是0,这样模型才能达到收敛的效果

y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())

这两行代码就是可以将y_true变成一个batchsize * batchsize大小的相似度矩阵,相似度的规则和上面描述的一样

 y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20

这几行代码就是计算SimCSE模型预测出来每个batch里的每个文本特征之间的相似度,特征越相似,K.dot(y_pred, K.transpose(y_pred)),特征向量点乘越接近1,similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12,则是为了消除相似度矩阵对角线上的元素,即同一条特征自身与自身点乘的结果。

loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)

最后用交叉熵损失来定义模型最后的输出损失

训练自己的数据

在这个模型需要训练自己的数据,首先是环境搭建:

jieba-0.42.1
bert4keras-0.10.5
keras-2.3.1
cudatoolkit 10.0.130
cudnn  7.6.0 
tensorflow-gpu  1.13.1

然后准备数据集,格式如下:

在这里插入图片描述

txt这个标签,0,1可以有,也可以没有

接着就是下载预训练模型,bert的模型,下载之后,修改eval.py中的数据集和预训练模型的路径,将其修改成自己的路径
在这里插入图片描述
最后运行代码训练模型即可得到预测结果

在这里插入图片描述

参考链接

SimCSE论文及源码解读
SimCSE的loss实现源码解读
SimCSE: Simple Contrastive Learning of Sentence Embeddings
princeton-nlp/SimCSE

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/333407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用STM32 CUBE IDE配置STM32F7 用DMA传输多通道ADC数据

我的使用环境: 硬件:STM32F767ZGT6、串口1、ADC1、16MHz晶振、216MHz主频 软件:STM32 CUBE IDE 优点:不用定时触发采样,ADC数据是不停的实时更新,ADC数据的更新频率根据采样时钟和采样周期决定,…

负载均衡反向代理下的webshell上传+apache漏洞

目录一、负载均衡反向代理下的webshell上传1、nginx 负载均衡2、搭建环境3、负载均衡下的 WebShell连接的难点总结难点一、需要在每一台节点的相同位置都上传相同内容的 WebShell难点二、无法预测下次的请求交给哪台机器去执行。难点三、下载文件时,可能会出现飘逸&…

面试题:Redis的内存策略

1 Redis内存回收Redis之所以性能强,主要原因是基于内存存储,然而单节点的Redis内存不易过大,会影响主从同步和持久化性能我们可以通过修改配置文件设置Redis的最大内存:当内存存储到上限时,就无法存储更多的数据了。1.…

html控件Aspose.Html for .NET 授权须知

Aspose.Html for .NET是一种高级的HTML操作API,可让您直接在.NET应用程序中执行广泛的HTML操作任务,Aspose.Html for .NET允许创建,加载,编辑或转换(X)HTML文档,而无需额外的软件或工具。API还为…

Java开发学习(四十六)----MyBatisPlus新增语句之id生成策略控制及其简化配置

在前面有一篇博客:Java开发学习(四十一)----MyBatisPlus标准数据层(增删查改分页)开发,我们在新增的时候留了一个问题,就是新增成功后,主键ID是一个很长串的内容。 我们更想要的是按照数据库表字段进行自增…

CleanMyMac X4.12新版本下载及功能介绍

CleanMyMac X2023最新版终于迎来了又4.12,重新设计了 UI 元素,华丽的现代化风格显露无余。如今的CleanMyMac,早已不是单纯的系统清理工具。在逐渐融入系统优化、软件管理、文件管理等功能后,逐渐趋近于macOS的系统管家&#xff0c…

Python数据可视化(三)(pyecharts)

分享一些python-pyecharts作图小技巧,用于展示汇报。 一、特点 任何元素皆可配置pyecharts只支持python原生的数据类型,包括int,float,str,bool,dict,list动态展示,炫酷的效果,给人视觉冲击力 # 安装 pip install pyecharts fr…

算法训练营DAY51|300.最长递增子序列、674. 最长连续递增序列、718. 最长重复子数组

本期是求子序列的新的一期,题目前两道有一些相似之处,思路差不多,第三道有一点难度,但并不意味着第一道没有难度,没有做过该类型题的选手,并不容易解出题解。 300. 最长递增子序列 - 力扣(Leet…

22级浙江大学MBA笔试备考的若干经验分享

我是浙江大学2022级的一名新生,虽然没有参加提前批面试,但是通过笔试的有序备考最终也有幸上岸浙大,对于部分提前批面试没拿到优秀资格的考友,今天我想把自己的笔试上岸经验做个总结,给大家提供一个参考模版。 先…

TC358774XBG/TC358775XBG替代方案|CS5518替代TC358774XBG/TC358775XBG设计DSI转LVSD设计资料

TC358774XBG/TC358775XBG替代方案|CS5518替代TC358774XBG/TC358775XBG设计DSI转LVSD设计资料 TC358774XBG/TC358775XBG 芯片的主要功能是作为 DSI - LVDS 通信协议桥接,主芯片的视频数据可通过 DSI 链路流 出,以驱动兼容 LVDS 的显示板。换句话说&#x…

百度官宣在前,阿里、京东在后,互联网大厂向ChatGPT而生?

ChatGPT蹿红后,互联网科技公司都坐不住了。 最早,百度正式对外官宣类ChatGPT项目“文心一言”(ERNIE Bot)。据笔者了解,该产品将于三月份完成内测,面向公众开放。 紧随其后,阿里巴巴公布阿里版…

流浪地球 | 建筑人是如何看待小破球里的黑科技的?

大家好,这里是建模助手。 想问问大家今年贺岁档,都跟上没有,今天请允许我蹭一下热点表达一下作为一个科幻迷的爱国之情。 抛开大刘的想象力、各种硬核科技&以及大国情怀不提,破球2中的传承还是让小编很受感动,无…

【2023】Prometheus-Prometheus与Alertmanager配置详解

记录一下Prometheus与Alertmanager的配置参数等内容 目录1.Prometheus1.1.prometheus.yml1.2.告警规则定义2.alertmanager2.1.alertmanager.yml2.1.1.global:全局配置2.1.1.1.以email方式作为告警发送方2.1.1.2.以wechat方式作为告警发送方2.1.1.3.以webhook方式作为…

c++基础入门二

一、数组的引用int main() {int a 10, b 20;int ar[10] { 1,2,3,4,6,7 };int& x ar[0];int& p[5] ar;//errorint(&p)[10] ar;//引用整个数组的大小sizeof(ar)int(*p)[10] &ar;//typesize表示整个数组//只有在这三种情况下代表整个数组,其他情…

C++ 浅谈之 STL Vector

C 浅谈之 STL Vector HELLO,各位博友好,我是阿呆 🙈🙈🙈 这里是 C 浅谈系列,收录在专栏 C 语言中 😜😜😜 本系列阿呆将记录一些 C 语言重要的语法特性 🏃&…

18-考虑柔性负荷的综合能源系统低碳经济优化调度MATLAB程序

参考文献:考虑柔性负荷的综合能源系统低碳经济优化调度_薛开阳考虑用户侧柔性负荷的社区综合能源系统日前优化调度_刘蓉晖主要内容:基础模型参考刘蓉晖的论文,主要做了场景1、2、3;碳交易模型采用薛开阳论文中的。采用CPIEX求解某…

ArcGIS API for JavaScript 4.15系列(2)——Dojo中的dom操作

1、前言 ArcGIS API for JavaScript是基于Dojo框架编写的开发包,因此了解并掌握Dojo的相关基础知识是极为必要的。很多开发者都反馈过一个问题,那就是一看见ArcGIS API for JavaScript里那些奇形怪状的代码就觉得无从下手。有一点必须得承认&#xff1a…

Xshell 安装及使用方法

公网地址:47.XXX.XXX.229 私网地址:172.XXX.128.XXX 用户:root 密码:1234561,百度xshell,下载,安装Xshell 2,填写配置及使用方式 主机:47.XXX.XXX.229 用户:root 密码&a…

SpringCloud学习笔记 - 系统自适应限流 - Sentinel

1. Sentinel 系统自适应限流 Sentinel 系统自适应限流从整体维度对应用入口流量进行控制,结合应用的 Load、CPU 使用率、总体平均 RT、入口 QPS 和并发线程数等几个维度的监控指标,通过自适应的流控策略,让系统的入口流量和系统的负载达到一…

第五十一章 BFS进阶(一)——双端队列广搜

第五十一章 BFS进阶(一)——双端队列广搜一、原理二、例题1、问题2、分析三、代码一、原理 在介绍双端队列广搜之前,我们先回顾一下堆优化版本的dijkstradijkstradijkstra算法。 在这个算法中,我们使用的是小根堆来找到距离起点…