PyG可解释学习模块torch_geometric.explain
- Philoshopy
- Explainer
- Explanations
- Explainer Algorithm
- Explanation Metrics
- 参考资料
torch_geometric.explain是PyTorch Geometric库中的一个模块,用于解释和可视化图神经网络(GNN)模型的预测结果。它提供了一些方法来解释模型的预测结果、边权重和节点重要性。
主要内容有:Philoshopy(哲学思想)、Explainer(解释器)、Explanations(解释)、Explainer Algorithm(解释器算法)、Explanation Metrics(解释度量)
Philoshopy
该模块提供了一组工具来解释 PyG 模型的预测或解释数据集的基本现象,详细信息可以参考“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”
。
我们使用torch_geometric.explain.Explanation
类来表示解释,该类是一个Data
对象,包含数据的节点、边、特征和任何属性的掩码。
torch_geometric.expllain.Explainer
类设计用于处理所有可解释性参数(有关更多详细信息,请参阅torch_geometric.explainn.config.ExplainerConfig
类):
- 使用的算法来自
torch_geometric.expllain.algorithm
模块中,例如GNNExplainer
- 要计算的解释类型(例如
explanation_type="phenomenon"
或者explanation_type="model"
) - 节点和边的不同类型掩码(例如
mask="object"
或者mask="attributes"
) - 掩码的任何后处理(例如
threshold_type="topk"
或者threshold_type="hard"
)
该类允许用户轻松比较不同的可解释性方法,并在不同类型的掩码之间轻松切换,同时确保高层的代码框架保持不变。
Explainer
基础类:object
图神经网络实例级解释的一个解释器类。
参数:
- model(
torch.nn.Module
)——要解释的模型 - algorithm(解释器算法)——解释算法
- explanation_type(解释类型或str)——要计算的解释类型。可能的值为:
- “model”:解释模型预测
- “phenomenon”:解释模型试图预测的现象。
在实践中,这意味着解释算法将计算其相对于模型输出(“model”)或目标输出(“phenomenon”)的损失。
- model_config——模型配置,参见
ModelConfig
,默认为None
- node_mask_type——要应用于节点的掩码类型。可能的值为:
- “None”:不会在节点上应用任何掩码。
- “object”:将屏蔽每个节点。
- “common_attributes”:将掩盖每个特征。
- “attributes”:将屏蔽所有节点上的每个特征。
- edge_mask_type——要应用于边的掩码类型。具有的可能值例如
node_mask_type
。默认为None - threshold_config——阈值设置,可选数值参见
ThresholdConfig
,默认为None
方法
get_prediction(*args, **kwargs)→ Tensor
:返回模型对输入图的预测。
如果模型模式为“regression”,则预测将作为标量值返回。如果模型模式为“multiclass_classification”或“binary_classifications”,则预测将作为预测类标签返回。get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Tensor
:返回应用了节点和边掩码的输入图上的模型预测。__call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算给定输入和目标的GNN的解释。
如果收到一条错误消息,如“Trying to backward through the graph a second time”,请确保提供的目标是用torch.no.grad()
计算的。- x——通志图或异质图的输入节点特征。
- edge_index——同质或异质图的输入边索引。
- target——模型的目标。如果解释类型是“phenomenon”,则必须提供解释对象。如果解释类型是“ model”,那么目标应该设置为 Nothing,并且会自动推断出来。(默认值: None)
- index——对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- **kwargs——要传递给GNN的其他参数。
get_target(prediction: Tensor)→ Tensor
:从给定的预测中返回模型的目标。
如果模型模式为“regression”类型,则按原样返回预测;如果模型模式类型为“multiclass_classification”或“binary_classifications”,则按预测类标签返回预测。
用于存储和验证高级解释参数的配置类。
参数:
- explanation_type——要计算的解释类型。可能的值为:
- “model”——解释模型预测。
- “pheonmenon”——解释模型试图预测的现象。
在实践中,这意味着解释算法将计算它们相对于模型输出(“model”)或目标输出(“pheonmenon”)的损失。
- node_mask_type——要应用于节点的掩码类型。可能的值为(默认值:None):
- “None”:不会在节点上应用任何掩码。
- “object”:将屏蔽每个节点。
- “common_attributes”:将掩盖每个特征。
- “attributes”:将屏蔽所有节点上的每个特征。
- edge_mask_type——要应用于边的掩码类型。具有的可能值例如
node_mask_type
。默认为None
用于存储模型参数的配置类。
参数:
- model——模型的模式。可能的值为:
- “binary_classification”:一个二分类模型。
- “multiclass_classification”:一种多类分类模型。
- “regression”:一个回归模型
- task_level——模型的任务级别。可能的值为:
- “node”:一个node-level预测模型
- “edge”:一个edge-level预测模型
- “graph”:一个graph-level预测模型
- return_type——模型的返回类型。可能的值为(默认值:None):
- “raw”:模型返回原始值。
- “probs”:模型返回概率值
- “log_probs”:模型返回对数概率
用于存储和验证阈值参数的配置类。
参数:
- threshold_type——要应用的阈值的类型。可能的值为:
- “None”:没有阈值被应用
- “hard”:将hard阈值应用于每个掩码。掩码中值低于该值的元素设置为0,其他元素设置为1。
- “topk”:soft阈值被应用于每个掩码。保留每个掩码的top obj:value元素,其他元素设置为0。
- “topk_hard”:“topk”相同,但保留的所有元素的值都设置为1。
- value——设置阈值时要使用的值。(默认值:None)
Explanations
基础类:Data
、ExplanationMixin
持有同质图的所有已得到的解释。解释对象是Data对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
参数:
-
node_mask——形状为
[num_nodes, 1], [1, num_features]
或[num_nodes, num_features]
的node-level掩码,默认为None -
edge_mask——形状为
[num_edges]
的edge_level掩码,默认为None -
kwargs——其他属性参数
方法: -
validate(raise_on_error: bool = True)→ bool
:验证解释对象的正确性。 -
get_explanation_subgraph()→ Explanation
:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。 -
get_complement_subgraph()→ Explanation
:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。 -
visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)
:通过对所有节点的节点掩码求和,创建节点要素重要性的条形图。
参数:- path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
- feat_labels: 特征的标签。(默认为“None”)
- top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)
-
visualize_graph(path: Optional[str] = None, backend: Optional[str] = None)
: 使具有与边重要性相对应的边不透明度的解释图可视化。
参数:- path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
- backend: 用于可视化的图形绘制后端(“
graphviz
”、“networkx
”)。如果设置为“None”,将根据可用的系统包使用最合适的可视化后端。(默认值:None)
基础类:HeteroData
、ExplanationMixin
包含所有已获得的对异构图的解释。解释对象是HeteroData对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
方法:
validate(raise_on_error: bool = True)→ bool
:验证解释对象的正确性。get_explanation_subgraph()→ HeteroExplanation
:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。get_complement_subgraph()→ HeteroExplanation
:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None)
:通过对每个节点类型的所有节点的节点掩码求和,创建节点特征重要性的条形图。
参数:- path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
- feat_labels: 特征的标签。(默认为“None”)
- top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)
Explainer Algorithm
1) ExplainerAlgorithm
:用于实现解释器算法的抽象基类。
方法
- abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算解释。
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- abstract supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
- property explainer_config: ExplainerConfig
:返回已连接的解释器配置。
- property model_config: ModelConfig
:返回已连接的模型配置
- connect(explainer_config: ExplainerConfig, model_config: ModelConfig)
:将解释器和模型配置连接到解释器算法。
2) DummyExplainer
:返回随机解释的伪解释程序(用于测试目的)。
基础类:ExplainerAlgorithm
方法
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:解释计算
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
-
GNNExplainer
:来自 “GNNExplainer: Generating Explanations for Graph Neural Networks” 论文中的GNN-Explainer模型用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。
基础类:ExplainerAlgorithm
有关使用GNNEexplainer的示例,请参见examples/explaine/gnn_explainer.py、examples/explain/gnn_eexplainer_ba_shapes.py和examples/explain/gn_explainer_link_pred.py。
参数:- epochs:要训练的epochs数。默认为100
- lr:学习率,默认为0.01
- kwargs:用于覆盖
coeffs
中默认设置的附加超参数。
方法:
forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
3)CaptumExplainer
:一种基于Captum的解释器,用于识别在GNN的预测中起关键作用的紧凑子图结构和节点特征。
基础类:ExplainerAlgorithm
这个解释器算法使用Captum来计算属性。目前,支持以下归因方法:- captum.attr.IntegratedGradients
- captum.attr.Saliency
- captum.attr.InputXGradient
- captum.attr.Deconvolution
- captum.attr.ShapleyValueSampling
- captum.attr.GuidedBackprop
参数: - attribution_method:要使用的Captum归因方法。可以是字符串或captum.attr方法。
- kwargs:Captum归因方法的其他参数。
方法: forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算解释supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
4)PGExplainer
: "Parameterized Explainer for Graph Neural Network"论文中的PGExplainer模型。
基础类:ExplainerAlgorithm
在内部,它利用神经网络来识别在GNN的预测中起关键作用的子图结构。重要的是,在生成解释之前,PGExplainer需要通过train()
进行训练:
explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=30, lr=0.003),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=ModelConfig(...),
)
# 针对各种节点级别或图级别的预测进行训练:
for epoch in range(30):
for index in [...]: # Indices to train against.
loss = explainer.algorithm.train(epoch, model, x, edge_index,
target=target, index=index)
# 获得最终解释:
explanation = explainer(x, edge_index, target=target, index=0)
参数:
- epochs:要训练的epochs数。
- lr:学习率,默认为0.003
- kwargs:用于覆盖coeffs
中默认设置的附加超参数。
方法:
- reset_parameters()
:重置模型中所有科学系参数
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)
:训练基础的解释者模型。需要在能够做出预测之前被调用。
参数:
- epoch:训练阶段的当前阶段。
- model:要被解释的模型
- x:同质图的输入节点特征。
- edge_index:同质图的输入边索引。
- target:模型的目标
- index:对模型输出的索引进行解释。需要是一个单独的索引。(默认值:None)
- kwargs:传递给model的其他关键字参数。
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
5) AttentionExplainer
:使用基于注意力的GNN(例如,GATConv、GATv2Conv或TransformerConv)产生的注意力系数作为边解释的解释器。各层和头部的注意力得分将根据reduce argument进行汇总。
基础类:ExplainerAlgorithm
参数:
- reduce:降低各层和头部注意力得分的方法。(默认值:“max”)
方法:
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释
- - supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
Explanation Metrics
解释的质量可以通过各种不同的方法来判断。PyG支持以下开箱即用的指标:
groundtruth_metrics
:将解释掩码与ground-truth解释掩码进行比较和评估。fidelity
:评估一个Explainer给出的Explanation的真实度 ,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文characterization_score
:返回组件式特征化分数(the componentwise characterization score),参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文fidelity_curve_auc
:返回真实度曲线的AUC,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文unfaithfulness
:评估一个Explanation对一个不足的GNN预测因子的真实度,参见 "Evaluating Explainability for Graph Neural Networks"论文
参考资料
- 可解释性研究(四)-GNNExplainer的内部实现
- GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks
- torch_geometric.explain官方文档
- GNNExplainer: Generating Explanations for Graph Neural Networks
- Captum
- PGExplainer
- Evaluating Explainability for Graph Neural Networks