diff --git a/mlforecast/core.py b/mlforecast/core.py index a1cf2d0e..6b0049e0 100644 --- a/mlforecast/core.py +++ b/mlforecast/core.py @@ -222,14 +222,19 @@ def _fit( self.id_col = id_col self.target_col = target_col self.time_col = time_col + to_drop = [id_col, time_col, target_col] + self.static_features = static_features if static_features is None: - static_features = df.columns.drop([id_col, time_col, target_col]) - elif id_col in static_features: - raise ValueError( - "Cannot use the id_col as a static feature. Please create a separate column." - ) - self.static_features = ( - df.set_index(id_col)[static_features].groupby(id_col, observed=True).head(1) + static_features = df.columns.drop([time_col, target_col]).tolist() + elif id_col not in static_features: + static_features = [id_col] + static_features + else: # static_features defined and contain id_col + to_drop = [time_col, target_col] + self.static_features_ = ( + df[static_features] + .groupby(id_col, observed=True) + .head(1) + .reset_index(drop=True) ) sort_idxs = pd.core.sorting.lexsort_indexer([df[id_col], df[time_col]]) self.restore_idxs = np.empty(df.shape[0], dtype=np.int32) @@ -249,9 +254,7 @@ def _fit( self.last_dates = sorted_df.index.get_level_values(self.time_col)[ self.ga.indptr[1:] - 1 ] - self.features_order_ = ( - df.columns.drop([id_col, time_col, target_col]).tolist() + self.features - ) + self.features_order_ = df.columns.drop(to_drop).tolist() + self.features return self def _apply_transforms(self, updates_only: bool = False) -> Dict[str, np.ndarray]: @@ -433,7 +436,7 @@ def _update_features(self) -> pd.DataFrame: features_df = pd.DataFrame(features, columns=self.features) features_df[self.id_col] = self.uids features_df[self.time_col] = self.curr_dates - return self.static_features.merge(features_df, on=self.id_col) + return self.static_features_.merge(features_df, on=self.id_col) def _get_raw_predictions(self) -> np.ndarray: return np.array(self.y_pred).ravel("F") @@ -580,10 +583,19 @@ def update(self, df: pd.DataFrame) -> None: ).astype(self.last_dates.dtype) self.uids = sizes.index new_statics = df.iloc[new_sizes.cumsum() - 1].set_index(self.id_col) - orig_dtypes = self.static_features.dtypes - self.static_features = self.static_features.reindex(self.uids) - self.static_features.update(new_statics) - self.static_features = self.static_features.astype(orig_dtypes) + orig_dtypes = self.static_features_.dtypes + if pd.api.types.is_categorical_dtype(orig_dtypes[self.id_col]): + orig_categories = orig_dtypes[self.id_col].categories.tolist() + missing_categories = set(self.uids) - set(orig_categories) + if missing_categories: + orig_dtypes[self.id_col] = pd.CategoricalDtype( + categories=orig_categories + list(missing_categories) + ) + self.static_features_ = self.static_features_.set_index(self.id_col).reindex( + self.uids + ) + self.static_features_.update(new_statics) + self.static_features_ = self.static_features_.reset_index().astype(orig_dtypes) self.ga = self.ga.append_several( new_sizes=sizes.values.astype(np.int32), new_values=values, diff --git a/mlforecast/forecast.py b/mlforecast/forecast.py index 046b5af1..8bd7295b 100644 --- a/mlforecast/forecast.py +++ b/mlforecast/forecast.py @@ -444,7 +444,7 @@ def predict( id_col=self.ts.id_col, time_col=self.ts.time_col, target_col=self.ts.target_col, - static_features=self.ts.static_features.columns, + static_features=self.ts.static_features, keep_last_n=self.ts.keep_last_n, ) new_ts.max_horizon = self.ts.max_horizon diff --git a/mlforecast/lgb_cv.py b/mlforecast/lgb_cv.py index 8e30a0aa..6b1da444 100644 --- a/mlforecast/lgb_cv.py +++ b/mlforecast/lgb_cv.py @@ -45,7 +45,7 @@ def _update(bst, n): def _predict(ts, bst, valid, h, before_predict_callback, after_predict_callback): ex_cols_to_drop = [ts.id_col, ts.time_col, ts.target_col] - static_features = ts.static_features.columns.tolist() + static_features = ts.static_features_.columns.drop(ts.id_col).tolist() ex_cols_to_drop.extend(static_features) has_ex = not valid.columns.drop(ex_cols_to_drop).empty dynamic_dfs = ( diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 03cbff25..0975f72f 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -671,18 +671,20 @@ " self.id_col = id_col\n", " self.target_col = target_col\n", " self.time_col = time_col\n", + " to_drop = [id_col, time_col, target_col]\n", + " self.static_features = static_features\n", " if static_features is None:\n", - " static_features = df.columns.drop([id_col, time_col, target_col]) \n", - " elif id_col in static_features:\n", - " raise ValueError(\n", - " 'Cannot use the id_col as a static feature. Please create a separate column.'\n", - " ) \n", - " self.static_features = (\n", + " static_features = df.columns.drop([time_col, target_col]).tolist()\n", + " elif id_col not in static_features:\n", + " static_features = [id_col] + static_features\n", + " else: # static_features defined and contain id_col\n", + " to_drop = [time_col, target_col]\n", + " self.static_features_ = (\n", " df\n", - " .set_index(id_col)\n", " [static_features]\n", " .groupby(id_col, observed=True)\n", " .head(1)\n", + " .reset_index(drop=True)\n", " )\n", " sort_idxs = pd.core.sorting.lexsort_indexer([df[id_col], df[time_col]])\n", " self.restore_idxs = np.empty(df.shape[0], dtype=np.int32) \n", @@ -700,7 +702,7 @@ " self.ga = self.ga.take_from_groups(slice(-keep_last_n, None))\n", " self._ga = GroupedArray(self.ga.data, self.ga.indptr)\n", " self.last_dates = sorted_df.index.get_level_values(self.time_col)[self.ga.indptr[1:] - 1]\n", - " self.features_order_ = df.columns.drop([id_col, time_col, target_col]).tolist() + self.features\n", + " self.features_order_ = df.columns.drop(to_drop).tolist() + self.features\n", " return self\n", "\n", " def _apply_transforms(self, updates_only: bool = False) -> Dict[str, np.ndarray]:\n", @@ -880,7 +882,7 @@ " features_df = pd.DataFrame(features, columns=self.features)\n", " features_df[self.id_col] = self.uids\n", " features_df[self.time_col] = self.curr_dates \n", - " return self.static_features.merge(features_df, on=self.id_col)\n", + " return self.static_features_.merge(features_df, on=self.id_col)\n", " \n", " def _get_raw_predictions(self) -> np.ndarray:\n", " return np.array(self.y_pred).ravel('F')\n", @@ -1027,10 +1029,15 @@ " ).astype(self.last_dates.dtype)\n", " self.uids = sizes.index\n", " new_statics = df.iloc[new_sizes.cumsum() - 1].set_index(self.id_col)\n", - " orig_dtypes = self.static_features.dtypes\n", - " self.static_features = self.static_features.reindex(self.uids)\n", - " self.static_features.update(new_statics)\n", - " self.static_features = self.static_features.astype(orig_dtypes)\n", + " orig_dtypes = self.static_features_.dtypes\n", + " if pd.api.types.is_categorical_dtype(orig_dtypes[self.id_col]):\n", + " orig_categories = orig_dtypes[self.id_col].categories.tolist()\n", + " missing_categories = set(self.uids) - set(orig_categories)\n", + " if missing_categories:\n", + " orig_dtypes[self.id_col] = pd.CategoricalDtype(categories=orig_categories + list(missing_categories))\n", + " self.static_features_ = self.static_features_.set_index(self.id_col).reindex(self.uids)\n", + " self.static_features_.update(new_statics)\n", + " self.static_features_ = self.static_features_.reset_index().astype(orig_dtypes)\n", " self.ga = self.ga.append_several(\n", " new_sizes=sizes.values.astype(np.int32),\n", " new_values=values,\n", @@ -1444,8 +1451,8 @@ "outputs": [], "source": [ "pd.testing.assert_frame_equal(\n", - " ts.static_features,\n", - " series.groupby('unique_id').tail(1).drop(columns=['ds', 'y']).set_index('unique_id'),\n", + " ts.static_features_,\n", + " series.groupby('unique_id').tail(1).drop(columns=['ds', 'y']).reset_index(drop=True),\n", ")" ] }, @@ -1465,8 +1472,8 @@ "ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', static_features=['static_0'])\n", "\n", "pd.testing.assert_frame_equal(\n", - " ts.static_features,\n", - " series.groupby('unique_id').tail(1).set_index('unique_id')[['static_0']],\n", + " ts.static_features_,\n", + " series.groupby('unique_id').tail(1)[['unique_id', 'static_0']].reset_index(drop=True),\n", ")" ] }, @@ -1496,7 +1503,7 @@ "expected_date_features = ['dayofweek', 'month', 'year']\n", "\n", "test_eq(ts.features, expected_lags + expected_transforms + expected_date_features)\n", - "test_eq(ts.static_features.columns.tolist() + ts.features, df.columns.drop(['unique_id', 'ds', 'y']).tolist())\n", + "test_eq(ts.static_features_.columns.tolist() + ts.features, df.columns.drop(['ds', 'y']).tolist())\n", "# we dropped 2 rows because of the lag 2 and 13 more to have the window of size 14\n", "test_eq(df.shape[0], series.shape[0] - (2 + 13) * ts.ga.ngroups)\n", "test_eq(ts.ga.data.size, ts.ga.ngroups * keep_last_n)" @@ -1800,8 +1807,13 @@ "expected['ds'] += pd.offsets.Day()\n", "pd.testing.assert_frame_equal(preds, expected)\n", "pd.testing.assert_frame_equal(\n", - " ts.static_features,\n", - " pd.concat([last_vals_two_series, new_serie.tail(1)]).set_index('unique_id')[['static_0', 'static_1']].astype(ts.static_features.dtypes)\n", + " ts.static_features_,\n", + " (\n", + " pd.concat([last_vals_two_series, new_serie.tail(1)])\n", + " [['unique_id', 'static_0', 'static_1']]\n", + " .astype(ts.static_features_.dtypes)\n", + " .reset_index(drop=True)\n", + " )\n", ")" ] } diff --git a/nbs/forecast.ipynb b/nbs/forecast.ipynb index 2b768461..fdcc47b8 100644 --- a/nbs/forecast.ipynb +++ b/nbs/forecast.ipynb @@ -540,7 +540,7 @@ " id_col=self.ts.id_col,\n", " time_col=self.ts.time_col,\n", " target_col=self.ts.target_col, \n", - " static_features=self.ts.static_features.columns,\n", + " static_features=self.ts.static_features,\n", " keep_last_n=self.ts.keep_last_n,\n", " )\n", " new_ts.max_horizon = self.ts.max_horizon\n", @@ -3768,7 +3768,7 @@ { "data": { "text/plain": [ - "MLForecast(models=[LGBMRegressor, XGBRegressor], freq=, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month', ], num_threads=2)" + "MLForecast(models=[LGBMRegressor], freq=, lag_features=['lag7', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size14'], date_features=['dayofweek', 'month', ], num_threads=2)" ] }, "execution_count": null, @@ -3780,12 +3780,8 @@ "def even_day(dates):\n", " return dates.day % 2 == 0\n", "\n", - "models = [\n", - " lgb.LGBMRegressor(n_jobs=1, random_state=0),\n", - " xgb.XGBRegressor(n_jobs=1, random_state=0),\n", - "]\n", "fcst = MLForecast(\n", - " models,\n", + " models=lgb.LGBMRegressor(n_jobs=1, random_state=0),\n", " freq='D',\n", " lags=[7],\n", " lag_transforms={\n", @@ -3795,7 +3791,7 @@ " date_features=['dayofweek', 'month', even_day],\n", " num_threads=2,\n", ")\n", - "fcst.fit(series_with_prices, static_features=['static_0', 'product_id'])" + "fcst.fit(series_with_prices, static_features=['unique_id', 'static_0', 'product_id'])" ] }, { @@ -3815,7 +3811,8 @@ { "data": { "text/plain": [ - "['static_0',\n", + "['unique_id',\n", + " 'static_0',\n", " 'product_id',\n", " 'price',\n", " 'lag7',\n", @@ -3873,7 +3870,6 @@ " unique_id\n", " ds\n", " LGBMRegressor\n", - " XGBRegressor\n", " \n", " \n", " \n", @@ -3881,99 +3877,88 @@ " 0\n", " id_00\n", " 2001-05-15\n", - " 42.184358\n", - " 43.174004\n", + " 42.406978\n", " \n", " \n", " 1\n", " id_00\n", " 2001-05-16\n", - " 50.186606\n", - " 50.842575\n", + " 50.076236\n", " \n", " \n", " 2\n", " id_00\n", " 2001-05-17\n", - " 1.940786\n", - " 1.911936\n", + " 1.904567\n", " \n", " \n", " 3\n", " id_00\n", " 2001-05-18\n", - " 10.432289\n", - " 9.788165\n", + " 10.259930\n", " \n", " \n", " 4\n", " id_00\n", " 2001-05-19\n", - " 18.701071\n", - " 18.377850\n", + " 18.727878\n", " \n", " \n", " ...\n", " ...\n", " ...\n", " ...\n", - " ...\n", " \n", " \n", " 695\n", " id_99\n", " 2001-05-17\n", - " 44.311743\n", - " 43.611797\n", + " 44.266018\n", " \n", " \n", " 696\n", " id_99\n", " 2001-05-18\n", - " 1.909511\n", - " 1.922798\n", + " 1.936728\n", " \n", " \n", " 697\n", " id_99\n", " 2001-05-19\n", - " 9.067718\n", - " 8.772107\n", + " 9.091219\n", " \n", " \n", " 698\n", " id_99\n", " 2001-05-20\n", - " 14.967183\n", - " 15.344975\n", + " 15.262409\n", " \n", " \n", " 699\n", " id_99\n", " 2001-05-21\n", - " 22.917440\n", - " 22.898575\n", + " 22.840666\n", " \n", " \n", "\n", - "

700 rows × 4 columns

\n", + "

700 rows × 3 columns

\n", "" ], "text/plain": [ - " unique_id ds LGBMRegressor XGBRegressor\n", - "0 id_00 2001-05-15 42.184358 43.174004\n", - "1 id_00 2001-05-16 50.186606 50.842575\n", - "2 id_00 2001-05-17 1.940786 1.911936\n", - "3 id_00 2001-05-18 10.432289 9.788165\n", - "4 id_00 2001-05-19 18.701071 18.377850\n", - ".. ... ... ... ...\n", - "695 id_99 2001-05-17 44.311743 43.611797\n", - "696 id_99 2001-05-18 1.909511 1.922798\n", - "697 id_99 2001-05-19 9.067718 8.772107\n", - "698 id_99 2001-05-20 14.967183 15.344975\n", - "699 id_99 2001-05-21 22.917440 22.898575\n", + " unique_id ds LGBMRegressor\n", + "0 id_00 2001-05-15 42.406978\n", + "1 id_00 2001-05-16 50.076236\n", + "2 id_00 2001-05-17 1.904567\n", + "3 id_00 2001-05-18 10.259930\n", + "4 id_00 2001-05-19 18.727878\n", + ".. ... ... ...\n", + "695 id_99 2001-05-17 44.266018\n", + "696 id_99 2001-05-18 1.936728\n", + "697 id_99 2001-05-19 9.091219\n", + "698 id_99 2001-05-20 15.262409\n", + "699 id_99 2001-05-21 22.840666\n", "\n", - "[700 rows x 4 columns]" + "[700 rows x 3 columns]" ] }, "execution_count": null, @@ -4031,6 +4016,10 @@ "non_std_series = series.copy()\n", "non_std_series['ds'] = non_std_series.groupby('unique_id').cumcount()\n", "non_std_series = non_std_series.rename(columns={'unique_id': 'some_id', 'ds': 'time', 'y': 'value'})\n", + "models = [\n", + " lgb.LGBMRegressor(n_jobs=1, random_state=0),\n", + " xgb.XGBRegressor(n_jobs=1, random_state=0),\n", + "]\n", "flow_params = dict(\n", " models=models,\n", " lags=[7],\n", @@ -4392,14 +4381,14 @@ " id_col='some_id',\n", " time_col='time',\n", " target_col='value',\n", - " static_features=['static_0', 'static_1'],\n", + " static_features=['some_id', 'static_0', 'static_1'],\n", " )\n", " renamer = {'some_id': 'unique_id', 'time': 'ds', 'value': 'y'}\n", " backtest_results = backtest_results.rename(columns=renamer)\n", " renamed = data.rename(columns=renamer)\n", " manual_results = []\n", " for cutoff, train, valid in backtest_splits(renamed, n_windows, window_size, 'unique_id', 'ds', 1):\n", - " fcst.fit(train, static_features=['static_0', 'static_1'])\n", + " fcst.fit(train, static_features=['unique_id', 'static_0', 'static_1'])\n", " if add_exogenous:\n", " dynamic_dfs = [valid.drop(columns=['y', 'static_0', 'static_1']).reset_index()]\n", " else:\n", diff --git a/nbs/lgb_cv.ipynb b/nbs/lgb_cv.ipynb index a01d9a56..3c76f33e 100644 --- a/nbs/lgb_cv.ipynb +++ b/nbs/lgb_cv.ipynb @@ -96,7 +96,7 @@ "\n", "def _predict(ts, bst, valid, h, before_predict_callback, after_predict_callback):\n", " ex_cols_to_drop = [ts.id_col, ts.time_col, ts.target_col]\n", - " static_features = ts.static_features.columns.tolist()\n", + " static_features = ts.static_features_.columns.drop(ts.id_col).tolist()\n", " ex_cols_to_drop.extend(static_features)\n", " has_ex = not valid.columns.drop(ex_cols_to_drop).empty\n", " dynamic_dfs = [valid.drop(columns=static_features + [ts.target_col])] if has_ex else None\n",