Skip to content

Commit

Permalink
Merge pull request #234 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Check if X_intact contains missing data for imputation models, check and list mismatched hyperparameters in the tuning mode
  • Loading branch information
WenjieDu authored Nov 12, 2023
2 parents 939ad55 + b8d6689 commit c962530
Show file tree
Hide file tree
Showing 17 changed files with 250 additions and 63 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ mae = cal_mae(imputation, X_intact, indicating_mask) # calculate mean absolute


## ❖ Available Algorithms
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values.
The currently available algorithms of four tasks are cataloged in the following table with four partitions.
The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.

🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework.

| ***`Imputation`*** | 🚥 | 🚥 | 🚥 |
|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|
Expand All @@ -183,7 +188,7 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on
| Neural Net | Transformer | Attention is All you Need [^2];<br>Self-Attention-based Imputation for Time Series [^1];<br><sub>Note: proposed in [^2], and re-implemented as an imputation model in [^1].</sub> | 2017 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 |
| Naive | LOCF | Last Observation Carried Forward | - |
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"python": ("https://docs.python.org/3", None),
"sphinx": ("https://www.sphinx-doc.org/en/master", None),
"torch": ("https://pytorch.org/docs/master/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
}

# configs for sphinx.ext.imgmath
Expand Down
22 changes: 14 additions & 8 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,35 @@ The rest of this readme file is organized as follows:

❖ PyPOTS Ecosystem
^^^^^^^^^^^^^^^^^^^
At PyPOTS, time series datasets are taken as coffee beans, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
At PyPOTS, things are related to coffee, which we're familiar with. Yes, this is a coffee universe!
As you can see, there is a coffee pot in the PyPOTS logo.
And what else? Please read on ;-)

.. image:: https://pypots.com/figs/pypots_logos/TSDB_logo_FFBG.svg
:width: 130
:width: 150
:alt: TSDB logo
:align: left
:target: https://github.com/WenjieDu/TSDB

👈 To make various open-source time-series datasets readily available to our users,
PyPOTS gets supported by its ecosystem library *Time Series Data Beans (TSDB)*, a toolbox making loading time-series datasets super easy!
👈 Time series datasets are taken as coffee beans at PyPOTS, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
To make various public time-series datasets readily available to users,
*Time Series Data Beans (TSDB)* is created to make loading time-series datasets super easy!
Visit `TSDB <https://github.com/WenjieDu/TSDB>`_ right now to know more about this handy tool 🛠, and it now supports a total of 168 open-source datasets!

.. image:: https://pypots.com/figs/pypots_logos/PyGrinder_logo_FFBG.svg
:width: 130
:width: 150
:alt: PyGrinder logo
:align: right
:target: https://github.com/WenjieDu/PyGrinder

👉 To simulate the real-world data beans with missingness, the ecosystem library `PyGrinder <https://github.com/WenjieDu/PyGrinder>`_,
a toolkit helping grind your coffee beans into incomplete ones, is created. Missing patterns fall into three categories according to Robin's theory:cite:`rubin1976missing`:
a toolkit helping grind your coffee beans into incomplete ones, is created. Missing patterns fall into three categories according to Robin's theory :cite:`rubin1976missing`:
MCAR (missing completely at random), MAR (missing at random), and MNAR (missing not at random).
PyGrinder supports all of them and additional functionalities related to missingness.
With PyGrinder, you can introduce synthetic missing values into your datasets with a single line of code.

.. image:: https://pypots.com/figs/pypots_logos/BrewPOTS_logo_FFBG.svg
:width: 130
:width: 150
:alt: BrewPOTS logo
:align: left
:target: https://github.com/WenjieDu/BrewPOTS
Expand All @@ -130,7 +132,7 @@ Considering the future workload, PyPOTS tutorials is released in a single repo,
and you can find them in `BrewPOTS <https://github.com/WenjieDu/BrewPOTS>`_.
Take a look at it now, and learn how to brew your POTS datasets!

☕️ Enjoy it and have fun!
☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!


❖ Installation
Expand All @@ -154,6 +156,10 @@ Additionally, we present you a usage example of imputing missing values in time
^^^^^^^^^^^^^^^^^^^^^^^
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.


🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the `Microsoft NNI <https://github.com/microsoft/nni>`_ framework.

============================== ================ ======================================================================================== ====== =========
Task Type Algorithm Year Reference
============================== ================ ======================================================================================== ====== =========
Expand Down
15 changes: 10 additions & 5 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
Expand Down Expand Up @@ -363,7 +368,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float("inf")):
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
Expand Down
34 changes: 26 additions & 8 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from argparse import ArgumentParser, Namespace

from .base import BaseCommand
from ..classification import Raindrop, GRUD, BRITS
from ..classification import BRITS as BRITS_classification
from ..classification import Raindrop, GRUD
from ..clustering import CRLI, VaDER
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS
from ..optim import Adam
from ..utils.logging import logger

Expand All @@ -24,7 +25,6 @@
"but is missing in the current environment."
)


NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
Expand All @@ -36,7 +36,7 @@
"pypots.imputation.MRNN": MRNN,
# classification models
"pypots.classification.GRUD": GRUD,
"pypots.classification.BRITS": BRITS,
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.Raindrop": Raindrop,
# clustering models
"pypots.clustering.CRLI": CRLI,
Expand Down Expand Up @@ -123,15 +123,33 @@ def checkup(self):
def run(self):
"""Execute the given command."""
if os.getenv("enable_tuning", False):
# fetch the next set of hyperparameters from NNI tuner
# fetch a new set of hyperparameters from NNI tuner
tuner_params = nni.get_next_parameter()
# get the specified NN class
# get the specified model class
model_class = NN_MODELS[self._model]
# pop out the learning rate
lr = tuner_params.pop("lr")

# check if hyperparameters match
model_all_arguments = model_class.__init__.__annotations__.keys()
tuner_params_set = set(tuner_params.keys())
model_arguments_set = set(model_all_arguments)
if_hyperparameter_match = tuner_params_set.issubset(model_arguments_set)
if not if_hyperparameter_match: # raise runtime error if mismatch
hyperparameter_intersection = tuner_params_set.intersection(
model_arguments_set
)
mismatched = tuner_params_set.difference(
set(hyperparameter_intersection)
)
raise RuntimeError(
f"Hyperparameters do not match. Mismatched hyperparameters "
f"(in the tuning configuration but not in the given model's arguments): {list(mismatched)}"
)

# initializing optimizer and model
# if tuning a GAN model, we need two optimizers
if "G_optimizer" in model_class.__init__.__annotations__.keys():
if "G_optimizer" in model_all_arguments:
# optimizer for the generator
tuner_params["G_optimizer"] = Adam(lr=lr)
# optimizer for the discriminator
Expand All @@ -144,4 +162,4 @@ def run(self):
# train the model and report to NNI
model.fit(train_set=self._train_set, val_set=self._val_set)
else:
logger.error("Argument `enable_tuning` is not set. Aborting...")
raise RuntimeError("Argument `enable_tuning` is not set. Aborting...")
15 changes: 10 additions & 5 deletions pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,20 @@ def _train_model(

mean_val_loss = np.mean(epoch_val_loss_collector)
logger.info(
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
Expand Down Expand Up @@ -357,7 +362,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float("inf")):
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
Expand Down
21 changes: 13 additions & 8 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,25 @@ def _train_model(
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, "
f"validating loss_generator {mean_val_G_loss:.4f}"
f"Epoch {epoch} - "
f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
f"discriminator training loss: {mean_epoch_train_D_loss:.4f}, "
f"generator validating loss: {mean_val_G_loss:.4f}"
)
mean_loss = mean_val_G_loss
else:
logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
f"training loss_discriminator {mean_epoch_train_D_loss:.4f}"
f"Epoch {epoch} - "
f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
f"discriminator training loss: {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
Expand Down Expand Up @@ -332,7 +337,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float("inf")):
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
Expand Down
15 changes: 10 additions & 5 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
Expand Down Expand Up @@ -345,7 +350,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float("inf")):
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
Expand Down
15 changes: 10 additions & 5 deletions pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
Expand Down Expand Up @@ -343,7 +348,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float("inf")):
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
Expand Down
Loading

0 comments on commit c962530

Please sign in to comment.