Interpretability research - internal implementation of gnnexpianer

Posted by ramli on Tue, 14 Dec 2021 22:48:50 +0100

I mentioned it before GNNExplainer But a lot of formulas in the paper are difficult to touch, so I went to GitHub to find some implementations of gnnexpianer

Gnnexpianer will interpret the diagram from two perspectives:

  • Edge: an edge mask will be generated, which represents the probability of each edge in the graph, and the value is a floating-point number between 0 and 1. The edge mask can also be used as a weight, which can be explained by going to the subgraph connected by the edge of the topk.
  • Node feature: node feature (NF) is the node vector. For example, if a node represents 128 features in 128 dimensions, it will generate an NF mask to represent the weight of each feature at the same time. This can be omitted.

Here, a simplified version of explain is pasted on the basis of DIG py.

import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from math import sqrt

from configuration import data_args

from torch_geometric.data import Batch, Data
from torch.nn.functional import cross_entropy

class ExplainerBase(nn.Module):
    def __init__(self, model: nn.Module, epochs=0, lr=0, explain_graph=False, molecule=False):
        super().__init__()
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.explain_graph = explain_graph
        self.molecule = molecule
        self.mp_layers = [module for module in self.model.modules() if isinstance(module, MessagePassing)]
        self.num_layers = len(self.mp_layers)

        self.ori_pred = None
        self.ex_labels = None
        self.edge_mask = None
        self.hard_edge_mask = None

        self.num_edges = None
        self.num_nodes = None
        self.device = None


    def __set_masks__(self, x, edge_index, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
        # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

    def __clear_masks__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = False
                module.__edge_mask__ = None
        self.node_feat_masks = None
        self.edge_mask = None

    @property
    def __num_hops__(self):
        if self.explain_graph:
            return -1
        else:
            return self.num_layers

    def __flow__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                return module.flow
        return 'source_to_target'



    def forward(self,
                x: Tensor,
                edge_index: Tensor,
                **kwargs
                ):
        self.num_edges = edge_index.shape[1]
        self.num_nodes = x.shape[0]
        self.device = x.device


    def eval_related_pred(self, x, edge_index, edge_masks, **kwargs):

        node_idx = kwargs.get('node_idx')
        node_idx = 0 if node_idx is None else node_idx  # graph level: 0, node level: node_idx
        related_preds = []

        for ex_label, edge_mask in enumerate(edge_masks):

            self.edge_mask.data = float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
            ori_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            self.edge_mask.data = edge_mask
            masked_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            # mask out important elements for fidelity calculation
            self.edge_mask.data = - edge_mask  # keep Parameter's id
            maskout_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            # zero_mask
            self.edge_mask.data = - float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
            zero_mask_pred = self.model(x=x, edge_index=edge_index, **kwargs)

            related_preds.append({'zero': zero_mask_pred[node_idx],
                                  'masked': masked_pred[node_idx],
                                  'maskout': maskout_pred[node_idx],
                                  'origin': ori_pred[node_idx]})


        return related_preds



EPS = 1e-15


class GNNExplainer(ExplainerBase):
    r"""The GNN-Explainer model from the `"GNNExplainer: Generating
    Explanations for Graph Neural Networks"
    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
    structures and small subsets node features that play a crucial role in a
    GNN's node-predictions.

    .. note::

        For an example of using GNN-Explainer, see `examples/gnn_explainer.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        gnn_explainer.py>`_.

    Args:
        model (torch.nn.Module): The GNN module to explain.
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        log (bool, optional): If set to :obj:`False`, will not log any learning
            progress. (default: :obj:`True`)
    """

    coeffs = {
        'edge_size': 0.005,
        'node_feat_size': 1.0,
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
    }

    def __init__(self, model, epochs=50, lr=0.001, explain_graph=True, molecule=False):
        super(GNNExplainer, self).__init__(model, epochs, lr, explain_graph, molecule)

    def __loss__(self, raw_preds, x_label):
        loss = cross_entropy(raw_preds, x_label)
        m = self.edge_mask.sigmoid()
        loss = loss + self.coeffs['edge_size'] * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.mask_features:
            m = self.node_feat_mask.sigmoid()
            loss = loss + self.coeffs['node_feat_size'] * m.sum()
            ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss

    def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
        # initialize a mask
        patience = 10
        self.to(x.device)
        self.mask_features = mask_features

        # train to get the mask
        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        best_loss = 4.0
        count = 0
        for epoch in range(1, self.epochs + 1):
            if mask_features:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()
            else:
                h = x
            raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
            loss = self.__loss__(raw_preds, ex_label)
            # if epoch % 10 == 0:
            #     print(f'#D#Loss:{loss.item()}')

            is_best = (loss < best_loss)

            if not is_best:
                count += 1
            else:
                count = 0
                best_loss = loss

            if count >= patience:
                break

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return self.edge_mask.data

    def forward(self, x, edge_index, mask_features=False,
                positive=True, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for node
        :attr:`node_idx`.

        Args:
            data (Batch): batch from dataloader
            edge_index (LongTensor): The edge indices.
            pos_neg (Literal['pos', 'neg']) : get positive or negative mask
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """
        self.model.eval()
        # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
        # Only operate on a k-hop subgraph around `node_idx`.
        # Calculate mask

        ex_label = torch.tensor([1]).to(data_args.device)
        self.__clear_masks__()
        self.__set_masks__(x, edge_index)
        edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
        # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))


        # with torch.no_grad():
        #     related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
        self.__clear_masks__()
        sorted_results = edge_mask.sort(descending=True)
        return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()


    def __repr__(self):
        return f'{self.__class__.__name__}()'

GNNExplainer.forward

The entry function is gnnexpianer Forward, here I only explain the samples classified as 1.

def forward(self, x, edge_index, mask_features=False,
                positive=True, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for node
        :attr:`node_idx`.

        Args:
            data (Batch): batch from dataloader
            edge_index (LongTensor): The edge indices.
            pos_neg (Literal['pos', 'neg']) : get positive or negative mask
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """
        self.model.eval()
        # self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
        # Only operate on a k-hop subgraph around `node_idx`.
        # Calculate mask

        ex_label = torch.tensor([1]).to(data_args.device)
        self.__clear_masks__()
        self.__set_masks__(x, edge_index)
        edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
        # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))


        # with torch.no_grad():
        #     related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
        self.__clear_masks__()
        sorted_results = edge_mask.sort(descending=True)
        return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()
  • The function first calls__ clear_masks__ () before removing the edge mask (though not), then call __set_masks__ to set up an initial randomly generated edge mask.

  • The function returns edge_mask and edge_mask sorted edge. Calculating edge mask called gnn_explainer_alg.

ExplainerBase.set_mask

Just understand this code, mainly to set the initially randomly generated edge mask and NF mask. However, in order to simplify the code, I deleted the setting of NF mask. The full version can refer to the code of DIG.

def __set_masks__(self, x, edge_index, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
        # self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module.__explain__ = True
                module.__edge_mask__ = self.edge_mask

GNNExplainer.gnn_explainer_alg

It mainly calculates an optimal edge mask, and NF mask is omitted first.

def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
        # initialize a mask
        patience = 10
        self.to(x.device)
        self.mask_features = mask_features

        # train to get the mask
        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        best_loss = 4.0
        count = 0
        for epoch in range(1, self.epochs + 1):
            if mask_features:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()
            else:
                h = x
            raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
            loss = self.__loss__(raw_preds, ex_label)
            # if epoch % 10 == 0:
            #     print(f'#D#Loss:{loss.item()}')

            is_best = (loss < best_loss)

            if not is_best:
                count += 1
            else:
                count = 0
                best_loss = loss

            if count >= patience:
                break

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return self.edge_mask.data

Here, the edge mask (and NF mask) is used as a trainable parameter and trained by neural network. The loss function is as follows

    def __loss__(self, raw_preds, x_label):
        loss = cross_entropy(raw_preds, x_label)
        m = self.edge_mask.sigmoid()
        loss = loss + self.coeffs['edge_size'] * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.mask_features:
            m = self.node_feat_mask.sigmoid()
            loss = loss + self.coeffs['node_feat_size'] * m.sum()
            ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss

This raw in loss_ Preds is the calculation result after the mask is set, because clear is not called in the whole process_ Mask to clear the label.

loss consists of three parts:

  • The cross entropy loss between the model output after adding edge mask and the label (whether it would be better to replace the output before adding without guidance). (this part feels that it has a great relationship with the accuracy of the model itself, and the accuracy of many GNN classification code vulnerabilities is worrying).
  • edge_ The size of the mask itself (size, sum)
  • edge_mask dispersion (the higher the value is, the better the value is)

The whole loss can be expressed by the following expression (without NF mask)

l o s s = C r o s s E n t r o p y ( f ( d a t a , m o d e l , e d g e _ m a s k ) , l a b e l ) + S i z e ( e d g e _ m a s k ) + D i s c r e t e ( e d g e _ m a s k ) loss = CrossEntropy(f(data, model, edge\_mask), label) + Size(edge\_mask) + Discrete(edge\_mask) loss=CrossEntropy(f(data,model,edge_mask),label)+Size(edge_mask)+Discrete(edge_mask)

Finally, you can get an optimal (loss minimum) edge mask. Then there is one question that the user code can't answer, that is f ( d a t a , m o d e l , e d g e _ m a s k ) f(data,model,edge\_mask) How the edge mask works in f(data,model,edge_mask). That is, how adding an edge mask affects the operation process of the model. To answer this question, we must explore how torch geometry supports GNN express.

torch_geometric.nn GNNExplainer

The core class here is MessagePassing. Its propagate function does not support gnnexpianer version. The code is as follows:

coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
            
msg_kwargs = self.inspector.distribute('message', coll_dict)    
out = self.message(**msg_kwargs)
 
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
 
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)

The message function returns x directly in the example_ j. Similarly, update directly returns the input tensor. Aggregate is to aggregate the information of surrounding nodes.

  • out = self.message(**msg_kwargs) the out dimension returned here is [edge_num, node_feature_dim]. It should be the vector representation of the target node of each edge.
  • out = self.aggregate(out, **aggr_kwargs) the out dimension returned here is [node_num, node_feature_dim]. It should be the vector representation of each node after aggregation.

The code of the propagate function after adding gnnexpianer support is as follows:

coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                         kwargs)

msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)

# For `GNNExplainer`, we require a separate message and aggregate
# procedure since this allows us to inject the `edge_mask` into the
# message passing computation scheme.
if self.__explain__:
     edge_mask = self.__edge_mask__.sigmoid()
     # Some ops add self-loops to `edge_index`. We need to do the
     # same for `edge_mask` (but do not train those).
     if out.size(self.node_dim) != edge_mask.size(0):
          loop = edge_mask.new_ones(size[0])
          edge_mask = torch.cat([edge_mask, loop], dim=0)
     assert out.size(self.node_dim) == edge_mask.size(0)
     out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))

aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)

update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)

You can see that a piece of code is inserted between message and aggregate, ignoring the judgment statement and shape modification. The content is out = out * edge_mask. That is, for nodes x i x_i xi, while aggregating with the surrounding nodes, first multiply the vector of the surrounding nodes by the edge mask. (the understanding here may be wrong. You are welcome to correct it.)

Topics: Pytorch Deep Learning