【图神经网络】GNNExplainer代码解读及其PyG实现

news2025/1/11 20:41:14

GNNExplainer代码解读及其PyG实现

  • 使用GNNExplainer
  • GNNExplainer源码速读
    • 前向传播
    • 损失函数
  • 基于GNNExplainer图分类解释的PyG代码示例
  • 参考资料

接上一篇博客图神经网络的可解释性方法及GNNexplainer代码示例,我们这里简单分析GNNExplainer源码,并用PyTorch Geometric手动实现。
GNNExplainer的源码地址:https://github.com/RexYing/gnn-model-explainer

使用GNNExplainer

(1)安装:

git clone https://github.com/RexYing/gnn-model-explainer

推荐使用python3.7以及创建虚拟环境:

virtualenv venv -p /usr/local/bin/python3
source venv/bin/activate

(2)训练一个GCN模型

python train.py --dataset=EXPERIMENT_NAME

其中EXPERIMENT_NAME表示想要复现的实验名称。

训练GCN模型的完整选项列表:

python train.py --help

(3)解释一个GCN模型
要运行解释器,请运行以下内容:

python explainer_main.py --dataset=EXPERIMENT_NAME

(4)可视化解释
使用Tensorboard:优化的结果可以通过Tensorboard可视化。

tensorboard --logdir log

GNNExplainer源码速读

GNNExplainer会从2个角度解释图:

  • 边(edge):会生成一个edge mask,表示每条边在图中出现的概率,值为0-1之间的浮点数。edge mask也可以当作一个权重,可以取topk的edge连成的子图来解释。
  • 结点特征(node feature):node feature(NF)即结点向量,比如一个结点128维表示128个特征,那么它同时会生成一个NF mask来表示每个特征的权重,这个可以不要。

代码目录

  • explainer目录下的ExplainModel类定义了GNNExplainer网络的模块结构,继承torch.nn.Module:

    • 在初始化init的时候,用construct_edge_maskconstruct_feat_mask函数初始化要学习的两个mask(分别对应于两个nn.Parameter类型的变量: n × n n×n n×n维的maskd维全0的feat_mask);diag_mask即主对角线上是0,其余元素均为1的矩阵,用于_masked_adj函数。
    • _masked_adj函数将mask用sigmod或ReLU激活后,加上自身转置再除以2,以转为对称矩阵,然后乘上diag_mask,最终将原邻接矩阵adj变换为masked_adj
  • Explainer类实现了解释的逻辑,主函数是其中的explain,用于解释原模型在单节点的预测结果,主要步骤:

    1. 取子图的adj, x, label图解释:取graph_idx对应的整个计算图;节点解释:调用extract_neighborhood函数取该节点num_gc_layers阶数的邻居。
    2. 将传入的模型预测输出pred转为pred_label
    3. 构建ExplainModule,进行num_epochs轮训练(前向+反向传播)
adj   = torch.tensor(sub_adj, dtype=torch.float)
x     = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
label = torch.tensor(sub_label, dtype=torch.long)

if self.graph_mode:
	pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
	print("Graph predicted label: ", pred_label)
else:
	pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)
	print("Node predicted label: ", pred_label[node_idx_new])

explainer = ExplainModule(
	adj=adj,
	x=x,
	model=self.model,
	label=label,
	args=self.args,
	writer=self.writer,
	graph_idx=self.graph_idx,
	graph_mode=self.graph_mode,
)
if self.args.gpu:
	explainer = explainer.cuda()

...

# NODE EXPLAINER
def explain_nodes(self, node_indices, args, graph_idx=0):
...

def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"):
...

# GRAPH EXPLAINER
def explain_graphs(self, graph_indices):
...

explain_nodesexplain_nodes_gnn_statsexplain_graphs这三个函数都是在它的基础上实现的。

下面分析其中的forwardloss函数。

前向传播

首先把待学习的参数mask和feat_mask分别乘上原邻接矩阵和特征向量,得到变换后的masked_adjx。前者通过调用_masked_adj函数完成,后者的实现如下:

feat_mask = (
	torch.sigmoid(self.feat_mask)
	if self.use_sigmoid
	else self.feat_mask
)
if marginalize:
	std_tensor = torch.ones_like(x, dtype=torch.float) / 2
	mean_tensor = torch.zeros_like(x, dtype=torch.float) - x
	z = torch.normal(mean=mean_tensor, std=std_tensor)
	x = x + z * (1 - feat_mask)
else:
	x = x * feat_mask

完整代码如下:
forward
这里需要说明的是marginalize为True的情况,参考论文中的Learning binary feature selector F:
Learning binary feature selector F

  • 如果同mask一样学习feature_mask,在某些情况下回导致重要特征也被忽略(学到的特征遮罩也是接近于0的值),因此,依据 X S X_S XS的经验边缘分布使用Monte Carlo方法来抽样得到 X = X S F X=X_S^F X=XSF.
  • 为了解决随机变量 X X X的反向传播的问题,引入了"重参数化"的技巧,即将其表示为一个无参的随机变量 Z Z Z的确定性变换: X = Z + ( X S − Z ) ⊙ F X=Z+(X_S-Z)\odot F X=Z+(XSZ)F s . t . ∑ j F j ≤ K F s.t. \sum_{j}F_j\le K_F s.t.jFjKF
    其中, Z Z Z是依据经验分布采样得到的 d d d维随机变量, K F K_F KF是表示保留的最大特征数的参数(utils/io_utils.py中的denoise_graph函数)。

接着将masked_adjx输入原始模型得到ExplainModule结果pred

损失函数

loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss

可知,总的loss包含五项,除了对应于论文中损失函数公式的pred_loss,其余各项损失的作用参考论文Integrating additional constraints into explanations,它们的权重定义在coeffs中:

self.coeffs = {
	"size": 0.005,
	"feat_size": 1.0,
	"ent": 1.0,
	"feat_ent": 0.1,
	"grad": 0,
	"lap": 1.0,
}

Integrating additional constraints into explanations

  1. pred_loss
mi_obj = False
if mi_obj:
	pred_loss = -torch.sum(pred * torch.log(pred))
else:
	pred_label_node = pred_label if self.graph_mode else pred_label[node_idx]
	gt_label_node = self.label if self.graph_mode else self.label[0][node_idx]
	logit = pred[gt_label_node]
	pred_loss = -torch.log(logit)

其中pred是当前的预测结果,pred_label是原始特征上的预测结果。

  1. mask_ent_loss
# entropy
mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
  1. size_loss
# size
mask = self.mask
if self.mask_act == "sigmoid":
	mask = torch.sigmoid(self.mask)
elif self.mask_act == "ReLU":
	mask = nn.ReLU()(self.mask)
size_loss = self.coeffs["size"] * torch.sum(mask)
  1. feat_size_loss
# pre_mask_sum = torch.sum(self.feat_mask)
feat_mask = (
	torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
)
feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
  1. lap_loss
# laplacian
D = torch.diag(torch.sum(self.masked_adj[0], 0))
m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx]
L = D - m_adj
pred_label_t = torch.tensor(pred_label, dtype=torch.float)
if self.args.gpu:
	pred_label_t = pred_label_t.cuda()
	L = L.cuda()
if self.graph_mode:
	lap_loss = 0
else:
	lap_loss = (self.coeffs["lap"] * (pred_label_t @ L @ pred_label_t) / self.adj.numel())

补充

基于GNNExplainer图分类解释的PyG代码示例

对于图分类问题的解释,关键点有两个:

  • 要学习的Mask作用在整个图上,不用取子图
  • 标签预测和损失函数的对象是单个graph

实现代码如下:

#!/usr/bin/env python
# encoding: utf-8
# Created by BIT09 at 2023/4/28
import torch
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
from tqdm import tqdm
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, to_networkx

EPS = 1e-15


class GNNExplainer(torch.nn.Module):
    r"""
    Args:
        model (torch.nn.Module): The GNN module to explain.
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        log (bool, optional): If set to :obj:`False`, will not log any learning
            progress. (default: :obj:`True`)
    """

    coeffs = {
        'edge_size': 0.001,
        'node_feat_size': 1.0,
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
    }

    def __init__(self, model, epochs=100, lr=0.01, log=True, node=False):  # disable node_feat_mask by default
        super(GNNExplainer, self).__init__()
        self.model = model
        self.epochs = epochs
        self.lr = lr
        self.log = log
        self.node = node

    def __set_masks__(self, x, edge_index, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        if self.node:
            self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)
        self.edge_mask = torch.nn.Parameter(torch.zeros(E) * 50)

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

    def __clear_masks__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = False
                module.__edge_mask__ = None
        if self.node:
            self.node_feat_masks = None
        self.edge_mask = None

    def __num_hops__(self):
        num_hops = 0
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                num_hops += 1
        return num_hops

    def __flow__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                return module.flow
        return 'source_to_target'

    def __subgraph__(self, node_idx, x, edge_index, **kwargs):
        num_nodes, num_edges = x.size(0), edge_index.size(1)

        if node_idx is not None:
            subset, edge_index, mapping, edge_mask = k_hop_subgraph(
                node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
                num_nodes=num_nodes, flow=self.__flow__())
            x = x[subset]
        else:
            x = x
            edge_index = edge_index
            row, col = edge_index
            edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
            edge_mask[:] = True
            mapping = None

        for key, item in kwargs:
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                item = item[subset]
            elif torch.is_tensor(item) and item.size(0) == num_edges:
                item = item[edge_mask]
            kwargs[key] = item

        return x, edge_index, mapping, edge_mask, kwargs

    def __graph_loss__(self, log_logits, pred_label):
        loss = -torch.log(log_logits[0, pred_label])
        m = self.edge_mask.sigmoid()
        loss = loss + self.coeffs['edge_size'] * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coeffs['edge_ent'] * ent.mean()

        return loss

    def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
                           threshold=None, **kwargs):
        r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
        :attr:`edge_mask`.

        Args:
            node_idx (int): The node id to explain.
            edge_index (LongTensor): The edge indices.
            edge_mask (Tensor): The edge mask.
            y (Tensor, optional): The ground-truth node-prediction labels used
                as node colorings. (default: :obj:`None`)
            threshold (float, optional): Sets a threshold for visualizing
                important edges. If set to :obj:`None`, will visualize all
                edges with transparancy indicating the importance of edges.
                (default: :obj:`None`)
            **kwargs (optional): Additional arguments passed to
                :func:`nx.draw`.

        :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
        """

        assert edge_mask.size(0) == edge_index.size(1)

        if node_idx is not None:
            # Only operate on a k-hop subgraph around `node_idx`.
            subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
                node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
                num_nodes=None, flow=self.__flow__())

            edge_mask = edge_mask[hard_edge_mask]
            subset = subset.tolist()
            if y is None:
                y = torch.zeros(edge_index.max().item() + 1,
                                device=edge_index.device)
            else:
                y = y[subset].to(torch.float) / y.max().item()
                y = y.tolist()
        else:
            subset = []
            for index, mask in enumerate(edge_mask):
                node_a = edge_index[0, index]
                node_b = edge_index[1, index]
                if node_a not in subset:
                    subset.append(node_a.item())
                if node_b not in subset:
                    subset.append(node_b.item())
            y = [y for i in range(len(subset))]

        if threshold is not None:
            edge_mask = (edge_mask >= threshold).to(torch.float)

        data = Data(edge_index=edge_index, att=edge_mask, y=y,
                    num_nodes=len(y)).to('cpu')
        G = to_networkx(data, edge_attrs=['att'])  # , node_attrs=['y']
        mapping = {k: i for k, i in enumerate(subset)}
        G = nx.relabel_nodes(G, mapping)

        kwargs['with_labels'] = kwargs.get('with_labels') or True
        kwargs['font_size'] = kwargs.get('font_size') or 10
        kwargs['node_size'] = kwargs.get('node_size') or 800
        kwargs['cmap'] = kwargs.get('cmap') or 'cool'

        pos = nx.spring_layout(G)
        ax = plt.gca()
        for source, target, data in G.edges(data=True):
            ax.annotate(
                '', xy=pos[target], xycoords='data', xytext=pos[source],
                textcoords='data', arrowprops=dict(
                    arrowstyle="->",
                    alpha=max(data['att'], 0.1),
                    shrinkA=sqrt(kwargs['node_size']) / 2.0,
                    shrinkB=sqrt(kwargs['node_size']) / 2.0,
                    connectionstyle="arc3,rad=0.1",
                ))
        nx.draw_networkx_nodes(G, pos, node_color=y, **kwargs)
        nx.draw_networkx_labels(G, pos, **kwargs)

        return ax, G

    def explain_graph(self, data, **kwargs):
        self.model.eval()
        self.__clear_masks__()
        x, edge_index, batch = data.x, data.edge_index, data.batch

        num_edges = edge_index.size(1)

        # Only operate on a k-hop subgraph around `node_idx`.
        x, edge_index, _, hard_edge_mask, kwargs = self.__subgraph__(node_idx=None, x=x, edge_index=edge_index,
                                                                     **kwargs)
        # Get the initial prediction.
        with torch.no_grad():
            log_logits = self.model(data, **kwargs)
            probs_Y = torch.softmax(log_logits, 1)
            pred_label = probs_Y.argmax(dim=-1)

        self.__set_masks__(x, edge_index)
        self.to(x.device)

        if self.node:
            optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                         lr=self.lr)
        else:
            optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)

        epoch_losses = []
        for epoch in range(1, self.epochs + 1):
            epoch_loss = 0
            optimizer.zero_grad()
            if self.node:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()

            log_logits = self.model(data, **kwargs)
            pred = torch.softmax(log_logits, 1)
            loss = self.__graph_loss__(pred, pred_label)
            loss.backward()

            optimizer.step()
            epoch_loss += loss.detach().item()
            epoch_losses.append(epoch_loss)

        edge_mask = self.edge_mask.detach().sigmoid()
        print(edge_mask)

        self.__clear_masks__()

        return edge_mask, epoch_losses

    def __repr__(self):
        return f'{self.__class__.__name__}()'

参考资料

  1. gnn-explainer
  2. 图神经网络的可解释性方法及GNNexplainer代码示例
  3. Pytorch实现GNNExplainer
  4. How to Explain Graph Neural Network — GNNExplainer
  5. https://gist.github.com/hongxuenong/9f7d4ce96352d4313358bc8368801707

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/496600.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年中职组“网络空间安全”赛项XX市竞赛任务书

2023年中职组“网络空间安全”赛项 XX市竞赛任务书 一、竞赛时间 共计:180分钟 二、竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第一阶段单兵模式系统渗透测试 任务一 SSH弱口令渗透测试 100分钟 100 任务二 Linux操作系统渗透测试 100 任…

deep learning system 笔记 自动微分 reverse mode AD

计算图 Computational Graph 图上的每个节点代表一个中间值边事输入输出的关系 forward 求导 forward mode AD 上图中从前向后,一步一步计算每个中间值对 x1的偏导,那么计算到 v7,就得到了整个函数对于 x1的偏导。 有limitation 对一个参数…

单机版部署Redis详细教程

概述 大多数企业都是基于Linux服务器来部署项目,而且Redis官方也没有提供Windows版本的安装包。因此课程中我们会基于Linux系统来安装Redis. 此处选择的Linux版本为CentOS 7. Redis的官方网站地址:https://redis.io/ 单机安装Redis 1.1.安装Redis依…

【IP地址与子网掩码】如何计算网络地址、广播地址、地址范围、主机个数、子网数(附详解与习题)

【写在前面】其实很多时候通过IP地址和子网掩码计算其网络地址、广播地址、可用IP,地址范围,主机数啥的,有些人不太清楚规则就只能瞎猜了,但是作为一个网络管理员还是一个基础常识的,这不因为最近备考网络管理员&#…

【数据结构】八大排序(二)

😛作者:日出等日落 📘 专栏:数据结构 在最黑暗的那段人生,是我自己把自己拉出深渊。没有那个人,我就做那个人。 …

API接口的对接流程和注意事项

一、对接API数据接口的步骤通常包括以下几个部分: 了解API:首先需要详细了解API的基本信息、请求格式、返回数据格式、错误码等相关信息。可以查看API的官方文档或者使用API探索工具。同时,还需要明确数据请求的频率和使用权限等限制。 ​​测…

恐怖,又要有多少人下岗!AI零成本设计主图,渗入10万亿电商市场

在电商平台上,主图是吸引消费者点击进入商品详情页的重要因素之一。 一张高点击的电商主图,不仅要能够吸引消费者的眼球,还要能够清晰地展示产品的特点和卖点。下面是一些制作高点击电商主图的建议。 1. 突出产品特点:在制作主图…

【Spring】Spring的事务管理

目录 1.Spring事务管理概述1.1 事务管理的核心接口1. PlatformTransactionManager2. TransactionDefinition3. TransactionStatus 1.2 事务管理的方式 2.声明式事务管理2.1 基于XML方式的声明式事务2.2 基于Annotation方式的声明式事务 1.Spring事务管理概述 Spring的事务…

惠普暗影精灵5 super 873-068rcn如何重装系统

惠普暗影精灵5 super 873-068rcn是一款家用游戏台式电脑,有时候你可能用久会遇到系统出现故障、中毒、卡顿等问题,或者你想要更换一个新的操作系统,这时候你就需要重装系统。重装系统可以让你的电脑恢复到出厂状态,清除所有的个人…

【vite+vue3.2 项目性能优化实战】打包体积分析插件rollup-plugin-visualizer视图分析

rollup-plugin-visualizer是一个用于Rollup构建工具的插件,它可以生成可视化的构建报告,帮助开发者更好地了解构建过程中的文件大小、依赖关系等信息。 使用rollup-plugin-visualizer插件,可以在构建完成后生成一个交互式的HTML报告&#xf…

【提示学习】Label prompt for multi-label text classification

论文信息 名称内容论文标题Label prompt for multi-label text classification论文地址https://link.springer.com/article/10.1007/s10489-022-03896-4研究领域NLP, 文本分类, 提示学习, 多标签提出模型LP-MTC(Label Prompt Multi-label Text Classification model)来源Appli…

Docker跨主机网络通信

常见的跨主机通信方案主要有以下几种: 形式描述Host模式容器直接使用宿主机的网络,这样天生就可以支持跨主机通信。这样方式虽然可以解决跨主机通信的问题,但应用场景很有限,容易出现端口冲突,也无法做到隔离网络环境…

buildroot系统调试苹果手机网络共享功能

苹果手机usb共享网络调试 首先了解usb基础知识,比如usb分为主设备和从设备进行通信, 1.HOST模式下是只能做主设备, 2.OTG模式下是可以即做主又可以做从,主设备即HCD,从设备即UDC(USB_GADGET &#xff09…

年后准备进腾讯的可以看看....

大家好~ 最近内卷严重,各种跳槽裁员,今天特意分享一套学习笔记 / 面试手册,年后跳槽的朋友想去腾讯的可以好好刷一刷,还是挺有必要的,它几乎涵盖了所有的软件测试技术栈,非常珍贵,肝完进大厂&a…

多态的原理

有了虚函数,会在类的对象增加一个指针,该指针就是虚函数表指针_vfptr;虚表本质就是函数指针数组,虚表里面存放着该对象的虚函数的地址; 派生类继承有虚函数基类的对象模型 子类继承父类的虚表指针时,是对父类的虚表指针进行了拷…

密码学:古典密码.

密码学:古典密码. 古典密码是密码学的一个类型,大部分加密方式是利用替换式密码或移项式密码,有时是两者的混合。古典密码在历史上普遍被使用,但到现代已经渐渐不常用了。一般来说,一种古典密码体制包含一个字母表(如…

MATLAB 点云均匀体素下采样(6)

MATLAB 点云均匀体素下采样的不同参数效果测试 (6) 一、实现效果二、算法介绍三、函数说明3.1 函数3.2 参数四、实现代码(详细注释!)一、实现效果 不同参数调整下的均匀体素下采样结果如下图所示,后续代码复制黏贴即可: 分别为0.3m,0.2m,0.1m尺度下的格网下采样结果…

【C++复习2】C++编译器的工作原理

如果你是一名newbird的话,建议观看如下视频加深你的理解,再看如下内容: https://www.bilibili.com/video/BV1N24y1B7nQ?p7 The cherno会额外告诉你如何将目标文件转换成汇编代码,CPU执行指令的过程以及编译器如何通过删除冗余变…

【MySQL】SQL优化

上一篇索引是针对查询语句进行优化,但在MySQL中可不仅有查询语句,针对其他的SQL语句同样也能进行优化 文章目录 1.插入数据2.主键优化3.order by 优化4.group by优化5.limit优化6.update优化 1.插入数据 插入数据所使用的关键字为insert,SQL语句为 insert into 表名(字段1,字…

Huntly: 一款超强大的自托管信息管理工具,支持管理RSS、自动保存网页、稍后阅读

Huntly是一款开源的自托管信息管理工具,旨在帮助用户更好地管理和处理各种信息。Huntly可以通过管理RSS、自动保存网页和稍后阅读等功能来帮助用户更有效地收集、保存和浏览信息。 github 地址:GitHub - lcomplete/huntly: Huntly, information manageme…