From fe309bec613e62285422a1737b75bb8d07600be5 Mon Sep 17 00:00:00 2001 From: Oskar Triebe Date: Thu, 12 Sep 2024 15:19:44 -0700 Subject: [PATCH] [Minor] Componentstacker dataclass and abstraction of stacking and unstacking of components (#1646) * remove stackers from TimeNet init * refactor stackers to dict * simplify and document set_compunents_stacker * refactor set_components_stacker arg to stacker * add docstring to forward and introduce mode flag instead of passing components_stacker * refactor include components * fix covar_weights * simplify unstack components * use dict to index unstack component functions * remove unused import * fix component_stacker * convert to dataclass * simply stack function names * fix references * fix trend/time * fix pre-existing typo * improve seasonality stacker * revert seasonalities * rename stack/unstack * update timenet * use stack function * kwargs * use stacker abstraction * simplify names * names * fix seasons * explicit update of feature_list and no double returns * conform seasonalities * move stack_all to stacker * ruff --- neuralprophet/forecaster.py | 21 +-- neuralprophet/time_dataset.py | 52 ++---- neuralprophet/time_net.py | 126 ++++++++------- neuralprophet/utils_time_dataset.py | 239 ++++++++++++++++------------ 4 files changed, 232 insertions(+), 206 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 22aa9c58f..6f14e2684 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -44,7 +44,6 @@ from neuralprophet.plot_model_parameters_plotly import plot_parameters as plot_parameters_plotly from neuralprophet.plot_utils import get_valid_configuration, log_warning_deprecation_plotly, select_plotting_backend from neuralprophet.uncertainty import Conformal -from neuralprophet.utils_time_dataset import ComponentStacker log = logging.getLogger("NP.forecaster") @@ -1210,7 +1209,6 @@ def fit( max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, - feature_indices={}, ) dataset = self._create_dataset(df, predict_mode=False, components_stacker=train_components_stacker) # Determine the max_number of epochs @@ -1253,7 +1251,6 @@ def fit( n_forecasts=self.config_model.n_forecasts, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, - feature_indices={}, ) dataset_val = self._create_dataset(df_val, predict_mode=False, components_stacker=val_components_stacker) loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False) @@ -1275,9 +1272,9 @@ def fit( if not self.fitted: self.model = self._init_model() - self.model.set_components_stacker(components_stacker=train_components_stacker, mode="train") + self.model.set_components_stacker(stacker=train_components_stacker, mode="train") if validation_enabled: - self.model.set_components_stacker(components_stacker=val_components_stacker, mode="val") + self.model.set_components_stacker(stacker=val_components_stacker, mode="val") # Find suitable learning rate if not set if self.config_train.learning_rate is None: @@ -1491,7 +1488,6 @@ def test(self, df: pd.DataFrame, verbose: bool = True): max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, - feature_indices={}, ) dataset = self._create_dataset(df, predict_mode=False, components_stacker=components_stacker) self.model.set_components_stacker(components_stacker, mode="test") @@ -2128,7 +2124,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): prev_n_lags = self.config_ar.n_lags prev_max_lags = self.config_model.max_lags prev_n_forecasts = self.config_model.n_forecasts - prev_predict_components_stacker = self.model.predict_components_stacker + prev_predict_components_stacker = self.model.components_stacker["predict"] self.config_model.max_lags = 0 self.config_ar.n_lags = 0 @@ -2138,7 +2134,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): df = _check_dataframe(self, df, check_y=False, exogenous=False) df = _normalize(df=df, config_normalization=self.config_normalization) for df_name, df_i in df.groupby("ID"): - feature_unstackor = ComponentStacker( + feature_unstackor = utils_time_dataset.ComponentStacker( n_lags=0, max_lags=0, n_forecasts=1, @@ -2169,12 +2165,12 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): meta_name_tensor = None elif self.model.config_seasonality.global_local in ["local", "glocal"]: meta = OrderedDict() - time_input = feature_unstackor.unstack_component("time", inputs_tensor) + time_input = feature_unstackor.unstack("time", inputs_tensor) meta["df_name"] = [df_name for _ in range(time_input.shape[0])] meta_name_tensor = torch.tensor([self.model.id_dict[i] for i in meta["df_name"]]) # type: ignore else: meta_name_tensor = None - seasonalities_input = feature_unstackor.unstack_component("seasonalities", inputs_tensor) + seasonalities_input = feature_unstackor.unstack("seasonalities", inputs_tensor) for name in self.config_seasonality.periods: features = seasonalities_input[name] quantile_index = self.config_model.quantiles.index(quantile) @@ -2198,7 +2194,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): self.config_ar.n_lags = prev_n_lags self.config_model.max_lags = prev_max_lags self.config_model.n_forecasts = prev_n_forecasts - self.model.predict_components_stacker = prev_predict_components_stacker + self.model.components_stacker["predict"] = prev_predict_components_stacker return df @@ -2989,7 +2985,6 @@ def _predict_raw(self, df, df_name, include_components=False): max_lags=self.config_model.max_lags, config_seasonality=self.config_seasonality, lagged_regressor_config=self.config_lagged_regressors, - feature_indices={}, ) dataset = self._create_dataset(df, predict_mode=True, components_stacker=components_stacker) self.model.set_components_stacker(components_stacker, mode="predict") @@ -3066,7 +3061,7 @@ def _predict_raw(self, df, df_name, include_components=False): elif multiplicative: # output absolute value of respective additive component components[name] = value * trend * scale_y # type: ignore - + self.model.reset_compute_components() else: components = None diff --git a/neuralprophet/time_dataset.py b/neuralprophet/time_dataset.py index 84be0dc05..1044d63eb 100644 --- a/neuralprophet/time_dataset.py +++ b/neuralprophet/time_dataset.py @@ -104,54 +104,36 @@ def __init__( self.df["ds"] = self.df["ds"].apply(lambda x: x.timestamp()) # Convert to Unix timestamp in seconds self.df_tensors["ds"] = torch.tensor(self.df["ds"].values, dtype=torch.int64) + self.seasonalities = None if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"): self.calculate_seasonalities() # Construct index map self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors) + # Stack all features into one large tensor self.components_stacker = components_stacker - - self.stack_all_features() + self.all_features = self.stack_all_features() def stack_all_features(self): """ Stack all features into one large tensor by calling individual stacking methods. """ - feature_list = [] - - current_idx = 0 - - # Call individual stacking functions - current_idx = self.components_stacker.stack_trend_component(self.df_tensors, feature_list, current_idx) - current_idx = self.components_stacker.stack_targets_component(self.df_tensors, feature_list, current_idx) - - current_idx = self.components_stacker.stack_lags_component( - self.df_tensors, feature_list, current_idx, self.config_ar.n_lags - ) - current_idx = self.components_stacker.stack_lagged_regerssors_component( - self.df_tensors, feature_list, current_idx, self.config_lagged_regressors - ) - current_idx = self.components_stacker.stack_additive_events_component( - self.df_tensors, feature_list, current_idx, self.additive_event_and_holiday_names - ) - current_idx = self.components_stacker.stack_multiplicative_events_component( - self.df_tensors, feature_list, current_idx, self.multiplicative_event_and_holiday_names - ) - current_idx = self.components_stacker.stack_additive_regressors_component( - self.df_tensors, feature_list, current_idx, self.additive_regressors_names - ) - current_idx = self.components_stacker.stack_multiplicative_regressors_component( - self.df_tensors, feature_list, current_idx, self.multiplicative_regressors_names - ) - - if self.config_seasonality is not None and hasattr(self.config_seasonality, "periods"): - current_idx = self.components_stacker.stack_seasonalities_component( - feature_list, current_idx, self.config_seasonality, self.seasonalities - ) + # Add seasonalities to df_tensors, this needs to be done after create_sample2index_map, before stacking. + self.df_tensors["seasonalities"] = self.seasonalities + component_args: dict = { + "time": {}, + "targets": {}, + "lags": {"n_lags": self.config_ar.n_lags}, + "lagged_regressors": {"config": self.config_lagged_regressors}, + "additive_events": {"names": self.additive_event_and_holiday_names}, + "multiplicative_events": {"names": self.multiplicative_event_and_holiday_names}, + "additive_regressors": {"names": self.additive_regressors_names}, + "multiplicative_regressors": {"names": self.multiplicative_regressors_names}, + "seasonalities": {"config": self.config_seasonality}, + } - # Concatenate all features into one big tensor - self.all_features = torch.cat(feature_list, dim=1) # Concatenating along the second dimension + return self.components_stacker.stack_all_features(self.df_tensors, component_args) def calculate_seasonalities(self): """Computes Fourier series components with the specified frequency and order.""" diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 096da0e67..e6fd8d7e0 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -21,7 +21,6 @@ reg_func_trend, reg_func_trend_glocal, ) -from neuralprophet.utils_time_dataset import ComponentStacker from neuralprophet.utils_torch import init_parameter, interprete_model log = logging.getLogger("NP.time_net") @@ -55,17 +54,12 @@ def __init__( n_forecasts: int = 1, n_lags: int = 0, ar_layers: Optional[List[int]] = [], - compute_components_flag: bool = False, metrics: Optional[np_types.CollectMetricsMode] = {}, id_list: List[str] = ["__df__"], num_trends_modelled: int = 1, num_seasonalities_modelled: int = 1, num_seasonalities_modelled_dict: dict = None, meta_used_in_model: bool = False, - train_components_stacker: Optional[ComponentStacker] = None, - val_components_stacker: Optional[ComponentStacker] = None, - test_components_stacker: Optional[ComponentStacker] = None, - predict_components_stacker: Optional[ComponentStacker] = None, ): """ Parameters @@ -98,8 +92,6 @@ def __init__( ---- The default value is ``[]``, which initializes no hidden layers. - compute_components_flag : bool - Flag whether to compute the components of the model or not. metrics : dict Dictionary of torchmetrics to be used during training and for evaluation. id_list : list @@ -143,15 +135,18 @@ def __init__( # General self.config_model = config_model self.config_model.n_forecasts = n_forecasts - self.train_components_stacker = train_components_stacker - self.val_components_stacker = val_components_stacker - self.test_components_stacker = test_components_stacker - self.predict_components_stacker = predict_components_stacker + # Components stackers to unpack the input tensor + self.components_stacker = { + "train": None, + "val": None, + "test": None, + "predict": None, + } # Lightning Config self.config_train = config_train self.config_normalization = config_normalization - self.compute_components_flag = compute_components_flag + self.include_components = False # flag to indicate if we are in include_components mode, set in prodiction mode by set_compute_components self.config_model = config_model # Manual optimization: we are responsible for calling .backward(), .step(), .zero_grad(). @@ -317,15 +312,19 @@ def ar_weights(self) -> torch.Tensor: if isinstance(layer, nn.Linear): return layer.weight - def set_components_stacker(self, components_stacker, mode): - if mode == "train": - self.train_components_stacker = components_stacker - if mode == "val": - self.val_components_stacker = components_stacker - if mode == "test": - self.test_components_stacker = components_stacker - if mode == "predict": - self.predict_components_stacker = components_stacker + def set_components_stacker(self, stacker, mode): + """Set the components stacker for the given mode. + Parameters + ---------- + components_stacker : ComponentStacker + The components stacker to be set. + mode : str + The mode for which the components stacker is to be set + options: ["train", "val", "test", "predict"] + """ + modes = ["train", "val", "test", "predict"] + assert mode in modes, f"mode must be one of {modes}" + self.components_stacker[mode] = stacker def get_covar_weights(self, covar_input=None) -> torch.Tensor: """ @@ -520,14 +519,25 @@ def forward_covar_net(self, covariates): def forward( self, input_tensor: torch.Tensor, - components_stacker=ComponentStacker, + mode: str, meta: Dict = None, - compute_components_flag: bool = False, - predict_mode: bool = False, ) -> torch.Tensor: - """This method defines the model forward pass.""" + """This method defines the model forward pass. + Parameters + ---------- + input_tensor : torch.Tensor + Input tensor of dims (batch, n_lags + n_forecasts, n_features) + mode : str operation mode ["train", "val", "test", "predict"] + meta : dict Static features of the time series + Returns + ------- + torch.Tensor Forecast tensor of dims (batch, n_forecasts, n_quantiles) + dict of components of the model if self.include_components is True, + each of dims (batch, n_forecasts, n_quantiles) + + """ - time_input = components_stacker.unstack_component(component_name="time", batch_tensor=input_tensor) + time_input = self.components_stacker[mode].unstack(component_name="time", batch_tensor=input_tensor) # Handle meta argument if meta is None and self.meta_used_in_model: name_id_dummy = self.id_list[0] @@ -557,7 +567,7 @@ def forward( # Unpack and process seasonalities seasonalities_input = None if self.config_seasonality and self.config_seasonality.periods: - seasonalities_input = components_stacker.unstack_component( + seasonalities_input = self.components_stacker[mode].unstack( component_name="seasonalities", batch_tensor=input_tensor ) s = self.seasonality(s=seasonalities_input, meta=meta) @@ -571,15 +581,15 @@ def forward( additive_events_input = None multiplicative_events_input = None if self.events_dims is not None: - if "additive_events" in components_stacker.feature_indices: - additive_events_input = components_stacker.unstack_component( + if "additive_events" in self.components_stacker[mode].feature_indices: + additive_events_input = self.components_stacker[mode].unstack( component_name="additive_events", batch_tensor=input_tensor ) additive_events = self.scalar_features_effects(additive_events_input, self.event_params["additive"]) additive_components_nonstationary += additive_events components["additive_events"] = additive_events - if "multiplicative_events" in components_stacker.feature_indices: - multiplicative_events_input = components_stacker.unstack_component( + if "multiplicative_events" in self.components_stacker[mode].feature_indices: + multiplicative_events_input = self.components_stacker[mode].unstack( component_name="multiplicative_events", batch_tensor=input_tensor ) multiplicative_events = self.scalar_features_effects( @@ -591,15 +601,15 @@ def forward( # Unpack and process regressors additive_regressors_input = None multiplicative_regressors_input = None - if "additive_regressors" in components_stacker.feature_indices: - additive_regressors_input = components_stacker.unstack_component( + if "additive_regressors" in self.components_stacker[mode].feature_indices: + additive_regressors_input = self.components_stacker[mode].unstack( component_name="additive_regressors", batch_tensor=input_tensor ) additive_regressors = self.future_regressors(additive_regressors_input, "additive") additive_components_nonstationary += additive_regressors components["additive_regressors"] = additive_regressors - if "multiplicative_regressors" in components_stacker.feature_indices: - multiplicative_regressors_input = components_stacker.unstack_component( + if "multiplicative_regressors" in self.components_stacker[mode].feature_indices: + multiplicative_regressors_input = self.components_stacker[mode].unstack( component_name="multiplicative_regressors", batch_tensor=input_tensor ) multiplicative_regressors = self.future_regressors(multiplicative_regressors_input, "multiplicative") @@ -608,8 +618,8 @@ def forward( # Unpack and process lags lags_input = None - if "lags" in components_stacker.feature_indices: - lags_input = components_stacker.unstack_component(component_name="lags", batch_tensor=input_tensor) + if "lags" in self.components_stacker[mode].feature_indices: + lags_input = self.components_stacker[mode].unstack(component_name="lags", batch_tensor=input_tensor) nonstationary_components = ( trend[:, : self.n_lags, 0] + additive_components_nonstationary[:, : self.n_lags, 0] @@ -623,7 +633,7 @@ def forward( # Unpack and process covariates covariates_input = None if self.config_lagged_regressors and self.config_lagged_regressors.regressors is not None: - covariates_input = components_stacker.unstack_component( + covariates_input = self.components_stacker[mode].unstack( component_name="lagged_regressors", batch_tensor=input_tensor ) covariates = self.forward_covar_net(covariates=covariates_input) @@ -640,10 +650,12 @@ def forward( prediction = predictions_nonstationary + additive_components # Correct crossing quantiles - prediction_with_quantiles = self._compute_quantile_forecasts_from_diffs(prediction, predict_mode) + prediction_with_quantiles = self._compute_quantile_forecasts_from_diffs( + diffs=prediction, predict_mode=mode != "train" + ) # Compute components if required - if compute_components_flag: + if self.include_components: components = self.compute_components( time_input, seasonalities_input, @@ -750,8 +762,12 @@ def compute_components( ) return components - def set_compute_components(self, compute_components_flag): - self.compute_components_flag = compute_components_flag + def set_compute_components(self, include_components): + self.prev_include_components = self.include_components + self.include_components = include_components + + def reset_compute_components(self): + self.include_components = self.prev_include_components def loss_func(self, time, predicted, targets): loss = None @@ -773,15 +789,15 @@ def training_step(self, batch, batch_idx): epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch) self.train_progress = epoch_float / float(self.config_train.epochs) - targets = self.train_components_stacker.unstack_component("targets", batch_tensor=inputs_tensor) - time = self.train_components_stacker.unstack_component("time", batch_tensor=inputs_tensor) + targets = self.components_stacker["train"].unstack("targets", batch_tensor=inputs_tensor) + time = self.components_stacker["train"].unstack("time", batch_tensor=inputs_tensor) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) else: meta_name_tensor = None # Run forward calculation - predicted, _ = self.forward(inputs_tensor, self.train_components_stacker, meta_name_tensor) + predicted, _ = self.forward(inputs_tensor, mode="train", meta=meta_name_tensor) # Store predictions in self for later network visualization self.train_epoch_prediction = predicted # Calculate loss @@ -818,15 +834,15 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): inputs_tensor, meta = batch - targets = self.val_components_stacker.unstack_component("targets", batch_tensor=inputs_tensor) - time = self.val_components_stacker.unstack_component("time", batch_tensor=inputs_tensor) + targets = self.components_stacker["val"].unstack("targets", batch_tensor=inputs_tensor) + time = self.components_stacker["val"].unstack("time", batch_tensor=inputs_tensor) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) else: meta_name_tensor = None # Run forward calculation - predicted, _ = self.forward(inputs_tensor, self.val_components_stacker, meta_name_tensor) + predicted, _ = self.forward(inputs_tensor, mode="val", meta=meta_name_tensor) # Calculate loss loss, reg_loss = self.loss_func(time, predicted, targets) # Metrics @@ -840,15 +856,15 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): inputs_tensor, meta = batch - targets = self.test_components_stacker.unstack_component("targets", batch_tensor=inputs_tensor) - time = self.test_components_stacker.unstack_component("time", batch_tensor=inputs_tensor) + targets = self.components_stacker["test"].unstack("targets", batch_tensor=inputs_tensor) + time = self.components_stacker["test"].unstack("time", batch_tensor=inputs_tensor) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) else: meta_name_tensor = None # Run forward calculation - predicted, _ = self.forward(inputs_tensor, self.test_components_stacker, meta_name_tensor) + predicted, _ = self.forward(inputs_tensor, mode="test", meta=meta_name_tensor) # Calculate loss loss, reg_loss = self.loss_func(time, predicted, targets) # Metrics @@ -872,10 +888,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): # Run forward calculation prediction, components = self.forward( inputs_tensor, - self.predict_components_stacker, - meta_name_tensor, - self.compute_components_flag, - predict_mode=True, + mode="predict", + meta=meta_name_tensor, ) return prediction, components diff --git a/neuralprophet/utils_time_dataset.py b/neuralprophet/utils_time_dataset.py index de09b4d9b..4aa756f44 100644 --- a/neuralprophet/utils_time_dataset.py +++ b/neuralprophet/utils_time_dataset.py @@ -1,66 +1,116 @@ from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Optional import torch +from neuralprophet.configure_components import LaggedRegressors, Seasonalities + +@dataclass class ComponentStacker: - def __init__( - self, - n_lags, - n_forecasts, - max_lags, - feature_indices={}, - config_seasonality=None, - lagged_regressor_config=None, - ): + """ + ComponentStacker is a utility class that helps in stacking and unstacking the different components of the time series data. + Args: + n_lags (int): Number of lags used in the model. + n_forecasts (int): Number of forecasts to be made. + max_lags (int): Maximum number of lags used in the model. + feature_indices (dict): A dictionary containing the start and end indices of different features in the tensor. + config_seasonality (object, optional): Configuration object that defines the seasonality periods. + lagged_regressor_config (dict, optional): Configuration dictionary that defines the lagged regressors and their properties. + """ + + n_lags: int + n_forecasts: int + max_lags: int + feature_indices: dict = field(default_factory=dict) + config_seasonality: Optional[Seasonalities] = None + lagged_regressor_config: Optional[LaggedRegressors] = None + stack_func: dict = field(init=False) + unstack_func: dict = field(init=False) + + def __post_init__(self): + """ + Initializes mappings to comonent stacking and unstacking functions. """ - Initializes the ComponentStacker with the necessary parameters. + self.stack_func = { + "targets": self.stack_targets, + "time": self.stack_time, + "seasonalities": self.stack_seasonalities, + "lagged_regressors": self.stack_lagged_regressors, + "lags": self.stack_lags, + "additive_events": self.stack_additive_events, + "multiplicative_events": self.stack_multiplicative_events, + "additive_regressors": self.stack_additive_regressors, + "multiplicative_regressors": self.stack_multiplicative_regressors, + } + self.unstack_func = { + "targets": self.unstack_targets, + "time": self.unstack_time, + "seasonalities": self.unstack_seasonalities, + "lagged_regressors": self.unstack_lagged_regressors, + "lags": self.unstack_lags, + "additive_events": self.unstack_additive_events, + "multiplicative_events": self.unstack_multiplicative_events, + "additive_regressors": self.unstack_additive_regressors, + "multiplicative_regressors": self.unstack_multiplicative_regressors, + } + + def unstack(self, component_name, batch_tensor): + """ + Routes the unstackion process to the appropriate function based on the component name. Args: - n_lags (int): Number of lags used in the model. - n_forecasts (int): Number of forecasts to be made. - max_lags (int): Maximum number of lags used in the model. - feature_indices (dict): A dictionary containing the start and end indices of different features in the tensor. - config_seasonality (object, optional): Configuration object that defines the seasonality periods. - lagged_regressor_config (dict, optional): Configuration dictionary that defines the lagged regressors and their properties. + component_name (str): The name of the component to unstack. + + Returns: + Various: The output of the specific unstacking function. """ - self.n_lags = n_lags - self.n_forecasts = n_forecasts - self.max_lags = max_lags - self.feature_indices = feature_indices - self.config_seasonality = config_seasonality - self.lagged_regressor_config = lagged_regressor_config - - def unstack_component(self, component_name, batch_tensor): + assert component_name in self.unstack_func, f"Unknown component name: {component_name}" + return self.unstack_func[component_name](batch_tensor) + + def stack(self, component_name, df_tensors, feature_list, current_idx, **kwargs): """ Routes the unstackion process to the appropriate function based on the component name. Args: - component_name (str): The name of the component to unstack. + component_name (str): The name of the component to stack. + df_tensors + feature_list + current_idx + kwargs for specific component, mostly component configuration Returns: - Various: The output of the specific unstackion function. + current_idx: the current index in the stack of features. """ - if component_name == "targets": - return self.unstack_targets(batch_tensor) - elif component_name == "time": - return self.unstack_time(batch_tensor) - elif component_name == "seasonalities": - return self.unstack_seasonalities(batch_tensor) - elif component_name == "lagged_regressors": - return self.unstack_lagged_regressors(batch_tensor) - elif component_name == "lags": - return self.unstack_lags(batch_tensor) - elif component_name == "additive_events": - return self.unstack_additive_events(batch_tensor) - elif component_name == "multiplicative_events": - return self.unstack_multiplicative_events(batch_tensor) - elif component_name == "additive_regressors": - return self.unstack_additive_regressors(batch_tensor) - elif component_name == "multiplicative_regressors": - return self.unstack_multiplicative_regressors(batch_tensor) - else: - raise ValueError(f"Unknown component name: {component_name}") + assert component_name in self.stack_func, f"Unknown component name: {component_name}" + return self.stack_func[component_name]( + df_tensors=df_tensors, feature_list=feature_list, current_idx=current_idx, **kwargs + ) + + def stack_all_features(self, df_tensors, component_args): + """ + Stack all features into one large tensor by calling individual stacking methods. + Concatenation along dimension second dimension (dim=1). + + Args: + df_tensors: Dictionary containing the tensors for different features. + component_args: Dictionary containing the configuration of different components. + """ + feature_list = [] + current_idx = 0 + + for component_name, args in component_args.items(): + feature_list, current_idx = self.stack( + component_name=component_name, + df_tensors=df_tensors, + feature_list=feature_list, + current_idx=current_idx, + **args, + ) + + # Concatenate all features into one big tensor + return torch.cat(feature_list, dim=1) # Concatenating along the second dimension def unstack_targets(self, batch_tensor): targets_start_idx, targets_end_idx = self.feature_indices["targets"] @@ -166,16 +216,16 @@ def unstack_multiplicative_regressors(self, batch_tensor): regressors_start_idx, regressors_end_idx = self.feature_indices["multiplicative_regressors"] return batch_tensor[:, regressors_start_idx : regressors_end_idx + 1].unsqueeze(1) - def stack_trend_component(self, df_tensors, feature_list, current_idx): + def stack_time(self, df_tensors, feature_list, current_idx): """ Stack the trend (time) feature. """ time_tensor = df_tensors["t"].unsqueeze(-1) # Shape: [T, 1] feature_list.append(time_tensor) self.feature_indices["time"] = (current_idx, current_idx) - return current_idx + 1 + return feature_list, current_idx + 1 - def stack_lags_component(self, df_tensors, feature_list, current_idx, n_lags): + def stack_lags(self, df_tensors, feature_list, current_idx, n_lags): """ Stack the lags feature. """ @@ -183,10 +233,10 @@ def stack_lags_component(self, df_tensors, feature_list, current_idx, n_lags): lags_tensor = df_tensors["y_scaled"].unsqueeze(-1) feature_list.append(lags_tensor) self.feature_indices["lags"] = (current_idx, current_idx) - return current_idx + 1 - return current_idx + current_idx = current_idx + 1 + return feature_list, current_idx - def stack_targets_component(self, df_tensors, feature_list, current_idx): + def stack_targets(self, df_tensors, feature_list, current_idx): """ Stack the targets feature. """ @@ -194,41 +244,33 @@ def stack_targets_component(self, df_tensors, feature_list, current_idx): targets_tensor = df_tensors["y_scaled"].unsqueeze(-1) feature_list.append(targets_tensor) self.feature_indices["targets"] = (current_idx, current_idx) - return current_idx + 1 - return current_idx + current_idx = current_idx + 1 + return feature_list, current_idx - def stack_lagged_regerssors_component(self, df_tensors, feature_list, current_idx, config_lagged_regressors): + def stack_lagged_regressors(self, df_tensors, feature_list, current_idx, config): """ Stack the lagged regressor features. """ - if config_lagged_regressors is not None and config_lagged_regressors.regressors is not None: - lagged_regressor_tensors = [ - df_tensors[name].unsqueeze(-1) for name in config_lagged_regressors.regressors.keys() - ] + if config is not None and config.regressors is not None: + lagged_regressor_tensors = [df_tensors[name].unsqueeze(-1) for name in config.regressors.keys()] stacked_lagged_regressor_tensor = torch.cat(lagged_regressor_tensors, dim=-1) feature_list.append(stacked_lagged_regressor_tensor) num_features = stacked_lagged_regressor_tensor.size(-1) - for i, name in enumerate(config_lagged_regressors.regressors.keys()): + for i, name in enumerate(config.regressors.keys()): self.feature_indices[f"lagged_regressor_{name}"] = ( current_idx + i, current_idx + i + 1, ) - return current_idx + num_features - return current_idx - - def stack_additive_events_component( - self, - df_tensors, - feature_list, - current_idx, - additive_event_and_holiday_names, - ): + current_idx = current_idx + num_features + return feature_list, current_idx + + def stack_additive_events(self, df_tensors, feature_list, current_idx, names): """ Stack the additive event and holiday features. """ - if additive_event_and_holiday_names: + if names: additive_events_tensor = torch.cat( - [df_tensors[name].unsqueeze(-1) for name in additive_event_and_holiday_names], + [df_tensors[name].unsqueeze(-1) for name in names], dim=1, ) feature_list.append(additive_events_tensor) @@ -236,67 +278,60 @@ def stack_additive_events_component( current_idx, current_idx + additive_events_tensor.size(1) - 1, ) - return current_idx + additive_events_tensor.size(1) - return current_idx + current_idx = current_idx + additive_events_tensor.size(1) + return feature_list, current_idx - def stack_multiplicative_events_component( - self, df_tensors, feature_list, current_idx, multiplicative_event_and_holiday_names - ): + def stack_multiplicative_events(self, df_tensors, feature_list, current_idx, names): """ Stack the multiplicative event and holiday features. """ - if multiplicative_event_and_holiday_names: - multiplicative_events_tensor = torch.cat( - [df_tensors[name].unsqueeze(-1) for name in multiplicative_event_and_holiday_names], dim=1 - ) + if names: + multiplicative_events_tensor = torch.cat([df_tensors[name].unsqueeze(-1) for name in names], dim=1) feature_list.append(multiplicative_events_tensor) self.feature_indices["multiplicative_events"] = ( current_idx, current_idx + multiplicative_events_tensor.size(1) - 1, ) - return current_idx + multiplicative_events_tensor.size(1) - return current_idx + current_idx = current_idx + multiplicative_events_tensor.size(1) + return feature_list, current_idx - def stack_additive_regressors_component(self, df_tensors, feature_list, current_idx, additive_regressors_names): + def stack_additive_regressors(self, df_tensors, feature_list, current_idx, names): """ Stack the additive regressor features. """ - if additive_regressors_names: - additive_regressors_tensor = torch.cat( - [df_tensors[name].unsqueeze(-1) for name in additive_regressors_names], dim=1 - ) + if names: + additive_regressors_tensor = torch.cat([df_tensors[name].unsqueeze(-1) for name in names], dim=1) feature_list.append(additive_regressors_tensor) self.feature_indices["additive_regressors"] = ( current_idx, current_idx + additive_regressors_tensor.size(1) - 1, ) - return current_idx + additive_regressors_tensor.size(1) - return current_idx + current_idx = current_idx + additive_regressors_tensor.size(1) + return feature_list, current_idx - def stack_multiplicative_regressors_component( - self, df_tensors, feature_list, current_idx, multiplicative_regressors_names - ): + def stack_multiplicative_regressors(self, df_tensors, feature_list, current_idx, names): """ Stack the multiplicative regressor features. """ - if multiplicative_regressors_names: + if names: multiplicative_regressors_tensor = torch.cat( - [df_tensors[name].unsqueeze(-1) for name in multiplicative_regressors_names], dim=1 + [df_tensors[name].unsqueeze(-1) for name in names], dim=1 ) # Shape: [batch_size, num_multiplicative_regressors, 1] feature_list.append(multiplicative_regressors_tensor) self.feature_indices["multiplicative_regressors"] = ( current_idx, - current_idx + len(multiplicative_regressors_names) - 1, + current_idx + len(names) - 1, ) - return current_idx + len(multiplicative_regressors_names) - return current_idx + current_idx = current_idx + len(names) + return feature_list, current_idx - def stack_seasonalities_component(self, feature_list, current_idx, config_seasonality, seasonalities): + def stack_seasonalities(self, df_tensors, feature_list, current_idx, config): """ Stack the seasonality features. """ - if config_seasonality and config_seasonality.periods: - for seasonality_name, features in seasonalities.items(): + # if config is not None and hasattr(config, "periods"): + if config and config.periods: + for seasonality_name, features in df_tensors["seasonalities"].items(): seasonal_tensor = features feature_list.append(seasonal_tensor) self.feature_indices[f"seasonality_{seasonality_name}"] = ( @@ -304,4 +339,4 @@ def stack_seasonalities_component(self, feature_list, current_idx, config_season current_idx + seasonal_tensor.size(1), ) current_idx += seasonal_tensor.size(1) - return current_idx + return feature_list, current_idx