-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
32 lines (22 loc) · 1.07 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
from bi_encoder.faiss_retriever import search_by_faiss
from msmarco_eval import compute_metrics_from_files
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--query_reps_path", type=str, default=None)
parser.add_argument("--passage_reps_path", type=str, default=None)
parser.add_argument("--qrels_file", type=str, default=None)
parser.add_argument("--ranking_file", type=str, default=None)
parser.add_argument("--use_gpu", action='store_true', default=False)
parser.add_argument("--depth", type=int, default=1000)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
search_by_faiss(args.query_reps_path, args.passage_reps_path, args.ranking_file, batch_size=512, depth=1000,
use_gpu=args.use_gpu)
if args.qrels_file is not None:
metrics = compute_metrics_from_files(args.qrels_file, args.ranking_file)
print('#####################')
for x, y in (metrics):
print('{}: {}'.format(x, y))
print('#####################')