I mentioned it before GNNExplainer But a lot of formulas in the paper are difficult to touch, so I went to GitHub to find some implementations of gnnexpianer
Gnnexpianer will interpret the diagram from two perspectives:
- Edge: an edge mask will be generated, which represents the probability of each edge in the graph, and the value is a floating-point number between 0 and 1. The edge mask can also be used as a weight, which can be explained by going to the subgraph connected by the edge of the topk.
- Node feature: node feature (NF) is the node vector. For example, if a node represents 128 features in 128 dimensions, it will generate an NF mask to represent the weight of each feature at the same time. This can be omitted.
Here, a simplified version of explain is pasted on the basis of DIG py.
import torch from torch import Tensor import torch.nn as nn from torch_geometric.nn import MessagePassing from math import sqrt from configuration import data_args from torch_geometric.data import Batch, Data from torch.nn.functional import cross_entropy class ExplainerBase(nn.Module): def __init__(self, model: nn.Module, epochs=0, lr=0, explain_graph=False, molecule=False): super().__init__() self.model = model self.lr = lr self.epochs = epochs self.explain_graph = explain_graph self.molecule = molecule self.mp_layers = [module for module in self.model.modules() if isinstance(module, MessagePassing)] self.num_layers = len(self.mp_layers) self.ori_pred = None self.ex_labels = None self.edge_mask = None self.hard_edge_mask = None self.num_edges = None self.num_nodes = None self.device = None def __set_masks__(self, x, edge_index, init="normal"): (N, F), E = x.size(), edge_index.size(1) std = 0.1 self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1) std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std) # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True)) for module in self.model.modules(): if isinstance(module, MessagePassing): module.__explain__ = True module.__edge_mask__ = self.edge_mask def __clear_masks__(self): for module in self.model.modules(): if isinstance(module, MessagePassing): module.__explain__ = False module.__edge_mask__ = None self.node_feat_masks = None self.edge_mask = None @property def __num_hops__(self): if self.explain_graph: return -1 else: return self.num_layers def __flow__(self): for module in self.model.modules(): if isinstance(module, MessagePassing): return module.flow return 'source_to_target' def forward(self, x: Tensor, edge_index: Tensor, **kwargs ): self.num_edges = edge_index.shape[1] self.num_nodes = x.shape[0] self.device = x.device def eval_related_pred(self, x, edge_index, edge_masks, **kwargs): node_idx = kwargs.get('node_idx') node_idx = 0 if node_idx is None else node_idx # graph level: 0, node level: node_idx related_preds = [] for ex_label, edge_mask in enumerate(edge_masks): self.edge_mask.data = float('inf') * torch.ones(edge_mask.size(), device=data_args.device) ori_pred = self.model(x=x, edge_index=edge_index, **kwargs) self.edge_mask.data = edge_mask masked_pred = self.model(x=x, edge_index=edge_index, **kwargs) # mask out important elements for fidelity calculation self.edge_mask.data = - edge_mask # keep Parameter's id maskout_pred = self.model(x=x, edge_index=edge_index, **kwargs) # zero_mask self.edge_mask.data = - float('inf') * torch.ones(edge_mask.size(), device=data_args.device) zero_mask_pred = self.model(x=x, edge_index=edge_index, **kwargs) related_preds.append({'zero': zero_mask_pred[node_idx], 'masked': masked_pred[node_idx], 'maskout': maskout_pred[node_idx], 'origin': ori_pred[node_idx]}) return related_preds EPS = 1e-15 class GNNExplainer(ExplainerBase): r"""The GNN-Explainer model from the `"GNNExplainer: Generating Explanations for Graph Neural Networks" <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph structures and small subsets node features that play a crucial role in a GNN's node-predictions. .. note:: For an example of using GNN-Explainer, see `examples/gnn_explainer.py <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ gnn_explainer.py>`_. Args: model (torch.nn.Module): The GNN module to explain. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) """ coeffs = { 'edge_size': 0.005, 'node_feat_size': 1.0, 'edge_ent': 1.0, 'node_feat_ent': 0.1, } def __init__(self, model, epochs=50, lr=0.001, explain_graph=True, molecule=False): super(GNNExplainer, self).__init__(model, epochs, lr, explain_graph, molecule) def __loss__(self, raw_preds, x_label): loss = cross_entropy(raw_preds, x_label) m = self.edge_mask.sigmoid() loss = loss + self.coeffs['edge_size'] * m.sum() ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['edge_ent'] * ent.mean() if self.mask_features: m = self.node_feat_mask.sigmoid() loss = loss + self.coeffs['node_feat_size'] * m.sum() ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None: # initialize a mask patience = 10 self.to(x.device) self.mask_features = mask_features # train to get the mask optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask], lr=self.lr) best_loss = 4.0 count = 0 for epoch in range(1, self.epochs + 1): if mask_features: h = x * self.node_feat_mask.view(1, -1).sigmoid() else: h = x raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)])) loss = self.__loss__(raw_preds, ex_label) # if epoch % 10 == 0: # print(f'#D#Loss:{loss.item()}') is_best = (loss < best_loss) if not is_best: count += 1 else: count = 0 best_loss = loss if count >= patience: break optimizer.zero_grad() loss.backward() optimizer.step() return self.edge_mask.data def forward(self, x, edge_index, mask_features=False, positive=True, **kwargs): r"""Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for node :attr:`node_idx`. Args: data (Batch): batch from dataloader edge_index (LongTensor): The edge indices. pos_neg (Literal['pos', 'neg']) : get positive or negative mask **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`Tensor`, :class:`Tensor`) """ self.model.eval() # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) # Only operate on a k-hop subgraph around `node_idx`. # Calculate mask ex_label = torch.tensor([1]).to(data_args.device) self.__clear_masks__() self.__set_masks__(x, edge_index) edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label) # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label)) # with torch.no_grad(): # related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs) self.__clear_masks__() sorted_results = edge_mask.sort(descending=True) return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu() def __repr__(self): return f'{self.__class__.__name__}()'
GNNExplainer.forward
The entry function is gnnexpianer Forward, here I only explain the samples classified as 1.
def forward(self, x, edge_index, mask_features=False, positive=True, **kwargs): r"""Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for node :attr:`node_idx`. Args: data (Batch): batch from dataloader edge_index (LongTensor): The edge indices. pos_neg (Literal['pos', 'neg']) : get positive or negative mask **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`Tensor`, :class:`Tensor`) """ self.model.eval() # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) # Only operate on a k-hop subgraph around `node_idx`. # Calculate mask ex_label = torch.tensor([1]).to(data_args.device) self.__clear_masks__() self.__set_masks__(x, edge_index) edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label) # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label)) # with torch.no_grad(): # related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs) self.__clear_masks__() sorted_results = edge_mask.sort(descending=True) return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()
-
The function first calls__ clear_masks__ () before removing the edge mask (though not), then call __set_masks__ to set up an initial randomly generated edge mask.
-
The function returns edge_mask and edge_mask sorted edge. Calculating edge mask called gnn_explainer_alg.
ExplainerBase.set_mask
Just understand this code, mainly to set the initially randomly generated edge mask and NF mask. However, in order to simplify the code, I deleted the setting of NF mask. The full version can refer to the code of DIG.
def __set_masks__(self, x, edge_index, init="normal"): (N, F), E = x.size(), edge_index.size(1) std = 0.1 self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1) std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std) # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True)) for module in self.model.modules(): if isinstance(module, MessagePassing): module.__explain__ = True module.__edge_mask__ = self.edge_mask
GNNExplainer.gnn_explainer_alg
It mainly calculates an optimal edge mask, and NF mask is omitted first.
def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None: # initialize a mask patience = 10 self.to(x.device) self.mask_features = mask_features # train to get the mask optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask], lr=self.lr) best_loss = 4.0 count = 0 for epoch in range(1, self.epochs + 1): if mask_features: h = x * self.node_feat_mask.view(1, -1).sigmoid() else: h = x raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)])) loss = self.__loss__(raw_preds, ex_label) # if epoch % 10 == 0: # print(f'#D#Loss:{loss.item()}') is_best = (loss < best_loss) if not is_best: count += 1 else: count = 0 best_loss = loss if count >= patience: break optimizer.zero_grad() loss.backward() optimizer.step() return self.edge_mask.data
Here, the edge mask (and NF mask) is used as a trainable parameter and trained by neural network. The loss function is as follows
def __loss__(self, raw_preds, x_label): loss = cross_entropy(raw_preds, x_label) m = self.edge_mask.sigmoid() loss = loss + self.coeffs['edge_size'] * m.sum() ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['edge_ent'] * ent.mean() if self.mask_features: m = self.node_feat_mask.sigmoid() loss = loss + self.coeffs['node_feat_size'] * m.sum() ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['node_feat_ent'] * ent.mean() return loss
This raw in loss_ Preds is the calculation result after the mask is set, because clear is not called in the whole process_ Mask to clear the label.
loss consists of three parts:
- The cross entropy loss between the model output after adding edge mask and the label (whether it would be better to replace the output before adding without guidance). (this part feels that it has a great relationship with the accuracy of the model itself, and the accuracy of many GNN classification code vulnerabilities is worrying).
- edge_ The size of the mask itself (size, sum)
- edge_mask dispersion (the higher the value is, the better the value is)
The whole loss can be expressed by the following expression (without NF mask)
l o s s = C r o s s E n t r o p y ( f ( d a t a , m o d e l , e d g e _ m a s k ) , l a b e l ) + S i z e ( e d g e _ m a s k ) + D i s c r e t e ( e d g e _ m a s k ) loss = CrossEntropy(f(data, model, edge\_mask), label) + Size(edge\_mask) + Discrete(edge\_mask) loss=CrossEntropy(f(data,model,edge_mask),label)+Size(edge_mask)+Discrete(edge_mask)
Finally, you can get an optimal (loss minimum) edge mask. Then there is one question that the user code can't answer, that is f ( d a t a , m o d e l , e d g e _ m a s k ) f(data,model,edge\_mask) How the edge mask works in f(data,model,edge_mask). That is, how adding an edge mask affects the operation process of the model. To answer this question, we must explore how torch geometry supports GNN express.
torch_geometric.nn GNNExplainer
The core class here is MessagePassing. Its propagate function does not support gnnexpianer version. The code is as follows:
coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) out = self.message(**msg_kwargs) aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) out = self.aggregate(out, **aggr_kwargs) update_kwargs = self.inspector.distribute('update', coll_dict) return self.update(out, **update_kwargs)
The message function returns x directly in the example_ j. Similarly, update directly returns the input tensor. Aggregate is to aggregate the information of surrounding nodes.
- out = self.message(**msg_kwargs) the out dimension returned here is [edge_num, node_feature_dim]. It should be the vector representation of the target node of each edge.
- out = self.aggregate(out, **aggr_kwargs) the out dimension returned here is [node_num, node_feature_dim]. It should be the vector representation of each node after aggregation.
The code of the propagate function after adding gnnexpianer support is as follows:
coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) out = self.message(**msg_kwargs) # For `GNNExplainer`, we require a separate message and aggregate # procedure since this allows us to inject the `edge_mask` into the # message passing computation scheme. if self.__explain__: edge_mask = self.__edge_mask__.sigmoid() # Some ops add self-loops to `edge_index`. We need to do the # same for `edge_mask` (but do not train those). if out.size(self.node_dim) != edge_mask.size(0): loop = edge_mask.new_ones(size[0]) edge_mask = torch.cat([edge_mask, loop], dim=0) assert out.size(self.node_dim) == edge_mask.size(0) out = out * edge_mask.view([-1] + [1] * (out.dim() - 1)) aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) out = self.aggregate(out, **aggr_kwargs) update_kwargs = self.inspector.distribute('update', coll_dict) return self.update(out, **update_kwargs)
You can see that a piece of code is inserted between message and aggregate, ignoring the judgment statement and shape modification. The content is out = out * edge_mask. That is, for nodes x i x_i xi, while aggregating with the surrounding nodes, first multiply the vector of the surrounding nodes by the edge mask. (the understanding here may be wrong. You are welcome to correct it.)