-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
35 lines (26 loc) · 853 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import pandas as pd
from datasets import load_dataset
def group_data(data):
groups = {col : []for col in set(data['subject'])}
for row in data:
groups[row['subject']].append(
[
row['question'],
row['subject'],
row['choices'],
row['answer']
]
)
return groups
def save_data(data, mode='test'):
for k, v in data.items():
pd.DataFrame(v, columns = ['question', 'subject', 'choices', 'answer']).to_csv(f'data/{mode}/{k}.csv', index=False)
def make_data(name='cais/mmlu'):
test, dev = load_dataset(name, 'all', split=['test', 'dev'])
grouped_test = group_data(test)
grouped_dev = group_data(dev)
save_data(grouped_test)
save_data(grouped_dev, 'dev')
if __name__ == "__main__":
make_data()