diff --git a/README.md b/README.md
index 8408a13c..2ae9d140 100644
--- a/README.md
+++ b/README.md
@@ -174,7 +174,12 @@ mae = cal_mae(imputation, X_intact, indicating_mask) # calculate mean absolute
## ❖ Available Algorithms
-PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
+PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values.
+The currently available algorithms of four tasks are cataloged in the following table with four partitions.
+The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
+
+🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
+This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework.
| ***`Imputation`*** | 🚥 | 🚥 | 🚥 |
|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|
@@ -183,7 +188,7 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on
| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
-| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
+| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 |
| Naive | LOCF | Last Observation Carried Forward | - |
diff --git a/docs/conf.py b/docs/conf.py
index 1dff9c8a..d1f3f155 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -61,6 +61,8 @@
"python": ("https://docs.python.org/3", None),
"sphinx": ("https://www.sphinx-doc.org/en/master", None),
"torch": ("https://pytorch.org/docs/master/", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "pandas": ("https://pandas.pydata.org/docs/", None),
}
# configs for sphinx.ext.imgmath
diff --git a/docs/index.rst b/docs/index.rst
index 75521d2f..917eb08e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -94,33 +94,35 @@ The rest of this readme file is organized as follows:
❖ PyPOTS Ecosystem
^^^^^^^^^^^^^^^^^^^
-At PyPOTS, time series datasets are taken as coffee beans, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
+At PyPOTS, things are related to coffee, which we're familiar with. Yes, this is a coffee universe!
As you can see, there is a coffee pot in the PyPOTS logo.
+And what else? Please read on ;-)
.. image:: https://pypots.com/figs/pypots_logos/TSDB_logo_FFBG.svg
- :width: 130
+ :width: 150
:alt: TSDB logo
:align: left
:target: https://github.com/WenjieDu/TSDB
-👈 To make various open-source time-series datasets readily available to our users,
-PyPOTS gets supported by its ecosystem library *Time Series Data Beans (TSDB)*, a toolbox making loading time-series datasets super easy!
+👈 Time series datasets are taken as coffee beans at PyPOTS, and POTS datasets are incomplete coffee beans with missing parts that have their own meanings.
+To make various public time-series datasets readily available to users,
+*Time Series Data Beans (TSDB)* is created to make loading time-series datasets super easy!
Visit `TSDB `_ right now to know more about this handy tool 🛠, and it now supports a total of 168 open-source datasets!
.. image:: https://pypots.com/figs/pypots_logos/PyGrinder_logo_FFBG.svg
- :width: 130
+ :width: 150
:alt: PyGrinder logo
:align: right
:target: https://github.com/WenjieDu/PyGrinder
👉 To simulate the real-world data beans with missingness, the ecosystem library `PyGrinder `_,
-a toolkit helping grind your coffee beans into incomplete ones, is created. Missing patterns fall into three categories according to Robin's theory:cite:`rubin1976missing`:
+a toolkit helping grind your coffee beans into incomplete ones, is created. Missing patterns fall into three categories according to Robin's theory :cite:`rubin1976missing`:
MCAR (missing completely at random), MAR (missing at random), and MNAR (missing not at random).
PyGrinder supports all of them and additional functionalities related to missingness.
With PyGrinder, you can introduce synthetic missing values into your datasets with a single line of code.
.. image:: https://pypots.com/figs/pypots_logos/BrewPOTS_logo_FFBG.svg
- :width: 130
+ :width: 150
:alt: BrewPOTS logo
:align: left
:target: https://github.com/WenjieDu/BrewPOTS
@@ -130,7 +132,7 @@ Considering the future workload, PyPOTS tutorials is released in a single repo,
and you can find them in `BrewPOTS `_.
Take a look at it now, and learn how to brew your POTS datasets!
-☕️ Enjoy it and have fun!
+☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!
❖ Installation
@@ -154,6 +156,10 @@ Additionally, we present you a usage example of imputing missing values in time
^^^^^^^^^^^^^^^^^^^^^^^
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
+
+🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
+This functionality is implemented with the `Microsoft NNI `_ framework.
+
============================== ================ ======================================================================================== ====== =========
Task Type Algorithm Year Reference
============================== ================ ======================================================================================== ====== =========
diff --git a/pypots/classification/base.py b/pypots/classification/base.py
index 37221cc9..75ee4682 100644
--- a/pypots/classification/base.py
+++ b/pypots/classification/base.py
@@ -318,15 +318,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
- f"epoch {epoch}: "
- f"training loss {mean_train_loss:.4f}, "
- f"validating loss {mean_val_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"training loss: {mean_train_loss:.4f}, "
+ f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -363,7 +368,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py
index 473f3ab6..baf9d200 100644
--- a/pypots/cli/tuning.py
+++ b/pypots/cli/tuning.py
@@ -10,9 +10,10 @@
from argparse import ArgumentParser, Namespace
from .base import BaseCommand
-from ..classification import Raindrop, GRUD, BRITS
+from ..classification import BRITS as BRITS_classification
+from ..classification import Raindrop, GRUD
from ..clustering import CRLI, VaDER
-from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN
+from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS
from ..optim import Adam
from ..utils.logging import logger
@@ -24,7 +25,6 @@
"but is missing in the current environment."
)
-
NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
@@ -36,7 +36,7 @@
"pypots.imputation.MRNN": MRNN,
# classification models
"pypots.classification.GRUD": GRUD,
- "pypots.classification.BRITS": BRITS,
+ "pypots.classification.BRITS": BRITS_classification,
"pypots.classification.Raindrop": Raindrop,
# clustering models
"pypots.clustering.CRLI": CRLI,
@@ -123,15 +123,33 @@ def checkup(self):
def run(self):
"""Execute the given command."""
if os.getenv("enable_tuning", False):
- # fetch the next set of hyperparameters from NNI tuner
+ # fetch a new set of hyperparameters from NNI tuner
tuner_params = nni.get_next_parameter()
- # get the specified NN class
+ # get the specified model class
model_class = NN_MODELS[self._model]
# pop out the learning rate
lr = tuner_params.pop("lr")
+ # check if hyperparameters match
+ model_all_arguments = model_class.__init__.__annotations__.keys()
+ tuner_params_set = set(tuner_params.keys())
+ model_arguments_set = set(model_all_arguments)
+ if_hyperparameter_match = tuner_params_set.issubset(model_arguments_set)
+ if not if_hyperparameter_match: # raise runtime error if mismatch
+ hyperparameter_intersection = tuner_params_set.intersection(
+ model_arguments_set
+ )
+ mismatched = tuner_params_set.difference(
+ set(hyperparameter_intersection)
+ )
+ raise RuntimeError(
+ f"Hyperparameters do not match. Mismatched hyperparameters "
+ f"(in the tuning configuration but not in the given model's arguments): {list(mismatched)}"
+ )
+
+ # initializing optimizer and model
# if tuning a GAN model, we need two optimizers
- if "G_optimizer" in model_class.__init__.__annotations__.keys():
+ if "G_optimizer" in model_all_arguments:
# optimizer for the generator
tuner_params["G_optimizer"] = Adam(lr=lr)
# optimizer for the discriminator
@@ -144,4 +162,4 @@ def run(self):
# train the model and report to NNI
model.fit(train_set=self._train_set, val_set=self._val_set)
else:
- logger.error("Argument `enable_tuning` is not set. Aborting...")
+ raise RuntimeError("Argument `enable_tuning` is not set. Aborting...")
diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py
index b0d2e8a0..f00118ee 100644
--- a/pypots/clustering/base.py
+++ b/pypots/clustering/base.py
@@ -317,15 +317,20 @@ def _train_model(
mean_val_loss = np.mean(epoch_val_loss_collector)
logger.info(
- f"epoch {epoch}: "
- f"training loss {mean_train_loss:.4f}, "
- f"validating loss {mean_val_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"training loss: {mean_train_loss:.4f}, "
+ f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -357,7 +362,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py
index dc95fd0a..40c90f18 100644
--- a/pypots/clustering/crli/model.py
+++ b/pypots/clustering/crli/model.py
@@ -282,20 +282,25 @@ def _train_model(
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
- f"epoch {epoch}: "
- f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
- f"training loss_discriminator {mean_epoch_train_D_loss:.4f}, "
- f"validating loss_generator {mean_val_G_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
+ f"discriminator training loss: {mean_epoch_train_D_loss:.4f}, "
+ f"generator validating loss: {mean_val_G_loss:.4f}"
)
mean_loss = mean_val_G_loss
else:
logger.info(
- f"epoch {epoch}: "
- f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
- f"training loss_discriminator {mean_epoch_train_D_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
+ f"discriminator training loss: {mean_epoch_train_D_loss:.4f}"
)
mean_loss = mean_epoch_train_G_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -332,7 +337,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py
index 3639de33..3a3e1c5c 100644
--- a/pypots/clustering/vader/model.py
+++ b/pypots/clustering/vader/model.py
@@ -300,15 +300,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
- f"epoch {epoch}: "
- f"training loss {mean_train_loss:.4f}, "
- f"validating loss {mean_val_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"training loss: {mean_train_loss:.4f}, "
+ f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -345,7 +350,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py
index ca857a2f..cd86583e 100644
--- a/pypots/forecasting/base.py
+++ b/pypots/forecasting/base.py
@@ -303,15 +303,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
- f"epoch {epoch}: "
- f"training loss {mean_train_loss:.4f}, "
- f"validating loss {mean_val_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"training loss: {mean_train_loss:.4f}, "
+ f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -343,7 +348,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py
index 38f93b31..cd0cbe69 100644
--- a/pypots/imputation/base.py
+++ b/pypots/imputation/base.py
@@ -314,15 +314,20 @@ def _train_model(
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
logger.info(
- f"epoch {epoch}: "
- f"training loss {mean_train_loss:.4f}, "
- f"validating loss {mean_val_loss:.4f}"
+ f"Epoch {epoch} - "
+ f"training loss: {mean_train_loss:.4f}, "
+ f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
+
if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
@@ -359,7 +364,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss.item(), float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py
index a791f75f..e9627c93 100644
--- a/pypots/imputation/brits/model.py
+++ b/pypots/imputation/brits/model.py
@@ -192,6 +192,16 @@ def fit(
"X_intact": hf["X_intact"][:],
"indicating_mask": hf["indicating_mask"][:],
}
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = DatasetForBRITS(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py
index 21af4121..903a4066 100644
--- a/pypots/imputation/csdi/model.py
+++ b/pypots/imputation/csdi/model.py
@@ -15,6 +15,7 @@
from typing import Union, Optional
+import h5py
import numpy as np
import torch
from torch.utils.data import DataLoader
@@ -223,6 +224,29 @@ def fit(
)
val_loader = None
if val_set is not None:
+ if isinstance(val_set, str):
+ with h5py.File(val_set, "r") as hf:
+ # Here we read the whole validation set from the file to mask a portion for validation.
+ # In PyPOTS, using a file usually because the data is too big. However, the validation set is
+ # generally shouldn't be too large. For example, we have 1 billion samples for model training.
+ # We won't take 20% of them as the validation set because we want as much as possible data for the
+ # training stage to enhance the model's generalization ability. Therefore, 100,000 representative
+ # samples will be enough to validate the model.
+ val_set = {
+ "X": hf["X"][:],
+ "X_intact": hf["X_intact"][:],
+ "indicating_mask": hf["indicating_mask"][:],
+ }
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = DatasetForCSDI(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py
index d8445597..c8f61199 100644
--- a/pypots/imputation/gpvae/model.py
+++ b/pypots/imputation/gpvae/model.py
@@ -208,6 +208,16 @@ def fit(
"X_intact": hf["X_intact"][:],
"indicating_mask": hf["indicating_mask"][:],
}
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = DatasetForGPVAE(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py
index 0b7cd731..25f14783 100644
--- a/pypots/imputation/mrnn/model.py
+++ b/pypots/imputation/mrnn/model.py
@@ -187,6 +187,16 @@ def fit(
"X_intact": hf["X_intact"][:],
"indicating_mask": hf["indicating_mask"][:],
}
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = DatasetForMRNN(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py
index 6cd68fe4..2eb52246 100644
--- a/pypots/imputation/saits/model.py
+++ b/pypots/imputation/saits/model.py
@@ -265,6 +265,15 @@ def fit(
"indicating_mask": hf["indicating_mask"][:],
}
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = BaseDataset(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py
index 8fa86508..f96b772e 100644
--- a/pypots/imputation/transformer/model.py
+++ b/pypots/imputation/transformer/model.py
@@ -262,6 +262,15 @@ def fit(
"indicating_mask": hf["indicating_mask"][:],
}
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = BaseDataset(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,
diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py
index ce51ca19..66cd6a6c 100644
--- a/pypots/imputation/usgan/model.py
+++ b/pypots/imputation/usgan/model.py
@@ -9,6 +9,7 @@
# Created by Jun Wang and Wenjie Du
# License: BSD-3-Clause
+import os
from typing import Union, Optional
import h5py
@@ -23,6 +24,11 @@
from ...optim.base import Optimizer
from ...utils.logging import logger
+try:
+ import nni
+except ImportError:
+ pass
+
class USGAN(BaseNNImputer):
"""The PyTorch implementation of the USGAN model. Refer to :cite:`miao2021SSGAN`.
@@ -257,12 +263,43 @@ def _train_model(
)
mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector)
- logger.info(
- f"epoch {epoch}: "
- f"training loss_generator {mean_epoch_train_G_loss:.4f}, "
- f"train loss_discriminator {mean_epoch_train_D_loss:.4f}"
- )
- mean_loss = mean_epoch_train_G_loss
+
+ if val_loader is not None:
+ self.model.eval()
+ epoch_val_loss_G_collector = []
+ with torch.no_grad():
+ for idx, data in enumerate(val_loader):
+ inputs = self._assemble_input_for_validating(data)
+ results = self.model.forward(inputs, training=True)
+ epoch_val_loss_G_collector.append(
+ results["generation_loss"].sum().item()
+ )
+ mean_val_G_loss = np.mean(epoch_val_loss_G_collector)
+ # save validating loss logs into the tensorboard file for every epoch if in need
+ if self.summary_writer is not None:
+ val_loss_dict = {
+ "generation_loss": mean_val_G_loss,
+ }
+ self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
+ logger.info(
+ f"Epoch {epoch} - "
+ f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
+ f"discriminator training loss: {mean_epoch_train_D_loss:.4f}, "
+ f"generator validating loss: {mean_val_G_loss:.4f}"
+ )
+ mean_loss = mean_val_G_loss
+ else:
+ logger.info(
+ f"Epoch {epoch} - "
+ f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
+ f"discriminator training loss: {mean_epoch_train_D_loss:.4f}"
+ )
+ mean_loss = mean_epoch_train_G_loss
+
+ if np.isnan(mean_loss):
+ logger.warning(
+ f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
+ )
if mean_loss < self.best_loss:
self.best_loss = mean_loss
@@ -275,11 +312,18 @@ def _train_model(
)
else:
self.patience -= 1
- if self.patience == 0:
- logger.info(
- "Exceeded the training patience. Terminating the training procedure..."
- )
- break
+
+ if os.getenv("enable_tuning", False):
+ nni.report_intermediate_result(mean_loss)
+ if epoch == self.epochs - 1 or self.patience == 0:
+ nni.report_final_result(self.best_loss)
+
+ if self.patience == 0:
+ logger.info(
+ "Exceeded the training patience. Terminating the training procedure..."
+ )
+ break
+
except Exception as e:
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
@@ -293,7 +337,7 @@ def _train_model(
"If you don't want it, please try fit() again."
)
- if np.equal(self.best_loss, float("inf")):
+ if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")
logger.info("Finished training.")
@@ -329,6 +373,16 @@ def fit(
"X_intact": hf["X_intact"][:],
"indicating_mask": hf["indicating_mask"][:],
}
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
val_set = DatasetForUSGAN(val_set, return_labels=False, file_type=file_type)
val_loader = DataLoader(
val_set,