forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_classify.py
100 lines (89 loc) · 2.81 KB
/
test_classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Infering Relational Data with Graph Convolutional Networks
"""
import argparse
from functools import partial
import torch as th
import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from entity_classify import EntityClassify
def main(args):
# load graph data
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
test_mask = g.nodes[category].data.pop("test_mask")
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop("labels")
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
labels = labels.cuda()
test_idx = test_idx.cuda()
g = g.to("cuda:%d" % args.gpu)
# create model
model = EntityClassify(
g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop,
)
model.load_state_dict(th.load(args.model_path))
if use_cuda:
model.cuda()
print("start testing...")
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden units"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-bases",
type=int,
default=-1,
help="number of filter weight matrices, default: -1 [use all]",
)
parser.add_argument(
"--n-layers", type=int, default=2, help="number of propagation rounds"
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, help="path of the model to load from"
)
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
args = parser.parse_args()
print(args)
main(args)