diff --git a/lightwood/mixer/gluonts.py b/lightwood/mixer/gluonts.py index 1061241f6..ac7d2bc5f 100644 --- a/lightwood/mixer/gluonts.py +++ b/lightwood/mixer/gluonts.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import mxnet as mx +from sklearn.preprocessing import OrdinalEncoder from gluonts.dataset.pandas import PandasDataset @@ -36,6 +37,8 @@ def __init__( early_stop_patience: int = 3, distribution_output: str = '', seed: int = 0, + static_features_cat: Optional[list[str]] = None, + static_features_real: Optional[list[str]] = None, ): """ Wrapper around GluonTS probabilistic deep learning models. For now, only DeepAR is supported. @@ -71,6 +74,9 @@ def __init__( self.patience = early_stop_patience self.seed = seed self.trains_once = True + self.static_features_cat_encoders = {} + self.static_features_cat = static_features_cat if static_features_cat else [] + self.static_features_real = static_features_real if static_features_real else [] dist_module = importlib.import_module('.'.join(['gluonts.mx.distribution', *distribution_output.split(".")[:-1]])) @@ -88,6 +94,9 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: # prepare data cat_ds = ConcatedEncodedDs([train_data, dev_data]) + for col in self.static_features_cat: + self.static_features_cat_encoders[col] = OrdinalEncoder().fit(cat_ds.data_frame[col].values.reshape(-1, 1)) + fit_groups = list(cat_ds.data_frame[self.grouped_by[0]].unique()) if self.grouped_by != ['__default'] else None train_ds = self._make_initial_ds(cat_ds.data_frame, phase='train', groups=fit_groups) batch_size = 32 @@ -99,6 +108,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: distr_output=self.distribution, lags_seq=[i + 1 for i in range(self.window)], batch_size=batch_size, + use_feat_static_cat=True if self.static_features_cat else False, + use_feat_static_real=True if self.static_features_real else False, trainer=Trainer( epochs=self.n_epochs, num_batches_per_epoch=max(1, len(cat_ds.data_frame) // batch_size), @@ -171,7 +182,13 @@ def _make_initial_ds(self, df=None, phase='predict', groups=None): oby_col_name = '__gluon_timestamp' gby = self.ts_analysis["tss"].group_by if self.ts_analysis["tss"].group_by else [] freq = self.ts_analysis['sample_freqs']['__default'] - keep_cols = [self.target] + [col for col in gby] + keep_cols = [self.target] + [col for col in gby] + self.static_features_cat + self.static_features_real + + agg_map = {self.target: 'sum'} + for col in self.static_features_cat: + agg_map[col] = 'first' + for col in self.static_features_real: + agg_map[col] = 'mean' if groups is None and gby: groups = self.groups @@ -207,15 +224,26 @@ def _make_initial_ds(self, df=None, phase='predict', groups=None): return None if gby: - df = df.groupby(by=gby[0]).resample(freq).sum().reset_index(level=[0]) # @TODO: multiple group support and remove groups without enough data + df = df.groupby(by=gby[0]).resample(freq).agg(agg_map).reset_index(level=[0]) else: - df = df.resample(freq).sum() + df = df.resample(freq).agg(agg_map) gby = '__default_group' df[gby] = '__default_group' df[oby_col_name] = df.index - ds = PandasDataset.from_long_dataframe(df, target=self.target, item_id=gby, freq=freq, timestamp=oby_col_name) + for col in self.static_features_cat: + df[col] = self.static_features_cat_encoders[col].transform(df[col].values.reshape(-1, 1)) + + ds = PandasDataset.from_long_dataframe( + df, + target=self.target, + item_id=gby, + freq=freq, + timestamp=oby_col_name, + feat_static_real=self.static_features_real if self.static_features_real else None, + feat_static_cat=self.static_features_cat if self.static_features_cat else None, + ) return ds