如何利用DGL官方库中的rgcn链接预测代码跑自己的数据集(如何在DGL库的链接预测数据集模块定义自己的数据集类)

news2025/1/12 20:55:46

最近在忙我的省创,是有关于知识图谱的,其中有一个内容是使用rgcn的链接预测方法跑自己的数据集,我是用的dgl库中给出的在pytorch环境下实现rgcn的链接预测的代码,相关链接贴在这里:

dgl库中关于rgcn的介绍文档

dgl库中在pytorch环境下实现rgcn的链接预测的代码

这个代码给的示例就是使用FB15k237数据集,调用方法是这样的:

from dgl.data.knowledge_graph import FB15k237Dataset
data = FB15k237Dataset(reverse=False)
graph = data[0]
print("graph",graph)

这里就调用了FB15k237数据集,返回的的data[0]就是使用dgl库使用该数据集构建的图g

我一开始想用自己的数据构图,然后使用rgcn的代码跑我自己的数据集,但是我不知道它的构图是如何实现的,于是我修改了rgcn的代码,实现了自己的构图方式如下,就是使用入结点出节点和边的编号列表构图:

g = dgl.graph((src, dst), num_nodes=num_nodes)
g.edata[dgl.ETYPE] = rel

鉴于rgcn示例里使用的FB15k237数据集的图的属性有'train_mask''test_mask'等属性,我就把rgcn代码里有关构图的部分全改成我自己的了,修改过后的完整可运行rgcn代码如下。

这个代码需要自己提供entity.txtrelation.txttrain.txtvalid.txttest.txt五个文件,entity.txtrelation.txt分别代表实体编号到实体描述的映射,关系编号到关系描述的映射,类似这样:

在这里插入图片描述
train.txtvalid.txttest.txt这三个文件就代表训练集,验证集和测试集的已经被映射为编号的(h,r,t)格式的三元组,类似这样:

在这里插入图片描述
在代码中写入对应的自己的数据集已经处理好的这五个文件的地址,运行下面的文件就可以运行完整的rgcn代码了:

import numpy as np
import torch
import torch.nn as nn
import scipy as sp
import torch.nn.functional as F
import dgl
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.data.knowledge_graph import FB15kDataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import RelGraphConv
import tqdm

# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
    src, dst = g.edges()
    sub_src = src[mask]
    sub_dst = dst[mask]
    sub_rel = g.edata['etype'][mask]

    if bidirected:
        sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat([sub_dst, sub_src])
        sub_rel = torch.cat([sub_rel, sub_rel + num_rels])

    sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
    sub_g.edata[dgl.ETYPE] = sub_rel
    return sub_g

class GlobalUniform:
    def __init__(self, g, sample_size):
        self.sample_size = sample_size
        self.eids = np.arange(g.num_edges(),dtype='int64')
    def sample(self):
        return torch.from_numpy(np.random.choice(self.eids, self.sample_size))

class NegativeSampler:
    def __init__(self, k=10): # negative sampling rate = 10
        self.k = k

    def sample(self, pos_samples, num_nodes):
        batch_size = len(pos_samples)
        neg_batch_size = batch_size * self.k
        neg_samples = np.tile(pos_samples, (self.k, 1))

        values = np.random.randint(num_nodes, size=neg_batch_size)
        choices = np.random.uniform(size=neg_batch_size)
        subj = choices > 0.5
        obj = choices <= 0.5
        neg_samples[subj, 0] = values[subj]
        neg_samples[obj, 2] = values[obj]
        samples = np.concatenate((pos_samples, neg_samples))

        # binary labels indicating positive and negative samples
        labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
        labels[:batch_size] = 1

        return torch.from_numpy(samples), torch.from_numpy(labels)

class SubgraphIterator:
    def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
        self.g = g
        self.num_rels = num_rels
        self.sample_size = sample_size
        self.num_epochs = num_epochs
        self.pos_sampler = GlobalUniform(g, sample_size)
        self.neg_sampler = NegativeSampler()

    def __len__(self):
        return self.num_epochs

    def __getitem__(self, i):
        eids = self.pos_sampler.sample()
        src, dst = self.g.find_edges(eids)
        src, dst = src.numpy(), dst.numpy()
        rel = self.g.edata[dgl.ETYPE][eids].numpy()

        # relabel nodes to have consecutive node IDs
        uniq_v, edges = np.unique((src, dst), return_inverse=True)
        num_nodes = len(uniq_v)
        # edges is the concatenation of src, dst with relabeled ID
        src, dst = np.reshape(edges, (2, -1))
        relabeled_data = np.stack((src, rel, dst)).transpose()

        samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)

        # use only half of the positive edges
        chosen_ids = np.random.choice(np.arange(self.sample_size),
                                      size=int(self.sample_size / 2),
                                      replace=False)
        src = src[chosen_ids]
        dst = dst[chosen_ids]
        rel = rel[chosen_ids]
        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
        rel = np.concatenate((rel, rel + self.num_rels))
        sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
        sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
        sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
        uniq_v = torch.from_numpy(uniq_v).view(-1).long()

        return sub_g, uniq_v, samples, labels

class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, num_rels):
        super().__init__()
        # two-layer RGCN
        self.emb = nn.Embedding(num_nodes, h_dim)
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
                                  num_bases=100, self_loop=True)
        self.conv2 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
                                  num_bases=100, self_loop=True)
        self.dropout = nn.Dropout(0.2)

    def forward(self, g, nids):
        x = self.emb(nids)
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
        h = self.dropout(h)
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
        return self.dropout(h)

class LinkPredict(nn.Module):
    def __init__(self, num_nodes, num_rels, h_dim = 500, reg_param=0.01):
        super().__init__()
        self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

    def calc_score(self, embedding, triplets):
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def forward(self, g, nids):
        return self.rgcn(g, nids)

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))

    def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
    """Get candidate heads or tails to score"""
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    # Add the ground truth node first
    if filter_o:
        candidate_nodes = [target_o]
    else:
        candidate_nodes = [target_s]
    for e in range(num_nodes):
        triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
        # Do not consider a node if it leads to a real triplet
        if triplet not in triplets_to_filter:
            candidate_nodes.append(e)
    return torch.LongTensor(candidate_nodes)

def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
    """Perturb subject or object in the triplets"""
    num_nodes = emb.shape[0]
    ranks = []
    for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
        candidate_nodes = filter(triplets_to_filter, target_s, target_r,
                                 target_o, num_nodes, filter_o=filter_o)
        if filter_o:
            emb_s = emb[target_s]
            emb_o = emb[candidate_nodes]
        else:
            emb_s = emb[candidate_nodes]
            emb_o = emb[target_o]
        target_idx = 0
        emb_r = w[target_r]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))

        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)

def calc_mrr(emb, w,  triplets_to_filter, batch_size=100, filter=True):
    with torch.no_grad():
        test_triplets = triplets_to_filter
        s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
        test_size = len(s)
        triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
        ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
                                                triplets_to_filter, filter_o=False)
        ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
                                                test_size, triplets_to_filter)
        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1 # change to 1-indexed
        mrr = torch.mean(1.0 / ranks.float()).item()
        mr = torch.mean(ranks.float()).item()
        print("MRR (filtered): {:.6f}".format(mrr))
        print("MR (filtered): {:.6f}".format(mr))
        hits=[1,3,10]
        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float())
            print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
    return mrr

def train(dataloader, test_g, test_nids, triplets, device, model_state_file, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    best_mrr = 0
    for epoch, batch_data in enumerate(dataloader): # single graph batch
        model.train()
        g, train_nids, edges, labels = batch_data
        g = g.to(device)
        train_nids = train_nids.to(device)
        edges = edges.to(device)
        labels = labels.to(device)
        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
        optimizer.step()
        print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))
        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
            embed = model(test_g, test_nids)
            mrr = calc_mrr(embed, model.w_relation,  triplets,
                           batch_size=500)
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
                torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
            model = model.to(device)

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training with DGL built-in RGCN module')

    # load and preprocess dataset
    # data = FB15k237Dataset(reverse=False)
    # data = FB15kDataset(reverse=False)
    entityfile=r'data/entity.txt'
    relationfile=r'data/relation.txt'
    f1 = open(entityfile, 'r')
    f2 = open(relationfile, 'r')
    entity=[]
    relation=[]
    for line in f1:
        l=line.strip().split("\t")
        entity.append(int(l[0]))
    for line in f2:
        l=line.strip().split("\t")
        relation.append(int(l[0]))
    num_nodes=len(entity)
    num_rels=len(relation)
    n_entities=num_nodes
    print("# entities:",num_nodes)
    print("# relations:",num_rels)
    
    trainfile=r'data/train.txt'
    f3 = open(trainfile, 'r')
    src_train=[]
    rel_train=[]
    dst_train=[]
    for line in f3:
        l=line.strip().split("\t")
        h=int(l[0])
        r=int(l[1])
        t=int(l[2])
        src_train.append(h)
        rel_train.append(r)
        dst_train.append(t)
    print("# training edges: ",len(src_train))
    src_train=torch.LongTensor(src_train)
    rel_train=torch.LongTensor(rel_train)
    dst_train=torch.LongTensor(dst_train)
    train_g = dgl.graph((src_train, dst_train), num_nodes=num_nodes)
    train_g.edata[dgl.ETYPE] = rel_train
    
    src_test, dst_test = torch.cat([src_train, dst_train]), torch.cat([dst_train,src_train])
    rel_test = torch.cat([rel_train, rel_train + num_rels])
    test_g = dgl.graph((src_test, dst_test), num_nodes=num_nodes)
    test_g.edata[dgl.ETYPE] = rel_test
    test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
    test_nids = torch.arange(0, num_nodes)
    
    subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
    dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])

    validfile=r'data/valid.txt'
    f4 = open(validfile, 'r')
    num_valid=0
    for line in f4:
        num_valid+=1
    print("# validation edges: ",num_valid)
    
    # Prepare data for metric computation
    testfile=r'data/test.txt'
    f5 = open(testfile, 'r')
    src=[]
    rel=[]
    dst=[]
    for line in f5:
        l=line.strip().split("\t")
        h=int(l[0])
        r=int(l[1])
        t=int(l[2])
        src.append(h)
        rel.append(r)
        dst.append(t)
    print("# testing edges: ",len(src))
    src=torch.LongTensor(src)
    rel=torch.LongTensor(rel)
    dst=torch.LongTensor(dst)
    triplets_test = torch.stack([src,rel, dst], dim=1)

    # create RGCN model
    model = LinkPredict(num_nodes, num_rels).to(device)

    # train
    model_state_file = 'model_state.pth'
    train(dataloader, test_g, test_nids, triplets_test, device, model_state_file, model)

    # testing
    print("Testing...")
    checkpoint = torch.load(model_state_file)
    model = model.cpu() # test on CPU
    model.eval()
    model.load_state_dict(checkpoint['state_dict'])
    embed = model(test_g, test_nids)
    best_mrr = calc_mrr(embed, model.w_relation,triplets_test,
                        batch_size=500)
    print("Best MRR {:.4f} achieved using the epoch {:04d}".format(best_mrr, checkpoint['epoch']))

但是,这个代码的效果并不太好,贴在这里只是做个过程记录,同样的数据集,为什么这样简单的构图效果就没有dgl库里自己构图的效果好呢?说实话我也不知道(°ー°〃)我也看了dgl库里处理数据然后构图的代码,确实要精细很多,我就认为是预处理数据的方式不一样导致效果的差别吧。因此下面要说的就是如何在如何在DGL库的链接预测数据集模块定义自己的数据集类,将自己的数据集输入,使用dgl库中处理数据的方法处理我们的数据,再像刚刚调用FB15k237数据集那样调用自己的数据集。

- step 1 :

找到你的dgl.data.knowledge_graph.py文件,(我这里使用的版本是dgl 0.9.0),在这个文件中,定义了FB15k237DatasetFB15DatasetWN18Dataset三个常用的知识图谱数据集类,我们添加一个自己的数据集类MyDataset(其实就是copy了一下别的类(°ー°〃))

在这里插入图片描述
name改成mydata:

class MyDataset(KnowledgeGraphDataset):
    
    def __init__(self, reverse=True, raw_dir=None, force_reload=False,
                 verbose=True, transform=None):
        name = 'mydata'
        super(MyDataset, self).__init__(name, reverse, raw_dir,
                                              force_reload, verbose, transform)

    def __getitem__(self, idx):
        r"""Gets the graph object """
        return super(MyDataset, self).__getitem__(idx)

    def __len__(self):
        r"""The number of graphs in the dataset."""
        return super(MyDataset, self).__len__()

- step 2

找到你的dgl.data.dgl_dataset.py文件,找到下图对应的代码位置,加入框框内的代码:
(至于为什么要这样呢,,,,自己看代码吧,虽然我也很想做记录,方便自己下次看懂,但是感觉要讲的话将不太清楚,打半天字解释不如自己看看代码咋写的 ┭┮﹏┭┮)

if self.name=='mydata':
     return os.path.join(self.raw_dir)

在这里插入图片描述

- step 3

在rgcn的链接预测代码里调用一下自己的数据就好啦,下面是一个简单的demo,这样就可以调用自己的数据集类了。

from dgl.data.knowledge_graph import MyDataset
dataset = MyDataset(raw_dir=r'你自己装数据集的文件夹位置',reverse=False)

在这里插入图片描述
- step 4

还有十分重要的一点就是,数据集的格式,我是把自己的数据集都设成了和它调用的FB15k237数据集一样的格式,因为step 3中要写入的文件夹地址内要包含的文件有5个:entities.dictrelations.dicttrain.txtvalid.txttest.txt

在这里插入图片描述

entities.dictrelations.dict分别代表实体编号到实体描述的映射,关系编号到关系描述的映射,类似这样:

在这里插入图片描述

train.txtvalid.txttest.txt这三个文件代表训练集,验证集和测试集的还没有被映射为编号的(h,r,t)格式的三元组,类似这样:(它们中间的间隔均是'\t')

在这里插入图片描述

把我改过的最终的rgcn代码贴在下面,做个记录,其中我对calc_mrr函数做了修改的,它原本的代码里只有mrr一个评估指标,我增加了mrhist@1hist@3hist@10这几个指标,在代码里看吧:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.data.knowledge_graph import FB15kDataset
from dgl.data.knowledge_graph import MyDataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import RelGraphConv
import tqdm

# for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False):
    src, dst = g.edges()
    sub_src = src[mask]
    sub_dst = dst[mask]
    sub_rel = g.edata['etype'][mask]

    if bidirected:
        sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat([sub_dst, sub_src])
        sub_rel = torch.cat([sub_rel, sub_rel + num_rels])

    sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
    sub_g.edata[dgl.ETYPE] = sub_rel
    return sub_g

class GlobalUniform:
    def __init__(self, g, sample_size):
        self.sample_size = sample_size
        self.eids = np.arange(g.num_edges())

    def sample(self):
        return torch.from_numpy(np.random.choice(self.eids, self.sample_size))

class NegativeSampler:
    def __init__(self, k=10): # negative sampling rate = 10
        self.k = k

    def sample(self, pos_samples, num_nodes):
        batch_size = len(pos_samples)
        neg_batch_size = batch_size * self.k
        neg_samples = np.tile(pos_samples, (self.k, 1))

        values = np.random.randint(num_nodes, size=neg_batch_size)
        choices = np.random.uniform(size=neg_batch_size)
        subj = choices > 0.5
        obj = choices <= 0.5
        neg_samples[subj, 0] = values[subj]
        neg_samples[obj, 2] = values[obj]
        samples = np.concatenate((pos_samples, neg_samples))

        # binary labels indicating positive and negative samples
        labels = np.zeros(batch_size * (self.k + 1), dtype=np.float32)
        labels[:batch_size] = 1

        return torch.from_numpy(samples), torch.from_numpy(labels)

class SubgraphIterator:
    def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
        self.g = g
        self.num_rels = num_rels
        self.sample_size = sample_size
        self.num_epochs = num_epochs
        self.pos_sampler = GlobalUniform(g, sample_size)
        self.neg_sampler = NegativeSampler()

    def __len__(self):
        return self.num_epochs

    def __getitem__(self, i):
        eids = self.pos_sampler.sample()
        src, dst = self.g.find_edges(eids)
        src, dst = src.numpy(), dst.numpy()
        rel = self.g.edata[dgl.ETYPE][eids].numpy()

        # relabel nodes to have consecutive node IDs
        uniq_v, edges = np.unique((src, dst), return_inverse=True)
        num_nodes = len(uniq_v)
        # edges is the concatenation of src, dst with relabeled ID
        src, dst = np.reshape(edges, (2, -1))
        relabeled_data = np.stack((src, rel, dst)).transpose()

        samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)

        # use only half of the positive edges
        chosen_ids = np.random.choice(np.arange(self.sample_size),
                                      size=int(self.sample_size / 2),
                                      replace=False)
        src = src[chosen_ids]
        dst = dst[chosen_ids]
        rel = rel[chosen_ids]
        src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
        rel = np.concatenate((rel, rel + self.num_rels))
        sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
        sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
        sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
        uniq_v = torch.from_numpy(uniq_v).view(-1).long()

        return sub_g, uniq_v, samples, labels

class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, num_rels):
        super().__init__()
        # two-layer RGCN
        self.emb = nn.Embedding(num_nodes, h_dim)
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
                                  num_bases=100, self_loop=True)
        self.conv2 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd',
                                  num_bases=100, self_loop=True)
        self.dropout = nn.Dropout(0.2)

    def forward(self, g, nids):
        x = self.emb(nids)
        h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
        h = self.dropout(h)
        h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
        return self.dropout(h)

class LinkPredict(nn.Module):
    def __init__(self, num_nodes, num_rels, h_dim = 500, reg_param=0.01):
        super().__init__()
        self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

    def calc_score(self, embedding, triplets):
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def forward(self, g, nids):
        return self.rgcn(g, nids)

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))

    def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
    """Get candidate heads or tails to score"""
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    # Add the ground truth node first
    if filter_o:
        candidate_nodes = [target_o]
    else:
        candidate_nodes = [target_s]
    for e in range(num_nodes):
        triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o)
        # Do not consider a node if it leads to a real triplet
        if triplet not in triplets_to_filter:
            candidate_nodes.append(e)
    return torch.LongTensor(candidate_nodes)

def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
    """Perturb subject or object in the triplets"""
    num_nodes = emb.shape[0]
    ranks = []
    for idx in tqdm.tqdm(range(test_size), desc="Evaluate"):
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
        candidate_nodes = filter(triplets_to_filter, target_s, target_r,
                                 target_o, num_nodes, filter_o=filter_o)
        if filter_o:
            emb_s = emb[target_s]
            emb_o = emb[candidate_nodes]
        else:
            emb_s = emb[candidate_nodes]
            emb_o = emb[target_o]
        target_idx = 0
        emb_r = w[target_r]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))

        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)

def calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True):
    with torch.no_grad():
        test_triplets = triplets_to_filter[test_mask]
        s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2]
        test_size = len(s)
        triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()}
        ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size,
                                                triplets_to_filter, filter_o=False)
        ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o,
                                                test_size, triplets_to_filter)
        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1 # change to 1-indexed
        mrr = torch.mean(1.0 / ranks.float()).item()
        mr = torch.mean(ranks.float()).item()
        print("MRR (filtered): {:.6f}".format(mrr))
        print("MR (filtered): {:.6f}".format(mr))
        hits=[1,3,10]
        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float())
            print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
    return mrr

def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    best_mrr = 0
    for epoch, batch_data in enumerate(dataloader): # single graph batch
        model.train()
        g, train_nids, edges, labels = batch_data
        g = g.to(device)
        train_nids = train_nids.to(device)
        edges = edges.to(device)
        labels = labels.to(device)

        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
        optimizer.step()
        print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))
        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
            embed = model(test_g, test_nids)
            mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
                           batch_size=500)
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
                torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
            model = model.to(device)

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training with DGL built-in RGCN module')

    # load and preprocess dataset
    # data = FB15k237Dataset(reverse=False)
    data = MyDataset(raw_dir=r'data/FB15k237',reverse=False)
    
    g = data[0]
    num_nodes = g.num_nodes()
    num_rels = data.num_rels
    train_g = get_subset_g(g, g.edata['train_mask'], num_rels)
    test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True)
    test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1)
    test_nids = torch.arange(0, num_nodes)
    test_mask = g.edata['test_mask']
    subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
    dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])

    # Prepare data for metric computation
    src, dst = g.edges()
    triplets = torch.stack([src, g.edata['etype'], dst], dim=1)

    # create RGCN model
    model = LinkPredict(num_nodes, num_rels).to(device)

    # train
    model_state_file = 'model_state.pth'
    train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model)

    # testing
    print("Testing...")
    checkpoint = torch.load(model_state_file)
    model = model.cpu() # test on CPU
    model.eval()
    model.load_state_dict(checkpoint['state_dict'])
    embed = model(test_g, test_nids)
    best_mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
                        batch_size=500)
    print("Best MRR {:.4f} achieved using the epoch {:04d}".format(best_mrr, checkpoint['epoch']))

跑代码的输出图如下:

在这里插入图片描述

🆗,over!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/27873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

联盟快应用-如何进行测试?

官方文档&#xff1a;快应用-无需安装&#xff0c;即点即用-厂商联盟官方网站 什么是快应用&#xff1f; 可以简单理解为是另一种小程序。 快应用是一种新的应用形态&#xff0c;以往的手机端应用主要有两种方式&#xff1a;网页、原生应用&#xff1b;网页无需安装&#xff…

阻塞车间调度

阻塞车间调度 当前机器上的作业处理必须保留在该机器上&#xff0c;直到下一台机器可用于处理为止。也就是说如果该作业要执行的下一个工序的机器被使用&#xff0c;则该机器必须被占用。 n个作业必须在m个机器f个工厂上进行处理&#xff0c;在每一个工厂中连续机器之间没有缓…

Android11 framework Handler

Android11 framework Handler引言Handler工作流程MessageQueue主要函数Looper主要函数思考1.一个线程有几个handler&#xff0c;有几个looper2.为什么handler会有内存泄漏3.如果想要在子线程new Handler怎么做&#xff1f;4.子线程中的loop如果消息队列中没有消息处理的时候怎么…

深入底层学git:目录中包含的秘密

1.Git简介 Git具有最优的存储能力&#xff0c;在没有远端git服务器的情况下&#xff0c;git本地就可以独立作为版本管控系统&#xff0c;这其中.git裸仓库中起了关键作用&#xff0c;那么我们一起来看看.git下都放了哪些文件。 打开Git Bash&#xff0c;切换到项目目录&#x…

王道考研——操作系统(第二章 进程管理)(进程;线程)

一、进程的概念、组成、特征 进程的概念 进程的组成——PCB 进程的组成——程序段、数据段 知识滚雪球&#xff1a;程序是如何运行的&#xff1f; 进程的组成 进程的特征 知识回顾与重要考点 二、进程的状态与转换 进程的状态——创建态、就绪态 进程的状态——运行态 进程的…

刷题日记【第十二篇】-笔试必刷题【洗牌+MP3光标位置+年终奖+迷宫问题】

洗牌【编程题】 import java.util.*;public class Main {// 左: i --> 2*i;// 右: in --> 2*i 1;private static void playCard(int[] cards, int n, int k ) {for (int i 0; i < k; i) {//一次洗牌的顺序int[] newCards new int[cards.length];//遍历编号为0-n-1…

【Servlet】2:认识一下Web服务器——Tomcat

目录 第三章 | Tomcat 认识与配置 | 章节概述 | HTTP服务器概述 | Tomcat 安装与配置 | Tomcat 的目录结构、端口号 第四章 | Tomcat 基本使用 | 章节概述 | 本地Tomcat 静态资源网站访问 | IDEATomcat 静态资源网站访问 | IDEA中最基础web项目的目录结构 本文章属于后…

从零开始操作系统-07:APIC

这一节主要主要是APIC。 所需要的文件在Github&#xff1a;https://github.com/yongkangluo/Ubuntu20.04OS/tree/main/Files/Lec7-ExternalInterrupt 历史方法&#xff1a;PIC&#xff08;Programmable Interrupt Controller&#xff09; Intel 8259&#xff1a; APIC&#…

小侃设计模式(十三)-策略模式

1.概述 策略模式&#xff08;Strategy Pattern&#xff09;是一种比较简单的模式&#xff0c;它定义了算法家族&#xff0c;分别封装起来&#xff0c;让它们之间可以互相替换&#xff0c;此模式让算法的变化&#xff0c;不会影响到使用算法的客户。策略模式具有较强的实用性&a…

ARM学习扫盲篇(一):CPSRSPSR、LcacheDcache、w/parityw/ECC

1、CPSR&SPSR CPSR—程序状态寄存器(current program status register) SPSR—程序状态保存寄存器&#xff08;saved program status register&#xff09; Icache&Dcache icache用来缓存指令&#xff1b; dcache用来缓存数据&#xff0c;dcache用的前提是mmu要启动…

(续)SSM整合之SSM整合笔记(ContextLoaderListener)(P177-178)

目录 ContextLoaderListener 一 ContextLoaderListener 二 测试ContextLoaderListener 1 新建模块spring_listener com.atguigu 2. 导入依赖 3 .转web 4 .web.xml 5 springmvc.xml 6 .spring.xml 7 首页index.html 8 控制层 HelloController 9 service接口…

【24计算机考研】备考前必须了解的避坑小知识,建议收藏

前言 我们可能已经了解到最近两三年的考研趋势&#xff0c;疫情的原因&#xff0c;不断增加的二战三战考生&#xff0c;导致每年考研人数持续增长&#xff0c;那么&#xff0c;如何在相同的时间里&#xff0c;赶超你的竞争对手&#xff0c;避坑 绝对是很重要的。 考研将是一场…

【Spring】——9、如何指定初始化和销毁的方法?

&#x1f4eb;作者简介&#xff1a;zhz小白 公众号&#xff1a;小白的Java进阶之路 专业技能&#xff1a; 1、Java基础&#xff0c;并精通多线程的开发&#xff0c;熟悉JVM原理 2、熟悉Java基础&#xff0c;并精通多线程的开发&#xff0c;熟悉JVM原理&#xff0c;具备⼀定的线…

(STM32)从零开始的RT-Thread之旅--SPI驱动ST7735(3)使用DMA

上一篇&#xff1a; (STM32)从零开始的RT-Thread之旅--SPI驱动ST7735(2) 上一篇完成了ST7735驱动的移植&#xff0c;并已经可以通过SPI在屏幕上显示字符了&#xff0c;这一章会把SPI修改为DMA的传输方式。由于RTT对于STM32H7的SPI的DMA传输方式目前支持的并不好&#xff0c;这…

Vuex3使用教程(待续)

Vuex定义 以下是Vue官网对于Vuex的定义&#xff1a; Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式 库。它采用集中式存储管理应用的所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化。 从官方定义上看&#xff1a; Vuex提供了一个全部组件…

Java注释:单行、多行和文档注释

注释是对程序语言的说明&#xff0c;有助于开发者和用户之间的交流&#xff0c;方便理解程序。注释不是编程语句&#xff0c;因此被编译器忽略。 Java入门基础视频教程&#xff0c;java零基础自学就选黑马程序员Java入门教程&#xff08;含Java项目和Java真题&#xff09; Ja…

【Django】Django4.1.2使用xadmin避坑指南(二)

上一篇【Django】Django4.1.2使用xadmin避坑指南调完后&#xff0c;还是继续有问题&#xff0c;没事&#xff0c;咱们继续&#xff0c;必须啃下硬骨头~ 文章目录环境问题一&#xff1a;if not ContentType._meta.installed:这一句报错&#xff1a;AttributeError: Options obje…

《深度学习进阶 自然语言处理》第八章:Attention介绍

文章目录8.1 Attention结构8.1.1 seq2seq存在的问题8.1.2 编码器的改进8.1.3 解码器的改进8.2 Attention的应用8.3 总结之前文章链接&#xff1a; 开篇介绍&#xff1a;《深度学习进阶 自然语言处理》书籍介绍 第一章&#xff1a;《深度学习进阶 自然语言处理》第一章&#xf…

SSH连接WSL2踩坑记录与增加端口转换规则,实现外网与WSL2的连接

SSH连接WSL2踩坑记录 文章目录SSH连接WSL2踩坑记录1. 在WSL里的操作2. ssh连接3. 可能出现的错误4. 再配置端口转发到WSL1. 在WSL里的操作 1.1 重装openssh-server sudo remove openssh-server # 如果已经安装了&#xff0c;建设先卸载 sudo apt install openssh-server…

Ansys Lumerical | 行波 Mach-Zehnder 调制器仿真分析

前言 本示例描述了行波 Mach-Zehnder 调制器的完整多物理场&#xff08;电气、光学、射频&#xff09;仿真&#xff0c;最后在INTERCONNECT中进行了紧凑模型电路仿真。计算了相对相移、光学传输、传输线带宽和眼图等关键结果。 综述 此示例中5毫米长的Si波导由5毫米长的Al共面…