Skip to content

Commit

Permalink
[Minor] Componentstacker dataclass and abstraction of stacking and un…
Browse files Browse the repository at this point in the history
…stacking of components (#1646)

* remove stackers from TimeNet init

* refactor stackers to dict

* simplify and document set_compunents_stacker

* refactor set_components_stacker arg to stacker

* add docstring to forward and introduce mode flag instead of passing components_stacker

* refactor include components

* fix covar_weights

* simplify unstack components

* use dict to index unstack component functions

* remove unused import

* fix component_stacker

* convert to dataclass

* simply stack function names

* fix references

* fix trend/time

* fix pre-existing typo

* improve seasonality stacker

* revert seasonalities

* rename stack/unstack

* update timenet

* use stack function

* kwargs

* use stacker abstraction

* simplify names

* names

* fix seasons

* explicit update of feature_list and no double returns

* conform seasonalities

* move stack_all to stacker

* ruff
  • Loading branch information
ourownstory authored Sep 12, 2024
1 parent 773b67a commit fe309be
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 206 deletions.
21 changes: 8 additions & 13 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from neuralprophet.plot_model_parameters_plotly import plot_parameters as plot_parameters_plotly
from neuralprophet.plot_utils import get_valid_configuration, log_warning_deprecation_plotly, select_plotting_backend
from neuralprophet.uncertainty import Conformal
from neuralprophet.utils_time_dataset import ComponentStacker

log = logging.getLogger("NP.forecaster")

Expand Down Expand Up @@ -1210,7 +1209,6 @@ def fit(
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=False, components_stacker=train_components_stacker)
# Determine the max_number of epochs
Expand Down Expand Up @@ -1253,7 +1251,6 @@ def fit(
n_forecasts=self.config_model.n_forecasts,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset_val = self._create_dataset(df_val, predict_mode=False, components_stacker=val_components_stacker)
loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False)
Expand All @@ -1275,9 +1272,9 @@ def fit(
if not self.fitted:
self.model = self._init_model()

self.model.set_components_stacker(components_stacker=train_components_stacker, mode="train")
self.model.set_components_stacker(stacker=train_components_stacker, mode="train")
if validation_enabled:
self.model.set_components_stacker(components_stacker=val_components_stacker, mode="val")
self.model.set_components_stacker(stacker=val_components_stacker, mode="val")

# Find suitable learning rate if not set
if self.config_train.learning_rate is None:
Expand Down Expand Up @@ -1491,7 +1488,6 @@ def test(self, df: pd.DataFrame, verbose: bool = True):
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=False, components_stacker=components_stacker)
self.model.set_components_stacker(components_stacker, mode="test")
Expand Down Expand Up @@ -2128,7 +2124,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
prev_n_lags = self.config_ar.n_lags
prev_max_lags = self.config_model.max_lags
prev_n_forecasts = self.config_model.n_forecasts
prev_predict_components_stacker = self.model.predict_components_stacker
prev_predict_components_stacker = self.model.components_stacker["predict"]

self.config_model.max_lags = 0
self.config_ar.n_lags = 0
Expand All @@ -2138,7 +2134,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
df = _check_dataframe(self, df, check_y=False, exogenous=False)
df = _normalize(df=df, config_normalization=self.config_normalization)
for df_name, df_i in df.groupby("ID"):
feature_unstackor = ComponentStacker(
feature_unstackor = utils_time_dataset.ComponentStacker(
n_lags=0,
max_lags=0,
n_forecasts=1,
Expand Down Expand Up @@ -2169,12 +2165,12 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
meta_name_tensor = None
elif self.model.config_seasonality.global_local in ["local", "glocal"]:
meta = OrderedDict()
time_input = feature_unstackor.unstack_component("time", inputs_tensor)
time_input = feature_unstackor.unstack("time", inputs_tensor)
meta["df_name"] = [df_name for _ in range(time_input.shape[0])]
meta_name_tensor = torch.tensor([self.model.id_dict[i] for i in meta["df_name"]]) # type: ignore
else:
meta_name_tensor = None
seasonalities_input = feature_unstackor.unstack_component("seasonalities", inputs_tensor)
seasonalities_input = feature_unstackor.unstack("seasonalities", inputs_tensor)
for name in self.config_seasonality.periods:
features = seasonalities_input[name]
quantile_index = self.config_model.quantiles.index(quantile)
Expand All @@ -2198,7 +2194,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
self.config_ar.n_lags = prev_n_lags
self.config_model.max_lags = prev_max_lags
self.config_model.n_forecasts = prev_n_forecasts
self.model.predict_components_stacker = prev_predict_components_stacker
self.model.components_stacker["predict"] = prev_predict_components_stacker

return df

Expand Down Expand Up @@ -2989,7 +2985,6 @@ def _predict_raw(self, df, df_name, include_components=False):
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=True, components_stacker=components_stacker)
self.model.set_components_stacker(components_stacker, mode="predict")
Expand Down Expand Up @@ -3066,7 +3061,7 @@ def _predict_raw(self, df, df_name, include_components=False):
elif multiplicative:
# output absolute value of respective additive component
components[name] = value * trend * scale_y # type: ignore

self.model.reset_compute_components()
else:
components = None

Expand Down
52 changes: 17 additions & 35 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,54 +104,36 @@ def __init__(
self.df["ds"] = self.df["ds"].apply(lambda x: x.timestamp()) # Convert to Unix timestamp in seconds
self.df_tensors["ds"] = torch.tensor(self.df["ds"].values, dtype=torch.int64)

self.seasonalities = None
if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
self.calculate_seasonalities()

# Construct index map
self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors)

# Stack all features into one large tensor
self.components_stacker = components_stacker

self.stack_all_features()
self.all_features = self.stack_all_features()

def stack_all_features(self):
"""
Stack all features into one large tensor by calling individual stacking methods.
"""
feature_list = []

current_idx = 0

# Call individual stacking functions
current_idx = self.components_stacker.stack_trend_component(self.df_tensors, feature_list, current_idx)
current_idx = self.components_stacker.stack_targets_component(self.df_tensors, feature_list, current_idx)

current_idx = self.components_stacker.stack_lags_component(
self.df_tensors, feature_list, current_idx, self.config_ar.n_lags
)
current_idx = self.components_stacker.stack_lagged_regerssors_component(
self.df_tensors, feature_list, current_idx, self.config_lagged_regressors
)
current_idx = self.components_stacker.stack_additive_events_component(
self.df_tensors, feature_list, current_idx, self.additive_event_and_holiday_names
)
current_idx = self.components_stacker.stack_multiplicative_events_component(
self.df_tensors, feature_list, current_idx, self.multiplicative_event_and_holiday_names
)
current_idx = self.components_stacker.stack_additive_regressors_component(
self.df_tensors, feature_list, current_idx, self.additive_regressors_names
)
current_idx = self.components_stacker.stack_multiplicative_regressors_component(
self.df_tensors, feature_list, current_idx, self.multiplicative_regressors_names
)

if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
current_idx = self.components_stacker.stack_seasonalities_component(
feature_list, current_idx, self.config_seasonality, self.seasonalities
)
# Add seasonalities to df_tensors, this needs to be done after create_sample2index_map, before stacking.
self.df_tensors["seasonalities"] = self.seasonalities
component_args: dict = {
"time": {},
"targets": {},
"lags": {"n_lags": self.config_ar.n_lags},
"lagged_regressors": {"config": self.config_lagged_regressors},
"additive_events": {"names": self.additive_event_and_holiday_names},
"multiplicative_events": {"names": self.multiplicative_event_and_holiday_names},
"additive_regressors": {"names": self.additive_regressors_names},
"multiplicative_regressors": {"names": self.multiplicative_regressors_names},
"seasonalities": {"config": self.config_seasonality},
}

# Concatenate all features into one big tensor
self.all_features = torch.cat(feature_list, dim=1) # Concatenating along the second dimension
return self.components_stacker.stack_all_features(self.df_tensors, component_args)

def calculate_seasonalities(self):
"""Computes Fourier series components with the specified frequency and order."""
Expand Down
Loading

0 comments on commit fe309be

Please sign in to comment.