Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

news2025/1/20 14:54:50

Deeplearning4j 实战 (22):基于DSSM的语义匹配建模

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

DSSM是微软在2013年提出的,最早用于搜索引擎语义召回的双塔模型。目前在工业界也广泛用于推荐召回、搜索相关性排序、语义召回等环节。DSSM是一个轻量级模型,在线上serving的时候,可以通过对query向量和doc向量计算内积,得到的相似值用来衡量query和doc的相似度,从而进行进一步的排序。下面就分别从DSSM模型结构、基于DL4J的DSSM建模、对开源数据集LCQMC的建模等几个环节来介绍如何使用DSSM模型。当然,由于DSSM模型的论文发表时间较早,发表时给出的模型结构比较简单,在我们具体实现的时候,会做一些调整,具体在介绍模型搭建的部分会提到。

1. DSSM模型简述

在论文中,query和doc分别通过各自独立的神经网络映射成一个语义向量。需要注意的是,原论文中doc是一个包含正样本和负样本的集合。正样本取1个,负样本取4个。论文中有提到,正样本是搜素后被点击的样本,负样本则是随机选取的搜索未被点击的样本集合。通过分别计算query的语义向量和正负doc样本的语义向量的余弦相似度,再通过softmax函数得到正负样本的概率分布后,和label计算交叉熵损失。这就是DSSM模型的大致的idea。下面先看下论文中对于DSSM描绘的架构图:
在这里插入图片描述
通过模型架构图可以看到,论文中是使用最简单的MLP对输入进行映射。这里需要提一下word hashing的操作。由于2013年时候word embedding技术还不是较广泛的使用,因此论文中的word hashing是在n-gram语言模型的基础上,通过hash操作将接近50W的词表计算每个词的索引值。这在当时是一种比较高效的做法,目前由于硬件的进步以及embedding技术的进一步成熟,可以直接使用预训练的embedding向量或者做端到端的建模。因此,在第三部分中构建DSSM模型的过程中,我们也是使用的端到端的方案。
在这里插入图片描述

上面这张截图中模型训练的有关描述。就像在本节开始时候提到的,通过softmax计算query和每个doc的余弦相似度的值归一化概率分布。由于softmax函数与cosine相似度的一致性,因此相似度越高的query-doc pair,其softmax值也会越接近于1。在损失函数部分,使用的是经典的log loss。这部分没啥说的。
另外需要说明的是,从ranking loss的角度,论文中的loss应当属于list-wise loss。当然,如果将负样本减少到一个或者doc集合中只有一个正样本或负样本(softmax更改为sigmoid函数),那就退化成pair-wise loss或者point-wise loss。为了方便起见,在第三部分的建模过程中,我们会使用point-wise loss。
对于搜索场景来说,双塔的输入分别是query和doc。对于推荐场景来说,双塔的输入可以是user和item或者item和item,用于U2I的召回或者I2I的召回。

2. LCQMC数据集

LCQMC是哈工大和阿里共同开源的用于QA的数据集,详情可参见论文。下载链接为:地址。压缩包中共有三个文件,三个文件都是以制表符作为分隔符。我们先来看下用于训练的部分数据的截图:
在这里插入图片描述
文件中有三列。最后一列用1或者0来代表 text_a 和 text_b两列文本的是否相关。如果把text_a列文本看作是query,那text_b列可以看作是doc。用于验证的文件中的内容也和训练文件中的数据格式相似,这里就不做另外截图了。

最后提一下,训练样本数量是:238767,验证的样本数量是:12501。

3. 基于DL4J的DSSM模型构建

在第一部分中,我们提到DSSM的论文中双塔内部是使用MLP结构。考虑到MLP结构的单一性,我们使用Embedding+LSTM+MLP的结构作为双塔的内部结构。虽然query和doc对应的塔结构相同,但是不做参数的共享。另外,由于LCQMC数据集中label是1或者0,因此我们将DSSM的输出层改为sigmoid + binary cross entropy loss。具体我们先给出代码片段:

private static ComputationGraph getDSSM(final int QUERY_VOCAB_SIZE, final int DOC_VOCAB_SIZE, 
										final int VECTOR_SIZE) {
	ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
		    .updater(new Adam(5 * 1e-3))
		    .weightInit(WeightInit.XAVIER)
		    .seed(12345L)
		    .graphBuilder()
		    .addInputs("query", "doc")
		    .setInputTypes(InputType.recurrent(QUERY_VOCAB_SIZE), InputType.recurrent(DOC_VOCAB_SIZE))
		    .addLayer("query-embedding", new EmbeddingSequenceLayer.Builder()
		    					.nIn(QUERY_VOCAB_SIZE + 1)
		    					.nOut(VECTOR_SIZE).build(), "query")
		    .addLayer("query-embedding-lstm", new LSTM.Builder()
					.nIn(VECTOR_SIZE)
					.nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "query-embedding")
		    .addLayer("doc-embedding", new EmbeddingSequenceLayer.Builder()
								.nIn(DOC_VOCAB_SIZE + 1)
								.nOut(VECTOR_SIZE).build(), "doc")
		    .addLayer("doc-embedding-lstm", new LSTM.Builder()
								.nIn(VECTOR_SIZE)
								.nOut(VECTOR_SIZE).activation(Activation.TANH).build(), "doc-embedding")
		    .addVertex("query-embedding-lstm-last-output", new LastTimeStepVertex("query"), "query-embedding-lstm")
		    .addVertex("doc-embedding-lstm-last-output", new LastTimeStepVertex("doc"), "doc-embedding-lstm")
		    .addLayer("query-output", new DenseLayer.Builder()
										.nIn(VECTOR_SIZE)
										.nOut(VECTOR_SIZE / 2)
										.activation(Activation.LEAKYRELU)
										.build(), "query-embedding-lstm-last-output")
		    .addLayer("doc-output", new DenseLayer.Builder()
										.nIn(VECTOR_SIZE)
										.nOut(VECTOR_SIZE / 2)
										.activation(Activation.LEAKYRELU)
										.build(), "doc-embedding-lstm-last-output")
		    .addVertex("query-output-l2-norm", new L2NormalizeVertex(), "query-output")
		    .addVertex("doc-output-l2-norm", new L2NormalizeVertex(), "doc-output")
		    .addVertex("cosing-similar", new ElementWiseVertex(ElementWiseVertex.Op.Product), "query-output-l2-norm", "doc-output-l2-norm")
		    .addLayer("out", new OutputLayer.Builder()
                    				.lossFunction(LossFunctions.LossFunction.XENT)	//bce
                    				.nIn(VECTOR_SIZE / 2).nOut(1).activation(Activation.SIGMOID).build(), "cosing-similar")
		    .setOutputs("out")
		    .build();

	ComputationGraph net = new ComputationGraph(conf);
	net.setListeners(new ScoreIterationListener(1));
	net.init();
	return net;
}

由于存在两个输入,因此使用DL4J中的ComputationGraph。这里需要说明的有几点:

  • LastTimeStepVertex的作用:获取LSTM最后一个time step输出的张量
  • L2NormalizeVertex的作用:L2归一化,将query和doc的向量转化为单位向量
  • ElementWiseVertex的作用:通过设置Op为点积,实际为计算query和doc单位向量的内积,因此L2NormalizeVertex + ElementWiseVertex联合起来的作用是计算向量间的余弦相似度值
  • 输出端使用sigmoid + bce 作point-wise的损失函数
    在这里插入图片描述

上面的截图中通过summary接口打印的模型结构和待训练参数。可见待训练参数68W。

另外,对于该静态方法,输入的几个参数QUERY_VOCAB_SIZE,DOC_VOCAB_SIZE,VECTOR_SIZE分别代表LCQMC数据集中text_a的词表大小和text_b的词表大小,以及词向量的大小。

需要指出的是,在第四部分进行建模的操作中,我们使用中文单字作为query和doc的最小粒度特征,而不做分词的处理。

4. DSSM模型训练和评估

首先介绍下数据处理的部分:

  • 读取训练文件,构建中文单字和单字的索引,存储在map结构中。同时记录最长的文本长度,用于后续的padding操作。
  • 再次读取文件,对每条记录构建MultiDataSet对象,并存储在LinkedList对象中。MultiDataSet对象中会存储query和doc作为输入,label作为输出,此外还有query和doc的mask张量,用于统一变长文本的处理。

我们看下具体的实现逻辑:

class DataSetInfo{
	public Map<String,Integer> queryDict = new TreeMap<>();
	public Map<String,Integer> docDict = new TreeMap<>();
	public int queryMaxLen = 0;
	public int docMaxLen = 0;
}

private static DataSetInfo preprocess(String filePath) {
	DataSetInfo info = new DataSetInfo();
	try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){
		String line = null;
		int lineIndex = 0;
		while( (line = br.readLine()) != null ) {
			if( lineIndex == 0 ) {
				lineIndex++;
				continue;
			}
			String[] splits = line.split("\t");
			if( null == splits || splits.length != 3 )continue;
			String query = splits[0];
			String doc = splits[1];
			if( query != null && query.length() > 0 ) {
				info.queryMaxLen = Math.max(query.length(), info.queryMaxLen);
				for( char c : query.toCharArray() ) {
					String charStr = String.valueOf(c);
					if( !info.queryDict.containsKey(charStr) ) {
						int curIndex = info.queryDict.size();
						info.queryDict.put(charStr, curIndex);
					}
				}
			}
			if( doc != null && doc.length() > 0 ) {
				info.docMaxLen = Math.max(doc.length(), info.docMaxLen);
				for( char c : doc.toCharArray() ) {
					String charStr = String.valueOf(c);
					if( !info.docDict.containsKey(charStr) ) {
						int curIndex = info.docDict.size();
						info.docDict.put(charStr, curIndex);
					}
				}
			}
		}
	}catch(Exception ex) {
		ex.printStackTrace();
	}finally {
		int curIndex = info.queryDict.size();
		info.queryDict.put("UNK", curIndex);
		//
		curIndex = info.docDict.size();
		info.docDict.put("UNK", curIndex);
	}
	return info;
}

这部分处理逻辑比较清晰,主要是先定义个DataSetInfo的类,里面包含了单字和单字索引的映射关系,还有最大文本长度。在finally部分,我们使用UNK代表所有未登录词。接着看下MultiDataSet的构造:

private static List<org.nd4j.linalg.dataset.api.MultiDataSet> getMultiDataIter(String filePath, DataSetInfo dataInfo) {
	List<org.nd4j.linalg.dataset.api.MultiDataSet> list = new LinkedList<>();
	try(BufferedReader br = Files.newReader(new File(filePath), Charset.forName("UTF-8"))){
		String line = null;
		int lineIndex = 0;
		while( (line = br.readLine()) != null ) {
			if( lineIndex == 0 ) {
				lineIndex++;
				continue;
			}
			String[] splits = line.split("\t");
			String query = splits[0];
			String doc = splits[1];
			String label = splits[2];
			if( query == null || query.isEmpty() ||
				doc == null || doc.isEmpty() || 
				label == null)continue;
			//
			double[][] queryIndexArray = new double[1][dataInfo.queryMaxLen];
			double[][] docIndexArray = new double[1][dataInfo.docMaxLen];
			double[][] queryIndexMaskArray = new double[1][dataInfo.queryMaxLen];
			double[][] docIndexMaskArray = new double[1][dataInfo.docMaxLen];
			double[][] labelIndexArray = new double[1][1];
			//
			for( int i = 0; i < query.length(); ++i ) {
				queryIndexArray[0][i] = dataInfo.queryDict.getOrDefault(String.valueOf(query.charAt(i)),
													dataInfo.queryDict.get("UNK"));
				queryIndexMaskArray[0][i] = 1.0;
			}
			for( int i = 0; i < doc.length(); ++i ) {
				docIndexArray[0][i] = dataInfo.docDict.getOrDefault(String.valueOf(doc.charAt(i)),
													dataInfo.docDict.get("UNK"));
				docIndexMaskArray[0][i] = 1.0;
			}
			labelIndexArray[0][0] = Double.parseDouble(label);
			//
			org.nd4j.linalg.dataset.api.MultiDataSet mds = 
					new MultiDataSet(new INDArray[] {Nd4j.create(queryIndexArray), Nd4j.create(docIndexArray)},
									new INDArray[] {Nd4j.create(labelIndexArray)},
									new INDArray[] {Nd4j.create(queryIndexMaskArray), Nd4j.create(docIndexMaskArray)},
									null);
			list.add(mds);
		}
	}catch(Exception ex) {
		ex.printStackTrace();
	}
	return list;
}

该部分逻辑主要是通过一个静态方法来读取训练文本中的每一行数据,并且针对text_a和text_b以及label转换成一个MultiDataSet对象,并存储在一个LinkedList对象中。需要注意的是Mask部分的处理。Mask张量中用1.0代表有效,0.0代表无效的部分。下面我们看下训练建模和评估的部分。

final int batchSize = 256;
final int embedding_size = 64;
DataSetInfo dataInfo = preprocess("data/lcqmc/train.tsv");
ComputationGraph dssm = getDSSM(dataInfo.queryDict.size(), dataInfo.docDict.size(), embedding_size);
System.out.println(dssm.summary());
List<org.nd4j.linalg.dataset.api.MultiDataSet> trainDataList = getMultiDataIter("data/lcqmc/train.tsv", dataInfo);
List<org.nd4j.linalg.dataset.api.MultiDataSet> testDataList = getMultiDataIter("data/lcqmc/test.tsv", dataInfo);
System.out.println("Finish Loading Train Data");
for(int epoch = 0; epoch < 5; ++epoch) {
	Collections.shuffle(trainDataList);
	MultiDataSetIterator trainIter = new IteratorMultiDataSetIterator(trainDataList.iterator(), batchSize);
	dssm.fit(trainIter);
	Evaluation eval = dssm.evaluate(new IteratorMultiDataSetIterator(testDataList.iterator(), batchSize));
	System.out.println(eval);
}

通过10个epoch的训练,我们最终在验证集上得到70%左右的准确率, loss值在0.4左右。
在这里插入图片描述

5. 总结

DSSM是一个经典的双塔模型,但其也有明显的缺点,就是两个塔之间是独立的,没有信息的交叉。这种信息的交叉对应推荐场景来说是很重要的。DSSM论文中的结构比较简单,是MLP为主,且输入层使用词袋模型进行处理,这其实忽略的上下文的语义信息,因此我们在实现的时候,使用LSTM模型来捕获序列的完整语义信息。当然,由于时间原因,我们这边并没有做分词处理,相信经过分词处理,在LCQMC数据集上的准确率可以进一步得到提升。另外,双塔的结构可以很灵活,内部可以直接上BERT来做,这里变体就太多,不做过多陈述了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1347973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Spring Boot的美妆分享系统:打造个性化推荐、互动社区与智能决策

基于Spring Boot的美妆分享系统&#xff1a;打造个性化推荐、互动社区与智能决策 1. 项目介绍2. 管理员功能2.1 美妆管理2.2 页面管理2.3 链接管理2.4 评论管理2.5 用户管理2.6 公告管理 3. 用户功能3.1 登录注册3.2 分享商品3.3 问答3.4 我的分享3.5 我的收藏夹 4. 创新点4.1 …

【基础】【Python网络爬虫】【3.chrome 开发者工具】(详细笔记)

Python网络爬虫基础 chrome 开发者工具元素面板&#xff08;Elements)控制台面板&#xff08;Console&#xff09;资源面板&#xff08;Source&#xff09;网络面板&#xff08;Network&#xff09;工具栏Requests Table详情 chrome 开发者工具 ​ 当我们爬取不同的网站是&…

「微服务」Saga 模式 如何使用微服务实现业务事务-第二部分

在上一篇文章中&#xff0c;我们看到了实现分布式事务的一些挑战&#xff0c;以及如何使用Event / Choreography方法实现Saga的模式。在本文中&#xff0c;我们将讨论如何通过使用另一种类型的Saga实现&#xff08;称为Command或Orchestration&#xff09;来解决一些问题&#…

win10系统请将eNSP相关应用程序添加到windows firewall的允许程序列表,并允许其在公用网络上运行!的解决办法

很多学习网络的小伙伴&#xff0c;在下载安装eNSP后&#xff0c;打开程序跳出&#xff1a;请将eNSP相关应用程序添加到windows firewall的允许程序列表&#xff0c;并允许其在公用网络上运行! 是不是挺闹心的! 其实&#xff0c;原因是很简单&#xff0c;就是win10系统防火墙访…

Linux中安装了openjdk后jps command not found

一、问题场景 在Linux中用yum安装了openjdk-17&#xff0c;也在.bashrc中配置了环境变量JAVA_HOME以及bin目录的PATH。 但是在运行jps命令时依然报错找不到命令 二、原因分析 进入到$JAVA_HOME/bin目录查看&#xff0c;发现只有寥寥几个命令&#xff0c;压根没有jps命令&…

微信小程序发送模板消息-详解【有图】

前言 在发送模板消息之前我们要首先搞清楚微信小程序的逻辑是什么&#xff0c;这只是前端的一个demo实现&#xff0c;建议大家在后端处理&#xff0c;前端具体实现&#xff1a;如下图 1.获取小程序Id和密钥 我们注册完微信小程序后&#xff0c;可以在开发设置中看到以下内容&a…

4.35 构建onnx结构模型-Layernorm

前言 构建onnx方式通常有两种&#xff1a; 1、通过代码转换成onnx结构&#xff0c;比如pytorch —> onnx 2、通过onnx 自定义结点&#xff0c;图&#xff0c;生成onnx结构 本文主要是简单学习和使用两种不同onnx结构&#xff0c; 下面以 Layernorm 结点进行分析 方式 方法…

LAYABOX:2024新年寄语

2024新年寄语 过去的一年&#xff0c;尽管许多行业面临严峻挑战和发展压力&#xff0c;小游戏领域却逆势上扬&#xff0c;年产值首次突破400亿元大关&#xff0c;众多优质小游戏企业收获颇丰。 对此&#xff0c;祝福大家&#xff0c;2024一定更好&#xff01; 过去的一年&#…

Good Bye 2023

Good Bye 2023 Good Bye 2023 A. 2023 题意&#xff1a;序列a中所有数的乘积应为2023&#xff0c;现在给出序列中的n个数&#xff0c;找到剩下的k个数并输出&#xff0c;报告不可能。 思路&#xff1a;把所有已知的数字乘起来&#xff0c;判断是否整除2023&#xff0c;不够…

掌握C++11标准库(STL):理解STL的核心概念

深入探索C11标准库STL&#xff1a;新特性和优化技巧 一、前言二、容器简介三、迭代器简介四、map与unordered_map&#xff08;红黑树VS哈希表&#xff09;4.1、map和unordered_map的差别4.2、优缺点以及适用处4.3、小结 五、总结 一、前言 STL定义了强大的、基于模板的、可复用…

海康visionmaster-渲染控件:渲染控件加载本地图像的方法

描述 环境&#xff1a;VM4.0.0 VS2015 及以上 现象&#xff1a;渲染控件如何显示本地图像&#xff1f; 解答 思路&#xff1a;在 2.3.1 中&#xff0c;可以通过绑定流程或者模块来显示图像和渲染效果。因此&#xff0c;第一步&#xff0c; 可以使用在 VM 软件平台中给图像源模…

【网络面试(1)】浏览器如何实现生成HTTP消息

我们经常会使用浏览器访问各种网站&#xff0c;获取各种信息&#xff0c;帮助解决工作生活中的问题。那你知道&#xff0c;浏览器是怎么帮助我们实现对web服务器的访问&#xff0c;并返回给我们想要的信息呢&#xff1f; 1. 浏览器生成HTTP消息 我们平时使用的浏览器有很多种&…

LeetCode每日一题.04(不同路径)

一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish” &#xff09;。 问总共有多少条不同的路径&#xff1f; 示例 1…

Redis(认识NoSQL,认识redis,安装redis,redis桌面客户端,redis常见命令,redis的Java客户端)

文章目录 Redis快速入门1.初识Redis1.1.认识NoSQL1.1.1.结构化与非结构化1.1.2.关联和非关联1.1.3.查询方式1.1.4.事务1.1.5.总结 1.2.认识Redis1.3.安装Redis1.3.1.依赖库1.3.2.上传安装包并解压1.3.3.启动1.3.4.默认启动1.3.5.指定配置启动1.3.6.开机自启 1.4.Redis桌面客户端…

【Linux操作系统】探秘Linux奥秘:文件系统的管理与使用

&#x1f308;个人主页&#xff1a;Sarapines Programmer&#x1f525; 系列专栏&#xff1a;《操作系统实验室》&#x1f516;诗赋清音&#xff1a;柳垂轻絮拂人衣&#xff0c;心随风舞梦飞。 山川湖海皆可涉&#xff0c;勇者征途逐星辉。 目录 &#x1fa90;1 初识Linux OS &…

Vue(三):Vue 生命周期与工程化开发

2023 的最后一篇博客&#xff0c;祝大家元旦快乐&#xff0c;新的一年一起共勉&#xff01; 06. Vue 生命周期 6.1 基本介绍 生命周期就是一个 Vue 示例从 创建 到 销毁 的整个过程&#xff0c;创建、挂载、更新、销毁 有一些请求是必须在某个阶段完成之后或者某个阶段之前执行…

C++ 递归函数 详细解析——C++日常学习随笔

1. 递归函数 1.1 递归函数的定义 递归函数&#xff1a;即在函数体中出现调用自身的函数&#xff0c;即函数Func(Type a,……)直接或间接调用函数本身&#xff1b; 递归函数&#xff1a;在数学上&#xff0c;关于递归函数的定义如下&#xff1a;对于某一函数f(x)&#xff0c;其…

多维时序 | MATLAB实现SSA-CNN-GRU-SAM-Attention麻雀算法优化卷积网络结合门控循环单元网络融合空间注意力机制多变量时间序列预测

多维时序 | MATLAB实现SSA-CNN-GRU-SAM-Attention麻雀算法优化卷积网络结合门控循环单元网络融合空间注意力机制多变量时间序列预测 目录 多维时序 | MATLAB实现SSA-CNN-GRU-SAM-Attention麻雀算法优化卷积网络结合门控循环单元网络融合空间注意力机制多变量时间序列预测预测效…

计算机网络:知识回顾

0 本节主要内容 问题描述 解决思路 1 问题描述 通过一个应用场景来回顾计算机网络涉及到的协议&#xff08;所有层&#xff09;。如下图所示场景&#xff1a; 学生Bob将笔记本电脑用一根以太网电缆连接到学校的以太网交换机&#xff1b;交换机又与学校的路由器相连&#xf…

Windows10系统的音频不可用,使用疑难解答后提示【 一个或多个音频服务未运行】

一、问题描述 打开电脑&#xff0c;发现电脑右下角的音频图标显示为X&#xff08;即不可用&#xff0c;无法播放声音&#xff09;&#xff0c;使用音频自带的【声音问题疑难解答】&#xff08;选中音频图标&#xff0c;点击鼠标右键&#xff0c;然后选择“声音问题疑难解答(T)”…