-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_utilities.py
117 lines (111 loc) · 3.43 KB
/
dataset_utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Created by xunannancy at 2024/03/04
from tqdm import tqdm
import collections
import os
import pickle
import json
import gzip
dataset_folder_dict = {
'LF-WikiSeeAlso-320K': './datasets/LF-WikiSeeAlso-320K',
'LF-Amazon-131K': './datasets/LF-Amazon-131K',
'UFET': './datasets/UFET',
'UFET_crowd': './datasets/UFET/crowd',
'UFET_crowd_dev': './datasets/UFET/crowd_dev',
'BANKING77': './datasets/BANKING77',
'BANKING77_train': './datasets/BANKING77/train',
'BANKING77_train_5': './datasets/BANKING77/train_5',
'BANKING77_train_10': './datasets/BANKING77/train_10',
'CLINC150': './datasets/CLINC150',
'CLINC150_train': './datasets/CLINC150/train',
'CLINC150_train_5': './datasets/CLINC150/train_5',
'CLINC150_train_10': './datasets/CLINC150/train_10',
'HWU64': './datasets/HWU64',
'HWU64_train': './datasets/HWU64/train',
'HWU64_train_5': './datasets/HWU64/train_5',
'HWU64_train_10': './datasets/HWU64/train_10',
}
instance_counts_dict = {
'LF-WikiSeeAlso-320K': {
'train': 693082,
'test': 177515,
},
'LF-Amazon-131K': {
'train': 294805,
'test': 134835,
},
'UFET_crowd': {
'train': 1998,
'test': 1998,
'dev': 1998
},
'UFET_crowd_dev': {
'train': 3996,
'dev': 1998,
'test': 1998,
},
'BANKING77_train_5': {
'train': 385,
'test': 3080,
'dev': 1540,
},
'CLINC150_train_5': {
'train': 750,
'test': 4500,
'dev': 3000,
},
'HWU64_train_5': {
'train': 320,
'test': 1076,
'dev': 1076,
},
'HWU64_train_10': {
'train': 640,
'test': 1076,
'dev': 1076,
}
}
label_counts_dict = {
'AmazonCat-13K': 13330,
'LF-WikiSeeAlso-320K': 312330,
'wikiseealso': 312330,
'LF-Amazon-131K': 131073,
'amazon131k': 131073,
'UFET': 10331,
'UFET_crowd': 10331,
'UFET_crowd_dev': 10331,
'BANKING77': 77,
'BANKING77_train_5': 77,
'BANKING77_train_10': 77,
'CLINC150': 150,
'CLINC150_train_5': 150,
'CLINC150_train_10': 150,
'HWU64': 64,
'HWU64_train_5': 64,
'HWU64_train_10': 64,
}
prefix_dict = {
'BANKING77': 'What\'s the intent of this customer query: ',
'CLINC150': 'What\'s the intent of this query: ',
'HWU64': 'What\'s the intent of this user utterance: ',
}
def label_check(dataset='LF-WikiSeeAlso-320K'):
label_path = os.path.join(dataset_folder_dict[dataset], 'lbl.json.gz')
labels = list()
label_title_id_mapping = dict()
for idx, line in tqdm(enumerate(gzip.open(label_path)), total=label_counts_dict[dataset]):
title = json.loads(line)['title']
labels.append(title)
if title not in label_title_id_mapping:
label_title_id_mapping[title] = [idx]
else:
label_title_id_mapping[title].append(idx)
saved_path = os.path.join(dataset_folder_dict[dataset], 'dpr_processed', 'labels', 'label_title_id_mapping.pkl')
os.makedirs(os.path.dirname(saved_path), exist_ok=True)
with open(saved_path, 'wb') as f:
pickle.dump(label_title_id_mapping, f)
print(f'label_title_id_mapping: {len(label_title_id_mapping)}')
# 312330 -> 312312
label_freq = collections.Counter(labels)
repetitive_labels = [label for label, count in label_freq.items() if count > 1]
print(f'repetitive_labels: {len(repetitive_labels)}') # 18, each appears twice
return