forked from thunlp/OpenKE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_subgraphs.py
64 lines (56 loc) · 3.63 KB
/
test_subgraphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import pickle
import argparse
from subgraphs import Subgraph
from subgraphs import SUBTYPE
from subgraph_predictor import SubgraphPredictor
def parse_args():
parser = argparse.ArgumentParser(description = 'Read training/test file and run LSTM training or test.')
parser.add_argument('--testfile', dest ='test_file', type = str, help = 'File containing test queries.')
parser.add_argument('--trainfile', dest ='train_file', type = str, help = 'File containing training triples.')
parser.add_argument('--modelfile', dest ='model_file',type = str, help = 'File containing test data.')
parser.add_argument('--weightsfile', dest ='weights_file', type = str, help = 'File containing test data.')
parser.add_argument('--subfile', dest ='sub_file', type = str, help = 'File containing subgraphs metadata.')
parser.add_argument('--subembdir', dest ='subemb_dir', type = str, help = 'Dir containing subgraphs embeddings.')
parser.add_argument('--embfile', dest ='emb_file', type = str, help = 'File containing entity embeddings.')
parser.add_argument('--entdict', dest ='ent_dict', type = str, default = '/var/scratch2/uji300/OpenKE-results/fb15k237/misc/fb15k237-id-to-entity.pkl',help = 'entity id dictionary.')
parser.add_argument('--reldict', dest ='rel_dict', type = str, default = '/var/scratch2/uji300/OpenKE-results/fb15k237/misc/fb15k237-id-to-relation.pkl',help = 'relation id dictionary.')
parser.add_argument('-rd', '--result-dir', dest ='result_dir', type = str, default = "/var/scratch2/uji300/OpenKE-results/",help = 'Output dir.')
parser.add_argument('--topk', dest = 'topk', required = True, type = int, default = 10)
parser.add_argument('--testonly', dest = 'num_test_queries', required = False, type = int, default = -1) # -1 means all
parser.add_argument('--db', required = True, dest = 'db', type = str, default = None)
parser.add_argument('--model', dest ='model',type = str, default = "transe", help = 'Embedding model name.')
parser.add_argument('-stp', '--subgraph-threshold-percentage', dest ='sub_threshold', default = 0.1, type = float, help = '% of top subgraphs to check the correctness of answers.')
parser.add_argument('-th', '--threshold',dest ='threshold', type = float, default = 0.5, help = 'Probability value that decides the boundary between class 0 and 1.')
parser.add_argument('--score', dest='score_func', type=str, help='Score function to evaluate subgraphs on', default="avg")
return parser.parse_args()
args = parse_args()
result_dir = args.result_dir + args.db + "/out/"
log_dir = args.result_dir + args.db + "/logs/"
os.makedirs(result_dir, exist_ok = True)
os.makedirs(log_dir, exist_ok = True)
queries_file_path = args.test_file
emb_file = args.emb_file
sub_file = args.sub_file
subemb_dir = args.subemb_dir
db_path = "./benchmarks/" + args.db + "/"
print("Initializing Subgraph predictor", flush=True)
mys = SubgraphPredictor(args.db, args.topk, emb_file, sub_file, subemb_dir, args.model, args.train_file, db_path, args.sub_threshold, args.score_func)
mys.set_test_triples(queries_file_path, args.num_test_queries)
# entity dict is the id to string dictionary for entities
mys.init_entity_dict(args.ent_dict, args.rel_dict)
# set log file
base_name = os.path.basename(sub_file).rsplit('.', maxsplit=1)[0]
logfile = log_dir + base_name + ".log"
mys.set_logfile(logfile)
mys.predict()
'''
raw_result, fil_result = mys.results()
# Pickle the output
output_file = result_dir + base_name + ".out"
result_dict = {}
result_dict['raw'] = raw_result
result_dict['fil'] = fil_result
with open(output_file, 'wb') as fout:
pickle.dump(result_dict, fout, protocol = pickle.HIGHEST_PROTOCOL)
'''