diff --git a/docs/source/tutorials/tutorial10.ipynb b/docs/source/tutorials/tutorial10.ipynb index bea2d186c..8feb525c0 100644 --- a/docs/source/tutorials/tutorial10.ipynb +++ b/docs/source/tutorials/tutorial10.ipynb @@ -73,7 +73,7 @@ "collapsed": false }, "source": [ - "Validation is performed by passing the validation set to the fit method during training. The resulting metrics show the performance of the model compared to our validation set." + "Validation is performed by passing the validation set to the fit method during training. The resulting metrics show the performance of the model compared to our validation set. " ] }, { @@ -356,6 +356,34 @@ "set_random_seed(0)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Aditionally, it is important to make sure to set the flag `deterministic` in the `fit` function to `True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from neuralprophet import NeuralProphet\n", + "\n", + "# Load the dataset from the CSV file using pandas\n", + "df = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv\")\n", + "\n", + "# Model and prediction\n", + "m = NeuralProphet()\n", + "\n", + "df_train, df_val = m.split_df(df, valid_p=0.2)\n", + "\n", + "# Set the deterministic flag to True\n", + "metrics = m.fit(df_train, validation_df=df_val, progress=None, deterministic=True)" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index d80fcef14..9993fd832 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -905,6 +905,7 @@ def fit( checkpointing: bool = False, continue_training: bool = False, num_workers: int = 0, + deterministic: bool = False, ): """Train, and potentially evaluate model. @@ -1069,6 +1070,7 @@ def fit( checkpointing_enabled=checkpointing, continue_training=continue_training, num_workers=num_workers, + deterministic=deterministic, ) else: df_val, _, _, _ = df_utils.prep_or_copy_df(validation_df) @@ -1093,6 +1095,7 @@ def fit( checkpointing_enabled=checkpointing, continue_training=continue_training, num_workers=num_workers, + deterministic=deterministic, ) # Show training plot @@ -2714,6 +2717,7 @@ def _train( checkpointing_enabled: bool = False, continue_training=False, num_workers=0, + deterministic: bool = False, ): """ Execute model training procedure for a configured number of epochs. @@ -2771,6 +2775,7 @@ def _train( metrics_enabled=metrics_enabled, checkpointing_enabled=checkpointing_enabled, num_batches_per_epoch=len(train_loader), + deterministic=deterministic, ) # Tune hyperparams and train diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 33f7c51e6..c00c920ce 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -11,6 +11,7 @@ import pandas as pd import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from neuralprophet import utils_torch from neuralprophet.logger import ProgressBar @@ -710,6 +711,7 @@ def set_random_seed(seed: int = 0): """ np.random.seed(seed) torch.manual_seed(seed) + seed_everything(seed, workers=True) def set_logger_level(logger, log_level, include_handlers=False): @@ -818,6 +820,7 @@ def configure_trainer( metrics_enabled: bool = False, checkpointing_enabled: bool = False, num_batches_per_epoch: int = 100, + deterministic: bool = False, ): """ Configures the PyTorch Lightning trainer. @@ -888,6 +891,8 @@ def configure_trainer( else: config["logger"] = False + config["deterministic"] = deterministic + # Configure callbacks callbacks = [] has_custom_callbacks = True if "callbacks" in config else False diff --git a/poetry.lock b/poetry.lock index 1ff31a91e..23592dbe5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -305,13 +305,13 @@ files = [ [[package]] name = "bokeh" -version = "3.4.1" +version = "3.4.2" description = "Interactive plots and applications in the browser from Python" optional = true python-versions = ">=3.9" files = [ - {file = "bokeh-3.4.1-py3-none-any.whl", hash = "sha256:1e3c502a0a8205338fc74dadbfa321f8a0965441b39501e36796a47b4017b642"}, - {file = "bokeh-3.4.1.tar.gz", hash = "sha256:d824961e4265367b0750ce58b07e564ad0b83ca64b335521cd3421e9b9f10d89"}, + {file = "bokeh-3.4.2-py3-none-any.whl", hash = "sha256:931a43ee59dbf1720383ab904f8205e126b85561aac55592415b800c96f1b0eb"}, + {file = "bokeh-3.4.2.tar.gz", hash = "sha256:a16d5cc0abb93d2d270d70fc35851f3e1b9208814a985a4678e0ba5ef2d9cd42"}, ] [package.dependencies] @@ -629,63 +629,63 @@ test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] [[package]] name = "coverage" -version = "7.5.3" +version = "7.5.4" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"}, - {file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"}, - {file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"}, - {file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"}, - {file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"}, - {file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"}, - {file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"}, - {file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"}, - {file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"}, - {file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"}, - {file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"}, - {file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"}, - {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, - {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, - {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, - {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, - {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, - {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, - {file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"}, - {file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"}, - {file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"}, - {file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"}, - {file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"}, - {file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"}, - {file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"}, - {file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"}, - {file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"}, - {file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"}, - {file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"}, - {file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"}, - {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, - {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, + {file = "coverage-7.5.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6cfb5a4f556bb51aba274588200a46e4dd6b505fb1a5f8c5ae408222eb416f99"}, + {file = "coverage-7.5.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2174e7c23e0a454ffe12267a10732c273243b4f2d50d07544a91198f05c48f47"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2214ee920787d85db1b6a0bd9da5f8503ccc8fcd5814d90796c2f2493a2f4d2e"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1137f46adb28e3813dec8c01fefadcb8c614f33576f672962e323b5128d9a68d"}, + {file = "coverage-7.5.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b385d49609f8e9efc885790a5a0e89f2e3ae042cdf12958b6034cc442de428d3"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b4a474f799456e0eb46d78ab07303286a84a3140e9700b9e154cfebc8f527016"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:5cd64adedf3be66f8ccee418473c2916492d53cbafbfcff851cbec5a8454b136"}, + {file = "coverage-7.5.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e564c2cf45d2f44a9da56f4e3a26b2236504a496eb4cb0ca7221cd4cc7a9aca9"}, + {file = "coverage-7.5.4-cp310-cp310-win32.whl", hash = "sha256:7076b4b3a5f6d2b5d7f1185fde25b1e54eb66e647a1dfef0e2c2bfaf9b4c88c8"}, + {file = "coverage-7.5.4-cp310-cp310-win_amd64.whl", hash = "sha256:018a12985185038a5b2bcafab04ab833a9a0f2c59995b3cec07e10074c78635f"}, + {file = "coverage-7.5.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:db14f552ac38f10758ad14dd7b983dbab424e731588d300c7db25b6f89e335b5"}, + {file = "coverage-7.5.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3257fdd8e574805f27bb5342b77bc65578e98cbc004a92232106344053f319ba"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a6612c99081d8d6134005b1354191e103ec9705d7ba2754e848211ac8cacc6b"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d45d3cbd94159c468b9b8c5a556e3f6b81a8d1af2a92b77320e887c3e7a5d080"}, + {file = "coverage-7.5.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed550e7442f278af76d9d65af48069f1fb84c9f745ae249c1a183c1e9d1b025c"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a892be37ca35eb5019ec85402c3371b0f7cda5ab5056023a7f13da0961e60da"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8192794d120167e2a64721d88dbd688584675e86e15d0569599257566dec9bf0"}, + {file = "coverage-7.5.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:820bc841faa502e727a48311948e0461132a9c8baa42f6b2b84a29ced24cc078"}, + {file = "coverage-7.5.4-cp311-cp311-win32.whl", hash = "sha256:6aae5cce399a0f065da65c7bb1e8abd5c7a3043da9dceb429ebe1b289bc07806"}, + {file = "coverage-7.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:d2e344d6adc8ef81c5a233d3a57b3c7d5181f40e79e05e1c143da143ccb6377d"}, + {file = "coverage-7.5.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:54317c2b806354cbb2dc7ac27e2b93f97096912cc16b18289c5d4e44fc663233"}, + {file = "coverage-7.5.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:042183de01f8b6d531e10c197f7f0315a61e8d805ab29c5f7b51a01d62782747"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6bb74ed465d5fb204b2ec41d79bcd28afccf817de721e8a807d5141c3426638"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3d45ff86efb129c599a3b287ae2e44c1e281ae0f9a9bad0edc202179bcc3a2e"}, + {file = "coverage-7.5.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5013ed890dc917cef2c9f765c4c6a8ae9df983cd60dbb635df8ed9f4ebc9f555"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1014fbf665fef86cdfd6cb5b7371496ce35e4d2a00cda501cf9f5b9e6fced69f"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3684bc2ff328f935981847082ba4fdc950d58906a40eafa93510d1b54c08a66c"}, + {file = "coverage-7.5.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:581ea96f92bf71a5ec0974001f900db495488434a6928a2ca7f01eee20c23805"}, + {file = "coverage-7.5.4-cp312-cp312-win32.whl", hash = "sha256:73ca8fbc5bc622e54627314c1a6f1dfdd8db69788f3443e752c215f29fa87a0b"}, + {file = "coverage-7.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:cef4649ec906ea7ea5e9e796e68b987f83fa9a718514fe147f538cfeda76d7a7"}, + {file = "coverage-7.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdd31315fc20868c194130de9ee6bfd99755cc9565edff98ecc12585b90be882"}, + {file = "coverage-7.5.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:02ff6e898197cc1e9fa375581382b72498eb2e6d5fc0b53f03e496cfee3fac6d"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05c16cf4b4c2fc880cb12ba4c9b526e9e5d5bb1d81313d4d732a5b9fe2b9d53"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5986ee7ea0795a4095ac4d113cbb3448601efca7f158ec7f7087a6c705304e4"}, + {file = "coverage-7.5.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df54843b88901fdc2f598ac06737f03d71168fd1175728054c8f5a2739ac3e4"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ab73b35e8d109bffbda9a3e91c64e29fe26e03e49addf5b43d85fc426dde11f9"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:aea072a941b033813f5e4814541fc265a5c12ed9720daef11ca516aeacd3bd7f"}, + {file = "coverage-7.5.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:16852febd96acd953b0d55fc842ce2dac1710f26729b31c80b940b9afcd9896f"}, + {file = "coverage-7.5.4-cp38-cp38-win32.whl", hash = "sha256:8f894208794b164e6bd4bba61fc98bf6b06be4d390cf2daacfa6eca0a6d2bb4f"}, + {file = "coverage-7.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:e2afe743289273209c992075a5a4913e8d007d569a406ffed0bd080ea02b0633"}, + {file = "coverage-7.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b95c3a8cb0463ba9f77383d0fa8c9194cf91f64445a63fc26fb2327e1e1eb088"}, + {file = "coverage-7.5.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d7564cc09dd91b5a6001754a5b3c6ecc4aba6323baf33a12bd751036c998be4"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44da56a2589b684813f86d07597fdf8a9c6ce77f58976727329272f5a01f99f7"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e16f3d6b491c48c5ae726308e6ab1e18ee830b4cdd6913f2d7f77354b33f91c8"}, + {file = "coverage-7.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbc5958cb471e5a5af41b0ddaea96a37e74ed289535e8deca404811f6cb0bc3d"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a04e990a2a41740b02d6182b498ee9796cf60eefe40cf859b016650147908029"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ddbd2f9713a79e8e7242d7c51f1929611e991d855f414ca9996c20e44a895f7c"}, + {file = "coverage-7.5.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b1ccf5e728ccf83acd313c89f07c22d70d6c375a9c6f339233dcf792094bcbf7"}, + {file = "coverage-7.5.4-cp39-cp39-win32.whl", hash = "sha256:56b4eafa21c6c175b3ede004ca12c653a88b6f922494b023aeb1e836df953ace"}, + {file = "coverage-7.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:65e528e2e921ba8fd67d9055e6b9f9e34b21ebd6768ae1c1723f4ea6ace1234d"}, + {file = "coverage-7.5.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:79b356f3dd5b26f3ad23b35c75dbdaf1f9e2450b6bcefc6d0825ea0aa3f86ca5"}, + {file = "coverage-7.5.4.tar.gz", hash = "sha256:a44963520b069e12789d0faea4e9fdb1e410cdc4aab89d94f7f55cbb7fef0353"}, ] [package.dependencies] @@ -777,33 +777,33 @@ files = [ [[package]] name = "debugpy" -version = "1.8.1" +version = "1.8.2" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" files = [ - {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"}, - {file = "debugpy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dda73bf69ea479c8577a0448f8c707691152e6c4de7f0c4dec5a4bc11dee516e"}, - {file = "debugpy-1.8.1-cp310-cp310-win32.whl", hash = "sha256:3a79c6f62adef994b2dbe9fc2cc9cc3864a23575b6e387339ab739873bea53d0"}, - {file = "debugpy-1.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:7eb7bd2b56ea3bedb009616d9e2f64aab8fc7000d481faec3cd26c98a964bcdd"}, - {file = "debugpy-1.8.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:016a9fcfc2c6b57f939673c874310d8581d51a0fe0858e7fac4e240c5eb743cb"}, - {file = "debugpy-1.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd97ed11a4c7f6d042d320ce03d83b20c3fb40da892f994bc041bbc415d7a099"}, - {file = "debugpy-1.8.1-cp311-cp311-win32.whl", hash = "sha256:0de56aba8249c28a300bdb0672a9b94785074eb82eb672db66c8144fff673146"}, - {file = "debugpy-1.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:1a9fe0829c2b854757b4fd0a338d93bc17249a3bf69ecf765c61d4c522bb92a8"}, - {file = "debugpy-1.8.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3ebb70ba1a6524d19fa7bb122f44b74170c447d5746a503e36adc244a20ac539"}, - {file = "debugpy-1.8.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2e658a9630f27534e63922ebf655a6ab60c370f4d2fc5c02a5b19baf4410ace"}, - {file = "debugpy-1.8.1-cp312-cp312-win32.whl", hash = "sha256:caad2846e21188797a1f17fc09c31b84c7c3c23baf2516fed5b40b378515bbf0"}, - {file = "debugpy-1.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:edcc9f58ec0fd121a25bc950d4578df47428d72e1a0d66c07403b04eb93bcf98"}, - {file = "debugpy-1.8.1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7a3afa222f6fd3d9dfecd52729bc2e12c93e22a7491405a0ecbf9e1d32d45b39"}, - {file = "debugpy-1.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d915a18f0597ef685e88bb35e5d7ab968964b7befefe1aaea1eb5b2640b586c7"}, - {file = "debugpy-1.8.1-cp38-cp38-win32.whl", hash = "sha256:92116039b5500633cc8d44ecc187abe2dfa9b90f7a82bbf81d079fcdd506bae9"}, - {file = "debugpy-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:e38beb7992b5afd9d5244e96ad5fa9135e94993b0c551ceebf3fe1a5d9beb234"}, - {file = "debugpy-1.8.1-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:bfb20cb57486c8e4793d41996652e5a6a885b4d9175dd369045dad59eaacea42"}, - {file = "debugpy-1.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efd3fdd3f67a7e576dd869c184c5dd71d9aaa36ded271939da352880c012e703"}, - {file = "debugpy-1.8.1-cp39-cp39-win32.whl", hash = "sha256:58911e8521ca0c785ac7a0539f1e77e0ce2df753f786188f382229278b4cdf23"}, - {file = "debugpy-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:6df9aa9599eb05ca179fb0b810282255202a66835c6efb1d112d21ecb830ddd3"}, - {file = "debugpy-1.8.1-py2.py3-none-any.whl", hash = "sha256:28acbe2241222b87e255260c76741e1fbf04fdc3b6d094fcf57b6c6f75ce1242"}, - {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"}, + {file = "debugpy-1.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7ee2e1afbf44b138c005e4380097d92532e1001580853a7cb40ed84e0ef1c3d2"}, + {file = "debugpy-1.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f8c3f7c53130a070f0fc845a0f2cee8ed88d220d6b04595897b66605df1edd6"}, + {file = "debugpy-1.8.2-cp310-cp310-win32.whl", hash = "sha256:f179af1e1bd4c88b0b9f0fa153569b24f6b6f3de33f94703336363ae62f4bf47"}, + {file = "debugpy-1.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:0600faef1d0b8d0e85c816b8bb0cb90ed94fc611f308d5fde28cb8b3d2ff0fe3"}, + {file = "debugpy-1.8.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8a13417ccd5978a642e91fb79b871baded925d4fadd4dfafec1928196292aa0a"}, + {file = "debugpy-1.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acdf39855f65c48ac9667b2801234fc64d46778021efac2de7e50907ab90c634"}, + {file = "debugpy-1.8.2-cp311-cp311-win32.whl", hash = "sha256:2cbd4d9a2fc5e7f583ff9bf11f3b7d78dfda8401e8bb6856ad1ed190be4281ad"}, + {file = "debugpy-1.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:d3408fddd76414034c02880e891ea434e9a9cf3a69842098ef92f6e809d09afa"}, + {file = "debugpy-1.8.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5d3ccd39e4021f2eb86b8d748a96c766058b39443c1f18b2dc52c10ac2757835"}, + {file = "debugpy-1.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62658aefe289598680193ff655ff3940e2a601765259b123dc7f89c0239b8cd3"}, + {file = "debugpy-1.8.2-cp312-cp312-win32.whl", hash = "sha256:bd11fe35d6fd3431f1546d94121322c0ac572e1bfb1f6be0e9b8655fb4ea941e"}, + {file = "debugpy-1.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:15bc2f4b0f5e99bf86c162c91a74c0631dbd9cef3c6a1d1329c946586255e859"}, + {file = "debugpy-1.8.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:5a019d4574afedc6ead1daa22736c530712465c0c4cd44f820d803d937531b2d"}, + {file = "debugpy-1.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40f062d6877d2e45b112c0bbade9a17aac507445fd638922b1a5434df34aed02"}, + {file = "debugpy-1.8.2-cp38-cp38-win32.whl", hash = "sha256:c78ba1680f1015c0ca7115671fe347b28b446081dada3fedf54138f44e4ba031"}, + {file = "debugpy-1.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cf327316ae0c0e7dd81eb92d24ba8b5e88bb4d1b585b5c0d32929274a66a5210"}, + {file = "debugpy-1.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:1523bc551e28e15147815d1397afc150ac99dbd3a8e64641d53425dba57b0ff9"}, + {file = "debugpy-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e24ccb0cd6f8bfaec68d577cb49e9c680621c336f347479b3fce060ba7c09ec1"}, + {file = "debugpy-1.8.2-cp39-cp39-win32.whl", hash = "sha256:7f8d57a98c5a486c5c7824bc0b9f2f11189d08d73635c326abef268f83950326"}, + {file = "debugpy-1.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:16c8dcab02617b75697a0a925a62943e26a0330da076e2a10437edd9f0bf3755"}, + {file = "debugpy-1.8.2-py2.py3-none-any.whl", hash = "sha256:16e16df3a98a35c63c3ab1e4d19be4cbc7fdda92d9ddc059294f18910928e0ca"}, + {file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"}, ] [[package]] @@ -897,13 +897,13 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.15.3" +version = "3.15.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, - {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] @@ -1256,13 +1256,13 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.2.0" +version = "8.0.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.2.0-py3-none-any.whl", hash = "sha256:04e4aad329b8b948a5711d394fa8759cb80f009225441b4f2a02bd4d8e5f426c"}, - {file = "importlib_metadata-7.2.0.tar.gz", hash = "sha256:3ff4519071ed42740522d494d04819b666541b9752c43012f85afb2cc220fcc6"}, + {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, + {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, ] [package.dependencies] @@ -1664,15 +1664,43 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "lightning-fabric" +version = "2.3.0" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lightning-fabric-2.3.0.tar.gz", hash = "sha256:b75438e96caba280141ece3512fd613ba680c102fda90657af1bbd2ea5e95bc1"}, + {file = "lightning_fabric-2.3.0-py3-none-any.whl", hash = "sha256:fff33b1e48a283e486b4a51bc5100b8d6a14dd50278a613c6d964b058584672c"}, +] + +[package.dependencies] +fsspec = {version = ">=2022.5.0", extras = ["http"]} +lightning-utilities = ">=0.8.0" +numpy = ">=1.17.2" +packaging = ">=20.0" +torch = ">=2.0.0" +typing-extensions = ">=4.4.0" + +[package.extras] +all = ["bitsandbytes (>=0.42.0)", "deepspeed (>=0.8.2,<=0.9.3)", "lightning-utilities (>=0.8.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.15.0)"] +bitsandbytes = ["bitsandbytes (>=0.42.0)"] +deepspeed = ["deepspeed (>=0.8.2,<=0.9.3)"] +dev = ["bitsandbytes (>=0.42.0)", "click (==8.1.7)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3)", "lightning-utilities (>=0.8.0)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchmetrics (>=0.7.0)", "torchvision (>=0.15.0)"] +examples = ["lightning-utilities (>=0.8.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.15.0)"] +strategies = ["bitsandbytes (>=0.42.0)", "deepspeed (>=0.8.2,<=0.9.3)"] +test = ["click (==8.1.7)", "coverage (==7.3.1)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.7.0)"] + [[package]] name = "lightning-utilities" -version = "0.11.2" +version = "0.11.3.post0" description = "Lightning toolbox for across the our ecosystem." optional = false python-versions = ">=3.8" files = [ - {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"}, - {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"}, + {file = "lightning_utilities-0.11.3.post0-py3-none-any.whl", hash = "sha256:2aec1d067e5ab61a8978f879998850a97f9a3764ee54aade329552706b0d189b"}, + {file = "lightning_utilities-0.11.3.post0.tar.gz", hash = "sha256:7485fad0e3c5607a6bde4507935689c553a2c91325de2127b4bb8171a601e236"}, ] [package.dependencies] @@ -2417,6 +2445,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -3110,6 +3139,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3404,13 +3434,13 @@ files = [ [[package]] name = "setuptools" -version = "70.1.0" +version = "70.1.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"}, - {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"}, + {file = "setuptools-70.1.1-py3-none-any.whl", hash = "sha256:a58a8fde0541dab0419750bcc521fbdf8585f6e5cb41909df3a472ef7b81ca95"}, + {file = "setuptools-70.1.1.tar.gz", hash = "sha256:937a48c7cdb7a21eb53cd7f9b59e525503aa8abaf3584c730dc5f7a5bec3a650"}, ] [package.extras] @@ -3672,13 +3702,13 @@ files = [ [[package]] name = "tenacity" -version = "8.4.1" +version = "8.4.2" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" files = [ - {file = "tenacity-8.4.1-py3-none-any.whl", hash = "sha256:28522e692eda3e1b8f5e99c51464efcc0b9fc86933da92415168bc1c4e2308fa"}, - {file = "tenacity-8.4.1.tar.gz", hash = "sha256:54b1412b878ddf7e1f1577cd49527bad8cdef32421bd599beac0c6c3f10582fd"}, + {file = "tenacity-8.4.2-py3-none-any.whl", hash = "sha256:9e6f7cf7da729125c7437222f8a522279751cdfbe6b67bfe64f75d3a348661b2"}, + {file = "tenacity-8.4.2.tar.gz", hash = "sha256:cd80a53a79336edba8489e767f729e4f391c896956b57140b5d7511a64bbd3ef"}, ] [package.extras] @@ -4229,4 +4259,4 @@ plotly-resampler = ["plotly-resampler"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "2918a6a6306adfdc98192da9235ddc0863ed75d38aee3c7fdf045dccd505e9ef" +content-hash = "9b88d1985e17192fe1334a1002a9cbc0a1c88c669e302548d5854edacd6193ff" diff --git a/pyproject.toml b/pyproject.toml index 4e69ae072..114392f9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ plotly = ">=5.13.1" kaleido = "0.2.1" # required for plotly static image export plotly-resampler = { version = ">=0.9.2", optional = true } livelossplot = { version = ">=0.5.5", optional = true } +lightning-fabric = ">=2.0.0" [tool.poetry.extras] plotly-resampler = ["plotly-resampler"] diff --git a/tests/test_glocal.py b/tests/test_glocal.py index e631b616d..5c171d597 100644 --- a/tests/test_glocal.py +++ b/tests/test_glocal.py @@ -205,36 +205,12 @@ def test_wrong_option_global_local_modeling(): forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) - -def test_different_seasonality_modeling(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - m = NeuralProphet( - n_forecasts=2, - n_lags=10, - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - season_global_local="local", - yearly_seasonality_glocal_mode="global", + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" ) - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df) - forecast = m.predict(future) - metrics = m.test(test_df) - forecast_trend = m.predict_trend(test_df) - forecast_seasonal_componets = m.predict_seasonal_components(test_df) -def test_adding_new_global_seasonality(): +def test_different_seasonality_modeling(): # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES log.info("Global Modeling + Global Normalization") df = pd.read_csv(PEYTON_FILE, nrows=512) @@ -253,7 +229,6 @@ def test_adding_new_global_seasonality(): season_global_local="local", yearly_seasonality_glocal_mode="global", ) - m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) m.fit(train_df) future = m.make_future_dataframe(test_df) @@ -262,142 +237,9 @@ def test_adding_new_global_seasonality(): forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) - -def test_adding_new_local_seasonality(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, season_global_local="global", trend_global_local="local") - m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="local") - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df, n_historic_predictions=True) - forecast = m.predict(future) - metrics = m.test(test_df) - forecast_trend = m.predict_trend(test_df) - forecast_seasonal_componets = m.predict_seasonal_components(test_df) - - -def test_trend_local_reg(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - for coef_i in [-30, 0, False, True]: - m = NeuralProphet( - n_forecasts=1, - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - trend_global_local="local", - trend_local_reg=coef_i, - ) - - m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df, n_historic_predictions=True) - forecast = m.predict(future) - metrics = m.test(test_df) - forecast_trend = m.predict_trend(test_df) - forecast_seasonal_componets = m.predict_seasonal_components(test_df) - - -def test_glocal_seasonality_reg(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - for coef_i in [-30, 0, False, True]: - m = NeuralProphet( - n_forecasts=1, - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - season_global_local="local", - yearly_seasonality_glocal_mode="global", - glocal_seasonality_reg=coef_i, - ) - - m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df, n_historic_predictions=True) - forecast = m.predict(future) - metrics = m.test(test_df) - - -def test_trend_local_reg_if_global(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - for coef_i in [-30, 0, False, True]: - m = NeuralProphet( - n_forecasts=1, - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - trend_global_local="global", - trend_local_reg=3, - ) - - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df, n_historic_predictions=True) - forecast = m.predict(future) - metrics = m.test(test_df) - forecast_trend = m.predict_trend(test_df) - forecast_seasonal_componets = m.predict_seasonal_components(test_df) - - -def test_different_seasonality_modeling(): - # SEASONALITY GLOBAL LOCAL MODELLING - NO EXOGENOUS VARIABLES - log.info("Global Modeling + Global Normalization") - df = pd.read_csv(PEYTON_FILE, nrows=512) - df1_0 = df.iloc[:128, :].copy(deep=True) - df1_0["ID"] = "df1" - df2_0 = df.iloc[128:256, :].copy(deep=True) - df2_0["ID"] = "df2" - df3_0 = df.iloc[256:384, :].copy(deep=True) - df3_0["ID"] = "df3" - m = NeuralProphet( - n_forecasts=2, - n_lags=10, - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - season_global_local="local", - yearly_seasonality_glocal_mode="global", + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" ) - train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) - m.fit(train_df) - future = m.make_future_dataframe(test_df) - forecast = m.predict(future) - metrics = m.test(test_df) - forecast_trend = m.predict_trend(test_df) - forecast_seasonal_componets = m.predict_seasonal_components(test_df) def test_adding_new_global_seasonality(): @@ -427,6 +269,9 @@ def test_adding_new_global_seasonality(): metrics = m.test(test_df) forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" + ) def test_adding_new_local_seasonality(): @@ -448,6 +293,9 @@ def test_adding_new_local_seasonality(): metrics = m.test(test_df) forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" + ) def test_trend_local_reg(): @@ -478,6 +326,9 @@ def test_trend_local_reg(): metrics = m.test(test_df) forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" + ) def test_glocal_seasonality_reg(): @@ -490,7 +341,7 @@ def test_glocal_seasonality_reg(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - for coef_i in [-30, 0, False, True]: + for _ in [-30, 0, False, True]: m = NeuralProphet( n_forecasts=1, epochs=EPOCHS, @@ -498,7 +349,6 @@ def test_glocal_seasonality_reg(): learning_rate=LR, season_global_local="local", yearly_seasonality_glocal_mode="global", - seasonality_local_reg=coef_i, ) m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="global") @@ -507,6 +357,7 @@ def test_glocal_seasonality_reg(): future = m.make_future_dataframe(test_df, n_historic_predictions=True) forecast = m.predict(future) metrics = m.test(test_df) + log.info(f"forecast = {forecast}, metrics = {metrics}") def test_trend_local_reg_if_global(): @@ -536,3 +387,6 @@ def test_trend_local_reg_if_global(): metrics = m.test(test_df) forecast_trend = m.predict_trend(test_df) forecast_seasonal_componets = m.predict_seasonal_components(test_df) + log.info( + f"forecast = {forecast}, metrics = {metrics}, forecast_trend = {forecast_trend}, forecast_seasonal_componets = {forecast_seasonal_componets}" + ) diff --git a/tests/test_model_performance.py b/tests/test_model_performance.py index ac0af79e0..af512d535 100644 --- a/tests/test_model_performance.py +++ b/tests/test_model_performance.py @@ -139,7 +139,7 @@ def test_PeytonManning(): system_speed, std = get_system_speed() start = time.time() - metrics = m.fit(df_train, validation_df=df_test, freq="D") # , early_stopping=True) + metrics = m.fit(df_train, validation_df=df_test, freq="D", deterministic=True) # , early_stopping=True) end = time.time() accuracy_metrics = metrics.to_dict("records")[-1] @@ -165,7 +165,12 @@ def test_YosemiteTemps(): system_speed, std = get_system_speed() start = time.time() - metrics = m.fit(df_train, validation_df=df_test, freq="5min") # , early_stopping=True) + metrics = m.fit( + df_train, + validation_df=df_test, + freq="5min", + deterministic=True, + ) # , early_stopping=True) end = time.time() accuracy_metrics = metrics.to_dict("records")[-1] @@ -185,7 +190,7 @@ def test_AirPassengers(): system_speed, std = get_system_speed() start = time.time() - metrics = m.fit(df_train, validation_df=df_test, freq="MS") # , early_stopping=True) + metrics = m.fit(df_train, validation_df=df_test, freq="MS", deterministic=True) # , early_stopping=True) end = time.time() accuracy_metrics = metrics.to_dict("records")[-1] @@ -217,7 +222,12 @@ def test_EnergyPriceDaily(): system_speed, std = get_system_speed() start = time.time() - metrics = m.fit(df_train, validation_df=df_test, freq="D") # , early_stopping=True) + metrics = m.fit( + df_train, + validation_df=df_test, + freq="D", + deterministic=True, + ) # , early_stopping=True) end = time.time() accuracy_metrics = metrics.to_dict("records")[-1]