【KG】TransE 及其实现

news2025/1/16 2:36:51

原文:https://yubincloud.github.io/notebook/pages/paper/kg/TransE/

TransE 及其实现

1. What is TransE?

TransE (Translating Embedding), an energy-based model for learning low-dimensional embeddings of entities.

核心思想:将 relationship 视为一个在 embedding space 的 translation。如果 (h, l, t) 存在,那么

Motivation:一是在 Knowledge Base 中,层次化的关系是非常常见的,translation 是一种很自然的用来表示它们的变换;二是近期一些从 text 中学习 word embedding 的研究发现,一些不同类型的实体之间的 1-to-1 的 relationship 可以被 model 表示为在 embedding space 中的一种 translation。

2. Learning TransE

TransE 的训练算法如下:

image-20220917220547146

2.1 输入参数

  • training set :用于训练的三元组的集合,entity 的集合为 ,rel. 的集合为
  • margin :损失函数中的间隔,这个在原 paper 中描述很模糊
  • 每个 entity 或 rel. 的 embedding dim

2.2 训练过程

初始化:对每一个 entity 和 rel. 的 embedding vector 用 xavier_uniform 分布来初始化,然后对它们实施 L1 or L2 正则化。

loop

  • 在 entity embedding 被更新前进行一次归一化,这是通过人为增加 embedding 的 norm 来防止 loss 在训练过程中极小化。
  • sample 出一个 mini-batch 的正样本集合
  • 初始化为空集,它表示本次 loop 用于训练 model 的数据集
  • for do:
    • 根据 (h, l, t) 构造出一个错误的三元组
    • 将 positive sample 和 negative sample 加入到
  • 计算 每一对 positive sample 和 negative sample 的 loss,然后累加起来用于更新 embedding matrix。每一对的 loss 计算方式为:

这个过程中,triplet 的 energy 就是指的 ,它衡量了 的距离,可以采用 L1 或 L2 norm,即 具体计算方式可见代码实现。

loss 的计算中,

关于 margin 的含义, 它相当于是一个正确 triple 与错误 triple 之前的间隔修正,margin 越大,则两个 triple 之前被修正的间隔就越大,则对于 embedding 的修正就越严格。我们看 ,我们希望是 越小越好,所以这一项前面为正号,希望 越大越好,所以这一项前面为负号。正常情况下, 一定是小于 的,所以 的结果应该是负值,那么 loss function 的外层取正就使得 loss = 0 了,所以 margin 的存在就使得负样本的 distance 必须与比正样本的 distance 大出 margin 的大小来才行,当两者差距足够大时,loss 就等于 0 了。

假设 处于理想情况下等于 0,那么由于 的存在, 如果不是很大的话,仍然会产生 loss,只有当 大于 时才会让 loss = 0,所以 越大,对 embedding 的修正就越严格。

错误三元组的构造方法:将 中的头实体、关系和尾实体其中之一随机替换为其他实体或关系来得到。

2.3 评价指标

链接预测是用来预测三元组 (h,r,t) 中缺失实体 h, t 或 r 的任务,对于每一个缺失的实体,模型将被要求用所有的知识图谱中的实体作为候选项进行计算,并进行排名,而不是单纯给出一个最优的预测结果。

  1. Mean rank - 正确三元组在测试样本中的得分排名,越小越好

首先对于每个 testing triple,以预测 tail entity 为例,我们将 中的 t 用 KG 中的每个 entity 来代替,然后通过 来计算分数,这样就可以得到一系列的分数,然后将这些分数排列。我们知道 f 函数值越小越好,那么在前面的排列中,排地越靠前越好。重点来了,我们去看每个 testing triple 中正确答案(也就是真实的 t)在上述序列中排多少位,比如 排 100, 排 200, 排 60 ....,之后对这些排名求平均,就得到 mean rank 值了。

  1. Hits@10 - 得分排名前 n 名的三元组中,正确三元组的占比,越大越好

还是按照上述进行 f 函数值排列,然后看每个 testing triple 正确答案是否排在序列的前十,如果在的话就计数 +1,最终 (排在前十的个数) / (总个数) 就等于 Hits@10。

在原论文中,由于这个 model 比较老了,其 baseline 也没啥参考性,就不做研究了,具体的实验可参考论文。

3. TransE 优缺点

优点:与以往模型相比,TransE 模型参数较少,计算复杂度低,却能直接建立实体和关系之间的复杂语义联系,在 WordNet 和 Freebase 等 dataset 上较以往模型的 performance 有了显著提升,特别是在大规模稀疏 KG 上,TransE 的性能尤其惊人。

缺点:在处理复杂关系(1-N、N-1 和 N-N)时,性能显著降低,这与 TransE 的模型假设有密切关系。假设有 (美国,总统,奥巴马)和(美国,总统,布什),这里的“总统”关系是典型的 1-N 的复杂关系,如果用 TransE 对其进行学习,则会有:

image-20220917220710170

那么这将会使奥巴马和布什的 vector 变得相同。所以由于这些复杂关系的存在,导致 TransE 学习得到的实体表示区分性较低。

4. TransE 实现

这里选择用 pytorch 来实现 TransE 模型。

4.1 __init__ 函数

其参数有:

  • ent_num:entity 的数量
  • rel_num:relationship 的数量
  • dim:每个 embedding vector 的维度
  • norm:在计算 时是使用 L1 norm 还是 L2 norm,即
  • margin:损失函数中的间隔,是个 hyper-parameter
  • :损失函数计算中的正则化项参数
class TransE(nn.Module):
    def __init__(self, ent_num, rel_num, device, dim=100, norm=1, margin=2.0, alpha=0.01):
        super(TransE, self).__init__()
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.device = device
        self.dim = dim
        self.norm = norm # 使用L1范数还是L2范数
        self.margin = margin
        self.alpha = alpha

        # 初始化实体和关系表示向量
        self.ent_embeddings = nn.Embedding(self.ent_num, self.dim)
        torch.nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
        self.ent_embeddings.weight.data = F.normalize(self.ent_embeddings.weight.data, 21)

        self.rel_embeddings = nn.Embedding(self.rel_num, self.dim)
        torch.nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
        self.rel_embeddings.weight.data = F.normalize(self.rel_embeddings.weight.data, 21)

        # 损失函数
        self.criterion = nn.MarginRankingLoss(margin=self.margin)

初始化 embedding matrix 时,直接用 nn.Embedding 来完成,参数分别是 entity 的数量和每个 embedding vector 的维数,这样得到的就是一个 ent_num * dim 大小的 Embedding Matrix。

torch.nn.init.xavier_uniform_ 是一个服从均匀分布的 Glorot 初始化器,在这里做的就是对 Embedding Matrix 中每个位置填充一个 xavier_uniform 初始化的值,这些值从均匀分布 中采样得到,这里的 是:

在这里,对于 Embedding 这样的二维矩阵来说,fan_in 和 fan_out 就是矩阵的长和宽,gain 默认为 1。其完整具体行为可参考 pytorch 初始化器文档。

F.normalize(self.ent_embeddings.weight.data, 2, 1) 这一步就是对 ent_embeddings 的每一个值除以 dim = 1 上的 2 范数值,注意 ent_embeddings.weight.data 的 size 是 (ent_num, embs_dim)。具体来说就是这一步把每行都除以该行下所有元素平方和的开方,也就是

损失函数这里先跳过,之后计算损失的步骤一同来看。

4.2 从 ent_idx 到 ent_embs

由于 network 的输入是 ent_idx,因此需要将其根据 embedding matrix 转换成 ent_embs。我们通过 get_ent_resps 函数来完成,其实就是个静态查表的操作:

class TransE(nn.Module):
 ...
 def get_ent_resps(self, ent_idx): #[batch]
        return self.ent_embeddings(ent_idx) # [batch, emb]

4.3 计算 energy

它衡量了 的距离,可以采用 L1 或 L2 norm 来算,具体采用哪个由 __init__ 函数中的 self.norm 来决定:

class TransE(nn.Module):
 ...
 def distance(self, h_idx, r_idx, t_idx):
        h_embs = self.ent_embeddings(h_idx) # [batch, emb]
        r_embs = self.rel_embeddings(r_idx) # [batch, emb]
        t_embs = self.ent_embeddings(t_idx) # [batch, emb]
        scores = h_embs + r_embs - t_embs
  
  # norm 是计算 loss 时的正则化项
        norms = (torch.mean(h_embs.norm(p=self.norm, dim=1) - 1.0)
                 + torch.mean(r_embs ** 2) +
                 torch.mean(t_embs.norm(p=self.norm, dim=1) - 1.0)) / 3

        return scores.norm(p=self.norm, dim=1), norms

4.4 计算 loss

self.criterion 是通过实例化 MarginRankingLoss 得到的,这个类的初始化接收 margin 参数,实例化得到 self.criterion,其计算方式如下:

借助于此,我们可以实现计算 loss 的代码:

class TransE(nn.Module):
 ...
 def loss(self, positive_distances, negative_distances):
        target = torch.tensor([-1], dtype=torch.float, device=self.device)
        return self.criterion(positive_distances, negative_distances, target)

positive_distances 就是 ,negative_distances 就是 ,target = [-1],代入 criterion 的计算公式就是我们计算 一对正样本和负样本的 loss 了。

4.5 forward

class TransE(nn.Module):
 ...
 def forward(self, ph_idx, pr_idx, pt_idx, nh_idx, nr_idx, nt_idx):
        pos_distances, pos_norms = self.scoring(ph_idx, pr_idx, pt_idx)
        neg_distances, neg_norms = self.scoring(nh_idx, nr_idx, nt_idx)

        tmp_loss = self.loss(pos_distances, neg_distances)
        tmp_loss += self.alpha * pos_norms   # 正则化项
        tmp_loss += self.alpha * neg_norms   # 正则化项

        return tmp_loss, pos_distances, neg_distances

以上我们讲完了 TransE 模型的定义,接下来就是讲对 TransE 模型的训练了,只要理解了 TransE 模型的定义,其训练应该不是难事。


关于我的 TransE 及知识表示学习模型的实现:https://github.com/yubinCloud/KRL,仓库中的 transe.ipynb 在更改一下 dataset 的位置后即可运行,dataset 可以在 GitHub 中下载,比如 KGDatasets。

如果有疑问,欢迎讨论。

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/145312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于R的Bilibili视频数据建模及分析——建模-因子分析篇

基于R的Bilibili视频数据建模及分析——建模-因子分析篇 文章目录基于R的Bilibili视频数据建模及分析——建模-因子分析篇0、写在前面1、数据分析1.1 建模-因子分析1.2 对数线性模型1.3 主成分分析1.4 因子分析1.5 多维标度法2、参考资料0、写在前面 实验环境 Python版本&#…

防火墙命令

启动: systemctl start firewalld 查看状态: systemctl status firewalld 停止:systemctl stop firewalld 禁用:systemctl disable firewalld 怎么开启一个端口呢 添加 firewall-cmd --zonepublic --add-port80/tcp --permanent …

easyx保姆级教程---->从游戏玩家到游戏制作者

请点击这里&#xff1a;安装教程 1.头文件 #include<easyx.h> //这个是只包含最新的API(函数接口) #include<graphics.h> //这个头文件包含了上面的&#xff0c;还包含了已经不推荐使用的函数2.窗口 1.初始化绘制窗口 initgraph(width,height,flag); //窗…

Domino Web应用中的搜索功能和结果选择问题

大家好&#xff0c;才是真的好。 还有不到十天Domino多瑙河版本就将发布&#xff0c;在此之前&#xff0c;我们还是讲述一下Web中的搜索技术。 废话不多说&#xff0c;我们直接上干货。 Notes应用的视图在Web浏览器中可以直接展现&#xff0c;并且可选择。 如果这样展现的话…

【QGIS入门实战精品教程】8.1:QGIS制作地图案例教程

文章目录 一、加载矢量数据二、加载影像底图三、美化矢量数据四、切换到排版视图五、添加经纬度格网六、添加其他修饰元素七、地图输出一、加载矢量数据 加载本实验数据基础数据.gpkg中的甘肃省政区矢量数据,如下所示: 二、加载影像底图 QGIS加载在线地图案例教程参考: 【…

5、Java中的JDBCJDBCUtilsJDBC控制事务getResource中文或有空格路径处理ResourceBundle演示

JDBC&#xff1a; 1. 概念&#xff1a;Java DataBase Connectivity Java 数据库连接&#xff0c; Java语言操作数据库 * JDBC本质&#xff1a;其实是官方&#xff08;sun公司&#xff09;定义的一套操作所有关系型数据库的规则&#xff0c;即接口。各个数据库厂商去实现这…

回收租赁商城系统功能拆解04讲-商品品牌

回收租赁系统适用于物品回收、物品租赁、二手买卖交易等三大场景。 可以快速帮助企业搭建类似闲鱼回收/爱回收/爱租机/人人租等回收租赁商城。 回收租赁系统支持智能评估回收价格&#xff0c;后台调整最终回收价&#xff0c;用户同意回收后系统即刻放款&#xff0c;用户微信零…

第04章 程序控制结构

在程序中&#xff0c;程序运行的流程控制决定程序是如何执行的。 顺序控制 介绍&#xff1a; 程序从上到下的逐行的执行&#xff0c;中间没有任何判断和跳转。 使用&#xff1a;java中定义变量时&#xff0c;采用合法的前向引用。如&#xff1a; public class Test{int num…

【虚幻引擎】UE4/UE5像素流在广域网上(云)部署(多实例)

一、选择云服务器 每个云平台都提供许多预设的镜像选择&#xff0c;由于像素流技术目前只支持Windows操作系统&#xff0c;所以我们需要选择Windows Server的镜像&#xff0c;2012/2016/2019皆可。我们这里选择了Windows Server 2016 R2 简体中文版的镜像&#xff0c;之所以选择…

【SSM整合】对Spring、SpringMVC、MyBatis的整合,以及Bootstrap的使用,简单的新闻管理系统

✅作者简介&#xff1a;热爱Java后端开发的一名学习者&#xff0c;大家可以跟我一起讨论各种问题喔。 &#x1f34e;个人主页&#xff1a;Hhzzy99 &#x1f34a;个人信条&#xff1a;坚持就是胜利&#xff01; &#x1f49e;当前专栏&#xff1a;【Spring】 &#x1f96d;本文内…

代码随想录第53天|● 1143.最长公共子序列 ● 1035.不相交的线 ● 53. 最大子序和 动态规划

1143.最长公共子序列 和718.最长重复子数组类似 包括二维数组初始化这些 不同之处在于递推公式主要就是两大情况&#xff1a; text1[i - 1] 与 text2[j - 1]相同&#xff0c;text1[i - 1] 与 text2[j - 1]不相同 如果text1[i - 1] 与 text2[j - 1]相同&#xff0c;那么找到了…

Windows/Linux日志分析

Windows日志分析 Windows系统日志是记录系统中硬件、软件和系统问题的信息&#xff0c;同时还可以监视系统中发生的事件。用户可以通过它来检查错误发生的原因&#xff0c;或者寻找受到攻击时攻击者留下的痕迹。 Windows主要有以下三类日志记录系统事件&#xff1a;应用程序日志…

【链表】leetcode707.设计链表(C/C++/Java/Js)

leetcode707.设计链表1 题目2 思路3 代码3.1 C版本3.2 C版本3.3 Java版本3.3.1 单链表3.3.2 双链表3.4 JavaScript版本4 总结1 题目 题源链接 设计链表的实现。您可以选择使用单链表或双链表。单链表中的节点应该具有两个属性&#xff1a;val 和 next。val 是当前节点的值&…

2022年地图产业研究报告

第一章 行业概况 地图是按照一定法则&#xff0c;有选择地以二维或多维形式与手段在平面或球面上表示地球&#xff08;或其它星球&#xff09;若干现象的图形或图像&#xff0c;它具有严格的数学基础、符号系统、文字注记&#xff0c;并能用地图概括原则&#xff0c;科学地反映…

canvasjs javascript-charts 3.7.3 Crack

canvasjs javascript-charts/ 3.7.3 具有 30 多种图表类型的 JavaScript 图表库 具有 10 倍性能和 30 多种图表类型的 JavaScript 图表和图形库。核心 JavaScript 图表库是独立的&#xff0c;但也带有流行框架的组件&#xff0c;如 React、Angular、Vue 等。图表响应迅速&#…

14、RH850 F1 RAM存储器介绍

前言: RAM——程序运行中数据的随机存取&#xff08;掉电后数据消失&#xff09;整个程序中&#xff0c;所用到的需要被改写的量&#xff0c;都存储在RAM中&#xff0c;“被改变的量”包括全局变量、局部变量、堆栈段&#xff0c;此专栏会有针对SPI的工作原理的详细介绍。 一、…

性能优化系列之如何选择合适的WebView内核?

文章の目录一、iOS UIWebView1、优点2、不足二、iOS WKWebView1、优势2、不足三、Android WebKit 和 Chromium四、Android 第三方1、X5 内核五、选型建议写在最后一、iOS UIWebView 1、优点 从 iOS 2 开始就作为 App 内展示 Web 内容的容器排版布局能力强 2、不足 内存泄露…

将两个对象以指定方法按指定轴对齐的DataFrame.align()方法

【小白从小学Python、C、Java】【计算机等级考试500强双证书】【Python-数据分析】将两个对象以指定方法按指定轴对齐DataFrame.align()选择题关于以下python代码说法错误的一项是?import pandas as pddf1 pd.DataFrame({"A": [1,2],"B":[3,4]})df2 pd.…

MySQL延时关联使查询速度提升N倍

以下内容也可以观看视频教程&#xff1a; https://space.bilibili.com/431152063先来看下面的sql语句&#xff1a; select * from orderinfo limit 1000000, 100目前orderinfo表中的数据大概是1亿行 查询耗时大概2秒多&#xff0c;如果将sql中的返回所有字段改成只返回dbid字段…

Linux驱动开发基础__APP怎么读取按键值

目录 1 妈妈怎么知道孩子醒了 2 APP读取按键的4种方法 2.1 查询方式 2.2 休眠-环形方式 2.3 poll方式 2.4 异步通知方式 在做单片机开发时&#xff0c;要读取 GPIO 按键&#xff0c;我们通常是执行一个循环&#xff0c;不断地 检测 GPIO 引脚电平有没有发生变化。但是在 Li…