-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsplit-data.py
49 lines (38 loc) · 1.36 KB
/
split-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding: utf-8 -*-
"""
Usage:
split-data.py [-p partitions] [-o output] DATASET
Options:
-p partitions number of cross-folds [default: 5]
-o output destination directory [default: data-split]
DATASET name of data set to load
"""
from docopt import docopt
from lkdemo import datasets, log
from pathlib import Path
from seedbank import init_file
import lenskit.crossfold as xf
def main(args):
dsname = args.get('DATASET')
partitions = int(args.get('-p'))
output = args.get('-o')
# initialize RNG with the data set name in the seed
init_file('params.yaml', 'split-data', dsname)
_log.info('locating data set %s', dsname)
data = getattr(datasets, dsname)
_log.info('loading ratings')
ratings = data.ratings
path = Path(output)
path.mkdir(exist_ok=True, parents=True)
_log.info('writing to %s', path)
testRowsPerUsers = 5
for i, tp in enumerate(xf.partition_users(ratings, partitions, xf.SampleN(testRowsPerUsers)), 1):
# _log.info('writing train set %d', i)
# tp.train.to_csv(path / f'train-{i}.csv.gz', index=False)
_log.info('writing test set %d', i)
tp.test.index.name = 'index'
tp.test.to_parquet(path / f'test-{i}.parquet')
if __name__ == '__main__':
_log = log.script(__file__)
args = docopt(__doc__)
main(args)