-
Notifications
You must be signed in to change notification settings - Fork 0
/
kscore.py
116 lines (100 loc) · 4.05 KB
/
kscore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
import time, csv, gzip
import logging
from concrete import AnnotationMetadata, ServiceInfo
from concrete.search import SearchService
from concrete.search.ttypes import SearchResult, SearchCapability, SearchType, SearchQuery
from concrete.services.ttypes import ServicesException
from concrete.util import AnalyticUUIDGeneratorFactory, SearchServiceWrapper, SearchClientWrapper
class SearchHandler(SearchService.Iface):
def __init__(self, other, corpus_name, host, port):
self.other = other
self.corpus_name = corpus_name
self.port = port
self.host = host
def alive(self):
return True
def about(self):
return ServiceInfo(name='search kscore', version='0.0')
def getCapabilities(self):
return [SearchCapability(SearchType.SENTENCES)]
# raise ServicesException()
def getCorpora(self):
raise [self.corpus_name]
def search(self, query):
return self.other.search(query)
# augf = AnalyticUUIDGeneratorFactory()
# aug = augf.create()
# with SearchClientWrapper(self.host, self.port) as sc:
# return sc.search(query)
def kscore(s):
truth = []
answer_labels = {}
with open("dev-match.tsv") as match:
reader = csv.reader(match, delimiter="\t", quotechar="'")
for row in reader:
answer_labels[row[3]] = row[4]
with gzip.open("WikiQA-dev.tsv.gz", 'rt') as wiki:
reader = csv.reader(wiki, delimiter="\t", quotechar="'")
next(reader)
used = {}
k_val_dict = {
1:[0,0],
10:[0,0],
100:[0,0],
1000:[0,0]
}
k_vals = [1, 10, 100, 1000]
for row in reader:
print(row)
query = row[1]
sentenceID = row[4]
# query = query.replace(","," ")
# query = query.replace("'"," ")
query = query.replace('"',"")
query = query.replace("/"," ")
query = query.replace("?","")
if query not in used:
used[query] = 0
terms = query.split(" ")
for k_val in k_vals:
query1 = SearchQuery(type=SearchType.SENTENCES, terms=terms, k=k_val, rawQuery=query)
results = s.search(query1)
atK = 0
totCorrect = 0
hasAnswerInMatch = False
for result in results.searchResultItems:
if atK == k_val:
break
else:
atK += 1
try:
totCorrect += int(answer_labels[result.sentenceId.uuidString])
hasAnswerInMatch = True
except (KeyError):
atK -= 1
if totCorrect >= 1:
k_val_dict[k_val][0] += 1
if hasAnswerInMatch:
k_val_dict[k_val][1] += 1
else:
continue
print("Baseline success @k")
print("1: {}".format(k_val_dict[1][0]/k_val_dict[1][1]))
print("10: {}".format(k_val_dict[10][0]/k_val_dict[10][1]))
print("100: {}".format(k_val_dict[100][0]/k_val_dict[100][1]))
print("1000: {}".format(k_val_dict[1000][0]/k_val_dict[1000][1]))
if __name__ == "__main__":
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("-p", "--port", type=int, default=9090)
parser.add_argument("--host", default="kdft")
args = parser.parse_args()
logging.basicConfig(format='%(asctime)-15s %(levelname)s: %(message)s',
level='DEBUG')
print(args.host)
print(args.port)
time.sleep(10)
with SearchClientWrapper(args.host, args.port) as search_client:
handler = SearchHandler(search_client, "wikiQA", "", "")
kscore(handler)