Skip to content

Commit

Permalink
Merge branch 'main' into upgrade-lightning-2
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory authored Jun 21, 2024
2 parents 891ba53 + f1a3820 commit 0be90e7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 12 deletions.
8 changes: 2 additions & 6 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,15 @@ def __post_init__(self):
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

# If trend_local_reg < 0
if self.trend_local_reg < 0:

Check failure on line 308 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

# If trend_local_reg = True
if self.trend_local_reg is True:
log.error("trend_local_reg = True. Default trend_local_reg value set to 1")
self.trend_local_reg = 1

# If Trend modelling is global.
# If Trend modelling is global but local regularization is set.
if self.trend_global_local == "global" and self.trend_local_reg:
log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local))
self.trend_local_reg = False
Expand Down Expand Up @@ -392,17 +390,15 @@ def __post_init__(self):
}
)

# If seasonality_local_reg < 0
if self.seasonality_local_reg < 0:

Check failure on line 393 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False

# If seasonality_local_reg = True
if self.seasonality_local_reg is True:
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global.
# If Season modelling is global but local regularization is set.
if self.global_local == "global" and self.seasonality_local_reg:
log.error(
"Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local)
Expand Down
9 changes: 5 additions & 4 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def fit(
self.fitted = True
return metrics_df

def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True):
"""Runs the model to make predictions.
Expects all data needed to be present in dataframe.
Expand Down Expand Up @@ -1177,7 +1177,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
quantiles=self.config_train.quantiles,
components=components,
)
if periods_added[df_name] > 0:
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-1]
else:
fcst = _reshape_raw_predictions_to_forecst_df(
Expand All @@ -1192,9 +1192,10 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
quantiles=self.config_train.quantiles,
config_lagged_regressors=self.config_lagged_regressors,
)
if periods_added[df_name] > 0:
fcst = fcst[: -periods_added[df_name]]
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-periods_added[df_name]]
forecast = pd.concat((forecast, fcst), ignore_index=True)

df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series)
self.predict_steps = self.n_forecasts
return df
Expand Down
1 change: 0 additions & 1 deletion neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def get_country_holidays(

return holiday_obj


def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None):

Check failure on line 44 in neuralprophet/hdays_utils.py

View workflow job for this annotation

GitHub Actions / flake8

expected 2 blank lines, found 1
"""
Return all possible holiday names of given countries
Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality):
seasonal_dims[name] = resolution
return seasonal_dims


Check warning on line 377 in neuralprophet/utils.py

View workflow job for this annotation

GitHub Actions / flake8

blank line contains whitespace
def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays):
"""
Convert user specified events configurations along with country specific
Expand Down

0 comments on commit 0be90e7

Please sign in to comment.