torch_geometric.explain
Warning
This module is in active development and may not be stable. Access requires installing PyG from master.
Philosophy
This module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper for more details).
We represent explanations using the torch_geometric.explain.Explanation
class, which is a Data
object containing masks for the nodes, edges, features and any attributes of the data.
The torch_geometric.explain.Explainer
class is designed to handle all explainability parameters (see the torch_geometric.explain.config.ExplainerConfig
class for more details):
which algorithm from the
torch_geometric.explain.algorithm
module to use (e.g.,GNNExplainer
)the type of explanation to compute (e.g.,
explanation_type="phenomenon"
orexplanation_type="model"
)the different type of masks for node and edges (e.g.,
mask="object"
ormask="attributes"
)any postprocessing of the masks (e.g.,
threshold_type="topk"
orthreshold_type="hard"
)
This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.
Explainer
- class Explainer(model: Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None)[source]
Bases:
object
An explainer class for instance-level explanations of Graph Neural Networks.
- Parameters:
model (torch.nn.Module) – The model to explain.
algorithm (ExplainerAlgorithm) – The explanation algorithm.
explanation_type (ExplanationType or str) –
The type of explanation to compute. The possible values are:
"model"
: Explains the model prediction."phenomenon"
: Explains the phenomenon that the model is trying to predict.
In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (
"model"
) or the target output ("phenomenon"
).model_config (ModelConfig) – The model configuration. See
ModelConfig
for available options. (default:None
)node_mask_type (MaskType or str, optional) –
The type of mask to apply on nodes. The possible values are (default:
None
):None
: Will not apply any mask on nodes."object"
: Will mask each node."common_attributes"
: Will mask each feature."attributes"
: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as
node_mask_type
. (default:None
)threshold_config (ThresholdConfig, optional) – The threshold configuration. See
ThresholdConfig
for available options. (default:None
)
- get_prediction(*args, **kwargs) Tensor [source]
Returns the prediction of the model on the input graph.
If the model mode is
"regression"
, the prediction is returned as a scalar value. If the model mode is"multiclass_classification"
or"binary_classification"
, the prediction is returned as the predicted class label.- Parameters:
*args – Arguments passed to the model.
**kwargs (optional) – Additional keyword arguments passed to the model.
- Return type:
- get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs) Tensor [source]
Returns the prediction of the model on the input graph with node and edge masks applied.
- Return type:
- __call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation] [source]
Computes the explanation of the GNN for the given inputs and target.
Note
If you get an error message like “Trying to backward through the graph a second time”, make sure that the target you provided was computed with
torch.no_grad()
.- Parameters:
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model. If the explanation type is
"phenomenon"
, the target has to be provided. If the explanation type is"model"
, the target should be set toNone
and will get automatically inferred. For classification tasks, the target needs to contain the class labels. (default:None
)index (Union[int, Tensor], optional) – The indices in the first-dimension of the model output to explain. Can be a single index or a tensor of indices. If set to
None
, all model outputs will be explained. (default:None
)**kwargs – additional arguments to pass to the GNN.
- Return type:
- get_target(prediction: Tensor) Tensor [source]
Returns the target of the model from a given prediction.
If the model mode is of type
"regression"
, the prediction is returned as it is. If the model mode is of type"multiclass_classification"
or"binary_classification"
, the prediction is returned as the predicted class label.- Return type:
- class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]
Configuration class to store and validate high level explanation parameters.
- Parameters:
explanation_type (ExplanationType or str) –
The type of explanation to compute. The possible values are:
"model"
: Explains the model prediction."phenomenon"
: Explains the phenomenon that the model is trying to predict.
In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (
"model"
) or the target output ("phenomenon"
).node_mask_type (MaskType or str, optional) –
The type of mask to apply on nodes. The possible values are (default:
None
):None
: Will not apply any mask on nodes."object"
: Will mask each node."common_attributes"
: Will mask each feature."attributes"
: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as
node_mask_type
. (default:None
)
- class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]
Configuration class to store model parameters.
- Parameters:
mode (ModelMode or str) –
The mode of the model. The possible values are:
"binary_classification"
: A binary classification model."multiclass_classification"
: A multiclass classification model."regression"
: A regression model.
task_level (ModelTaskLevel or str) –
The task-level of the model. The possible values are:
"node"
: A node-level prediction model."edge"
: An edge-level prediction model."graph"
: A graph-level prediction model.
return_type (ModelReturnType or str, optional) –
The return type of the model. The possible values are (default:
None
):"raw"
: The model returns raw values."probs"
: The model returns probabilities."log_probs"
: The model returns log-probabilities.
- class ThresholdConfig(threshold_type: Union[ThresholdType, str], value: Union[float, int])[source]
Configuration class to store and validate threshold parameters.
- Parameters:
threshold_type (ThresholdType or str) –
The type of threshold to apply. The possible values are:
None
: No threshold is applied."hard"
: A hard threshold is applied to each mask. The elements of the mask with a value below thevalue
are set to0
, the others are set to1
."topk"
: A soft threshold is applied to each mask. The top obj:value elements of each mask are kept, the others are set to0
."topk_hard"
: Same as"topk"
but values are set to1
for all elements which are kept.
value (int or float, optional) – The value to use when thresholding. (default:
None
)
Explanations
- class Explanation(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Union[Tensor, int, float]] = None, pos: Optional[Tensor] = None, time: Optional[Tensor] = None, **kwargs)[source]
Bases:
Data
,ExplanationMixin
Holds all the obtained explanations of a homogeneous graph.
The explanation object is a
Data
object and can hold node attributions and edge attributions. It can also hold the original graph if needed.- Parameters:
- validate(raise_on_error: bool = True) bool [source]
Validates the correctness of the
Explanation
object.- Return type:
- get_explanation_subgraph() Explanation [source]
Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.
- Return type:
- get_complement_subgraph() Explanation [source]
Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.
- Return type:
- visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)[source]
Creates a bar plot of the node feature importances by summing up the node mask across all nodes.
- Parameters:
- visualize_graph(path: Optional[str] = None, backend: Optional[str] = None, node_labels: Optional[List[str]] = None) None [source]
Visualizes the explanation graph with edge opacity corresponding to edge importance.
- Parameters:
path (str, optional) – The path to where the plot is saved. If set to
None
, will visualize the plot on-the-fly. (default:None
)backend (str, optional) – The graph drawing backend to use for visualization (
"graphviz"
,"networkx"
). If set toNone
, will use the most appropriate visualization backend based on available system packages. (default:None
)node_labels (list[str], optional) – The labels/IDs of nodes. (default:
None
)
- Return type:
- class HeteroExplanation(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]
Bases:
HeteroData
,ExplanationMixin
Holds all the obtained explanations of a heterogeneous graph.
The explanation object is a
HeteroData
object and can hold node attributions and edge attributions. It can also hold the original graph if needed.- validate(raise_on_error: bool = True) bool [source]
Validates the correctness of the
Explanation
object.- Return type:
- get_explanation_subgraph() HeteroExplanation [source]
Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.
- Return type:
- get_complement_subgraph() HeteroExplanation [source]
Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.
- Return type:
- visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None)[source]
Creates a bar plot of the node feature importances by summing up node masks across all nodes for each node type.
- Parameters:
path (str, optional) – The path to where the plot is saved. If set to
None
, will visualize the plot on-the-fly. (default:None
)feat_labels (Dict[NodeType, List[str]], optional) – The labels of features for each node type. (default
None
)top_k (int, optional) – Top k features to plot. If
None
plots all features. (default:None
)
Explainer Algorithms
An abstract base class for implementing explainer algorithms. |
|
A dummy explainer that returns random explanations (useful for testing purposes). |
|
The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. |
|
A Captum-based explainer for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. |
|
The PGExplainer model from the "Parameterized Explainer for Graph Neural Network" paper. |
|
An explainer that uses the attention coefficients produced by an attention-based GNN (e.g., |
|
The GraphMask-Explainer model from the "Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking" paper for identifying layer-wise compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. |
Explanation Metrics
The quality of an explanation can be judged by a variety of different methods. PyG supports the following metrics out-of-the-box:
Compares and evaluates an explanation mask with the ground-truth explanation mask. |
|
Evaluates the fidelity of an |
|
Returns the componentwise characterization score as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper. |
|
Returns the AUC for the fidelity curve as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper. |
|
Evaluates how faithful an |