From f1a382029b1fab279874df449d9c37d85f271171 Mon Sep 17 00:00:00 2001 From: Maisa Ben Salah <76703998+MaiBe-ctrl@users.noreply.github.com> Date: Thu, 20 Jun 2024 22:47:42 -0700 Subject: [PATCH] [Minor] fix future regressor (#1585) * add io buffer support * added tests * remove typealias * add subdivions to country * removed perods_added logic * added auto_extend * fix linting issues * reversed auto_extended * removed subdisions * fix isTrue * update lock * remove subdivision * uncomment test_glocal lines 180 onward --------- Co-authored-by: Maisa Ben Salah Co-authored-by: Oskar Triebe --- .../components/future_regressors/linear.py | 1 - .../future_regressors/neural_nets.py | 2 +- .../future_regressors/shared_neural_nets.py | 5 +- .../shared_neural_nets_coef.py | 5 +- neuralprophet/configure.py | 16 ++-- neuralprophet/forecaster.py | 9 ++- neuralprophet/hdays_utils.py | 1 - neuralprophet/time_net.py | 4 +- neuralprophet/utils.py | 2 +- poetry.lock | 76 ++++++++++--------- tests/test_model_performance.py | 5 +- 11 files changed, 61 insertions(+), 65 deletions(-) diff --git a/neuralprophet/components/future_regressors/linear.py b/neuralprophet/components/future_regressors/linear.py index dbf9dd0ff..bb7051372 100644 --- a/neuralprophet/components/future_regressors/linear.py +++ b/neuralprophet/components/future_regressors/linear.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -from neuralprophet import utils from neuralprophet.components.future_regressors import FutureRegressors from neuralprophet.utils_torch import init_parameter diff --git a/neuralprophet/components/future_regressors/neural_nets.py b/neuralprophet/components/future_regressors/neural_nets.py index 8ed580aa3..00d83d506 100644 --- a/neuralprophet/components/future_regressors/neural_nets.py +++ b/neuralprophet/components/future_regressors/neural_nets.py @@ -3,7 +3,7 @@ import torch.nn as nn from neuralprophet.components.future_regressors import FutureRegressors -from neuralprophet.utils_torch import init_parameter, interprete_model +from neuralprophet.utils_torch import interprete_model # from neuralprophet.utils_torch import init_parameter diff --git a/neuralprophet/components/future_regressors/shared_neural_nets.py b/neuralprophet/components/future_regressors/shared_neural_nets.py index eae2f96d7..85ed9ceb4 100644 --- a/neuralprophet/components/future_regressors/shared_neural_nets.py +++ b/neuralprophet/components/future_regressors/shared_neural_nets.py @@ -1,10 +1,9 @@ -from collections import Counter, OrderedDict +from collections import Counter -import torch import torch.nn as nn from neuralprophet.components.future_regressors import FutureRegressors -from neuralprophet.utils_torch import init_parameter, interprete_model +from neuralprophet.utils_torch import interprete_model # from neuralprophet.utils_torch import init_parameter diff --git a/neuralprophet/components/future_regressors/shared_neural_nets_coef.py b/neuralprophet/components/future_regressors/shared_neural_nets_coef.py index c5cbe9107..db53c2bc8 100644 --- a/neuralprophet/components/future_regressors/shared_neural_nets_coef.py +++ b/neuralprophet/components/future_regressors/shared_neural_nets_coef.py @@ -1,10 +1,9 @@ -from collections import Counter, OrderedDict +from collections import Counter -import torch import torch.nn as nn from neuralprophet.components.future_regressors import FutureRegressors -from neuralprophet.utils_torch import init_parameter, interprete_model +from neuralprophet.utils_torch import interprete_model # from neuralprophet.utils_torch import init_parameter diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index 3c2b8c3e4..784e00194 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -305,18 +305,16 @@ def __post_init__(self): log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local)) self.trend_global_local = "global" - # If trend_local_reg < 0 if self.trend_local_reg < 0: log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg)) self.trend_local_reg = False - # If trend_local_reg = True - if self.trend_local_reg == True: + if self.trend_local_reg is True: log.error("trend_local_reg = True. Default trend_local_reg value set to 1") self.trend_local_reg = 1 - # If Trend modelling is global. - if self.trend_global_local == "global" and self.trend_local_reg != False: + # If Trend modelling is global but local regularization is set. + if self.trend_global_local == "global" and self.trend_local_reg: log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local)) self.trend_local_reg = False @@ -392,18 +390,16 @@ def __post_init__(self): } ) - # If seasonality_local_reg < 0 if self.seasonality_local_reg < 0: log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg)) self.seasonality_local_reg = False - # If seasonality_local_reg = True - if self.seasonality_local_reg == True: + if self.seasonality_local_reg is True: log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") self.seasonality_local_reg = 1 - # If Season modelling is global. - if self.global_local == "global" and self.seasonality_local_reg != False: + # If Season modelling is global but local regularization is set. + if self.global_local == "global" and self.seasonality_local_reg: log.error( "Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local) ) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 76371cc06..43f730948 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -1107,7 +1107,7 @@ def fit( self.fitted = True return metrics_df - def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): + def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True): """Runs the model to make predictions. Expects all data needed to be present in dataframe. @@ -1176,7 +1176,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): quantiles=self.config_train.quantiles, components=components, ) - if periods_added[df_name] > 0: + if auto_extend and periods_added[df_name] > 0: fcst = fcst[:-1] else: fcst = _reshape_raw_predictions_to_forecst_df( @@ -1191,9 +1191,10 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False): quantiles=self.config_train.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) - if periods_added[df_name] > 0: - fcst = fcst[: -periods_added[df_name]] + if auto_extend and periods_added[df_name] > 0: + fcst = fcst[:-periods_added[df_name]] forecast = pd.concat((forecast, fcst), ignore_index=True) + df = df_utils.return_df_in_original_format(forecast, received_ID_col, received_single_time_series) self.predict_steps = self.n_forecasts return df diff --git a/neuralprophet/hdays_utils.py b/neuralprophet/hdays_utils.py index 46dc61570..6303696fb 100644 --- a/neuralprophet/hdays_utils.py +++ b/neuralprophet/hdays_utils.py @@ -41,7 +41,6 @@ def get_country_holidays( return holiday_obj - def get_holidays_from_country(country: Union[str, Iterable[str], dict], df=None): """ Return all possible holiday names of given countries diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 5a58db9b9..f2b7388eb 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -940,7 +940,7 @@ def _add_batch_regularizations(self, loss, epoch, progress): trend_glocal_loss = torch.zeros(1, dtype=torch.float, requires_grad=False) # Glocal Trend if self.config_trend is not None: - if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg != False: + if self.config_trend.trend_global_local == "local" and self.config_trend.trend_local_reg: trend_glocal_loss = reg_func_trend_glocal( self.trend.trend_k0, self.trend.trend_deltas, self.config_trend.trend_local_reg ) @@ -949,7 +949,7 @@ def _add_batch_regularizations(self, loss, epoch, progress): if self.config_seasonality is not None: if ( self.config_seasonality.global_local in ["local", "glocal"] - and self.config_seasonality.seasonality_local_reg != False + and self.config_seasonality.seasonality_local_reg ): seasonality_glocal_loss = reg_func_seasonality_glocal( self.seasonality.season_params, self.config_seasonality.seasonality_local_reg diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index afa16ecb8..fb3f016e5 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -375,7 +375,7 @@ def config_seasonality_to_model_dims(config_seasonality: ConfigSeasonality): seasonal_dims[name] = resolution return seasonal_dims - + def config_events_to_model_dims(config_events: Optional[ConfigEvents], config_country_holidays): """ Convert user specified events configurations along with country specific diff --git a/poetry.lock b/poetry.lock index 1b9117f1d..681594e43 100644 --- a/poetry.lock +++ b/poetry.lock @@ -897,18 +897,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.15.1" +version = "3.15.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, - {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, + {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, + {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -1256,22 +1256,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.1.0" +version = "7.2.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, + {file = "importlib_metadata-7.2.0-py3-none-any.whl", hash = "sha256:04e4aad329b8b948a5711d394fa8759cb80f009225441b4f2a02bd4d8e5f426c"}, + {file = "importlib_metadata-7.2.0.tar.gz", hash = "sha256:3ff4519071ed42740522d494d04819b666541b9752c43012f85afb2cc220fcc6"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" @@ -2417,6 +2417,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -2833,27 +2834,28 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" +version = "6.0.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, - {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, - {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, - {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, - {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, - {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, + {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, + {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, + {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, + {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, + {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [package.extras] @@ -3407,18 +3409,18 @@ files = [ [[package]] name = "setuptools" -version = "70.0.0" +version = "70.1.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"}, + {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -3662,15 +3664,15 @@ widechars = ["wcwidth"] [[package]] name = "tbb" -version = "2021.12.0" +version = "2021.13.0" description = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" optional = false python-versions = "*" files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, + {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, + {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, + {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, ] [[package]] diff --git a/tests/test_model_performance.py b/tests/test_model_performance.py index 8cf639d85..ac0af79e0 100644 --- a/tests/test_model_performance.py +++ b/tests/test_model_performance.py @@ -11,19 +11,20 @@ import plotly.graph_objects as go from plotly.subplots import make_subplots +from neuralprophet import NeuralProphet, set_random_seed + log = logging.getLogger("NP.test") log.setLevel("DEBUG") log.parent.setLevel("WARNING") try: - from plotly_resampler import register_plotly_resampler, unregister_plotly_resampler + from plotly_resampler import unregister_plotly_resampler plotly_resampler_installed = True except ImportError: plotly_resampler_installed = False log.error("Importing plotly failed. Interactive plots will not work.") -from neuralprophet import NeuralProphet, set_random_seed DIR = pathlib.Path(__file__).parent.parent.absolute() DATA_DIR = os.path.join(DIR, "tests", "test-data")