-
Notifications
You must be signed in to change notification settings - Fork 12
/
subsample_superglue.py
40 lines (34 loc) · 1.36 KB
/
subsample_superglue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import random
tns = ['BoolQ', 'CB', 'COPA', 'MultiRC', 'RTE', 'ReCoRD', 'WSC', 'WiC']
data_dir = f'{os.getenv("BASE")}/data'
num_train = 32
# Create other random subsets of FewGLUE
for tn in tns:
with open(f'{data_dir}/superglue/{tn}/train.jsonl') as f:
data = [line for line in f]
for seed in [1, 2, 3]:
idxs = list(range(len(data)))
random.Random(seed).shuffle(idxs)
if tn == 'WSC':
data_subset = []
for idx in idxs:
if data[idx].endswith(', "label": true}\n'):
data_subset.append(data[idx])
if len(data_subset) >= num_train:
break
else:
data_subset = [data[idx] for idx in idxs[:num_train]]
with open(f'{data_dir}/fewglue/{tn}/train.seed-{seed}.jsonl', 'w') as f:
f.writelines(''.join(data_subset))
# Shuffle FewGLUE Examples
seed = 0
idxs = list(range(num_train))
random.Random(seed).shuffle(idxs)
for tn in tns:
with open(f'{data_dir}/fewglue/{tn}/train.jsonl') as f:
data = [line for line in f]
assert len(data) == num_train, f'Expected len(data) ({len(data)}) == num_train ({num_train})'
data_subset = [data[idx] for idx in idxs[:num_train]]
with open(f'{data_dir}/fewglue/{tn}/train.seed-{seed}.jsonl', 'w') as f:
f.writelines(''.join(data_subset))