diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 8595287b3..f64315e26 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -305,17 +305,15 @@ def __post_init__(self): log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local)) self.trend_global_local = "global" - # If trend_local_reg < 0 if self.trend_local_reg < 0: log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg)) self.trend_local_reg = False - # If trend_local_reg = True if self.trend_local_reg is True: log.error("trend_local_reg = True. Default trend_local_reg value set to 1") self.trend_local_reg = 1 - # If Trend modelling is global. + # If Trend modelling is global but local regularization is set. if self.trend_global_local == "global" and self.trend_local_reg: log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local)) self.trend_local_reg = False @@ -392,17 +390,15 @@ def __post_init__(self): } ) - # If seasonality_local_reg < 0 if self.seasonality_local_reg < 0: log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg)) self.seasonality_local_reg = False - # If seasonality_local_reg = True if self.seasonality_local_reg is True: log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") self.seasonality_local_reg = 1 - # If Season modelling is global. + # If Season modelling is global but local regularization is set. if self.global_local == "global" and self.seasonality_local_reg: log.error( "Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 5e94131ef..16872d202 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1108,7 +1108,7 @@ def fit( self.fitted = True return metrics_df - def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): + def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True): """Runs the model to make predictions. Expects all data needed to be present in dataframe. @@ -1177,7 +1177,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): quantiles=self.config_train.quantiles, components=components, ) - if periods_added[df_name] > 0: + if auto_extend and periods_added[df_name] > 0: fcst = fcst[:-1] else: fcst = _reshape_raw_predictions_to_forecst_df( @@ -1192,9 +1192,10 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): quantiles=self.config_train.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) - if periods_added[df_name] > 0: - fcst = fcst[: -periods_added[df_name]] + if auto_extend and periods_added[df_name] > 0: + fcst = fcst[:-periods_added[df_name]] forecast = pd.concat((forecast, fcst), ignore_index=True) + df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series) self.predict_steps = self.n_forecasts return df diff --git a/neuralprophet/hdays_utils.py b/neuralprophet/hdays_utils.py index 46dc61570..6303696fb 100644 --- a/neuralprophet/hdays_utils.py +++ b/neuralprophet/hdays_utils.py @@ -41,7 +41,6 @@ def get_country_holidays( return holiday_obj - def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None): """ Return all possible holiday names of given countries diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 33f7c51e6..e7f86ead0 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -374,7 +374,7 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality): seasonal_dims[name] = resolution return seasonal_dims - + def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays): """ Convert user specified events configurations along with country specific