forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 5
/
coref.py
84 lines (75 loc) · 3.84 KB
/
coref.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import spacy
nlp = spacy.load('en_coref_lg')
count = 0
mistake_type = "NUM"
fsrc = open(mistake_type + "_contr.src.txt", "w")
fref = open(mistake_type + "_contr.ref.txt", "w")
fcon = open(mistake_type + "_contr.contr.txt", "w")
males = ['he','his','him','himself']
females = ['she','hers','her','herself']
subj_pro = ['i', 'you', 'he', 'she', 'it', 'we', 'you', 'they']
obj_pro = ['me', 'you', 'him', 'her', 'it', 'us', 'you', 'them']
poss_adj = ['my', 'your', 'his', 'her', 'its', 'our', 'your', 'their']
poss_pro = ['mine', 'yours', 'his', 'hers', None, 'ours', 'yours', 'theirs']
ref_pro = ['myself', 'yourself', 'himself', 'herself', 'itself', 'ourselves', 'yourselves', 'themselves']
pros = [subj_pro, obj_pro, poss_adj, poss_pro, ref_pro]
for src, line in zip(open(sys.argv[1]), open(sys.argv[2])):
line = line.strip()
doc = nlp(line)
corefs = {}
for i, tok in enumerate(doc):
mentions = []
try:
for c in tok._.coref_clusters:
mentions.append(c.main)
except:
pass
mentions = [x for x in mentions if tok not in x]
if mentions != []:
corefs[i] = mentions
if corefs != {}:
original = [str(t) for t in doc]
examples = []
for k in corefs:
for v in corefs[k]:
if mistake_type == "ANTEC":
if original[k] != str(v):
examples.append(" ".join(original[:k] + [str(v)] + original[k + 1:]))
elif mistake_type == "TYPE":
for p in pros:
if original[k] in p:
idx = p.index(original[k])
for p2 in pros:
if p[idx] != p2[idx] and p2[idx] is not None:
examples.append(" ".join(original[:k] + [str(p2[idx])] + original[k + 1:]))
elif mistake_type == "NUM":
for p in pros:
if original[k] in p:
for item in p:
if item != original[k] and item is not None:
if (item in females and original[k] in females) or (item in males and original[k] in males) or (item not in males and item not in females) or (original[k] not in males and original[k] not in females):
examples.append(" ".join(original[:k] + [str(item)] + original[k + 1:]))
# elif mistake_type == "PERS":
# for p in pros:
# if original[k] in p:
# for item in p:
# if item != original[k] and item is not None:
# if item not in thirds or original[k] not in thirds:
# examples.append(" ".join(original[:k] + [str(item)] + original[k + 1:]))
elif mistake_type == "GEND":
for p in pros:
if original[k] in p:
for item in p:
if item != original[k] and item is not None:
if (item in females and original[k] in males) or (item in males and original[k] in females):
examples.append(" ".join(original[:k] + [str(item)] + original[k + 1:]))
for example in examples:
count += 1
fsrc.write(src.strip() + "\n")
fref.write(line + "\n")
fcon.write(example + "\n")
fsrc.close()
fref.close()
fcon.close()
print(mistake_type, count)