-
Notifications
You must be signed in to change notification settings - Fork 154
/
test.py
127 lines (90 loc) · 4.26 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import __init__
import torch
from dataset import OGBNDataset
from utils.data_util import intersection, process_indexes
import numpy as np
from ogb.nodeproppred import Evaluator
from model import DeeperGCN
from args import ArgsInit
@torch.no_grad()
def multi_evaluate(valid_data_list, dataset, model, evaluator, device):
model.eval()
target = dataset.y.detach().numpy()
train_pre_ordered_list = []
valid_pre_ordered_list = []
test_pre_ordered_list = []
test_idx = dataset.test_idx.tolist()
train_idx = dataset.train_idx.tolist()
valid_idx = dataset.valid_idx.tolist()
for valid_data_item in valid_data_list:
sg_nodes, sg_edges, sg_edges_index, _ = valid_data_item
idx_clusters = np.arange(len(sg_nodes))
test_predict = []
test_target_idx = []
train_predict = []
valid_predict = []
train_target_idx = []
valid_target_idx = []
for idx in idx_clusters:
x = dataset.x[sg_nodes[idx]].float().to(device)
sg_nodes_idx = torch.LongTensor(sg_nodes[idx]).to(device)
mapper = {node: idx for idx, node in enumerate(sg_nodes[idx])}
sg_edges_attr = dataset.edge_attr[sg_edges_index[idx]].to(device)
inter_tr_idx = intersection(sg_nodes[idx], train_idx)
inter_v_idx = intersection(sg_nodes[idx], valid_idx)
train_target_idx += inter_tr_idx
valid_target_idx += inter_v_idx
tr_idx = [mapper[tr_idx] for tr_idx in inter_tr_idx]
v_idx = [mapper[v_idx] for v_idx in inter_v_idx]
pred = model(x, sg_nodes_idx, sg_edges[idx].to(device), sg_edges_attr).cpu().detach()
train_predict.append(pred[tr_idx])
valid_predict.append(pred[v_idx])
inter_te_idx = intersection(sg_nodes[idx], test_idx)
test_target_idx += inter_te_idx
te_idx = [mapper[te_idx] for te_idx in inter_te_idx]
test_predict.append(pred[te_idx])
train_pre = torch.cat(train_predict, 0).numpy()
valid_pre = torch.cat(valid_predict, 0).numpy()
test_pre = torch.cat(test_predict, 0).numpy()
train_pre_ordered = train_pre[process_indexes(train_target_idx)]
valid_pre_ordered = valid_pre[process_indexes(valid_target_idx)]
test_pre_ordered = test_pre[process_indexes(test_target_idx)]
train_pre_ordered_list.append(train_pre_ordered)
valid_pre_ordered_list.append(valid_pre_ordered)
test_pre_ordered_list.append(test_pre_ordered)
train_pre_final = torch.mean(torch.Tensor(train_pre_ordered_list), dim=0)
valid_pre_final = torch.mean(torch.Tensor(valid_pre_ordered_list), dim=0)
test_pre_final = torch.mean(torch.Tensor(test_pre_ordered_list), dim=0)
eval_result = {}
input_dict = {"y_true": target[train_idx], "y_pred": train_pre_final}
eval_result["train"] = evaluator.eval(input_dict)
input_dict = {"y_true": target[valid_idx], "y_pred": valid_pre_final}
eval_result["valid"] = evaluator.eval(input_dict)
input_dict = {"y_true": target[test_idx], "y_pred": test_pre_final}
eval_result["test"] = evaluator.eval(input_dict)
return eval_result
def main():
args = ArgsInit().args
if args.use_gpu:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
else:
device = torch.device("cpu")
dataset = OGBNDataset(dataset_name=args.dataset)
args.num_tasks = dataset.num_tasks
args.nf_path = dataset.extract_node_features(args.aggr)
evaluator = Evaluator(args.dataset)
valid_data_list = []
for i in range(args.num_evals):
parts = dataset.random_partition_graph(dataset.total_no_of_nodes,
cluster_number=args.valid_cluster_number)
valid_data = dataset.generate_sub_graphs(parts,
cluster_number=args.valid_cluster_number)
valid_data_list.append(valid_data)
model = DeeperGCN(args)
model.load_state_dict(torch.load(args.model_load_path)['model_state_dict'])
model.to(device)
result = multi_evaluate(valid_data_list, dataset, model, evaluator, device)
print(result)
model.print_params(final=True)
if __name__ == "__main__":
main()