-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
61 lines (50 loc) · 2.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'Stefan Jansen'
import numpy as np
np.random.seed(42)
def format_time(t):
"""Return a formatted time string 'HH:MM:SS
based on a numeric time() value"""
m, s = divmod(t, 60)
h, m = divmod(m, 60)
return f'{h:0>2.0f}:{m:0>2.0f}:{s:0>2.0f}'
class MultipleTimeSeriesCV:
"""Generates tuples of train_idx, test_idx pairs
Assumes the MultiIndex contains levels 'symbol' and 'date'
purges overlapping outcomes"""
def __init__(self,
n_splits=3,
train_period_length=126,
test_period_length=21,
lookahead=None,
date_idx='date',
shuffle=False):
self.n_splits = n_splits
self.lookahead = lookahead
self.test_length = test_period_length
self.train_length = train_period_length
self.shuffle = shuffle
self.date_idx = date_idx
def split(self, X, y=None, groups=None):
unique_dates = X.index.get_level_values(self.date_idx).unique()
days = sorted(unique_dates, reverse=True)
split_idx = []
for i in range(self.n_splits):
test_end_idx = i * self.test_length
test_start_idx = test_end_idx + self.test_length
train_end_idx = test_start_idx + self.lookahead - 1
train_start_idx = train_end_idx + self.train_length + self.lookahead - 1
split_idx.append([train_start_idx, train_end_idx,
test_start_idx, test_end_idx])
dates = X.reset_index()[[self.date_idx]]
for train_start, train_end, test_start, test_end in split_idx:
train_idx = dates[(dates[self.date_idx] > days[train_start])
& (dates.date <= days[train_end])].index
test_idx = dates[(dates.date > days[test_start])
& (dates.date <= days[test_end])].index
if self.shuffle:
np.random.shuffle(list(train_idx))
yield train_idx.to_numpy(), test_idx.to_numpy()
def get_n_splits(self, X, y, groups=None):
return self.n_splits