-
Notifications
You must be signed in to change notification settings - Fork 17
/
script_explain.py
77 lines (64 loc) · 2.62 KB
/
script_explain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
""" script_explain.py
Derive explanations using GraphSVX
"""
import argparse
import random
import numpy as np
import torch
import warnings
warnings.filterwarnings("ignore")
import configs
from utils.io_utils import fix_seed
from src.data import prepare_data
from src.explainers import GraphSVX
from src.train import evaluate, test
def main():
args = configs.arg_parse()
fix_seed(args.seed)
# Load the dataset
data = prepare_data(args.dataset, args.train_ratio,
args.input_dim, args.seed)
# Load the model
model_path = 'models/{}_model_{}.pth'.format(args.model, args.dataset)
model = torch.load(model_path)
# Evaluate the model
if args.dataset in ['Cora', 'PubMed']:
_, test_acc = evaluate(data, model, data.test_mask)
else:
test_acc = test(data, model, data.test_mask)
print('Test accuracy is {:.4f}'.format(test_acc))
# Explain it with GraphSVX
explainer = GraphSVX(data, model, args.gpu)
# Distinguish graph classfication from node classification
if args.dataset in ['Mutagenicity', 'syn6']:
explanations = explainer.explain_graphs(args.indexes,
args.hops,
args.num_samples,
args.info,
args.multiclass,
args.fullempty,
args.S,
'graph_classification',
args.feat,
args.coal,
args.g,
args.regu,
True)
else:
explanations = explainer.explain(args.indexes,
args.hops,
args.num_samples,
args.info,
args.multiclass,
args.fullempty,
args.S,
args.hv,
args.feat,
args.coal,
args.g,
args.regu,
True)
print('Sum explanations: ', [np.sum(explanation) for explanation in explanations])
print('Base value: ', explainer.base_values)
if __name__ == "__main__":
main()