Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Sep 12, 2024
1 parent 6649347 commit 5c8a192
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def check_split(df_in, df_len_expected, n_lags, n_forecasts, freq, p=0.1):

def test_cv():
def check_folds(df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_overlap_pct):
df, _, _, _ = df_utils.check_multiple_series_id(df)
folds = df_utils.crossvalidation_split_df(
df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_overlap_pct
)
Expand All @@ -338,17 +339,19 @@ def check_folds(df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_ov
assert all([x == y for (x, y) in zip(train_folds_samples, train_folds_should)])

len_df = 100
df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)})
check_folds(
df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}),
df=df,
n_lags=0,
n_forecasts=1,
valid_fold_num=3,
valid_fold_pct=0.1,
fold_overlap_pct=0.0,
)
len_df = 1000
df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)})
check_folds(
df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}),
df=df,
n_lags=50,
n_forecasts=10,
valid_fold_num=10,
Expand Down

0 comments on commit 5c8a192

Please sign in to comment.