-
Notifications
You must be signed in to change notification settings - Fork 0
/
annotate.py
121 lines (108 loc) · 2.96 KB
/
annotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os, pdb, sys
import json
import random
from collections import defaultdict
from nltk.tokenize import sent_tokenize
dev_list=[
"MUL0129.json",
"MUL0178.json",
"MUL0476.json",
"MUL0602.json",
"MUL0603.json",
"MUL1125.json",
"MUL1160.json",
"MUL1227.json",
"MUL1381.json",
"MUL2020.json",
"MUL2251.json",
"MUL2344.json",
"MUL2418.json",
"MUL2690.json",
"PMUL0134.json",
"PMUL0187.json",
"PMUL0287.json",
"PMUL0626.json",
"PMUL0689.json",
"PMUL1159.json",
"PMUL1181.json",
"PMUL1557.json",
"PMUL1579.json",
"PMUL1599.json",
"PMUL1635.json",
"PMUL1879.json",
"PMUL2389.json",
"PMUL2748.json",
"PMUL2804.json",
"PMUL3363.json",
"PMUL3466.json",
"PMUL3470.json",
"PMUL3554.json",
"PMUL4029.json",
"PMUL4053.json",
"PMUL4126.json",
"PMUL4711.json",
"SNG0019.json",
"SNG01172.json",
"SNG01297.json",
"SNG02214.json",
"SNG0271.json",
"SNG0314.json",
"SNG0494.json",
"SNG0907.json",
"SNG0910.json",
"SNG1046.json",
"SNG1069.json"
]
def save_results(results, convo_id, version=1):
save_path = os.path.join('results', 'annotations', f"saliency_v{version}.json")
json.dump(results, open(save_path, 'w'), indent=4)
size = len(results[convo_id])
print(f"Saved {size} more annotations for a total of {len(results)} conversations")
def create_prior(version=1):
prior_path = os.path.join('results', 'annotations', f"saliency_v{version}.json")
if os.path.exists(prior_path):
prior_results = json.load(open(prior_path, 'r'))
return prior_results
else:
return {}
def load_data():
data_path = os.path.join('assets', 'mwoz', 'dev.json')
convos = json.load(open(data_path, 'r'))
return convos
def annotate_data(data, results, version):
speakers = ['customer', 'agent']
data_keys = list(data.keys())
random.shuffle(data_keys)
for convo_id in data_keys:
conversation = data[convo_id]
if convo_id in results or convo_id in dev_list:
continue
else:
results[convo_id] = []
prev_sent = ""
speaker_id = 0
for turn in conversation['log']:
speaker = speakers[speaker_id]
print(speaker)
for sentence in sent_tokenize(turn['text']):
annotation = input(sentence + " --> ")
annotation = annotation.strip()
example = {'speaker': speaker, 'previous': prev_sent, 'current': sentence}
if annotation in ['s', 'salient', 'S', 'd', 'True', 'true', 'y']:
example['label'] = True
elif annotation in ['n', 'not', 'N', 'h', 'False', 'false']:
example['label'] = False
else:
continue
results[convo_id].append(example)
prev_sent = sentence
speaker_id = 1 - speaker_id
save_results(results, convo_id, version)
response = input("End of conversation. Continue?")
if response not in ['y', 'yes', 'ok', 'c', 'continue']:
sys.exit()
if __name__ == "__main__":
version = 2
data = load_data()
prior = create_prior(version)
annotate_data(data, prior, version)