Skip to content

Commit

Permalink
Merge pull request #1120 from mindsdb/gluon_static_feats
Browse files Browse the repository at this point in the history
Gluon static feats
  • Loading branch information
paxcema authored Mar 13, 2023
2 parents a080f27 + ae18b5d commit 211670d
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions lightwood/mixer/gluonts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
import mxnet as mx
from sklearn.preprocessing import OrdinalEncoder

from gluonts.dataset.pandas import PandasDataset

Expand Down Expand Up @@ -36,6 +37,8 @@ def __init__(
early_stop_patience: int = 3,
distribution_output: str = '',
seed: int = 0,
static_features_cat: Optional[list[str]] = None,
static_features_real: Optional[list[str]] = None,
):
"""
Wrapper around GluonTS probabilistic deep learning models. For now, only DeepAR is supported.
Expand Down Expand Up @@ -71,6 +74,9 @@ def __init__(
self.patience = early_stop_patience
self.seed = seed
self.trains_once = True
self.static_features_cat_encoders = {}
self.static_features_cat = static_features_cat if static_features_cat else []
self.static_features_real = static_features_real if static_features_real else []

dist_module = importlib.import_module('.'.join(['gluonts.mx.distribution',
*distribution_output.split(".")[:-1]]))
Expand All @@ -88,6 +94,9 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:

# prepare data
cat_ds = ConcatedEncodedDs([train_data, dev_data])
for col in self.static_features_cat:
self.static_features_cat_encoders[col] = OrdinalEncoder().fit(cat_ds.data_frame[col].values.reshape(-1, 1))

fit_groups = list(cat_ds.data_frame[self.grouped_by[0]].unique()) if self.grouped_by != ['__default'] else None
train_ds = self._make_initial_ds(cat_ds.data_frame, phase='train', groups=fit_groups)
batch_size = 32
Expand All @@ -99,6 +108,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
distr_output=self.distribution,
lags_seq=[i + 1 for i in range(self.window)],
batch_size=batch_size,
use_feat_static_cat=True if self.static_features_cat else False,
use_feat_static_real=True if self.static_features_real else False,
trainer=Trainer(
epochs=self.n_epochs,
num_batches_per_epoch=max(1, len(cat_ds.data_frame) // batch_size),
Expand Down Expand Up @@ -171,7 +182,13 @@ def _make_initial_ds(self, df=None, phase='predict', groups=None):
oby_col_name = '__gluon_timestamp'
gby = self.ts_analysis["tss"].group_by if self.ts_analysis["tss"].group_by else []
freq = self.ts_analysis['sample_freqs']['__default']
keep_cols = [self.target] + [col for col in gby]
keep_cols = [self.target] + [col for col in gby] + self.static_features_cat + self.static_features_real

agg_map = {self.target: 'sum'}
for col in self.static_features_cat:
agg_map[col] = 'first'
for col in self.static_features_real:
agg_map[col] = 'mean'

if groups is None and gby:
groups = self.groups
Expand Down Expand Up @@ -207,15 +224,26 @@ def _make_initial_ds(self, df=None, phase='predict', groups=None):
return None

if gby:
df = df.groupby(by=gby[0]).resample(freq).sum().reset_index(level=[0])
# @TODO: multiple group support and remove groups without enough data
df = df.groupby(by=gby[0]).resample(freq).agg(agg_map).reset_index(level=[0])
else:
df = df.resample(freq).sum()
df = df.resample(freq).agg(agg_map)
gby = '__default_group'
df[gby] = '__default_group'

df[oby_col_name] = df.index
ds = PandasDataset.from_long_dataframe(df, target=self.target, item_id=gby, freq=freq, timestamp=oby_col_name)
for col in self.static_features_cat:
df[col] = self.static_features_cat_encoders[col].transform(df[col].values.reshape(-1, 1))

ds = PandasDataset.from_long_dataframe(
df,
target=self.target,
item_id=gby,
freq=freq,
timestamp=oby_col_name,
feat_static_real=self.static_features_real if self.static_features_real else None,
feat_static_cat=self.static_features_cat if self.static_features_cat else None,
)
return ds


Expand Down

0 comments on commit 211670d

Please sign in to comment.