Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] fix future regressor #1585

Merged
merged 18 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neuralprophet/components/future_regressors/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn

from neuralprophet import utils
from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter

Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/components/future_regressors/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import Counter, OrderedDict
from collections import Counter

import torch
import torch.nn as nn

from neuralprophet.components.future_regressors import FutureRegressors
from neuralprophet.utils_torch import init_parameter, interprete_model
from neuralprophet.utils_torch import interprete_model

# from neuralprophet.utils_torch import init_parameter

Expand Down
16 changes: 6 additions & 10 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import types
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Callable, Iterable, List, Optional

Check failure on line 8 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / flake8

'typing.Iterable' imported but unused
from typing import OrderedDict as OrderedDictType
from typing import Type, Union

Expand All @@ -13,7 +13,7 @@
import pandas as pd
import torch

from neuralprophet import df_utils, np_types, utils, utils_torch

Check failure on line 16 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / flake8

'neuralprophet.utils' imported but unused
from neuralprophet.custom_loss_metrics import PinballLoss
from neuralprophet.hdays_utils import get_holidays_from_country

Expand Down Expand Up @@ -305,18 +305,16 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

# If trend_local_reg < 0
if self.trend_local_reg < 0:

Check failure on line 308 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

# If trend_local_reg = True
if self.trend_local_reg == True:
if self.trend_local_reg is True:
log.error("trend_local_reg = True. Default trend_local_reg value set to 1")
self.trend_local_reg = 1

# If Trend modelling is global.
if self.trend_global_local == "global" and self.trend_local_reg != False:
# If Trend modelling is global but local regularization is set.
if self.trend_global_local == "global" and self.trend_local_reg:
log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local))
self.trend_local_reg = False

Expand Down Expand Up @@ -356,13 +354,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 357 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 358 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 363 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -373,7 +371,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 374 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -384,7 +382,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 385 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -392,18 +390,16 @@
}
)

# If seasonality_local_reg < 0
if self.seasonality_local_reg < 0:

Check failure on line 393 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg))
self.seasonality_local_reg = False

# If seasonality_local_reg = True
if self.seasonality_local_reg == True:
if self.seasonality_local_reg is True:
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
self.seasonality_local_reg = 1

# If Season modelling is global.
if self.global_local == "global" and self.seasonality_local_reg != False:
# If Season modelling is global but local regularization is set.
if self.global_local == "global" and self.seasonality_local_reg:
log.error(
"Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local)
)
Expand All @@ -414,7 +410,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 413 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is incompatible with type "SeasonGlobalLocalMode"     "str" is incompatible with type "Literal['global']"     "str" is incompatible with type "Literal['local']"     "str" is incompatible with type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -490,7 +486,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 489 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is incompatible with "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand Down
9 changes: 5 additions & 4 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,12 +1102,12 @@
# Only display the plot if the session is interactive, eg. do not show in github actions since it
# causes an error in the Windows and MacOS environment
if matplotlib.is_interactive():
fig

Check warning on line 1105 in neuralprophet/forecaster.py

View workflow job for this annotation

GitHub Actions / pyright

Expression value is unused (reportUnusedExpression)

self.fitted = True
return metrics_df

def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False):
def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True):
"""Runs the model to make predictions.

Expects all data needed to be present in dataframe.
Expand Down Expand Up @@ -1176,7 +1176,7 @@
quantiles=self.config_train.quantiles,
components=components,
)
if periods_added[df_name] > 0:
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-1]
else:
fcst = _reshape_raw_predictions_to_forecst_df(
Expand All @@ -1191,9 +1191,10 @@
quantiles=self.config_train.quantiles,
config_lagged_regressors=self.config_lagged_regressors,
)
if periods_added[df_name] > 0:
fcst = fcst[: -periods_added[df_name]]
if auto_extend and periods_added[df_name] > 0:
fcst = fcst[:-periods_added[df_name]]
forecast = pd.concat((forecast, fcst), ignore_index=True)

df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series)
self.predict_steps = self.n_forecasts
return df
Expand Down
1 change: 0 additions & 1 deletion neuralprophet/hdays_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@

return holiday_obj


def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None):

Check failure on line 44 in neuralprophet/hdays_utils.py

View workflow job for this annotation

GitHub Actions / flake8

expected 2 blank lines, found 1
"""
Return all possible holiday names of given countries

Expand Down
4 changes: 2 additions & 2 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
trend_glocal_loss = torch.zeros(1, dtype=torch.float, requires_grad=False)
# Glocal Trend
if self.config_trend is not None:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg != False:
if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg:
trend_glocal_loss = reg_func_trend_glocal(
self.trend.trend_k0, self.trend.trend_deltas, self.config_trend.trend_local_reg
)
Expand All @@ -949,7 +949,7 @@ def _add_batch_regularizations(self, loss, epoch, progress):
if self.config_seasonality is not None:
if (
self.config_seasonality.global_local in ["local", "glocal"]
and self.config_seasonality.seasonality_local_reg != False
and self.config_seasonality.seasonality_local_reg
):
seasonality_glocal_loss = reg_func_seasonality_glocal(
self.seasonality.season_params, self.config_seasonality.seasonality_local_reg
Expand Down
2 changes: 1 addition & 1 deletion neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from collections import OrderedDict
from typing import IO, TYPE_CHECKING, BinaryIO, Iterable, Optional, Union

Check failure on line 8 in neuralprophet/utils.py

View workflow job for this annotation

GitHub Actions / flake8

'typing.Iterable' imported but unused

import numpy as np
import pandas as pd
Expand All @@ -13,7 +13,7 @@
import torch

from neuralprophet import utils_torch
from neuralprophet.hdays_utils import get_country_holidays

Check failure on line 16 in neuralprophet/utils.py

View workflow job for this annotation

GitHub Actions / flake8

'neuralprophet.hdays_utils.get_country_holidays' imported but unused
from neuralprophet.logger import ProgressBar

if TYPE_CHECKING:
Expand Down Expand Up @@ -375,7 +375,7 @@
seasonal_dims[name] = resolution
return seasonal_dims


Check warning on line 378 in neuralprophet/utils.py

View workflow job for this annotation

GitHub Actions / flake8

blank line contains whitespace
def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays):
"""
Convert user specified events configurations along with country specific
Expand Down
76 changes: 39 additions & 37 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions tests/test_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from neuralprophet import NeuralProphet, set_random_seed

log = logging.getLogger("NP.test")
log.setLevel("DEBUG")
log.parent.setLevel("WARNING")

try:
from plotly_resampler import register_plotly_resampler, unregister_plotly_resampler
from plotly_resampler import unregister_plotly_resampler

plotly_resampler_installed = True
except ImportError:
plotly_resampler_installed = False
log.error("Importing plotly failed. Interactive plots will not work.")

from neuralprophet import NeuralProphet, set_random_seed

DIR = pathlib.Path(__file__).parent.parent.absolute()
DATA_DIR = os.path.join(DIR, "tests", "test-data")
Expand Down
Loading