Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Sep 13, 2024
1 parent 5c8a192 commit 804caae
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def check_split(df_in, df_len_expected, n_lags, n_forecasts, freq, p=0.1):
n_lags=n_lags,
n_forecasts=n_forecasts,
)
df, _, _, id_list = df_utils.check_multiple_series_id(df)
df_in, _, _, _ = df_utils.check_multiple_series_id(df_in)
df_in, _, _ = df_utils.check_dataframe(df_in, check_y=False)
df_in = _handle_missing_data(
df=df_in,
Expand Down Expand Up @@ -365,6 +365,7 @@ def check_folds_dict(
df, n_lags, n_forecasts, valid_fold_num, valid_fold_pct, fold_overlap_pct, global_model_cv_type="local"
):
"Does not work with global_model_cv_type == global-time or global_model_cv_type is None"
df, _, _, _ = df_utils.check_multiple_series_id(df)
folds = df_utils.crossvalidation_split_df(
df,
n_lags,
Expand Down Expand Up @@ -525,8 +526,9 @@ def test_reg_delay():

def test_double_crossvalidation():
len_df = 100
df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df), "ID": "__df__"})
folds_val, folds_test = df_utils.double_crossvalidation_split_df(
df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}),
df=df,
n_lags=0,
n_forecasts=1,
k=3,
Expand Down Expand Up @@ -554,8 +556,10 @@ def test_double_crossvalidation():
learning_rate=LR,
n_lags=2,
)
len_df = 100
df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df), "ID": "__df__"})
folds_val, folds_test = m.double_crossvalidation_split_df(
df=pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)}),
df=df,
k=3,
valid_pct=0.3,
test_pct=0.15,
Expand All @@ -577,7 +581,10 @@ def test_double_crossvalidation():

# Raise not implemented error as double_crossvalidation is not compatible with many time series
with pytest.raises(NotImplementedError):
df = pd.DataFrame({"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df)})
len_df = 100
df = pd.DataFrame(
{"ds": pd.date_range(start="2017-01-01", periods=len_df), "y": np.arange(len_df), "ID": "__df__"}
)
df1 = df.copy(deep=True)
df1["ID"] = "df1"
df2 = df.copy(deep=True)
Expand Down

0 comments on commit 804caae

Please sign in to comment.