diff --git a/README.md b/README.md index 9c86f08a..bc9535b5 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

Welcome to PyPOTS

-**

A Python Toolbox for Data Mining on Partially-Observed Time Series

** +

a Python toolbox for data mining on Partially-Observed Time Series

diff --git a/docs/pypots.forecasting.rst b/docs/pypots.forecasting.rst index c4ac76b7..5cd6eaa1 100644 --- a/docs/pypots.forecasting.rst +++ b/docs/pypots.forecasting.rst @@ -1,31 +1,10 @@ pypots.forecasting package ========================== -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - pypots.forecasting.bttf - pypots.forecasting.template - -Submodules ----------- - -pypots.forecasting.base module +pypots.forecasting.bttf module ------------------------------ -.. automodule:: pypots.forecasting.base - :members: - :undoc-members: - :show-inheritance: - :inherited-members: - -Module contents ---------------- - -.. automodule:: pypots.forecasting +.. automodule:: pypots.forecasting.bttf :members: :undoc-members: :show-inheritance: diff --git a/docs/references.bib b/docs/references.bib index 214b69b8..8aa37c62 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -119,8 +119,8 @@ @article{du2023SAITS pages = {119619}, year = {2023}, issn = {0957-4174}, -doi = {https://doi.org/10.1016/j.eswa.2023.119619}, -url = {https://www.sciencedirect.com/science/article/pii/S0957417423001203}, +doi = {10.1016/j.eswa.2023.119619}, +url = {https://arxiv.org/abs/2202.08516}, author = {Wenjie Du and David Cote and Yan Liu}, } @article{fortuin2020GPVAEDeep, @@ -418,3 +418,35 @@ @inproceedings{reddi2018OnTheConvergence year={2018}, url={https://openreview.net/forum?id=ryQu7f-RZ}, } + +@article{hubert1985, + title={Comparing partitions}, + author={Hubert, Lawrence and Arabie, Phipps}, + journal={Journal of classification}, + volume={2}, + pages={193--218}, + year={1985}, + publisher={Springer} +} + +@article{steinley2004, + title={Properties of the hubert-arable adjusted rand index}, + author={Steinley, Douglas}, + journal={Psychological methods}, + volume={9}, + number={3}, + pages={386}, + year={2004}, + publisher={American Psychological Association} +} + +@article{calinski1974, + title={A dendrite method for cluster analysis}, + author={Cali{\'n}ski, Tadeusz and Harabasz, Jerzy}, + journal={Communications in Statistics-theory and Methods}, + volume={3}, + number={1}, + pages={1--27}, + year={1974}, + publisher={Taylor \& Francis} +} \ No newline at end of file diff --git a/environment-dev.yml b/environment-dev.yml index 281fb5a9..404a6a09 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -11,8 +11,8 @@ dependencies: #- conda-forge::python #- conda-forge::pip #- conda-forge::scipy - #- conda-forge::numpy >=1.23.3 # numpy should , otherwise may encounter "number not available" when torch>1.11 - #- conda-forge::scikit-learn >=0.24.1 + #- conda-forge::numpy + #- conda-forge::scikit-learn #- conda-forge::pandas <2.0.0 #- conda-forge::h5py #- conda-forge::tensorboard diff --git a/pypots/__init__.py b/pypots/__init__.py index 72846c15..731e11cd 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.1.2" +__version__ = "0.1.3" __all__ = [ diff --git a/pypots/base.py b/pypots/base.py index 52b659c2..24cb21f3 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -7,9 +7,11 @@ import os from abc import ABC +from abc import abstractmethod from datetime import datetime from typing import Optional, Union +import numpy as np import torch from torch.utils.tensorboard import SummaryWriter @@ -209,6 +211,33 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None if ("loss" in item_name) or ("error" in item_name): self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step) + def _auto_save_model_if_necessary( + self, + training_finished: bool = True, + saving_name: str = None, + ): + """Automatically save the current model into a file if in need. + + Parameters + ---------- + training_finished : + Whether the training is already finished when invoke this function. + The saving_strategy "better" only works when training_finished is False. + The saving_strategy "best" only works when training_finished is True. + + saving_name : + The file name of the saved model. + + """ + if self.saving_path is not None and self.model_saving_strategy is not None: + name = self.__class__.__name__ if saving_name is None else saving_name + if not training_finished and self.model_saving_strategy == "better": + self.save_model(self.saving_path, name) + elif training_finished and self.model_saving_strategy == "best": + self.save_model(self.saving_path, name) + else: + return + def save_model( self, saving_dir: str, @@ -258,33 +287,6 @@ def save_model( f'Failed to save the model to "{saving_path}" because of the below error! \n{e}' ) - def _auto_save_model_if_necessary( - self, - training_finished: bool = True, - saving_name: str = None, - ): - """Automatically save the current model into a file if in need. - - Parameters - ---------- - training_finished : - Whether the training is already finished when invoke this function. - The saving_strategy "better" only works when training_finished is False. - The saving_strategy "best" only works when training_finished is True. - - saving_name : - The file name of the saved model. - - """ - if self.saving_path is not None and self.model_saving_strategy is not None: - name = self.__class__.__name__ if saving_name is None else saving_name - if not training_finished and self.model_saving_strategy == "better": - self.save_model(self.saving_path, name) - elif training_finished and self.model_saving_strategy == "best": - self.save_model(self.saving_path, name) - else: - return - def load_model(self, model_path: str) -> None: """Load the saved model from a disk file. @@ -317,6 +319,72 @@ def load_model(self, model_path: str) -> None: raise e logger.info(f"Model loaded successfully from {model_path}.") + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + """Train the classifier on the given data. + + Parameters + ---------- + train_set : dict or str + The dataset for model training, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for training, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + val_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if train_set and val_set are path strings. + + """ + raise NotImplementedError + + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict: dict + Prediction results in a Python Dictionary for the given samples. + It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'. + For sure, only the keys that relevant tasks are supported by the model will be returned. + """ + raise NotImplementedError + class BaseNNModel(BaseModel): """The abstract class for all neural-network models. @@ -400,7 +468,7 @@ def __init__( else: assert ( patience <= epochs - ), f"patience must be smaller than epoches which is {epochs}, but got patience={patience}" + ), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}" # training hype-parameters self.batch_size = batch_size @@ -421,3 +489,20 @@ def _print_model_size(self) -> None: logger.info( f"Model initialized successfully with the number of trainable parameters: {num_params:,}" ) + + @abstractmethod + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + raise NotImplementedError + + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 34d6743f..b52dfc13 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -95,6 +95,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def classify( self, @@ -117,6 +125,8 @@ def classify( array-like, shape [n_samples], Classification results of the given samples. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -402,4 +412,6 @@ def classify( array-like, shape [n_samples], Classification results of the given samples. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index bbccb7ce..af1faec8 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -16,6 +16,7 @@ from typing import Optional, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -29,6 +30,7 @@ ) from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _BRITS(imputation_BRITS, nn.Module): @@ -173,13 +175,12 @@ class BRITS(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Cao, Wei, Dong Wang, Jian Li, Hao Zhou, Lei Li, and Yitan Li. + "Brits: Bidirectional recurrent imputation for time series." + Advances in neural information processing systems 31 (2018). + `_ """ @@ -327,23 +328,41 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py"): + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type) + test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) classification_pred = results["classification_pred"] - prediction_collector.append(classification_pred) + classification_collector.append(classification_pred) - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + classification = torch.cat(classification_collector).cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict + + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index a9e4f6e6..60fd6482 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -23,6 +23,7 @@ from ...imputation.brits.modules import TemporalDecay from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _GRUD(nn.Module): @@ -169,13 +170,13 @@ class GRUD(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying GRU-D model. + .. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu. + "Recurrent neural networks for multivariate time series with missing values." + Scientific reports 8, no. 1 (2018): 6085. + `_ - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. """ def __init__( @@ -303,23 +304,41 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type) + test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) prediction = results["classification_pred"] - prediction_collector.append(prediction) + classification_collector.append(prediction) + + classification = torch.cat(classification_collector).cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 75bd1470..8b9d1670 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -367,14 +367,12 @@ class Raindrop(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Raindrop model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. - + .. [1] `Zhang, Xiang, Marko Zeman, Theodoros Tsiligkaridis, and Marinka Zitnik. + "Graph-guided network for irregularly sampled multivariate time series." + International Conference on Learning Representations (ICLR). 2022. + `_ """ def __init__( @@ -521,9 +519,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type) + test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -531,13 +533,28 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) prediction = results["classification_pred"] - prediction_collector.append(prediction) + classification_collector.append(prediction) + + classification = torch.cat(classification_collector).cpu().detach().numpy() - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict + + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/template/model.py b/pypots/classification/template/model.py index f7e2d15a..9e625ab2 100644 --- a/pypots/classification/template/model.py +++ b/pypots/classification/template/model.py @@ -101,5 +101,9 @@ def fit( ) -> None: raise NotImplementedError - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index c9a27ad4..43d2e82e 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -64,6 +64,7 @@ def __init__( def fit( self, train_set: Union[dict, str], + val_set: Union[dict, str] = None, file_type: str = "h5py", ) -> None: """Train the cluster. @@ -78,11 +79,29 @@ def fit( If it is a path string, the path should point to a data file, e.g. a h5 file, which contains key-value pairs like a dict, and it has to include the key 'X'. + val_set : + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + file_type : - The type of the given file if train_set is a path string. + The type of the given file if train_set and val_set are path strings. + """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def cluster( self, @@ -105,6 +124,8 @@ def cluster( array-like, Clustering results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -332,6 +353,7 @@ def _train_model( def fit( self, train_set: Union[dict, str], + val_set: Union[dict, str] = None, file_type: str = "h5py", ) -> None: """Train the cluster. @@ -346,8 +368,18 @@ def fit( If it is a path string, the path should point to a data file, e.g. a h5 file, which contains key-value pairs like a dict, and it has to include the key 'X'. + val_set : + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + file_type : - The type of the given file if train_set is a path string. + The type of the given file if train_set and val_set are path strings. + """ raise NotImplementedError @@ -373,4 +405,6 @@ def cluster( array-like, Clustering results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 26f0a769..46805c38 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -196,13 +196,13 @@ class CRLI(BaseNNClusterer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying CRLI model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Ma, Qianli, Chuxin Chen, Sen Li, and Garrison W. Cottrell. 2021. + "Learning Representations for Incomplete Time Series Clustering". + Proceedings of the AAAI Conference on Artificial Intelligence 35 (10):8837-46. + https://doi.org/10.1609/aaai.v35i10.17070. + `_ """ @@ -415,14 +415,14 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def cluster( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", return_latent: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type) + test_set = DatasetForCRLI(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -443,13 +443,33 @@ def cluster( clustering_latent = ( torch.cat(clustering_latent_collector).cpu().detach().numpy() ) - clustering_results = self.model.kmeans.fit_predict(clustering_latent) + clustering = self.model.kmeans.fit_predict(clustering_latent) latent_collector = { "clustering_latent": clustering_latent, "imputation": imputation, } + result_dict = { + "clustering": clustering, + } + + if return_latent: + result_dict["latent"] = latent_collector + + return result_dict + + def cluster( + self, + X: Union[dict, str], + file_type: str = "h5py", + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + result_dict = self.predict(X, file_type, return_latent) if return_latent: - return clustering_results, latent_collector + return result_dict["clustering"], result_dict["latent"] - return clustering_results + return result_dict["clustering"] diff --git a/pypots/clustering/crli/modules.py b/pypots/clustering/crli/modules.py index d5413e37..f6837647 100644 --- a/pypots/clustering/crli/modules.py +++ b/pypots/clustering/crli/modules.py @@ -65,8 +65,11 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: ) output_collector = torch.empty((bz, n_steps, self.d_input), device=self.device) if self.cell_type == "LSTM": - # TODO: cell states should have different shapes - cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device) + cell_states = [ + torch.zeros((bz, self.d_hidden), device=self.device) + for i in range(self.n_layer) + ] + for step in range(n_steps): x = X[:, step, :] estimation = self.output_layer(hidden_state) @@ -76,13 +79,14 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]: ) for i in range(self.n_layer): if i == 0: - hidden_state, cell_states = self.model[i]( - imputed_x, (hidden_state, cell_states) + hidden_state, cell_state = self.model[i]( + imputed_x, (hidden_state, cell_states[i]) ) else: - hidden_state, cell_states = self.model[i]( - hidden_state, (hidden_state, cell_states) + hidden_state, cell_state = self.model[i]( + hidden_state, (hidden_state, cell_states[i]) ) + hidden_state_collector[:, step, :] = hidden_state elif self.cell_type == "GRU": @@ -168,19 +172,27 @@ def forward(self, inputs: dict) -> torch.Tensor: ] hidden_state_collector = torch.empty((bz, n_steps, 32), device=self.device) if self.cell_type == "LSTM": - cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device) + cell_states = [ + torch.zeros((bz, 32), device=self.device), + torch.zeros((bz, 16), device=self.device), + torch.zeros((bz, 8), device=self.device), + torch.zeros((bz, 16), device=self.device), + torch.zeros((bz, 32), device=self.device), + ] for step in range(n_steps): x = imputed_X[:, step, :] for i, rnn_cell in enumerate(self.rnn_cell_module_list): if i == 0: - hidden_state, cell_states = rnn_cell( - x, (hidden_states[i], cell_states) + hidden_state, cell_state = rnn_cell( + x, (hidden_states[i], cell_states[i]) ) else: - hidden_state, cell_states = rnn_cell( - hidden_states[i - 1], (hidden_states[i], cell_states) + hidden_state, cell_state = rnn_cell( + hidden_states[i - 1], (hidden_states[i], cell_states[i]) ) + cell_states[i] = cell_state hidden_states[i] = hidden_state + hidden_state_collector[:, step, :] = hidden_state elif self.cell_type == "GRU": diff --git a/pypots/clustering/template/model.py b/pypots/clustering/template/model.py index d9b40cff..e51beb6d 100644 --- a/pypots/clustering/template/model.py +++ b/pypots/clustering/template/model.py @@ -101,5 +101,9 @@ def fit( ) -> None: raise NotImplementedError - def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 7c85ad13..d6d910c5 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -328,13 +328,15 @@ class VaDER(BaseNNClusterer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying VaDER model. + .. [1] `de Jong, Johann, Mohammad Asif Emon, Ping Wu, Reagon Karki, Meemansa Sood, Patrice Godard, + Ashar Ahmad, Henri Vrooman, Martin Hofmann-Apitius, and Holger Fröhlich. + "Deep learning for clustering of multivariate clinical patient trajectories with missing values." + GigaScience 8, no. 11 (2019): giz134. + `_ + - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. """ @@ -602,14 +604,14 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def cluster( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", return_latent: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForVaDER(X, return_labels=False, file_type=file_type) + test_set = DatasetForVaDER(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -666,7 +668,7 @@ def func_to_apply( clustering_results = np.argmax(p, axis=0) clustering_results_collector.append(clustering_results) - clustering_results = np.concatenate(clustering_results_collector) + clustering = np.concatenate(clustering_results_collector) latent_collector = { "mu_tilde": np.concatenate(mu_tilde_collector), "stddev_tilde": np.concatenate(stddev_tilde_collector), @@ -677,7 +679,27 @@ def func_to_apply( "imputation": np.concatenate(imputed_X_collector), } + result_dict = { + "clustering": clustering, + } + + if return_latent: + result_dict["latent"] = latent_collector + + return result_dict + + def cluster( + self, + X: Union[dict, str], + file_type: str = "h5py", + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + result_dict = self.predict(X, file_type, return_latent) if return_latent: - return clustering_results, latent_collector + return result_dict["clustering"], result_dict["latent"] - return clustering_results + return result_dict["clustering"] diff --git a/pypots/data/__init__.py b/pypots/data/__init__.py index a3a68be9..dc1bfbf8 100644 --- a/pypots/data/__init__.py +++ b/pypots/data/__init__.py @@ -8,8 +8,9 @@ from .base import BaseDataset from .generating import ( gene_complete_random_walk, - gene_random_walk_for_classification, - gene_incomplete_random_walk_dataset, + gene_complete_random_walk_for_anomaly_detection, + gene_complete_random_walk_for_classification, + gene_random_walk, gene_physionet2012, ) from .load_specific_datasets import ( @@ -29,8 +30,9 @@ "BaseDataset", # data generation "gene_complete_random_walk", - "gene_random_walk_for_classification", - "gene_incomplete_random_walk_dataset", + "gene_complete_random_walk_for_anomaly_detection", + "gene_complete_random_walk_for_classification", + "gene_random_walk", "gene_physionet2012", # list and load datasets "list_supported_datasets", diff --git a/pypots/data/base.py b/pypots/data/base.py index 1bef9f9c..c163e9ef 100644 --- a/pypots/data/base.py +++ b/pypots/data/base.py @@ -77,7 +77,7 @@ def __init__( y = None if "y" not in data.keys() else data["y"] self.X, self.y = self._check_input(X, y) - self.sample_num = self._get_sample_num() + self.n_samples, self.n_steps, self.n_features = self._get_data_sizes() # set up function fetch_data() if isinstance(self.data, str): @@ -85,25 +85,31 @@ def __init__( else: self.fetch_data = self._fetch_data_from_array - def _get_sample_num(self) -> int: + def _get_data_sizes(self) -> Tuple[int, int, int]: """Determine the number of samples in the dataset and return the number. Returns ------- - sample_num : + n_samples : The number of the samples in the given dataset. """ + if isinstance(self.data, str): if self.file_handle is None: self.file_handle = self._open_file_handle() - sample_num = len(self.file_handle["X"]) + n_samples = len(self.file_handle["X"]) + first_sample = self.file_handle["X"][0] + n_steps = len(first_sample) + n_features = first_sample.shape[-1] else: - sample_num = len(self.X) + n_samples = len(self.X) + n_steps = len(self.X[0]) + n_features = self.X[0].shape[-1] - return sample_num + return n_samples, n_steps, n_features def __len__(self) -> int: - return self.sample_num + return self.n_samples @staticmethod def _check_input( diff --git a/pypots/data/generating.py b/pypots/data/generating.py index e80efe49..f0a20473 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -26,26 +26,26 @@ def gene_complete_random_walk( std: float = 1.0, random_state: Optional[int] = None, ) -> np.ndarray: - """Generate complete random walk time-series data. + """Generate complete random walk time-series data, i.e. having no missing values. Parameters ---------- - n_samples : + n_samples : int, default=1000 The number of training time-series samples to generate. n_steps: int, default=24 The number of time steps (length) of generated time-series samples. - n_features : + n_features : int, default=10 The number of features (dimensions) of generated time-series samples. - mu : + mu : float, default=0.0 Mean of the normal distribution, which random walk steps are sampled from. - std : + std : float, default=1.0 Standard deviation of the normal distribution, which random walk steps are sampled from. - random_state : + random_state : int, default=None Random seed for data generation. Returns @@ -63,7 +63,7 @@ def gene_complete_random_walk( return ts_samples -def gene_random_walk_for_classification( +def gene_complete_random_walk_for_classification( n_classes: int = 2, n_samples_each_class: int = 500, n_steps: int = 24, @@ -75,37 +75,39 @@ def gene_random_walk_for_classification( Parameters ---------- - n_classes : + n_classes : int, must >=1, default=2 Number of classes (types) of the generated data. - n_samples_each_class : + n_samples_each_class : int, default=500 Number of samples for each class to generate. - n_steps : + n_steps : int, default=24 Number of time steps in each sample. - n_features : + n_features : int, default=10 Number of features. - shuffle : + shuffle : bool, default=True Whether to shuffle generated samples. If not, you can separate samples of each class according to `n_samples_each_class`. For example, X_class0=X[:n_samples_each_class], X_class1=X[n_samples_each_class:n_samples_each_class*2] - random_state : + random_state : int, default=None Random seed for data generation. Returns ------- - X : + X : array, shape of [n_samples, n_steps, n_features] Generated time-series data. - y : + y : array, shape of [n_samples] Labels indicating classes of time-series samples. """ + assert n_classes > 1, f"n_classes should be >1, but got {n_classes}" + ts_collector = [] label_collector = [] @@ -149,39 +151,39 @@ def gene_complete_random_walk_for_anomaly_detection( Parameters ---------- - n_samples : + n_samples : int, default=1000 The number of training time-series samples to generate. - n_features : + n_features : int, default=10 The number of features (dimensions) of generated time-series samples. n_steps: int, default=24 The number of time steps (length) of generated time-series samples. - mu : + mu : float, default=0.0 Mean of the normal distribution, which random walk steps are sampled from. - std : + std : float, default=1.0 Standard deviation of the normal distribution, which random walk steps are sampled from. - anomaly_proportion : + anomaly_proportion : float, default=0.1 Proportion of anomaly samples in all samples. - anomaly_fraction : + anomaly_fraction : float, default=0.02 Fraction of anomaly points in each anomaly sample. - anomaly_scale_factor : + anomaly_scale_factor : float, default=2.0 Scale factor for value scaling to create anomaly points in time series samples. - random_state : + random_state : int, default=None Random seed for data generation. Returns ------- - X : + X : array, shape of [n_samples, n_steps, n_features] Generated time-series data. - y : + y : array, shape of [n_samples] Labels indicating if time-series samples are anomalies. """ assert ( @@ -225,35 +227,41 @@ def gene_complete_random_walk_for_anomaly_detection( return X, y -def gene_incomplete_random_walk_dataset( - n_steps=24, n_features=10, n_classes=2, n_samples_each_class=1000, missing_rate=0.1 +def gene_random_walk( + n_steps=24, + n_features=10, + n_classes=2, + n_samples_each_class=1000, + missing_rate=0.1, ) -> dict: """Generate a random-walk data. Parameters ---------- - n_steps : + n_steps : int, default=24 Number of time steps in each sample. - n_features : + n_features : int, default=10 Number of features. - n_classes : + n_classes : int, default=2 Number of classes (types) of the generated data. - n_samples_each_class : + n_samples_each_class : int, default=1000 Number of samples for each class to generate. - missing_rate : - The rate of randomly missing values to generate. + missing_rate : float, default=0.1 + The rate of randomly missing values to generate, should be in [0,1). Returns ------- data: dict, A dictionary containing the generated data. """ + assert 0 <= missing_rate < 1, "missing_rate must be in [0,1)" + # generate samples - X, y = gene_random_walk_for_classification( + X, y = gene_complete_random_walk_for_classification( n_classes=n_classes, n_samples_each_class=n_samples_each_class, n_steps=n_steps, @@ -262,12 +270,14 @@ def gene_incomplete_random_walk_dataset( # split into train/val/test sets train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2) train_X, val_X, train_y, val_y = train_test_split(train_X, train_y, test_size=0.2) - # create random missing values - _, train_X, missing_mask, _ = mcar(train_X, missing_rate) - train_X = masked_fill(train_X, 1 - missing_mask, torch.nan) - _, val_X, missing_mask, _ = mcar(val_X, missing_rate) - val_X = masked_fill(val_X, 1 - missing_mask, torch.nan) - # test set is left to mask after normalization + + if missing_rate > 0: + # create random missing values + _, train_X, missing_mask, _ = mcar(train_X, missing_rate) + train_X = masked_fill(train_X, 1 - missing_mask, torch.nan) + _, val_X, missing_mask, _ = mcar(val_X, missing_rate) + val_X = masked_fill(val_X, 1 - missing_mask, torch.nan) + # test set is left to mask after normalization train_X = train_X.reshape(-1, n_features) val_X = val_X.reshape(-1, n_features) @@ -281,19 +291,6 @@ def gene_incomplete_random_walk_dataset( train_X = train_X.reshape(-1, n_steps, n_features) val_X = val_X.reshape(-1, n_steps, n_features) test_X = test_X.reshape(-1, n_steps, n_features) - - # mask values in the validation set as ground truth - val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( - val_X, missing_rate - ) - val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) - - # mask values in the test set as ground truth - test_X_intact, test_X, test_X_missing_mask, test_X_indicating_mask = mcar( - test_X, 0.3 - ) - test_X = masked_fill(test_X, 1 - test_X_missing_mask, torch.nan) - data = { "n_classes": n_classes, "n_steps": n_steps, @@ -302,13 +299,30 @@ def gene_incomplete_random_walk_dataset( "train_y": train_y, "val_X": val_X, "val_y": val_y, - "val_X_intact": val_X_intact, - "val_X_indicating_mask": val_X_indicating_mask, "test_X": test_X, "test_y": test_y, - "test_X_intact": test_X_intact, - "test_X_indicating_mask": test_X_indicating_mask, + "scaler": scaler, } + + if missing_rate > 0: + # mask values in the validation set as ground truth + val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( + val_X, missing_rate + ) + val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) + + # mask values in the test set as ground truth + test_X_intact, test_X, test_X_missing_mask, test_X_indicating_mask = mcar( + test_X, 0.3 + ) + test_X = masked_fill(test_X, 1 - test_X_missing_mask, torch.nan) + + data["val_X"] = val_X + data["val_X_intact"] = val_X_intact + data["val_X_indicating_mask"] = val_X_indicating_mask + data["test_X"] = test_X + data["test_X_intact"] = test_X_intact + data["test_X_indicating_mask"] = test_X_indicating_mask return data @@ -317,7 +331,7 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): Parameters ---------- - artificially_missing_rate : + artificially_missing_rate : float, default=0.1 The rate of artificially missing values to generate for model evaluation. This ratio is calculated based on the number of observed values, i.e. if artificially_missing_rate = 0.1, then 10% of the observed values will be randomly masked as missing data and hold out for model evaluation. diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 150aeacf..7189dabb 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -88,6 +88,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def forecast( self, @@ -109,6 +117,8 @@ def forecast( array-like, shape [n_samples, prediction_horizon, n_features], Forecasting results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -380,4 +390,6 @@ def forecast( array-like, shape [n_samples, prediction_horizon, n_features], Forecasting results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/forecasting/bttf/model.py b/pypots/forecasting/bttf/model.py index 500412a9..3f5aa9e6 100644 --- a/pypots/forecasting/bttf/model.py +++ b/pypots/forecasting/bttf/model.py @@ -34,6 +34,7 @@ ar4cast, ) from ..base import BaseForecaster +from ...utils.logging import logger def _BTTF( @@ -311,6 +312,13 @@ class BTTF(BaseForecaster): 2). ``n_steps - pred_step`` must be larger than ``max(time_lags)``; + References + ---------- + .. [1] `Chen, Xinyu, and Lijun Sun. + "Bayesian temporal factorization for multidimensional time series prediction." + IEEE Transactions on Pattern Analysis and Machine Intelligence 44, no. 9 (2021): 4659-4673. + `_ + """ def __init__( @@ -351,16 +359,16 @@ def fit( """ warnings.warn("Please run func forecast(X) directly.") - def forecast( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", - ) -> np.ndarray: + ) -> dict: assert not isinstance( - X, str + test_set, str ), "BTTF so far does not accept file input. It needs a specified Dataset class." - X = X["X"] + X = test_set["X"] X = X.transpose((0, 2, 1)) pred = BTTF_forecast( @@ -373,5 +381,20 @@ def forecast( self.burn_iter, self.gibbs_iter, ) - pred = pred.transpose((0, 2, 1)) - return pred + forecasting = pred.transpose((0, 2, 1)) + result_dict = { + "forecasting": forecasting, + } + return result_dict + + def forecast( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + forecasting = result_dict["forecasting"] + return forecasting diff --git a/pypots/forecasting/template/model.py b/pypots/forecasting/template/model.py index f761d795..3b665117 100644 --- a/pypots/forecasting/template/model.py +++ b/pypots/forecasting/template/model.py @@ -99,5 +99,9 @@ def fit( ) -> None: raise NotImplementedError - def forecast(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 24bb4a70..8ee46619 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -95,6 +95,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def impute( self, @@ -117,6 +125,8 @@ def impute( array-like, shape [n_samples, sequence length (time steps), n_features], Imputed data. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -409,4 +419,6 @@ def impute( array-like, shape [n_samples, sequence length (time steps), n_features], Imputed data. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 0ce03f97..54a47331 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -27,6 +27,7 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_mae @@ -394,13 +395,12 @@ class BRITS(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Cao, Wei, Dong Wang, Jian Li, Hao Zhou, Lei Li, and Yitan Li. + "Brits: Bidirectional recurrent imputation for time series." + Advances in neural information processing systems 31 (2018). + `_ """ @@ -528,13 +528,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute( + def predict( self, - X: Union[dict, str], - file_type="h5py", - ) -> np.ndarray: + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type) + test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -550,5 +550,19 @@ def impute( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 350ff14e..56602c72 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -30,6 +30,7 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _GPVAE(nn.Module): @@ -173,7 +174,6 @@ def forward(self, inputs, training=True): @staticmethod def kl_divergence(a, b): - # TODO: different from the author's implementation return torch.distributions.kl.kl_divergence(a, b) def _init_prior(self): @@ -222,36 +222,36 @@ def _init_prior(self): class GPVAE(BaseNNImputer): - """The PyTorch implementation of the GPVAE model :cite:``. + """The PyTorch implementation of the GPVAE model :cite:`fortuin2020GPVAEDeep`. Parameters ---------- - beta: + beta: float The weight of KL divergence in EBLO. - kernel: + kernel: str The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"] - batch_size : + batch_size : int The batch size for training and evaluating the model. - epochs : + epochs : int The number of epochs for training the model. - patience : + patience : int The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. - optimizer : + optimizer : pypots.optim.base.Optimizer The optimizer for model training. If not given, will use a default Adam optimizer. - num_workers : + num_workers : int The number of subprocesses to use for data loading. `0` means data loading will be in the main process, i.e. there won't be subprocesses. - device : + device : :class:`torch.device` or list The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. @@ -259,24 +259,24 @@ class GPVAE(BaseNNImputer): model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. - saving_path : + saving_path : str The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. - model_saving_strategy : + model_saving_strategy : str The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying GPVAE model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S.. (2020). + "GP-VAE: Deep Probabilistic Time Series Imputation". + Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, + in Proceedings of Machine Learning Research 108:1651-1661 + `_ """ @@ -420,13 +420,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type="h5py", - ) -> np.ndarray: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGPVAE(X, return_labels=False, file_type=file_type) + test_set = DatasetForGPVAE(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -442,5 +442,19 @@ def impute( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/locf/model.py b/pypots/imputation/locf/model.py index b1a23a98..286cb0ec 100644 --- a/pypots/imputation/locf/model.py +++ b/pypots/imputation/locf/model.py @@ -13,6 +13,7 @@ import torch from ..base import BaseImputer +from ...utils.logging import logger class LOCF(BaseImputer): @@ -121,13 +122,13 @@ def _locf_torch(self, X: torch.Tensor) -> torch.Tensor: return X_imputed - def impute( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", - ) -> np.ndarray: - assert not isinstance(X, str) - X = X["X"] + ) -> dict: + assert not isinstance(test_set, str) + X = test_set["X"] assert len(X.shape) == 3, ( f"Input X should have 3 dimensions [n_samples, n_steps, n_features], " @@ -145,4 +146,18 @@ def impute( "X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}" ) - return imputed_data + result_dict = { + "imputation": imputed_data, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index 5d50cc32..afdde3b4 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -21,6 +21,7 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_rmse @@ -152,13 +153,13 @@ class MRNN(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying BRITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `J. Yoon, W. R. Zame and M. van der Schaar, + "Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks," + in IEEE Transactions on Biomedical Engineering, + vol. 66, no. 5, pp. 1477-1490, May 2019, doi: 10.1109/TBME.2018.2874712. + `_ """ @@ -286,13 +287,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type="h5py", - ) -> np.ndarray: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForMRNN(X, return_labels=False, file_type=file_type) + test_set = DatasetForMRNN(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -308,5 +309,19 @@ def impute( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 85731df7..21e72cff 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -28,6 +28,7 @@ from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_mae @@ -280,13 +281,12 @@ class SAITS(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying SAITS model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Du, Wenjie, David Côté, and Yan Liu. + "Saits: Self-attention-based imputation for time series". + Expert Systems with Applications 219 (2023): 119619. + `_ """ @@ -440,15 +440,15 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", diagonal_attention_mask: bool = True, - ) -> np.ndarray: + ) -> dict: # Step 1: wrap the input data with classes Dataset and DataLoader self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X, return_labels=False, file_type=file_type) + test_set = BaseDataset(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -468,6 +468,19 @@ def impute( imputation_collector.append(imputed_data) # Step 3: output collection and return - imputation_collector = torch.cat(imputation_collector) - imputed_data = imputation_collector.cpu().detach().numpy() - return imputed_data + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/template/model.py b/pypots/imputation/template/model.py index bb9107d7..6f229baa 100644 --- a/pypots/imputation/template/model.py +++ b/pypots/imputation/template/model.py @@ -99,5 +99,9 @@ def fit( ) -> None: raise NotImplementedError - def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index dfc925ad..e7f9c02b 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -27,6 +27,7 @@ from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_mae @@ -201,13 +202,18 @@ class Transformer(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. + .. [1] `Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, + and Illia Polosukhin. + "Attention is all you need." + Advances in neural information processing systems 30 (2017). + `_ - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [2] `Du, Wenjie, David Côté, and Yan Liu. + "Saits: Self-attention-based imputation for time series". + Expert Systems with Applications 219 (2023): 119619. + `_ """ @@ -358,9 +364,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X, return_labels=False, file_type=file_type) + test_set = BaseDataset(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -376,5 +386,19 @@ def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index c171d810..f1a6ab27 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -101,33 +101,7 @@ def forward( class _USGAN(nn.Module): - """model USGAN: - USGAN consists of a generator, a discriminator, which are all built on bidirectional recurrent neural networks. - - Attributes - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell - - lambda_mse : - the weigth of the reconstruction loss - - hint_rate : - the hint rate for the discriminator - - dropout_rate : - the dropout rate for the last layer in Discriminator - - device : - specify running the model on which device, CPU/GPU - - """ + """USGAN model""" def __init__( self, @@ -192,58 +166,58 @@ def forward( class USGAN(BaseNNImputer): - """The PyTorch implementation of the CRLI model :cite:`ma2021CRLI`. + """The PyTorch implementation of the USGAN model. Refer to :cite:`miao2021SSGAN`. Parameters ---------- - n_steps : + n_steps : int The number of time steps in the time-series data sample. - n_features : + n_features : int The number of features in the time-series data sample. - rnn_hidden_size : - the hidden size of the RNN cell + rnn_hidden_size : int + The hidden size of the RNN cell - lambda_mse : - the weight of the reconstruction loss + lambda_mse : float + The weight of the reconstruction loss - hint_rate : - the hint rate for the discriminator + hint_rate : float + The hint rate for the discriminator - dropout_rate : - the dropout rate for the last layer in Discriminator + dropout_rate : float + The dropout rate for the last layer in Discriminator - G_steps : + G_steps : int The number of steps to train the generator in each iteration. - D_steps : + D_steps : int The number of steps to train the discriminator in each iteration. - batch_size : + batch_size : int The batch size for training and evaluating the model. - epochs : + epochs : int The number of epochs for training the model. - patience : + patience : int The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. - G_optimizer : + G_optimizer : :class:`pypots.optim.Optimizer` The optimizer for the generator training. If not given, will use a default Adam optimizer. - D_optimizer : + D_optimizer : :class:`pypots.optim.Optimizer` The optimizer for the discriminator training. If not given, will use a default Adam optimizer. - num_workers : + num_workers : int The number of subprocesses to use for data loading. `0` means data loading will be in the main process, i.e. there won't be subprocesses. - device : + device : Union[str, torch.device, list] The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. @@ -251,24 +225,23 @@ class USGAN(BaseNNImputer): model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. - saving_path : + saving_path : str The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. - model_saving_strategy : + model_saving_strategy : str The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying CRLI model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Miao, Xiaoye, Yangyang Wu, Jun Wang, Yunjun Gao, Xudong Mao, and Jianwei Yin. 2021. + "Generative Semi-Supervised Learning for Multivariate Time Series Imputation". + Proceedings of the AAAI Conference on Artificial Intelligence 35 (10):8983-91. + `_ """ @@ -513,13 +486,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def impute( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type="h5py", - ) -> np.ndarray: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForUSGAN(X, return_labels=False, file_type=file_type) + test_set = DatasetForUSGAN(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -535,5 +508,19 @@ def impute( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/optim/adadelta.py b/pypots/optim/adadelta.py index 59e98f2a..4ff0037d 100644 --- a/pypots/optim/adadelta.py +++ b/pypots/optim/adadelta.py @@ -15,23 +15,25 @@ class Adadelta(Optimizer): - """The optimizer wrapper for PyTorch Adadelta. - https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html#torch.optim.Adadelta + """The optimizer wrapper for PyTorch Adadelta :class:`torch.optim.Adadelta`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - rho : + rho : float Coefficient used for computing a running average of squared gradients. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/adagrad.py b/pypots/optim/adagrad.py index 8a10f06c..b25efbc1 100644 --- a/pypots/optim/adagrad.py +++ b/pypots/optim/adagrad.py @@ -15,26 +15,28 @@ class Adagrad(Optimizer): - """The optimizer wrapper for PyTorch Adagrad. - https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html#torch.optim.Adagrad + """The optimizer wrapper for PyTorch Adagrad :class:`torch.optim.Adagrad`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - lr_decay : + lr_decay : float Learning rate decay. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - eps : + eps : float Term added to the denominator to improve numerical stability. - initial_accumulator_value : + initial_accumulator_value : float A floating point value. Starting value for the accumulators, must be positive. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/adam.py b/pypots/optim/adam.py index c5e0e1af..9817b50e 100644 --- a/pypots/optim/adam.py +++ b/pypots/optim/adam.py @@ -15,25 +15,27 @@ class Adam(Optimizer): - """The optimizer wrapper for PyTorch Adam. - https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam + """The optimizer wrapper for PyTorch Adam :class:`torch.optim.Adam`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - betas : + betas : Tuple[float, float] Coefficients used for computing running averages of gradient and its square. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - amsgrad : + amsgrad : bool Whether to use the AMSGrad variant of this algorithm from the paper :cite:`reddi2018OnTheConvergence`. + + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. """ def __init__( diff --git a/pypots/optim/adamw.py b/pypots/optim/adamw.py index 6a5191e4..26887b2c 100644 --- a/pypots/optim/adamw.py +++ b/pypots/optim/adamw.py @@ -15,25 +15,28 @@ class AdamW(Optimizer): - """The optimizer wrapper for PyTorch AdamW. - https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW + """The optimizer wrapper for PyTorch AdamW :class:`torch.optim.AdamW`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - betas : + betas : Tuple[float, float] Coefficients used for computing running averages of gradient and its square. - eps : + eps : float Term added to the denominator to improve numerical stability. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - amsgrad : + amsgrad : bool Whether to use the AMSGrad variant of this algorithm from the paper :cite:`reddi2018OnTheConvergence`. + + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/base.py b/pypots/optim/base.py index db09fb3a..6a57ab7e 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -23,13 +23,16 @@ class Optimizer(ABC): - """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim. + """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in PyPOTS. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + Attributes ---------- torch_optimizer : @@ -95,7 +98,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ---------- closure : A closure that reevaluates the model and returns the loss. Optional for most optimizers. - Refer to the torch.optim.Optimizer.step() docstring for more details. + Refer to the :class:`torch.optim.Optimizer.step()` docstring for more details. """ self.torch_optimizer.step(closure) diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py index ddb14350..89015847 100644 --- a/pypots/optim/lr_scheduler/__init__.py +++ b/pypots/optim/lr_scheduler/__init__.py @@ -1,9 +1,9 @@ """ Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, the only difference that is also why we implement them is that you don't have to pass according optimizers -into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer -after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize -schedulers together with optimizers. +into them immediately while initializing them. Instead, you can pass them into :class:`pypots.optim.base.Optimizer` +after initialization and call their `init_scheduler()` method in :class:`pypots.optim.base.Optimizer.init_optimizer()` +to initialize schedulers together with optimizers. """ # Created by Wenjie Du diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py index 0aeffd8b..9c787ae7 100644 --- a/pypots/optim/lr_scheduler/base.py +++ b/pypots/optim/lr_scheduler/base.py @@ -37,12 +37,12 @@ def __init__(self, last_epoch=-1, verbose=False): self._step_count = 0 def init_scheduler(self, optimizer): - """Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() - to initialize the scheduler together with the optimizer. + """Initialize the scheduler. This method should be called in + :class:`pypots.optim.base.Optimizer.init_optimizer()` to initialize the scheduler together with the optimizer. Parameters ---------- - optimizer: torch.optim.Optimizer, + optimizer: torch.optim.Optimizer The optimizer to be scheduled. """ @@ -113,8 +113,9 @@ def print_lr(is_verbose, group, lr): logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") def step(self): - """Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` - after ``pypots.optim.Optimizer.torch_optimizer.step()``. + """Step could be called after every batch update. + This should be called in :class:`pypots.optim.base.Optimizer.step()` after + :class:`pypots.optim.base.Optimizer.torch_optimizer.step()`. """ # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 diff --git a/pypots/optim/rmsprop.py b/pypots/optim/rmsprop.py index f00da68d..9451c0a4 100644 --- a/pypots/optim/rmsprop.py +++ b/pypots/optim/rmsprop.py @@ -15,29 +15,31 @@ class RMSprop(Optimizer): - """The optimizer wrapper for PyTorch RMSprop. - https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop + """The optimizer wrapper for PyTorch RMSprop :class:`torch.optim.RMSprop`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - momentum : + momentum : float Momentum factor. - alpha : + alpha : float Smoothing constant. - eps : + eps : float Term added to the denominator to improve numerical stability. - centered : + centered : bool If True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance - weight_decay : + weight_decay : float Weight decay (L2 penalty). + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/optim/sgd.py b/pypots/optim/sgd.py index 34cd07f0..b31baf5f 100644 --- a/pypots/optim/sgd.py +++ b/pypots/optim/sgd.py @@ -1,5 +1,5 @@ """ -The optimizer wrapper for PyTorch SGD. +The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`. """ @@ -15,26 +15,28 @@ class SGD(Optimizer): - """The optimizer wrapper for PyTorch SGD. - https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD + """The optimizer wrapper for PyTorch SGD :class:`torch.optim.SGD`. Parameters ---------- - lr : + lr : float The learning rate of the optimizer. - momentum : + momentum : float Momentum factor. - weight_decay : + weight_decay : float Weight decay (L2 penalty). - dampening : + dampening : float Dampening for momentum. - nesterov : + nesterov : bool Whether to enable Nesterov momentum. + lr_scheduler : pypots.optim.lr_scheduler.base.LRScheduler + The learning rate scheduler of the optimizer. + """ def __init__( diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index ac239648..b6dbd5ae 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -497,6 +497,15 @@ def cal_rand_index( RI : Rand index. + References + ---------- + .. L. Hubert and P. Arabie, Comparing Partitions, Journal of + Classification 1985 + https://link.springer.com/article/10.1007%2FBF01908075 + + .. https://en.wikipedia.org/wiki/Simple_matching_coefficient + + .. https://en.wikipedia.org/wiki/Rand_index """ # # detailed implementation # n = len(targets) @@ -523,7 +532,7 @@ def cal_adjusted_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: - """Calculate adjusted Rand Index. Refer to :cite:`hubert1985AdjustedRI`. + """Calculate adjusted Rand Index. Parameters ---------- @@ -538,6 +547,18 @@ def cal_adjusted_rand_index( aRI : Adjusted Rand index. + References + ---------- + .. [1] `L. Hubert and P. Arabie, Comparing Partitions, + Journal of Classification 1985 + `_ + + .. [2] `D. Steinley, Properties of the Hubert-Arabie + adjusted Rand index, Psychological Methods 2004 + `_ + + .. [3] https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index + """ aRI = metrics.adjusted_rand_score(targets, class_predictions) return aRI @@ -644,7 +665,17 @@ def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: Returns ------- silhouette_score : float - Mean Silhouette Coefficient for all samples. + Mean Silhouette Coefficient for all samples. In short, the higher, the better. + + References + ---------- + .. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the + Interpretation and Validation of Cluster Analysis". Computational + and Applied Mathematics 20: 53-65. + `_ + + .. [2] `Wikipedia entry on the Silhouette Coefficient + `_ """ silhouette_score = metrics.silhouette_score(X, predicted_labels) @@ -659,10 +690,17 @@ def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: predicted_labels : array-like of shape (n_samples) Predicted labels for each sample. + Returns ------- calinski_harabasz_score : float - The resulting Calinski-Harabasz score. + The resulting Calinski-Harabasz score. In short, the higher, the better. + + References + ---------- + .. [1] `T. Calinski and J. Harabasz, 1974. "A dendrite method for cluster + analysis". Communications in Statistics + `_ """ calinski_harabasz_score = metrics.calinski_harabasz_score(X, predicted_labels) @@ -683,7 +721,15 @@ def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: Returns ------- davies_bouldin_score : float - The resulting Davies-Bouldin score. + The resulting Davies-Bouldin score. In short, the lower, the better. + + References + ---------- + .. [1] `Davies, David L.; Bouldin, Donald W. (1979). + "A Cluster Separation Measure" + IEEE Transactions on Pattern Analysis and Machine Intelligence. + PAMI-1 (2): 224-227 + `_ """ davies_bouldin_score = metrics.davies_bouldin_score(X, predicted_labels) diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 191f58c8..2385d1e5 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -43,12 +43,27 @@ class TestCRLI(unittest.TestCase): D_optimizer = Adam(lr=0.001, weight_decay=1e-5) # initialize a CRLI model - crli = CRLI( + crli_gru = CRLI( n_steps=DATA["n_steps"], n_features=DATA["n_features"], n_clusters=DATA["n_classes"], n_generator_layers=2, rnn_hidden_size=128, + rnn_cell_type="GRU", + epochs=EPOCHS, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + device=DEVICE, + ) + + crli_lstm = CRLI( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + n_clusters=DATA["n_classes"], + n_generator_layers=2, + rnn_hidden_size=128, + rnn_cell_type="LSTM", epochs=EPOCHS, saving_path=saving_path, G_optimizer=G_optimizer, @@ -58,34 +73,80 @@ class TestCRLI(unittest.TestCase): @pytest.mark.xdist_group(name="clustering-crli") def test_0_fit(self): - self.crli.fit(TRAIN_SET) + logger.info("Training CRLI-GRU...") + self.crli_gru.fit(TRAIN_SET) + logger.info("Training CRLI-LSTM...") + self.crli_lstm.fit(TRAIN_SET) @pytest.mark.xdist_group(name="clustering-crli") def test_1_parameters(self): - assert hasattr(self.crli, "model") and self.crli.model is not None + # GRU cell + assert hasattr(self.crli_gru, "model") and self.crli_gru.model is not None - assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None - assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None + assert ( + hasattr(self.crli_gru, "G_optimizer") + and self.crli_gru.G_optimizer is not None + ) + assert ( + hasattr(self.crli_gru, "D_optimizer") + and self.crli_gru.D_optimizer is not None + ) - assert hasattr(self.crli, "best_loss") - self.assertNotEqual(self.crli.best_loss, float("inf")) + assert hasattr(self.crli_gru, "best_loss") + self.assertNotEqual(self.crli_gru.best_loss, float("inf")) assert ( - hasattr(self.crli, "best_model_dict") - and self.crli.best_model_dict is not None + hasattr(self.crli_gru, "best_model_dict") + and self.crli_gru.best_model_dict is not None + ) + + # LSTM cell + assert hasattr(self.crli_lstm, "model") and self.crli_lstm.model is not None + + assert ( + hasattr(self.crli_lstm, "G_optimizer") + and self.crli_lstm.G_optimizer is not None + ) + assert ( + hasattr(self.crli_lstm, "D_optimizer") + and self.crli_lstm.D_optimizer is not None + ) + + assert hasattr(self.crli_lstm, "best_loss") + self.assertNotEqual(self.crli_lstm.best_loss, float("inf")) + + assert ( + hasattr(self.crli_lstm, "best_model_dict") + and self.crli_lstm.best_model_dict is not None ) @pytest.mark.xdist_group(name="clustering-crli") def test_2_cluster(self): - clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True) + # GRU cell + clustering, latent_collector = self.crli_gru.cluster( + TEST_SET, return_latent=True + ) + external_metrics = cal_external_cluster_validation_metrics( + clustering, DATA["test_y"] + ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["clustering_latent"], DATA["test_y"] + ) + logger.info(f"CRLI-GRU: {external_metrics}") + logger.info(f"CRLI-GRU:{internal_metrics}") + + # LSTM cell + clustering, latent_collector = self.crli_lstm.cluster( + TEST_SET, return_latent=True + ) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) internal_metrics = cal_internal_cluster_validation_metrics( latent_collector["clustering_latent"], DATA["test_y"] ) - logger.info(f"{external_metrics}") - logger.info(f"{internal_metrics}") + logger.info(f"CRLI-LSTM: {external_metrics}") + logger.info(f"CRLI-LSTM: {internal_metrics}") @pytest.mark.xdist_group(name="clustering-crli") def test_3_saving_path(self): @@ -95,16 +156,16 @@ def test_3_saving_path(self): ), f"file {self.saving_path} does not exist" # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.crli) + check_tb_and_model_checkpoints_existence(self.crli_gru) # save the trained model into file, and check if the path exists - self.crli.save_model( + self.crli_gru.save_model( saving_dir=self.saving_path, file_name=self.model_save_name ) # test loading the saved model, not necessary, but need to test saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.crli.load_model(saved_model_path) + self.crli_gru.load_model(saved_model_path) if __name__ == "__main__": diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 5e152734..62ad73bb 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -9,12 +9,18 @@ import torch -from pypots.data.generating import gene_incomplete_random_walk_dataset +from pypots.data.generating import gene_random_walk from pypots.utils.logging import logger # Generate the unified data for testing and cache it first, DATA here is a singleton # Otherwise, file lock will cause bug if running test parallely with pytest-xdist. -DATA = gene_incomplete_random_walk_dataset() +DATA = gene_random_walk( + n_steps=24, + n_features=10, + n_classes=2, + n_samples_each_class=1000, + missing_rate=0.1, +) # The directory for saving the dataset into files for testing DATA_SAVING_DIR = "h5data_for_tests"