-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgnn_explainer.py
84 lines (62 loc) · 2.92 KB
/
gnn_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch_geometric.nn import GNNExplainer
from tqdm import tqdm
EPS = 1e-15
class TargetedGNNExplainer(GNNExplainer):
def __loss__(self, node_idx, log_logits, target_class):
loss = -log_logits[node_idx, target_class]
m = self.edge_mask.sigmoid()
loss = loss + self.coeffs['edge_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
m = self.node_feat_mask.sigmoid()
loss = loss + self.coeffs['node_feat_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
return loss
def explain_node_with_target(self, node_idx, x, edge_index, target_class, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Args:
node_idx (int): The node to explain.
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""
self.model.eval()
self.__clear_masks__()
num_edges = edge_index.size(1)
# Only operate on a k-hop subgraph around `node_idx`.
x, edge_index, mapping, hard_edge_mask, kwargs = self.__subgraph__(
node_idx, x, edge_index, **kwargs)
# Get the initial prediction.
if target_class is None:
with torch.no_grad():
log_logits = self.model(x=x, edge_index=edge_index, **kwargs)
pred_label = log_logits.argmax(dim=-1)
target_class = pred_label[mapping].item()
self.__set_masks__(x, edge_index)
self.to(x.device)
optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)
if self.log: # pragma: no cover
pbar = tqdm(total=self.epochs)
pbar.set_description(f'Explain node {node_idx}')
for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.view(1, -1).sigmoid()
log_logits = self.model(x=h, edge_index=edge_index, **kwargs)
loss = self.__loss__(mapping, log_logits, target_class)
loss.backward()
optimizer.step()
if self.log: # pragma: no cover
pbar.update(1)
if self.log: # pragma: no cover
pbar.close()
node_feat_mask = self.node_feat_mask.detach().sigmoid()
edge_mask = self.edge_mask.new_zeros(num_edges)
edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid()
self.__clear_masks__()
return node_feat_mask, edge_mask