【可解释学习】PyG可解释学习模块torch_geometric.explain

news2024/12/23 6:51:40

PyG可解释学习模块torch_geometric.explain

  • Philoshopy
  • Explainer
  • Explanations
  • Explainer Algorithm
  • Explanation Metrics
  • 参考资料

torch_geometric.explain是PyTorch Geometric库中的一个模块,用于解释和可视化图神经网络(GNN)模型的预测结果。它提供了一些方法来解释模型的预测结果、边权重和节点重要性。

主要内容有:Philoshopy(哲学思想)、Explainer(解释器)、Explanations(解释)、Explainer Algorithm(解释器算法)、Explanation Metrics(解释度量)

Philoshopy

该模块提供了一组工具来解释 PyG 模型的预测或解释数据集的基本现象,详细信息可以参考“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”

我们使用torch_geometric.explain.Explanation类来表示解释,该类是一个Data对象,包含数据的节点、边、特征和任何属性的掩码。
torch_geometric.expllain.Explainer类设计用于处理所有可解释性参数(有关更多详细信息,请参阅torch_geometric.explainn.config.ExplainerConfig类):

  • 使用的算法来自torch_geometric.expllain.algorithm模块中,例如GNNExplainer
  • 要计算的解释类型(例如explanation_type="phenomenon"或者explanation_type="model"
  • 节点和边的不同类型掩码(例如mask="object"或者mask="attributes"
  • 掩码的任何后处理(例如threshold_type="topk"或者threshold_type="hard"

该类允许用户轻松比较不同的可解释性方法,并在不同类型的掩码之间轻松切换,同时确保高层的代码框架保持不变。

Explainer

Explainer
基础类:object
图神经网络实例级解释的一个解释器类。
参数

  • model(torch.nn.Module)——要解释的模型
  • algorithm(解释器算法)——解释算法
  • explanation_type(解释类型或str)——要计算的解释类型。可能的值为:
    • “model”:解释模型预测
    • “phenomenon”:解释模型试图预测的现象。
      在实践中,这意味着解释算法将计算其相对于模型输出(“model”)或目标输出(“phenomenon”)的损失。
  • model_config——模型配置,参见ModelConfig,默认为None
  • node_mask_type——要应用于节点的掩码类型。可能的值为:
    • “None”:不会在节点上应用任何掩码。
    • “object”:将屏蔽每个节点。
    • “common_attributes”:将掩盖每个特征。
    • “attributes”:将屏蔽所有节点上的每个特征。
  • edge_mask_type——要应用于边的掩码类型。具有的可能值例如node_mask_type。默认为None
  • threshold_config——阈值设置,可选数值参见ThresholdConfig,默认为None

方法

  • get_prediction(*args, **kwargs)→ Tensor:返回模型对输入图的预测。
    如果模型模式为“regression”,则预测将作为标量值返回。如果模型模式为“multiclass_classification”或“binary_classifications”,则预测将作为预测类标签返回。
  • get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Tensor:返回应用了节点和边掩码的输入图上的模型预测。
  • __call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算给定输入和目标的GNN的解释。
    如果收到一条错误消息,如“Trying to backward through the graph a second time”,请确保提供的目标是用torch.no.grad()计算的。
    • x——通志图或异质图的输入节点特征。
    • edge_index——同质或异质图的输入边索引。
    • target——模型的目标。如果解释类型是“phenomenon”,则必须提供解释对象。如果解释类型是“ model”,那么目标应该设置为 Nothing,并且会自动推断出来。(默认值: None)
    • index——对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
    • **kwargs——要传递给GNN的其他参数。
  • get_target(prediction: Tensor)→ Tensor:从给定的预测中返回模型的目标。
    如果模型模式为“regression”类型,则按原样返回预测;如果模型模式类型为“multiclass_classification”或“binary_classifications”,则按预测类标签返回预测。

ExplainerConfig
用于存储和验证高级解释参数的配置类。
参数

  • explanation_type——要计算的解释类型。可能的值为:
    • “model”——解释模型预测。
    • “pheonmenon”——解释模型试图预测的现象。
      在实践中,这意味着解释算法将计算它们相对于模型输出(“model”)或目标输出(“pheonmenon”)的损失。
  • node_mask_type——要应用于节点的掩码类型。可能的值为(默认值:None):
    • “None”:不会在节点上应用任何掩码。
    • “object”:将屏蔽每个节点。
    • “common_attributes”:将掩盖每个特征。
    • “attributes”:将屏蔽所有节点上的每个特征。
  • edge_mask_type——要应用于边的掩码类型。具有的可能值例如node_mask_type。默认为None

ModelConfig
用于存储模型参数的配置类。
参数

  • model——模型的模式。可能的值为:
    • “binary_classification”:一个二分类模型。
    • “multiclass_classification”:一种多类分类模型。
    • “regression”:一个回归模型
  • task_level——模型的任务级别。可能的值为:
    • “node”:一个node-level预测模型
    • “edge”:一个edge-level预测模型
    • “graph”:一个graph-level预测模型
  • return_type——模型的返回类型。可能的值为(默认值:None):
    • “raw”:模型返回原始值。
    • “probs”:模型返回概率值
    • “log_probs”:模型返回对数概率

ThresholdConfig
用于存储和验证阈值参数的配置类。
参数

  • threshold_type——要应用的阈值的类型。可能的值为:
    • “None”:没有阈值被应用
    • “hard”:将hard阈值应用于每个掩码。掩码中值低于该值的元素设置为0,其他元素设置为1。
    • “topk”:soft阈值被应用于每个掩码。保留每个掩码的top obj:value元素,其他元素设置为0。
    • “topk_hard”:“topk”相同,但保留的所有元素的值都设置为1。
  • value——设置阈值时要使用的值。(默认值:None)

Explanations

Explanation
基础类:DataExplanationMixin
持有同质图的所有已得到的解释。解释对象是Data对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
参数:

  • node_mask——形状为[num_nodes, 1], [1, num_features][num_nodes, num_features]的node-level掩码,默认为None

  • edge_mask——形状为[num_edges]的edge_level掩码,默认为None

  • kwargs——其他属性参数
    方法

  • validate(raise_on_error: bool = True)→ bool:验证解释对象的正确性。

  • get_explanation_subgraph()→ Explanation:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。

  • get_complement_subgraph()→ Explanation:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。

  • visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None):通过对所有节点的节点掩码求和,创建节点要素重要性的条形图。
    参数

    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • feat_labels: 特征的标签。(默认为“None”)
    • top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)
  • visualize_graph(path: Optional[str] = None, backend: Optional[str] = None): 使具有与边重要性相对应的边不透明度的解释图可视化。
    参数

    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • backend: 用于可视化的图形绘制后端(“graphviz”、“networkx”)。如果设置为“None”,将根据可用的系统包使用最合适的可视化后端。(默认值:None)

HeteroExplanation
基础类:HeteroDataExplanationMixin
包含所有已获得的对异构图的解释。解释对象是HeteroData对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
方法:

  • validate(raise_on_error: bool = True)→ bool:验证解释对象的正确性。
  • get_explanation_subgraph()→ HeteroExplanation:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。
  • get_complement_subgraph()→ HeteroExplanation:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。
  • visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None):通过对每个节点类型的所有节点的节点掩码求和,创建节点特征重要性的条形图。
    参数
    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • feat_labels: 特征的标签。(默认为“None”)
    • top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)

Explainer Algorithm

1) ExplainerAlgorithm:用于实现解释器算法的抽象基类。
方法
- abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算解释。
参数
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- abstract supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
- property explainer_config: ExplainerConfig:返回已连接的解释器配置。
- property model_config: ModelConfig:返回已连接的模型配置
- connect(explainer_config: ExplainerConfig, model_config: ModelConfig):将解释器和模型配置连接到解释器算法。
2) DummyExplainer:返回随机解释的伪解释程序(用于测试目的)。
基础类:ExplainerAlgorithm
方法
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:解释计算
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。

  • GNNExplainer:来自 “GNNExplainer: Generating Explanations for Graph Neural Networks” 论文中的GNN-Explainer模型用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。
    基础类:ExplainerAlgorithm
    有关使用GNNEexplainer的示例,请参见examples/explaine/gnn_explainer.py、examples/explain/gnn_eexplainer_ba_shapes.py和examples/explain/gn_explainer_link_pred.py。
    参数

    • epochs:要训练的epochs数。默认为100
    • lr:学习率,默认为0.01
    • kwargs:用于覆盖coeffs中默认设置的附加超参数。

    方法:

    • forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
    • supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
      3) CaptumExplainer:一种基于Captum的解释器,用于识别在GNN的预测中起关键作用的紧凑子图结构和节点特征。
      基础类:ExplainerAlgorithm
      这个解释器算法使用Captum来计算属性。目前,支持以下归因方法:
    • captum.attr.IntegratedGradients
    • captum.attr.Saliency
    • captum.attr.InputXGradient
    • captum.attr.Deconvolution
    • captum.attr.ShapleyValueSampling
    • captum.attr.GuidedBackprop
      参数
    • attribution_method:要使用的Captum归因方法。可以是字符串或captum.attr方法。
    • kwargs:Captum归因方法的其他参数。
      方法
    • forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算解释
    • supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
      4) PGExplainer: "Parameterized Explainer for Graph Neural Network"论文中的PGExplainer模型。
      基础类:ExplainerAlgorithm
      在内部,它利用神经网络来识别在GNN的预测中起关键作用的子图结构。重要的是,在生成解释之前,PGExplainer需要通过train()进行训练:
explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=ModelConfig(...),
)

# 针对各种节点级别或图级别的预测进行训练:
for epoch in range(30):
    for index in [...]:  # Indices to train against.
        loss = explainer.algorithm.train(epoch, model, x, edge_index,
                                         target=target, index=index)

# 获得最终解释:
explanation = explainer(x, edge_index, target=target, index=0)

参数
- epochs:要训练的epochs数。
- lr:学习率,默认为0.003
- kwargs:用于覆盖coeffs中默认设置的附加超参数。
方法
- reset_parameters():重置模型中所有科学系参数
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs):训练基础的解释者模型。需要在能够做出预测之前被调用。
参数
- epoch:训练阶段的当前阶段。
- model:要被解释的模型
- x:同质图的输入节点特征。
- edge_index:同质图的输入边索引。
- target:模型的目标
- index:对模型输出的索引进行解释。需要是一个单独的索引。(默认值:None)
- kwargs:传递给model的其他关键字参数。
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
5) AttentionExplainer:使用基于注意力的GNN(例如,GATConv、GATv2Conv或TransformerConv)产生的注意力系数作为边解释的解释器。各层和头部的注意力得分将根据reduce argument进行汇总。
基础类:ExplainerAlgorithm
参数:
- reduce:降低各层和头部注意力得分的方法。(默认值:“max”)
方法:
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
- - supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。

Explanation Metrics

解释的质量可以通过各种不同的方法来判断。PyG支持以下开箱即用的指标:

  • groundtruth_metrics:将解释掩码与ground-truth解释掩码进行比较和评估。
  • fidelity:评估一个Explainer给出的Explanation的真实度 ,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • characterization_score:返回组件式特征化分数(the componentwise characterization score),参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • fidelity_curve_auc:返回真实度曲线的AUC,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • unfaithfulness:评估一个Explanation对一个不足的GNN预测因子的真实度,参见 "Evaluating Explainability for Graph Neural Networks"论文

参考资料

  1. 可解释性研究(四)-GNNExplainer的内部实现
  2. GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks
  3. torch_geometric.explain官方文档
  4. GNNExplainer: Generating Explanations for Graph Neural Networks
  5. Captum
  6. PGExplainer
  7. Evaluating Explainability for Graph Neural Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/754826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RestClient操作文档和DSL查询语法

一、 文档操作 1、新增文档 本案例中,hotel为索引库名,61083为文档idTestvoid testAddDocument() throws IOException {// 1.根据id查询酒店数据Hotel hotel hotelService.getById(61083L);// 2.转换为文档类型HotelDoc hotelDoc new HotelDoc(hotel…

【数据结构】二叉树——链式结构

目录 一、前置声明 二、二叉树的遍历 2.1 前序、中序以及后序遍历 2.2 层序遍历 三、节点个数以及高度 3.1 节点个数 3.2 叶子节点个数 3.3 第k层节点个数 3.4 二叉树的高度/深度 3.5 查找值为x的节点 四、二叉树的创建和销毁 4.1 构建二叉树 4.2 二叉树销毁 4.3 …

2023年7月14日,ArrayList底层

集合框架图: 集合和数组的区别 AarrayList ArrayList底层实现原理 ArrayList的底层实现是基于数组的动态扩容。 初始容量:当创建一个新的ArrayList对象时,它会分配一个初始容量为10的数组。这个初始容量可以根据需求进行调整。 //表示默认的…

在Python中优雅地用多进程:进程池 Pool、管道通信 Pipe、队列通信 Queue、共享内存 Manager Value

Python 自带的多进程库 multiprocessing 可实现多进程。我想用这些短例子示范如何优雅地用多线程。中文网络上,有些人只是翻译了旧版的 Python 官网的多进程文档。而我这篇文章会额外讲一讲下方加粗部分的内容。 创建进程 Process,fork 直接继承资源&am…

zigbee DL-20无线串口收发模块使用(双车通讯,电赛模块推荐)

前言 (1)通常有时候,我们可能会需要让两个MCU进行通讯。而zigbee是最适合两个MCU短距离通讯的模块。他使用极其简单,非常适合两款MCU之间的进行数据交互。 (2)在各类比赛中,经常出现需要两个MCU…

独立看门狗 IWDG

独立看门狗介绍 Q:什么是看门狗? A:可以理解为对于一只修勾的定时投喂,如果不给它吃东西就会狂叫,因此可以通过观察修勾的状态来判断喂它的人有没有正常工作。 在由单片机构成的微型计算机系统中,由于单…

【业务功能篇44】Mysql 海量数据查询优化,进行分区操作

业务场景:当前有个发料表,随着业务数据量增多,达到了几千万级别水平,查询的效率就越来越低了,针对当前的架构情况,我们进行了分区的设置,通过对时间字段,按年月,一个月作…

ios 启动页storyboard 使用记录

本文简单记录ios启动页storyboard 如何使用和注意事项。 xcode窗口简介 以xcode14为例,新建项目如下图,左边文件栏中的LaunchScreen.storyboard 为默认启动页布局。窗口中间部分是storyboard中的组件列表,右侧为预览,可以看到渲…

H3C-Cloud Lab-实验-DHCP实验

实验拓扑图: 实验需求: 1、按照图示为R1配置IP地址 2、配置R1为DHCP服务器,提供服务的地址池为192.168.1.0/24网段,网关为192.168.1.254,DNS服务器地址为202.103.24.68,202.103.0.117 3、192.168.1.10-1…

Camtasia Studio 2023 最新中文版,camtasiaStudio如何添加背景音乐

Camtasia2023的视频编辑工具可以帮助用户剪辑、裁剪、旋转、调整大小、添加特效、混合音频等。用户还可以使用Camtasia2023的字幕功能添加字幕和注释,以及使用其内置的特效和转场来提高视频的视觉效果。 Camtasia Studio 2023新功能介绍 的光标增强 由于光标在屏幕…

解决win10电脑无法访问局域网内其它共享文件问题

问题描述 今天需要上传文件到一个共享的局域网文件夹里,在我的电脑和浏览器访问//192.168.0.16//public都提升访问受限,开始以为是因为用户没授权,后来一般沟通后,发现其它电脑都能正常访问的,所以确定是自己电脑配置…

Caerulein,17650-98-5,雨蛙肽,以三氟醋酸盐形式提供的十肽分子

资料编辑|陕西新研博美生物科技有限公司小编MISSwu​ Caerulein |雨蛙素,雨蛙肽,蓝肽| CAS:17650-98-5 | 纯度:95% ------雨蛙素结构式---- ----试剂参数信息--- CAS号:17650-98-5 外观(Appearance&a…

java中使用POI生成Excel并导出

注:本文章中代码均为本地Demo版本,若后续代码更新将不会更新文章 需求说明及实现方式 根据从数据库查询出的数据,将其写入excel表并导出 我的想法是通过在实体属性上写自定义注解的方式去完成。因为我们在代码中可以通过反射的方式去获取实体…

js小写金额转大写 自动转换

// 小写转为大写convertCurrency(money) {var cnNums [零, 壹, 贰, 叁, 肆, 伍, 陆, 柒, 捌, 玖]var cnIntRadice [, 拾, 佰, 仟]var cnIntUnits [, 万, 亿, 兆]var cnDecUnits [角, 分, 毫, 厘]// var cnInteger 整var cnIntLast 元var maxNum 999999999999999.9999var…

vulnhub靶场red:1教程

靶场搭建 靶机下载地址:Red: 1 ~ VulnHub 难度:中等 信息收集 arp-scan -l 这里没截图忘记了,就只是发现主机 扫描端口 nmap --min-rate 1000 -p- 192.168.21.130 nmap -sT -sV -sC -O -p22,80 192.168.21.130 先看80端口 看到链接点一…

怎么又快又准的确定业务系统属于等保几级?

等保2.0政策已经落地严格执行了一段时间,但大家对于等保政策还有很多不清楚。这不不少人在问,怎么又快有准的确定业务系统属于等保几级? 怎么又快又准的确定业务系统属于等保几级? 【回答】:根据《信息安全等级保护管…

AtcoderABC255场

A - You should output ARC, though this is ABC.A - You should output ARC, though this is ABC. 题目大意 给定整数R和C以及一个2x2矩阵A,需要输出A R,C的值。 思路分析 简单的矩阵查找。根据给定的索引R和C,找到矩阵A中相应位置的元素&#xff0c…

实例014 OutLook界面

实例说明 程序主界面包括菜单栏、工具栏、状态栏和树状视图。OutLook界面美观、友好,是一个很实用的程序主界面,并且菜单栏和工具栏是可移动的。运行本例效果如图1.14所示。 图1.14 Out Look界面 技术要点 一般程序的菜单栏和工具栏是不可移动的&…

【Ajax】笔记-服务端响应JSON数据

服务端响应JSON数据 构建测试案例 键盘按键触发请求服务端&#xff1a; 键盘按下触发事件 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width,…

项目中前端如何实现无感刷新 token!

场景&#xff1a;线上平台有时会出现用户正在使用的时候&#xff0c;突然要用户去进行登录&#xff0c;这样会造成很不好的用户体验。 1.请求采用的是axios 2.平台的采用的 JWT(JSON Web Tokens) 进行用户登录鉴权。 原因&#xff1a; 1.突然跳转到登录页面&#xff0c;是…