Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Componentstacker dataclass and abstraction of stacking and unstacking of components #1646

Merged
merged 30 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b079fa5
remove stackers from TimeNet init
ourownstory Sep 11, 2024
299593c
refactor stackers to dict
ourownstory Sep 11, 2024
6f9a0e0
simplify and document set_compunents_stacker
ourownstory Sep 11, 2024
7580c71
refactor set_components_stacker arg to stacker
ourownstory Sep 11, 2024
24ac07e
add docstring to forward and introduce mode flag instead of passing c…
ourownstory Sep 11, 2024
37e3ab7
refactor include components
ourownstory Sep 12, 2024
27615cc
fix covar_weights
ourownstory Sep 12, 2024
bcf5726
simplify unstack components
ourownstory Sep 12, 2024
264d553
use dict to index unstack component functions
ourownstory Sep 12, 2024
1404fab
remove unused import
ourownstory Sep 12, 2024
c83fead
fix component_stacker
ourownstory Sep 12, 2024
f106f82
convert to dataclass
ourownstory Sep 12, 2024
63a64dd
simply stack function names
ourownstory Sep 12, 2024
86e37a1
fix references
ourownstory Sep 12, 2024
f59c6b9
fix trend/time
ourownstory Sep 12, 2024
7ce4fcb
fix pre-existing typo
ourownstory Sep 12, 2024
eb32047
improve seasonality stacker
ourownstory Sep 12, 2024
5f71d43
revert seasonalities
ourownstory Sep 12, 2024
d790087
rename stack/unstack
ourownstory Sep 12, 2024
348b3d7
update timenet
ourownstory Sep 12, 2024
83137f2
use stack function
ourownstory Sep 12, 2024
218a43d
kwargs
ourownstory Sep 12, 2024
ca7c52e
use stacker abstraction
ourownstory Sep 12, 2024
5357f8e
simplify names
ourownstory Sep 12, 2024
f209d97
names
ourownstory Sep 12, 2024
5a06d39
fix seasons
ourownstory Sep 12, 2024
74186a6
explicit update of feature_list and no double returns
ourownstory Sep 12, 2024
d3219d9
conform seasonalities
ourownstory Sep 12, 2024
d37b534
move stack_all to stacker
ourownstory Sep 12, 2024
a2f5c4a
ruff
ourownstory Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from neuralprophet.plot_model_parameters_plotly import plot_parameters as plot_parameters_plotly
from neuralprophet.plot_utils import get_valid_configuration, log_warning_deprecation_plotly, select_plotting_backend
from neuralprophet.uncertainty import Conformal
from neuralprophet.utils_time_dataset import ComponentStacker

log = logging.getLogger("NP.forecaster")

Expand Down Expand Up @@ -1210,7 +1209,6 @@ def fit(
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=False, components_stacker=train_components_stacker)
# Determine the max_number of epochs
Expand Down Expand Up @@ -1253,7 +1251,6 @@ def fit(
n_forecasts=self.config_model.n_forecasts,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset_val = self._create_dataset(df_val, predict_mode=False, components_stacker=val_components_stacker)
loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False)
Expand All @@ -1275,9 +1272,9 @@ def fit(
if not self.fitted:
self.model = self._init_model()

self.model.set_components_stacker(components_stacker=train_components_stacker, mode="train")
self.model.set_components_stacker(stacker=train_components_stacker, mode="train")
if validation_enabled:
self.model.set_components_stacker(components_stacker=val_components_stacker, mode="val")
self.model.set_components_stacker(stacker=val_components_stacker, mode="val")

# Find suitable learning rate if not set
if self.config_train.learning_rate is None:
Expand Down Expand Up @@ -1491,7 +1488,6 @@ def test(self, df: pd.DataFrame, verbose: bool = True):
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=False, components_stacker=components_stacker)
self.model.set_components_stacker(components_stacker, mode="test")
Expand Down Expand Up @@ -2128,7 +2124,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
prev_n_lags = self.config_ar.n_lags
prev_max_lags = self.config_model.max_lags
prev_n_forecasts = self.config_model.n_forecasts
prev_predict_components_stacker = self.model.predict_components_stacker
prev_predict_components_stacker = self.model.components_stacker["predict"]

self.config_model.max_lags = 0
self.config_ar.n_lags = 0
Expand All @@ -2138,7 +2134,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
df = _check_dataframe(self, df, check_y=False, exogenous=False)
df = _normalize(df=df, config_normalization=self.config_normalization)
for df_name, df_i in df.groupby("ID"):
feature_unstackor = ComponentStacker(
feature_unstackor = utils_time_dataset.ComponentStacker(
n_lags=0,
max_lags=0,
n_forecasts=1,
Expand Down Expand Up @@ -2169,12 +2165,12 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
meta_name_tensor = None
elif self.model.config_seasonality.global_local in ["local", "glocal"]:
meta = OrderedDict()
time_input = feature_unstackor.unstack_component("time", inputs_tensor)
time_input = feature_unstackor.unstack("time", inputs_tensor)
meta["df_name"] = [df_name for _ in range(time_input.shape[0])]
meta_name_tensor = torch.tensor([self.model.id_dict[i] for i in meta["df_name"]]) # type: ignore
else:
meta_name_tensor = None
seasonalities_input = feature_unstackor.unstack_component("seasonalities", inputs_tensor)
seasonalities_input = feature_unstackor.unstack("seasonalities", inputs_tensor)
for name in self.config_seasonality.periods:
features = seasonalities_input[name]
quantile_index = self.config_model.quantiles.index(quantile)
Expand All @@ -2198,7 +2194,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
self.config_ar.n_lags = prev_n_lags
self.config_model.max_lags = prev_max_lags
self.config_model.n_forecasts = prev_n_forecasts
self.model.predict_components_stacker = prev_predict_components_stacker
self.model.components_stacker["predict"] = prev_predict_components_stacker

return df

Expand Down Expand Up @@ -2989,7 +2985,6 @@ def _predict_raw(self, df, df_name, include_components=False):
max_lags=self.config_model.max_lags,
config_seasonality=self.config_seasonality,
lagged_regressor_config=self.config_lagged_regressors,
feature_indices={},
)
dataset = self._create_dataset(df, predict_mode=True, components_stacker=components_stacker)
self.model.set_components_stacker(components_stacker, mode="predict")
Expand Down Expand Up @@ -3066,7 +3061,7 @@ def _predict_raw(self, df, df_name, include_components=False):
elif multiplicative:
# output absolute value of respective additive component
components[name] = value * trend * scale_y # type: ignore

self.model.reset_compute_components()
else:
components = None

Expand Down
52 changes: 17 additions & 35 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,54 +104,36 @@ def __init__(
self.df["ds"] = self.df["ds"].apply(lambda x: x.timestamp()) # Convert to Unix timestamp in seconds
self.df_tensors["ds"] = torch.tensor(self.df["ds"].values, dtype=torch.int64)

self.seasonalities = None
if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
self.calculate_seasonalities()

# Construct index map
self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors)

# Stack all features into one large tensor
self.components_stacker = components_stacker

self.stack_all_features()
self.all_features = self.stack_all_features()

def stack_all_features(self):
"""
Stack all features into one large tensor by calling individual stacking methods.
"""
feature_list = []

current_idx = 0

# Call individual stacking functions
current_idx = self.components_stacker.stack_trend_component(self.df_tensors, feature_list, current_idx)
current_idx = self.components_stacker.stack_targets_component(self.df_tensors, feature_list, current_idx)

current_idx = self.components_stacker.stack_lags_component(
self.df_tensors, feature_list, current_idx, self.config_ar.n_lags
)
current_idx = self.components_stacker.stack_lagged_regerssors_component(
self.df_tensors, feature_list, current_idx, self.config_lagged_regressors
)
current_idx = self.components_stacker.stack_additive_events_component(
self.df_tensors, feature_list, current_idx, self.additive_event_and_holiday_names
)
current_idx = self.components_stacker.stack_multiplicative_events_component(
self.df_tensors, feature_list, current_idx, self.multiplicative_event_and_holiday_names
)
current_idx = self.components_stacker.stack_additive_regressors_component(
self.df_tensors, feature_list, current_idx, self.additive_regressors_names
)
current_idx = self.components_stacker.stack_multiplicative_regressors_component(
self.df_tensors, feature_list, current_idx, self.multiplicative_regressors_names
)

if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"):
current_idx = self.components_stacker.stack_seasonalities_component(
feature_list, current_idx, self.config_seasonality, self.seasonalities
)
# Add seasonalities to df_tensors, this needs to be done after create_sample2index_map, before stacking.
self.df_tensors["seasonalities"] = self.seasonalities
component_args: dict = {
"time": {},
"targets": {},
"lags": {"n_lags": self.config_ar.n_lags},
"lagged_regressors": {"config": self.config_lagged_regressors},
"additive_events": {"names": self.additive_event_and_holiday_names},
"multiplicative_events": {"names": self.multiplicative_event_and_holiday_names},
"additive_regressors": {"names": self.additive_regressors_names},
"multiplicative_regressors": {"names": self.multiplicative_regressors_names},
"seasonalities": {"config": self.config_seasonality},
}

# Concatenate all features into one big tensor
self.all_features = torch.cat(feature_list, dim=1) # Concatenating along the second dimension
return self.components_stacker.stack_all_features(self.df_tensors, component_args)

def calculate_seasonalities(self):
"""Computes Fourier series components with the specified frequency and order."""
Expand Down
Loading
Loading