Skip to content

Commit

Permalink
allow id_col in static_features (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jul 5, 2023
1 parent 03644e8 commit 8583b39
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 85 deletions.
42 changes: 27 additions & 15 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,19 @@ def _fit(
self.id_col = id_col
self.target_col = target_col
self.time_col = time_col
to_drop = [id_col, time_col, target_col]
self.static_features = static_features
if static_features is None:
static_features = df.columns.drop([id_col, time_col, target_col])
elif id_col in static_features:
raise ValueError(
"Cannot use the id_col as a static feature. Please create a separate column."
)
self.static_features = (
df.set_index(id_col)[static_features].groupby(id_col, observed=True).head(1)
static_features = df.columns.drop([time_col, target_col]).tolist()
elif id_col not in static_features:
static_features = [id_col] + static_features
else: # static_features defined and contain id_col
to_drop = [time_col, target_col]
self.static_features_ = (
df[static_features]
.groupby(id_col, observed=True)
.head(1)
.reset_index(drop=True)
)
sort_idxs = pd.core.sorting.lexsort_indexer([df[id_col], df[time_col]])
self.restore_idxs = np.empty(df.shape[0], dtype=np.int32)
Expand All @@ -249,9 +254,7 @@ def _fit(
self.last_dates = sorted_df.index.get_level_values(self.time_col)[
self.ga.indptr[1:] - 1
]
self.features_order_ = (
df.columns.drop([id_col, time_col, target_col]).tolist() + self.features
)
self.features_order_ = df.columns.drop(to_drop).tolist() + self.features
return self

def _apply_transforms(self, updates_only: bool = False) -> Dict[str, np.ndarray]:
Expand Down Expand Up @@ -433,7 +436,7 @@ def _update_features(self) -> pd.DataFrame:
features_df = pd.DataFrame(features, columns=self.features)
features_df[self.id_col] = self.uids
features_df[self.time_col] = self.curr_dates
return self.static_features.merge(features_df, on=self.id_col)
return self.static_features_.merge(features_df, on=self.id_col)

def _get_raw_predictions(self) -> np.ndarray:
return np.array(self.y_pred).ravel("F")
Expand Down Expand Up @@ -580,10 +583,19 @@ def update(self, df: pd.DataFrame) -> None:
).astype(self.last_dates.dtype)
self.uids = sizes.index
new_statics = df.iloc[new_sizes.cumsum() - 1].set_index(self.id_col)
orig_dtypes = self.static_features.dtypes
self.static_features = self.static_features.reindex(self.uids)
self.static_features.update(new_statics)
self.static_features = self.static_features.astype(orig_dtypes)
orig_dtypes = self.static_features_.dtypes
if pd.api.types.is_categorical_dtype(orig_dtypes[self.id_col]):
orig_categories = orig_dtypes[self.id_col].categories.tolist()
missing_categories = set(self.uids) - set(orig_categories)
if missing_categories:
orig_dtypes[self.id_col] = pd.CategoricalDtype(
categories=orig_categories + list(missing_categories)
)
self.static_features_ = self.static_features_.set_index(self.id_col).reindex(
self.uids
)
self.static_features_.update(new_statics)
self.static_features_ = self.static_features_.reset_index().astype(orig_dtypes)
self.ga = self.ga.append_several(
new_sizes=sizes.values.astype(np.int32),
new_values=values,
Expand Down
2 changes: 1 addition & 1 deletion mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def predict(
id_col=self.ts.id_col,
time_col=self.ts.time_col,
target_col=self.ts.target_col,
static_features=self.ts.static_features.columns,
static_features=self.ts.static_features,
keep_last_n=self.ts.keep_last_n,
)
new_ts.max_horizon = self.ts.max_horizon
Expand Down
2 changes: 1 addition & 1 deletion mlforecast/lgb_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _update(bst, n):

def _predict(ts, bst, valid, h, before_predict_callback, after_predict_callback):
ex_cols_to_drop = [ts.id_col, ts.time_col, ts.target_col]
static_features = ts.static_features.columns.tolist()
static_features = ts.static_features_.columns.drop(ts.id_col).tolist()
ex_cols_to_drop.extend(static_features)
has_ex = not valid.columns.drop(ex_cols_to_drop).empty
dynamic_dfs = (
Expand Down
52 changes: 32 additions & 20 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -671,18 +671,20 @@
" self.id_col = id_col\n",
" self.target_col = target_col\n",
" self.time_col = time_col\n",
" to_drop = [id_col, time_col, target_col]\n",
" self.static_features = static_features\n",
" if static_features is None:\n",
" static_features = df.columns.drop([id_col, time_col, target_col]) \n",
" elif id_col in static_features:\n",
" raise ValueError(\n",
" 'Cannot use the id_col as a static feature. Please create a separate column.'\n",
" ) \n",
" self.static_features = (\n",
" static_features = df.columns.drop([time_col, target_col]).tolist()\n",
" elif id_col not in static_features:\n",
" static_features = [id_col] + static_features\n",
" else: # static_features defined and contain id_col\n",
" to_drop = [time_col, target_col]\n",
" self.static_features_ = (\n",
" df\n",
" .set_index(id_col)\n",
" [static_features]\n",
" .groupby(id_col, observed=True)\n",
" .head(1)\n",
" .reset_index(drop=True)\n",
" )\n",
" sort_idxs = pd.core.sorting.lexsort_indexer([df[id_col], df[time_col]])\n",
" self.restore_idxs = np.empty(df.shape[0], dtype=np.int32) \n",
Expand All @@ -700,7 +702,7 @@
" self.ga = self.ga.take_from_groups(slice(-keep_last_n, None))\n",
" self._ga = GroupedArray(self.ga.data, self.ga.indptr)\n",
" self.last_dates = sorted_df.index.get_level_values(self.time_col)[self.ga.indptr[1:] - 1]\n",
" self.features_order_ = df.columns.drop([id_col, time_col, target_col]).tolist() + self.features\n",
" self.features_order_ = df.columns.drop(to_drop).tolist() + self.features\n",
" return self\n",
"\n",
" def _apply_transforms(self, updates_only: bool = False) -> Dict[str, np.ndarray]:\n",
Expand Down Expand Up @@ -880,7 +882,7 @@
" features_df = pd.DataFrame(features, columns=self.features)\n",
" features_df[self.id_col] = self.uids\n",
" features_df[self.time_col] = self.curr_dates \n",
" return self.static_features.merge(features_df, on=self.id_col)\n",
" return self.static_features_.merge(features_df, on=self.id_col)\n",
" \n",
" def _get_raw_predictions(self) -> np.ndarray:\n",
" return np.array(self.y_pred).ravel('F')\n",
Expand Down Expand Up @@ -1027,10 +1029,15 @@
" ).astype(self.last_dates.dtype)\n",
" self.uids = sizes.index\n",
" new_statics = df.iloc[new_sizes.cumsum() - 1].set_index(self.id_col)\n",
" orig_dtypes = self.static_features.dtypes\n",
" self.static_features = self.static_features.reindex(self.uids)\n",
" self.static_features.update(new_statics)\n",
" self.static_features = self.static_features.astype(orig_dtypes)\n",
" orig_dtypes = self.static_features_.dtypes\n",
" if pd.api.types.is_categorical_dtype(orig_dtypes[self.id_col]):\n",
" orig_categories = orig_dtypes[self.id_col].categories.tolist()\n",
" missing_categories = set(self.uids) - set(orig_categories)\n",
" if missing_categories:\n",
" orig_dtypes[self.id_col] = pd.CategoricalDtype(categories=orig_categories + list(missing_categories))\n",
" self.static_features_ = self.static_features_.set_index(self.id_col).reindex(self.uids)\n",
" self.static_features_.update(new_statics)\n",
" self.static_features_ = self.static_features_.reset_index().astype(orig_dtypes)\n",
" self.ga = self.ga.append_several(\n",
" new_sizes=sizes.values.astype(np.int32),\n",
" new_values=values,\n",
Expand Down Expand Up @@ -1444,8 +1451,8 @@
"outputs": [],
"source": [
"pd.testing.assert_frame_equal(\n",
" ts.static_features,\n",
" series.groupby('unique_id').tail(1).drop(columns=['ds', 'y']).set_index('unique_id'),\n",
" ts.static_features_,\n",
" series.groupby('unique_id').tail(1).drop(columns=['ds', 'y']).reset_index(drop=True),\n",
")"
]
},
Expand All @@ -1465,8 +1472,8 @@
"ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y', static_features=['static_0'])\n",
"\n",
"pd.testing.assert_frame_equal(\n",
" ts.static_features,\n",
" series.groupby('unique_id').tail(1).set_index('unique_id')[['static_0']],\n",
" ts.static_features_,\n",
" series.groupby('unique_id').tail(1)[['unique_id', 'static_0']].reset_index(drop=True),\n",
")"
]
},
Expand Down Expand Up @@ -1496,7 +1503,7 @@
"expected_date_features = ['dayofweek', 'month', 'year']\n",
"\n",
"test_eq(ts.features, expected_lags + expected_transforms + expected_date_features)\n",
"test_eq(ts.static_features.columns.tolist() + ts.features, df.columns.drop(['unique_id', 'ds', 'y']).tolist())\n",
"test_eq(ts.static_features_.columns.tolist() + ts.features, df.columns.drop(['ds', 'y']).tolist())\n",
"# we dropped 2 rows because of the lag 2 and 13 more to have the window of size 14\n",
"test_eq(df.shape[0], series.shape[0] - (2 + 13) * ts.ga.ngroups)\n",
"test_eq(ts.ga.data.size, ts.ga.ngroups * keep_last_n)"
Expand Down Expand Up @@ -1800,8 +1807,13 @@
"expected['ds'] += pd.offsets.Day()\n",
"pd.testing.assert_frame_equal(preds, expected)\n",
"pd.testing.assert_frame_equal(\n",
" ts.static_features,\n",
" pd.concat([last_vals_two_series, new_serie.tail(1)]).set_index('unique_id')[['static_0', 'static_1']].astype(ts.static_features.dtypes)\n",
" ts.static_features_,\n",
" (\n",
" pd.concat([last_vals_two_series, new_serie.tail(1)])\n",
" [['unique_id', 'static_0', 'static_1']]\n",
" .astype(ts.static_features_.dtypes)\n",
" .reset_index(drop=True)\n",
" )\n",
")"
]
}
Expand Down
Loading

0 comments on commit 8583b39

Please sign in to comment.