H2-FDetector模型解析

news2024/11/16 2:48:55

文章目录

  • 1. H2FDetector_layer 类
  • 2. RelationAware 类
  • 3. MultiRelationH2FDetectorLayer 类
  • 4. H2FDetector 类

这个实现包括三个主要部分:H2FDetector_layer、MultiRelationH2FDetectorLayer 和 H2FDetector。每个部分都有其独特的功能和职责。下面是这些组件的详细实现和解释。

1. H2FDetector_layer 类

这是一个基本的 GNN 层,处理图卷积和注意力机制。

  • 这是基本的图卷积层,包含注意力机制和关系感知的边签名计算。
class H2FDetector_layer(nn.Module):
    def __init__(self, input_dim, output_dim, head, relation_aware, etype, dropout, if_sum=False):
        super().__init__()
        self.etype = etype
        self.head = head
        self.hd = output_dim
        self.if_sum = if_sum
        self.relation_aware = relation_aware
        self.w_liner = nn.Linear(input_dim, output_dim * head)
        self.atten = nn.Linear(2 * self.hd, 1)
        self.relu = nn.ReLU()
        self.leakyrelu = nn.LeakyReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['feat'] = h
            g.apply_edges(self.sign_edges, etype=self.etype)
            h = self.w_liner(h)
            g.ndata['h'] = h
            g.update_all(message_func=self.message, reduce_func=self.reduce, etype=self.etype)
            out = g.ndata['out']
            return out

    def message(self, edges):
        src = edges.src
        src_features = edges.data['sign'].view(-1, 1) * src['h']
        src_features = src_features.view(-1, self.head, self.hd)
        z = torch.cat([src_features, edges.dst['h'].view(-1, self.head, self.hd)], dim=-1)
        alpha = self.atten(z)
        alpha = self.leakyrelu(alpha)
        return {'atten': alpha, 'sf': src_features}

    def reduce(self, nodes):
        alpha = nodes.mailbox['atten']
        sf = nodes.mailbox['sf']
        alpha = self.softmax(alpha)
        out = torch.sum(alpha * sf, dim=1)
        if not self.if_sum:
            out = out.view(-1, self.head * self.hd)
        else:
            out = out.sum(dim=-2)
        return {'out': out}

    def sign_edges(self, edges):
        src = edges.src['feat']
        dst = edges.dst['feat']
        score = self.relation_aware(src, dst)
        return {'sign': torch.sign(score)}

这里是对 H2FDetector_layer 类的详细解释。这个类定义了一个图神经网络(GNN)层,它使用注意力机制来对图中的节点进行特征提取和更新。下面是对每一部分代码的详细解释。

class H2FDetector_layer(nn.Module):
    def __init__(self, input_dim, output_dim, head, relation_aware, etype, dropout, if_sum=False):
        super().__init__()
        self.etype = etype
        self.head = head
        self.hd = output_dim
        self.if_sum = if_sum
        self.relation_aware = relation_aware
        self.w_liner = nn.Linear(input_dim, output_dim * head)
        self.atten = nn.Linear(2 * self.hd, 1)
        self.relu = nn.ReLU()
        self.leakyrelu = nn.LeakyReLU()
        self.softmax = nn.Softmax(dim=1)

在这里插入图片描述
2.

def forward(self, g, h):
    with g.local_scope():
        g.ndata['feat'] = h
        g.apply_edges(self.sign_edges, etype=self.etype)
        h = self.w_liner(h)
        g.ndata['h'] = h
        g.update_all(message_func=self.message, reduce_func=self.reduce, etype=self.etype)
        out = g.ndata['out']
        return out

在这里插入图片描述
3.

def message(self, edges):
    src = edges.src
    src_features = edges.data['sign'].view(-1, 1) * src['h']
    src_features = src_features.view(-1, self.head, self.hd)
    z = torch.cat([src_features, edges.dst['h'].view(-1, self.head, self.hd)], dim=-1)
    alpha = self.atten(z)
    alpha = self.leakyrelu(alpha)
    return {'atten': alpha, 'sf': src_features}

在这里插入图片描述
4.

def reduce(self, nodes):
    alpha = nodes.mailbox['atten']
    sf = nodes.mailbox['sf']
    alpha = self.softmax(alpha)
    out = torch.sum(alpha * sf, dim=1)
    if not self.if_sum:
        out = out.view(-1, self.head * self.hd)
    else:
        out = out.sum(dim=-2)
    return {'out': out}

在这里插入图片描述
5.

def sign_edges(self, edges):
    src = edges.src['feat']
    dst = edges.dst['feat']
    score = self.relation_aware(src, dst)
    return {'sign': torch.sign(score)}

在这里插入图片描述
6.
在这里插入图片描述

2. RelationAware 类

这是一个关系感知的模块,用于计算边的关系权重。

  • 关系感知模块,用于计算边的关系权重。
class RelationAware(nn.Module):
    def __init__(self, input_dim, output_dim, dropout):
        super().__init__()
        self.d_liner = nn.Linear(input_dim, output_dim)
        self.f_liner = nn.Linear(3 * output_dim, 1)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, dst):
        src = self.d_liner(src)
        dst = self.d_liner(dst)
        diff = src - dst
        e_feats = torch.cat([src, dst, diff], dim=1)
        e_feats = self.dropout(e_feats)
        score = self.f_liner(e_feats).squeeze()
        score = self.tanh(score)
        return score

RelationAware 类是一个关系感知模块,用于计算图中边的关系权重。它通过处理源节点和目标节点的特征,生成一个关系得分。这个模块在图神经网络(GNN)中常用于捕捉节点之间的关系,从而增强模型的表达能力。
1.

class RelationAware(nn.Module):
    def __init__(self, input_dim, output_dim, dropout):
        super().__init__()
        self.d_liner = nn.Linear(input_dim, output_dim)
        self.f_liner = nn.Linear(3 * output_dim, 1)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(dropout)

在这里插入图片描述
2.
在这里插入图片描述

3. MultiRelationH2FDetectorLayer 类

这是一个处理多种关系的 GNN 层。

  • 处理多种关系的图卷积层,包含对不同关系类型的处理逻辑。
class MultiRelationH2FDetectorLayer(nn.Module):
    def __init__(self, input_dim, output_dim, head, dataset, dropout, if_sum=False):
        super().__init__()
        self.relation = copy.deepcopy(dataset.etypes)
        self.relation.remove('homo')
        self.n_relation = len(self.relation)
        if not if_sum:
            self.liner = nn.Linear(self.n_relation * output_dim * head, output_dim * head)
        else:
            self.liner = nn.Linear(self.n_relation * output_dim, output_dim)
        self.relation_aware = RelationAware(input_dim, output_dim * head, dropout)
        self.minelayers = nn.ModuleDict()
        self.dropout = nn.Dropout(dropout)
        for e in self.relation:
            self.minelayers[e] = H2FDetector_layer(input_dim, output_dim, head, self.relation_aware, e, dropout, if_sum)

    def forward(self, g, h):
        hs = []
        for e in self.relation:
            he = self.minelayers[e](g, h)
            hs.append(he)
        h = torch.cat(hs, dim=1)
        h = self.dropout(h)
        h = self.liner(h)
        return h

    def loss(self, g, h):
        with g.local_scope():
            g.ndata['feat'] = h
            agg_h = self.forward(g, h)

            g.apply_edges(self.score_edges, etype='homo')
            edges_score = g.edges['homo'].data['score']
            edge_train_mask = g.edges['homo'].data['train_mask'].bool()
            edge_train_label = g.edges['homo'].data['label'][edge_train_mask]
            edge_train_pos = edge_train_label == 1
            edge_train_neg = edge_train_label == -1
            edge_train_pos_index = edge_train_pos.nonzero().flatten().detach().cpu().numpy()
            edge_train_neg_index = edge_train_neg.nonzero().flatten().detach().cpu().numpy()
            edge_train_pos_index = np.random.choice(edge_train_pos_index, size=len(edge_train_neg_index))
            index = np.concatenate([edge_train_pos_index, edge_train_neg_index])
            index.sort()
            edge_train_score = edges_score[edge_train_mask]
            # hinge loss
            edge_diff_loss = hinge_loss(edge_train_label[index], edge_train_score[index])

            train_mask = g.ndata['train_mask'].bool()
            train_h = agg_h[train_mask]
            train_label = g.ndata['label'][train_mask]
            train_pos = train_label == 1
            train_neg = train_label == 0
            train_pos_index = train_pos.nonzero().flatten().detach().cpu().numpy()
            train_neg_index = train_neg.nonzero().flatten().detach().cpu().numpy()
            train_neg_index = np.random.choice(train_neg_index, size=len(train_pos_index))
            node_index = np.concatenate([train_neg_index, train_pos_index])
            node_index.sort()
            pos_prototype = torch.mean(train_h[train_pos], dim=0).view(1, -1)
            neg_prototype = torch.mean(train_h[train_neg], dim=0).view(1, -1)
            train_h_loss = train_h[node_index]
            pos_prototypes = pos_prototype.expand(train_h_loss.shape)
            neg_prototypes = neg_prototype.expand(train_h_loss.shape)
            diff_pos = -F.pairwise_distance(train_h_loss, pos_prototypes)
            diff_neg = -F.pairwise_distance(train_h_loss, neg_prototypes)
            diff_pos = diff_pos.view(-1, 1)
            diff_neg = diff_neg.view(-1, 1)
            diff = torch.cat([diff_neg, diff_pos], dim=1)
            diff_loss = F.cross_entropy(diff, train_label[node_index])

            return agg_h, edge_diff_loss, diff_loss

    def score_edges(self, edges):
        src = edges.src['feat']
        dst = edges.dst['feat']
        score = self.relation_aware(src, dst)
        return {'score': score}

4. H2FDetector 类

这是一个多层的 GNN 模型,用于构建一个关系感知的图神经网络模型。

  • 多层的关系感知图神经网络模型,包含前向传播和损失计算方法。
class H2FDetector(nn.Module):
    def __init__(self, args, g):
        super().__init__()
        self.n_layer = args.n_layer
        self.input_dim = g.nodes['r'].data['feature'].shape[1]
        self.intra_dim = args.intra_dim
        self.n_class = args.n_class
        self.gamma1 = args.gamma1
        self.gamma2 = args.gamma2
        self.n_layer = args.n_layer
        self.mine_layers = nn.ModuleList()
        if args.n_layer == 1:
            self.mine_layers.append(MultiRelationH2FDetectorLayer(self.input_dim, self.n_class, args.head, g, args.dropout, if_sum=True))
        else:
            self.mine_layers.append(MultiRelationH2FDetectorLayer(self.input_dim, self.intra_dim, args.head, g, args.dropout))
            for _ in range(1, self.n_layer - 1):
                self.mine_layers.append(MultiRelationH2FDetectorLayer(self.intra_dim * args.head, self.intra_dim, args.head, g, args.dropout))
            self.mine_layers.append(MultiRelationH2FDetectorLayer(self.intra_dim * args.head, self.n_class, args.head, g, args.dropout, if_sum=True))
        self.dropout = nn.Dropout(args.dropout)
        self.relu = nn.ReLU()

    def forward(self, g):
        feats = g.ndata['feature'].float()
        h = self.mine_layers[0](g, feats)
        if self.n_layer > 1:
            h = self.relu(h)
            h = self.dropout(h)
            for i in range(1, len(self.mine_layers) - 1):
                h = self.mine_layers[i](g, h)
                h = self.relu(h)
                h = self.dropout(h)
            h = self.mine_layers[-1](g, h)
        return h

    def loss(self, g):
        feats = g.ndata['feature'].float()
        train_mask = g.ndata['train_mask'].bool()
        train_label = g.ndata['label'][train_mask]
        train_pos = train_label == 1
        train_neg = train_label == 0

        pos_index = train_pos.nonzero().flatten().detach().cpu().numpy()
        neg_index = train_neg.nonzero().flatten().detach().cpu().numpy()
        neg_index = np.random.choice(neg_index, size=len(pos_index), replace=False)
        index = np.concatenate([pos_index, neg_index])
        index.sort()
        h, edge_loss, prototype_loss = self.mine_layers[0].loss(g, feats)
        if self.n_layer > 1:
            h = self.relu(h)
            h = self.dropout(h)
            for i in range(1, len(self.mine_layers) - 1):
                h, e_loss, p_loss = self.mine_layers[i].loss(g, h)
                h = self.relu(h)
                h = self.dropout(h)
                edge_loss += e_loss
                prototype_loss += p_loss
            h, e_loss, p_loss = self.mine_layers[-1].loss(g, h)
            edge_loss += e_loss
            prototype_loss += p_loss
        model_loss = F.cross_entropy(h[train_mask][index], train_label[index])
        loss = model_loss + self.gamma1 * edge_loss + self.gamma2 * prototype_loss
        return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1678310.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

thinkphp8扩展think-swoole4.0-事件监听代码

首先服务端配置监听 swoole.php <?phpreturn [http > [enable > true,host > 0.0.0.0,port > 8000,worker_num > swoole_cpu_num(),options > [],],websocket > [enable > true,handler > \think\swo…

【达梦数据库】搭建 DM->mysql dblink

DM->mysql dblink 1安装mysql odbc rpm -ivh mysql-connector-odbc-5.3.14-1.el7.x86_64.rpm2mysql创建远程用户与远程数据库 mysql> show databases; ------------------------- | Database | ------------------------- | information_schema | …

行测练习题

、、 【任意直角三角形&#xff0c;斜边的中点到三个顶点的距离相等。】 因此无人机的投影点一定为直角三角形斜边中点&#xff0c;之后根据勾股定理可以求得高度为500. 、、

桌椅3D模型素材从哪下载比较好?

对于室内设计师而言&#xff0c;经常需要用到桌椅3D模型来完成自己的设计方案&#xff0c;那么从哪里能下载高质量的桌椅3D模型素材呢? 1、建e网&#xff1a;建e网的3D模型库不仅数量庞大&#xff0c;而且质量上乘。模型制作精细&#xff0c;纹理清晰&#xff0c;可以直接用于…

【Open AI】GPT-4o深夜发布:视觉、听觉跨越式升级

北京时间5月14日1点整&#xff0c;OpenAI 召开了首场春季发布会&#xff0c;CTO Mira Murati 在台上和团队用短短不到30分钟的时间&#xff0c;揭开了最新旗舰模型 GPT-4o 的神秘面纱&#xff0c;以及基于 GPT-4o 的 ChatGPT&#xff0c;均为免费使用。 本文内容来自OpenAI网站…

vue-cropper裁剪图片 vue

效果图 1.配置环境 npm install vue-cropper 2.代码 <template><div class"cropper-content"><div class"cropper-box"><div class"cropper"><vue-cropper ref"cropper" :img"option.img" :…

5 个免费使用 GPT-4o 的方法

5 个免费使用 GPT-4o 的方法 虽然距离 OpenAI 发布 GPT-4o 已过去一天&#xff0c;我仍然对 GPT-4o 感到震撼。Demo 中语音助手功能实在是太令人惊叹了——它咯咯的笑声、准确的语气感叹和歌唱方式让 Siri 和 Google Assistant 显得相形见绌。 虽然备受期待的语音助手功能还要…

Elasticsearch:向量相似度技术和评分

作者&#xff1a;来自 Elastic Valentin Crettaz 当需要搜索自由文本并且 CtrlF / CmdF 不再有效时&#xff0c;使用词法搜索引擎通常是你想到的下一个合理选择。 词汇搜索引擎擅长分析要搜索的文本并将其标记为可在搜索时匹配的术语&#xff0c;但在理解和理解被索引和搜索的…

Acrel-2000L/A 绝缘监测系统设备 对多个绝缘检测仪进行统一数据管理

一、产品简介 Acrel-2000L/A 绝缘监测系统设备适用于 1kV 及以下低压配电系统。该设备可以集中采集监测显示绝缘监测仪的数据&#xff0c;实现最多 8 个绝缘监测仪的数据&#xff0c;并且实时记录告警信息和曲线查询。匹配的绝缘监测仪可以是 AIM-T300、AIM-T500 和 AIM-T500L。…

python 两种colorbar 最大最小和分类的绘制

1 colorbar 按照自定义的最值绘制 归一化方法使用Normalize(vmin0, vmax40.0) import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib.colors as mcolors from matplotlib import rcParams from matplot…

【全开源】国际版JAVA同城服务美容美发到店服务上门服务系统源码支持Android+IOS+H5

国际版同城服务美容美发到店与上门服务系统&#xff1a;一站式打造美丽新体验 随着人们生活水平的提高和审美观念的升级&#xff0c;美容美发服务已成为人们日常生活中不可或缺的一部分。为了满足全球消费者的多样化需求&#xff0c;我们推出了“国际版同城服务美容美发到店与…

基于单片机的光照检测系统—光敏电阻

基于单片机的光照检测系统 &#xff08;仿真&#xff0b;程序&#xff0b;原理图&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 1.光敏电阻实时采集环境光照值&#xff1b; 2.采用ADC0804将模拟值转换为数字量&#xff1b; 3.四位数码管显示当前的光照…

(gpt4o教程)gpt-4o如何开启和使用呢?

我发现&#xff0c;很多人反馈他的官网里没有gpt-4o的选项&#xff0c;下面介绍一下怎么查看是否使用了gpt-4o模型。 一、使用方法 1. 官网网站直接使用 2. 通过Open API申请接口使用 3. 通过LLM基准测试竞技场体验 还有其他方法&#xff0c;就不一一举例了。可以先看看上…

Linux之内存管理-malloc \kmalloc\vmalloc

1、malloc 函数 1.1分配内存小于128k,调用brk malloc是C库实现的函数&#xff0c;C库维护了一个缓存&#xff0c;当内存够用时&#xff0c;malloc直接从C库缓存分配&#xff0c;只有当C库缓存不够用&#xff1b; 当申请的内存小于128K时&#xff0c;通过系统调用brk&#xff…

提升写作效率的秘密武器:一个资深编辑的AI写作体验

有句话说:“写作是一项你坐在打字机前流血的工作。”而如今,各类生成式软件的涌现似乎打破了写作这一古老的艺术形式壁垒。过去,作家们独自在书桌前冥思苦想,如今,一款名为“玲珑AI工具”的ai写作助手正悄然改变着文案写作行业的创作生态,成为提升写作效率的秘密武器。 在传统…

STL—string类(1)

一、string类 1、为什么要学习string&#xff1f; C语言中&#xff0c;字符串是以\0结尾的一些字符的集合&#xff0c;为了操作方便&#xff0c;C标准库中提供了一些str系列的库函数&#xff0c;但是这些库函数与字符串是分离开的&#xff0c;不太符合OOP&#xff08;面向对象…

JVS物联网、无忧企业文档、规则引擎5.14功能新增说明

项目介绍 JVS是企业级数字化服务构建的基础脚手架&#xff0c;主要解决企业信息化项目交付难、实施效率低、开发成本高的问题&#xff0c;采用微服务配置化的方式&#xff0c;提供了 低代码数据分析物联网的核心能力产品&#xff0c;并构建了协同办公、企业常用的管理工具等&am…

uniapp 配置请求代理+请求封装

uniapp官网提供了三种方式&#xff1a;什么是跨域 | uni-app官网 1. 通过uniapp自带浏览器 打开项目是不存在跨域的 第二种方式&#xff1a; "h5" : {"template" : "static/index.html","devServer": {"proxy": {&quo…

汇凯金业:3个高效的黄金投资技巧

黄金投资中的高效技巧往往承载了许多投资前辈的智慧与经验教训&#xff0c;成为新手投资者宝贵的学习资料。历史上积累的黄金投资经验可以作为新投资者的学习榜样。 3个高效的黄金投资技巧 一、稳健的中长期投资策略 在金属投资领域虽然不乏短线交易高手&#xff0c;但新手投资…

BFS和DFS优先搜索算法

1. BFS与DFS 1.1 BFS DFS即Depth First Search&#xff0c;深度优先搜索。它是一种图遍历算法&#xff0c;它从一个起始点开始&#xff0c;逐层扩展搜索范围&#xff0c;直到找到目标节点为止。 这种算法通常用于解决“最短路径”问题&#xff0c;比如在迷宫中找到从起点到终…