Skip to content

Commit

Permalink
[Minor] fix future regressor (#1585)
Browse files Browse the repository at this point in the history
* add io buffer support

* added tests

* remove typealias

* add subdivions to country

* removed  perods_added logic

* added auto_extend

* fix linting issues

* reversed auto_extended

* removed subdisions

* fix isTrue

* update lock

* remove subdivision

* uncomment test_glocal lines 180 onward

---------

Co-authored-by: Maisa Ben Salah <[email protected]>
Co-authored-by: Oskar Triebe <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent 9ef1aef commit f1a3820
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 65 deletions.
1 change: 0 additions & 1 deletion neuralprophet/components/future_regressors/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn

from neuralprophet import utils
from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter

Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/components/future_regressors/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
16 changes: 6 additions & 10 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,16 @@ def __post_init__(self):
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

# If trend_local_reg < 0
if self.trend_local_reg < 0:

Check failure on line 308 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

# If trend_local_reg = True
if self.trend_local_reg == True:
if self.trend_local_reg is True:
log.error("trend_local_reg = True. Default trend_local_reg value set to 1")
self.trend_local_reg = 1

# If Trend modelling is global.
if self.trend_global_local == "global" and self.trend_local_reg != False:
# If Trend modelling is global but local regularization is set.
if self.trend_global_local == "global" and self.trend_local_reg:
log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local))
self.trend_local_reg = False

Expand Down Expand Up @@ -392,18 +390,16 @@ def __post_init__(self):
}
)

# If seasonality_local_reg < 0
if self.seasonality_local_reg < 0:

Check failure on line 393 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False

# If seasonality_local_reg = True
if self.seasonality_local_reg == True:
if self.seasonality_local_reg is True:
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global.
if self.global_local == "global" and self.seasonality_local_reg != False:
# If Season modelling is global but local regularization is set.
if self.global_local == "global" and self.seasonality_local_reg:
log.error(
"Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local)
)
Expand Down
9 changes: 5 additions & 4 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def fit(
self.fitted = True
return metrics_df

def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True):
"""Runs the model to make predictions.
Expects all data needed to be present in dataframe.
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
quantiles=self.config_train.quantiles,
components=components,
)
if periods_added[df_name] > 0:
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-1]
else:
fcst = _reshape_raw_predictions_to_forecst_df(
Expand All @@ -1191,9 +1191,10 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
quantiles=self.config_train.quantiles,
config_lagged_regressors=self.config_lagged_regressors,
)
if periods_added[df_name] > 0:
fcst = fcst[: -periods_added[df_name]]
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-periods_added[df_name]]
forecast = pd.concat((forecast, fcst), ignore_index=True)

df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series)
self.predict_steps = self.n_forecasts
return df
Expand Down
1 change: 0 additions & 1 deletion neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def get_country_holidays(

return holiday_obj


def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None):

Check failure on line 44 in neuralprophet/hdays_utils.py

View workflow job for this annotation

GitHub Actions / flake8

expected 2 blank lines, found 1
"""
Return all possible holiday names of given countries
Expand Down
4 changes: 2 additions & 2 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
trend_glocal_loss = torch.zeros(1, dtype=torch.float, requires_grad=False)
# Glocal Trend
if self.config_trend is not None:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg != False:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg:
trend_glocal_loss = reg_func_trend_glocal(
self.trend.trend_k0, self.trend.trend_deltas, self.config_trend.trend_local_reg
)
Expand All @@ -949,7 +949,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
if self.config_seasonality is not None:
if (
self.config_seasonality.global_local in ["local", "glocal"]
and self.config_seasonality.seasonality_local_reg != False
and self.config_seasonality.seasonality_local_reg
):
seasonality_glocal_loss = reg_func_seasonality_glocal(
self.seasonality.season_params, self.config_seasonality.seasonality_local_reg
Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality):
seasonal_dims[name] = resolution
return seasonal_dims


Check warning on line 378 in neuralprophet/utils.py

View workflow job for this annotation

GitHub Actions / flake8

blank line contains whitespace
def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays):
"""
Convert user specified events configurations along with country specific
Expand Down
76 changes: 39 additions & 37 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions tests/test_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from neuralprophet import NeuralProphet, set_random_seed

log = logging.getLogger("NP.test")
log.setLevel("DEBUG")
log.parent.setLevel("WARNING")

try:
from plotly_resampler import register_plotly_resampler, unregister_plotly_resampler
from plotly_resampler import unregister_plotly_resampler

plotly_resampler_installed = True
except ImportError:
plotly_resampler_installed = False
log.error("Importing plotly failed. Interactive plots will not work.")

from neuralprophet import NeuralProphet, set_random_seed

DIR = pathlib.Path(__file__).parent.parent.absolute()
DATA_DIR = os.path.join(DIR, "tests", "test-data")
Expand Down

0 comments on commit f1a3820

Please sign in to comment.