From 5c8a192d526a545e85f22d5bf8915738c2179e3f Mon Sep 17 00:00:00 2001 From: ourownstory Date: Thu, 12 Sep 2024 16:58:13 -0700 Subject: [PATCH] update tests --- tests/test_unit.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_unit.py b/tests/test_unit.py index 115c71e42..a5ca1b9cb 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -317,6 +317,7 @@ def check_split(df_in, df_len_expected, n_lags, n_forecasts, freq, p=0.1): def test_cv(): def check_folds(df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_overlap_pct): + df, _, _, _ = df_utils.check_multiple_series_id(df) folds = df_utils.crossvalidation_split_df( df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_overlap_pct ) @@ -338,8 +339,9 @@ def check_folds(df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_ov assert all([x == y for (x, y) in zip(train_folds_samples, train_folds_should)]) len_df = 100 + df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}) check_folds( - df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}), + df=df, n_lags=0, n_forecasts=1, valid_fold_num=3, @@ -347,8 +349,9 @@ def check_folds(df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_ov fold_overlap_pct=0.0, ) len_df = 1000 + df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}) check_folds( - df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}), + df=df, n_lags=50, n_forecasts=10, valid_fold_num=10,