From 37a9e4562ab384dc1b908af83219d86720e54f81 Mon Sep 17 00:00:00 2001 From: Jun Wang Date: Mon, 21 Aug 2023 21:57:53 +0800 Subject: [PATCH 01/17] feat: add model gpvae, usgan --- pypots/base.py | 5 +- pypots/classification/base.py | 1 - pypots/classification/raindrop/modules.py | 1 - pypots/clustering/base.py | 1 - pypots/clustering/crli/model.py | 1 - pypots/clustering/vader/model.py | 2 - pypots/forecasting/base.py | 1 - pypots/imputation/__init__.py | 10 +- pypots/imputation/gpvae/__init__.py | 12 + pypots/imputation/gpvae/data.py | 144 ++++ pypots/imputation/gpvae/model.py | 399 +++++++++++ pypots/imputation/gpvae/modules.py | 234 ++++++ pypots/imputation/usgan/__init__.py | 12 + pypots/imputation/usgan/data.py | 168 +++++ pypots/imputation/usgan/model.py | 832 ++++++++++++++++++++++ pypots/imputation/usgan/modules.py | 140 ++++ 16 files changed, 1947 insertions(+), 16 deletions(-) create mode 100644 pypots/imputation/gpvae/__init__.py create mode 100644 pypots/imputation/gpvae/data.py create mode 100644 pypots/imputation/gpvae/model.py create mode 100644 pypots/imputation/gpvae/modules.py create mode 100644 pypots/imputation/usgan/__init__.py create mode 100644 pypots/imputation/usgan/data.py create mode 100644 pypots/imputation/usgan/model.py create mode 100644 pypots/imputation/usgan/modules.py diff --git a/pypots/base.py b/pypots/base.py index ad11eda7..472f338c 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -96,7 +96,9 @@ def _setup_device(self, device: Union[None, str, torch.device, list]): self.device = device elif isinstance(device, list): if len(device) == 0: - raise ValueError("The list of devices should have at least 1 device, but got 0.") + raise ValueError( + "The list of devices should have at least 1 device, but got 0." + ) elif len(device) == 1: return self._setup_device(device[0]) # parallely training on multiple CUDA devices @@ -176,7 +178,6 @@ def _send_data_to_given_device(self, data): if isinstance(self.device, torch.device): # single device data = map(lambda x: x.to(self.device), data) else: # parallely training on multiple devices - # randomly choose one device to balance the workload # device = np.random.choice(self.device) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index a30fd698..a16dbc01 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -256,7 +256,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/classification/raindrop/modules.py b/pypots/classification/raindrop/modules.py index 76a992ef..191ff9c7 100644 --- a/pypots/classification/raindrop/modules.py +++ b/pypots/classification/raindrop/modules.py @@ -174,7 +174,6 @@ def forward( edge_attr: OptTensor = None, return_attention_weights=None, ) -> Tuple[torch.Tensor, Any]: - r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index 324e6718..fd9b7f0d 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -244,7 +244,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - """ Parameters diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index b5e2e14a..8b7a63a1 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -226,7 +226,6 @@ def __init__( saving_path: Optional[str] = None, model_saving_strategy: Optional[str] = "best", ): - super().__init__( n_clusters, batch_size, diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index f2912cce..5a44da85 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -184,7 +184,6 @@ def forward( ) = self.get_results(X, missing_mask) if not training and not pretrain: - results = { "mu_tilde": mu_tilde, "mu": mu_c, @@ -403,7 +402,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 5188999b..079f5925 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -242,7 +242,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 9de8d0bc..3d513430 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -10,11 +10,7 @@ from .saits import SAITS from .transformer import Transformer from .mrnn import MRNN +from .gpvae import GPVAE +from .usgan import USGAN -__all__ = [ - "SAITS", - "Transformer", - "BRITS", - "MRNN", - "LOCF", -] +__all__ = ["SAITS", "Transformer", "BRITS", "MRNN", "LOCF", "GPVAE" "USGAN"] diff --git a/pypots/imputation/gpvae/__init__.py b/pypots/imputation/gpvae/__init__.py new file mode 100644 index 00000000..f5ffb05e --- /dev/null +++ b/pypots/imputation/gpvae/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method GP-VAE. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from .model import GPVAE + +__all__ = [ + "GPVAE", +] diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py new file mode 100644 index 00000000..de7d7747 --- /dev/null +++ b/pypots/imputation/gpvae/data.py @@ -0,0 +1,144 @@ +""" +Dataset class for model GP-VAE. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from typing import Union, Iterable + +import torch + +from ...data.base import BaseDataset +from ...data.utils import torch_parse_delta + + +class DatasetForGPVAE(BaseDataset): + """Dataset class for BRITS. + + Parameters + ---------- + data : dict or str, + The dataset for model input, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for input, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + return_labels : bool, default = True, + Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, + during training of classification models, the Dataset class will return labels in __getitem__() for model input. + Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we + need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 + files, they already have both X and y saved. But we don't read labels from the file for validating and testing + with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for + distinction. + + file_type : str, default = "h5py" + The type of the given file if train_set and val_set are path strings. + """ + + def __init__( + self, + data: Union[dict, str], + return_labels: bool = True, + file_type: str = "h5py", + ): + super().__init__(data, return_labels, file_type) + + if not isinstance(self.data, str): + # calculate all delta here. + missing_mask = (~torch.isnan(self.X)).type(torch.float32) + X = torch.nan_to_num(self.X) + delta = torch_parse_delta(missing_mask) + + self.processed_data = { + "X": X, + "missing_mask": missing_mask, + "delta": delta, + } + + def _fetch_data_from_array(self, idx: int) -> Iterable: + """Fetch data from self.X if it is given. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + A list contains + + index : int tensor, + The index of the sample. + + X : tensor, + The feature vector for model input. + + missing_mask : tensor, + The mask indicates all missing values in X. + + delta : tensor, + The delta matrix contains time gaps of missing values. + + label (optional) : tensor, + The target label of the time-series sample. + """ + sample = [ + torch.tensor(idx), + # for forward + self.processed_data["X"][idx].to(torch.float32), + self.processed_data["missing_mask"][idx].to(torch.float32), + self.processed_data["delta"][idx].to(torch.float32), + ] + + if self.y is not None and self.return_labels: + sample.append(self.y[idx].to(torch.long)) + + return sample + + def _fetch_data_from_file(self, idx: int) -> Iterable: + """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. + Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + The collated data sample, a list including all necessary sample info. + """ + + if self.file_handle is None: + self.file_handle = self._open_file_handle() + + X = torch.from_numpy(self.file_handle["X"][idx]) + missing_mask = (~torch.isnan(X)).to(torch.float32) + X = torch.nan_to_num(X) + + forward = { + "X": X, + "missing_mask": missing_mask, + "deltas": torch_parse_delta(missing_mask), + } + + sample = [ + torch.tensor(idx), + # for forward + forward["X"], + forward["missing_mask"], + forward["deltas"], + ] + + # if the dataset has labels and is for training, then fetch it from the file + if "y" in self.file_handle.keys() and self.return_labels: + sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long)) + + return sample diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py new file mode 100644 index 00000000..3bf5a866 --- /dev/null +++ b/pypots/imputation/gpvae/model.py @@ -0,0 +1,399 @@ +""" +The implementation of GP-VAE for the partially-observed time-series imputation task. + +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. Gp-vae: Deep probabilistic time series imputation[C]//International conference on artificial intelligence and statistics. PMLR, 2020: 1651-1661. + +Notes +----- +Pytorch implementation of the code from https://github.com/ratschlab/GP-VAE. + +""" + +# Created by Jun Wang +# License: GPL-v3 + + +from typing import Tuple, Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from .modules import * +from .data import DatasetForGPVAE +from ..base import BaseNNImputer +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.metrics import cal_mae + + +class _GPVAE(nn.Module): + def __init__( + self, + input_dim, + time_length, + latent_dim, + device, + encoder_sizes=(64, 64), + encoder=Encoder, + decoder_sizes=(64, 64), + decoder=Decoder, + beta=1, + M=1, + K=1, + kernel="cauchy", + sigma=1.0, + length_scale=7.0, + kernel_scales=1, + ): + """GPVAE model with Gaussian Process prior + :param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] + :param sigma: scale parameter for a kernel function + :param length_scale: length scale parameter for a kernel function + :param kernel_scales: number of different length scales over latent space dimensions + """ + super(_GPVAE, self).__init__() + self.kernel = kernel + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales + + # Precomputed KL components for efficiency + self.pz_scale_inv = None + self.pz_scale_log_abs_determinant = None + self.prior = None + + self.input_dim = input_dim + self.time_length = time_length + self.latent_dim = latent_dim + self.beta = beta + self.encoder = encoder(input_dim, latent_dim, encoder_sizes).to(device) + self.decoder = decoder(latent_dim, input_dim, decoder_sizes).to(device) + self.device = device + self.M = M + self.K = K + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + if not torch.is_tensor(z): + z = torch.tensor(z).float() + num_dim = len(z.shape) + assert num_dim > 2 + return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) + + def __call__(self, inputs): + return self.decoder(self.encode(inputs).sample()).sample() + + def forward(self, inputs, training=True): + x = inputs["forward"]["X"] + m_mask = inputs["forward"]["missing_mask"] + delta = inputs["forward"]["deltas"] + x = x.repeat(self.M * self.K, 1, 1) + if m_mask is not None: + m_mask = m_mask.repeat(self.M * self.K, 1, 1) + m_mask = m_mask.type(torch.bool) + + pz = self._get_prior() + qz_x = self.encode(x) + z = qz_x.rsample() + px_z = self.decode(z) + + nll = -px_z.log_prob(x) + nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) + if m_mask is not None: + nll = torch.where(m_mask, nll, torch.zeros_like(nll)) + nll = nll.sum(dim=(1, 2)) + + if self.K > 1: + kl = qz_x.log_prob(z) - pz.log_prob(z) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + weights = -nll - kl + weights = torch.reshape(weights, [self.M, self.K, -1]) + + elbo = torch.logsumexp(weights, dim=1) + elbo = elbo.mean() + else: + kl = self.kl_divergence(qz_x, pz) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + elbo = -nll - self.beta * kl + elbo = elbo.mean() + + imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + results = { + "loss": -elbo.mean(), + "imputed_data": imputed_data, + } + return results + + def kl_divergence(self, a, b): + return torch.distributions.kl.kl_divergence(a, b) + + def _get_prior(self): + if self.prior is None: + # Compute kernel matrices for each latent dimension + kernel_matrices = [] + for i in range(self.kernel_scales): + if self.kernel == "rbf": + kernel_matrices.append( + rbf_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "diffusion": + kernel_matrices.append( + diffusion_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "matern": + kernel_matrices.append( + matern_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "cauchy": + kernel_matrices.append( + cauchy_kernel( + self.time_length, self.sigma, self.length_scale / 2**i + ) + ) + + # Combine kernel matrices for each latent dimension + tiled_matrices = [] + total = 0 + for i in range(self.kernel_scales): + if i == self.kernel_scales - 1: + multiplier = self.latent_dim - total + else: + multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) + total += multiplier + tiled_matrices.append( + torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) + ) + kernel_matrix_tiled = torch.cat(tiled_matrices) + assert len(kernel_matrix_tiled) == self.latent_dim + self.prior = torch.distributions.MultivariateNormal( + loc=torch.zeros(self.latent_dim, self.time_length, device=self.device), + covariance_matrix=kernel_matrix_tiled.to(self.device), + ) + + return self.prior + + +class GPVAE(BaseNNImputer): + """The PyTorch implementation of the GPVAE model :cite:``. + + Parameters + ---------- + latent_dim : + The size of the latent variable. + + beta: + The weight of KL divergence in EBLO. + + kernel: + The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"] + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying GPVAE model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + latent_size: int, + kernel: str = "cauchy", + beta: float = 0.2, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.latent_size = latent_size + self.kernel = kernel + self.beta = beta + + # set up the model + self.model = _GPVAE( + input_dim=self.n_features, + time_length=self.n_steps, + latent_dim=self.latent_size, + kernel=self.kernel, + device=self.device, + beta=self.beta, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + # fetch data + ( + indices, + X, + missing_mask, + deltas, + ) = self._send_data_to_given_device(data) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + }, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForGPVAE( + train_set, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if isinstance(val_set, str): + with h5py.File(val_set, "r") as hf: + # Here we read the whole validation set from the file to mask a portion for validation. + # In PyPOTS, using a file usually because the data is too big. However, the validation set is + # generally shouldn't be too large. For example, we have 1 billion samples for model training. + # We won't take 20% of them as the validation set because we want as much as possible data for the + # training stage to enhance the model's generalization ability. Therefore, 100,000 representative + # samples will be enough to validate the model. + val_set = { + "X": hf["X"][:], + "X_intact": hf["X_intact"][:], + "indicating_mask": hf["indicating_mask"][:], + } + val_set = DatasetForGPVAE(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(training_finished=True) + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForGPVAE(X, return_labels=False, file_type=file_type) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py new file mode 100644 index 00000000..41e74c4b --- /dev/null +++ b/pypots/imputation/gpvae/modules.py @@ -0,0 +1,234 @@ +""" +The implementation of GP-VAE for the partially-observed time-series imputation task. + +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. Gp-vae: Deep probabilistic time series imputation[C]//International conference on artificial intelligence and statistics. PMLR, 2020: 1651-1661. + +Notes +----- +Pytorch implementation of the code from https://github.com/ratschlab/GP-VAE. + +""" + +# Created by Jun Wang +# License: GPL-v3 + +from typing import Tuple, Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.distributions.multivariate_normal import MultivariateNormal + + +def rbf_kernel(T, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = (xs_in - xs_out) ** 2 + distance_matrix_scaled = distance_matrix / length_scale**2 + kernel_matrix = torch.exp(-distance_matrix_scaled) + return kernel_matrix + + +def diffusion_kernel(T, length_scale): + assert length_scale < 0.5, ( + "length_scale has to be smaller than 0.5 for the " + "kernel matrix to be diagonally dominant" + ) + sigmas = torch.ones(T, T) * length_scale + sigmas_tridiag = torch.diagonal(sigmas, offset=0, dim1=-2, dim2=-1) + sigmas_tridiag += torch.diagonal(sigmas, offset=1, dim1=-2, dim2=-1) + sigmas_tridiag += torch.diagonal(sigmas, offset=-1, dim1=-2, dim2=-1) + kernel_matrix = sigmas_tridiag + torch.eye(T) * (1.0 - length_scale) + return kernel_matrix + + +def matern_kernel(T, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = torch.abs(xs_in - xs_out) + distance_matrix_scaled = distance_matrix / torch.sqrt(length_scale).type( + torch.float32 + ) + kernel_matrix = torch.exp(-distance_matrix_scaled) + return kernel_matrix + + +def cauchy_kernel(T, sigma, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = (xs_in - xs_out) ** 2 + distance_matrix_scaled = distance_matrix / length_scale**2 + kernel_matrix = sigma / (distance_matrix_scaled + 1.0) + + alpha = 0.001 + eye = torch.eye(kernel_matrix.shape[-1]) + return kernel_matrix + alpha * eye + + +def make_nn(input_size, output_size, hidden_sizes): + """Creates fully connected neural network + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + """ + layers = [] + for i in range(len(hidden_sizes)): + if i == 0: + layers.append( + nn.Linear(in_features=input_size, out_features=hidden_sizes[i]) + ) + else: + layers.append( + nn.Linear(in_features=hidden_sizes[i - 1], out_features=hidden_sizes[i]) + ) + layers.append(nn.ReLU()) + layers.append(nn.Linear(in_features=hidden_sizes[-1], out_features=output_size)) + return nn.Sequential(*layers) + + +class CustomConv1d(torch.nn.Conv1d): + def __init(self, in_channels, out_channels, kernal_size, padding): + super(CustomConv1d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernal_size + self.padding = padding + + def forward(self, x): + if len(x.shape) > 2: + shape = list(np.arange(len(x.shape))) + new_shape = [0, shape[-1]] + shape[1:-1] + out = super(CustomConv1d, self).forward(x.permute(*new_shape)) + shape = list(np.arange(len(out.shape))) + new_shape = [0, shape[-1]] + shape[1:-1] + if self.kernel_size[0] % 2 == 0: + out = F.pad(out, (0, -1), "constant", 0) + return out.permute(new_shape) + + return super(CustomConv1d, self).forward(x) + + +def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): + """Construct neural network consisting of + one 1d-convolutional layer that utilizes temporal dependences, + fully connected network + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + :param kernel_size: kernel size for convolutional layer + """ + padding = kernel_size // 2 + + cnn_layer = CustomConv1d( + input_size, hidden_sizes[0], kernel_size=kernel_size, padding=padding + ) + layers = [cnn_layer] + + for i, h in zip(hidden_sizes, hidden_sizes[1:]): + layers.extend([nn.Linear(i, h), nn.ReLU()]) + if isinstance(output_size, tuple): + net = nn.Sequential(*layers) + return [net] + [nn.Linear(hidden_sizes[-1], o) for o in output_size] + + layers.append(nn.Linear(hidden_sizes[-1], output_size)) + return nn.Sequential(*layers) + + +class Encoder(nn.Module): + def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): + """Encoder with 1d-convolutional network and multivariate Normal posterior + Used by GP-VAE with proposed banded covariance matrix + :param z_size: latent space dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + :param window_size: kernel size for Conv1D layer + :param data_type: needed for some data specific modifications, e.g: + tf.nn.softplus is a more common and correct choice, however + tf.nn.sigmoid provides more stable performance on Physionet dataset + """ + super(Encoder, self).__init__() + self.z_size = int(z_size) + self.input_size = input_size + self.net, self.mu_layer, self.logvar_layer = make_cnn( + input_size, (z_size, z_size * 2), hidden_sizes, window_size + ) + + def __call__(self, x): + mapped = self.net(x) + batch_size = mapped.size(0) + time_length = mapped.size(1) + + # Obtain mean and precision matrix components + num_dim = len(mapped.shape) + mu = self.mu_layer(mapped) + logvar = self.logvar_layer(mapped) + mapped_mean = torch.transpose(mu, num_dim - 1, num_dim - 2) + mapped_covar = torch.transpose(logvar, num_dim - 1, num_dim - 2) + mapped_covar = torch.sigmoid(mapped_covar) + mapped_reshaped = mapped_covar.reshape(batch_size, self.z_size, 2 * time_length) + + dense_shape = [batch_size, self.z_size, time_length, time_length] + idxs_1 = np.repeat(np.arange(batch_size), self.z_size * (2 * time_length - 1)) + idxs_2 = np.tile( + np.repeat(np.arange(self.z_size), (2 * time_length - 1)), batch_size + ) + idxs_3 = np.tile( + np.concatenate([np.arange(time_length), np.arange(time_length - 1)]), + batch_size * self.z_size, + ) + idxs_4 = np.tile( + np.concatenate([np.arange(time_length), np.arange(1, time_length)]), + batch_size * self.z_size, + ) + idxs_all = np.stack([idxs_1, idxs_2, idxs_3, idxs_4], axis=1) + + mapped_values = mapped_reshaped[:, :, :-1].reshape(-1) + prec_sparse = torch.sparse_coo_tensor( + torch.LongTensor(idxs_all).t().to(mapped.device), + (mapped_values).to(mapped.device), + (dense_shape), + ) + prec_sparse = prec_sparse.coalesce() + prec_tril = prec_sparse.to_dense() + eye = ( + torch.eye(prec_tril.shape[-1]) + .unsqueeze(0) + .repeat(prec_tril.shape[0], prec_tril.shape[1], 1, 1) + .to(mapped.device) + ) + prec_tril = prec_tril + eye + cov_tril = torch.linalg.solve_triangular(prec_tril, eye, upper=True) + cov_tril = torch.where( + torch.isfinite(cov_tril), cov_tril, torch.zeros_like(cov_tril) + ).to(mapped.device) + + num_dim = len(cov_tril.shape) + cov_tril_lower = torch.transpose(cov_tril, num_dim - 1, num_dim - 2) + + z_dist = torch.distributions.MultivariateNormal( + loc=mapped_mean, scale_tril=(cov_tril_lower) + ) + return z_dist + + +class Decoder(nn.Module): + def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): + """Decoder with Gaussian output distribution + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + """ + super(Decoder, self).__init__() + self.output_size = int(output_size) + self.net = make_nn(input_size, output_size, hidden_sizes) + + def __call__(self, x): + mu = self.net(x) + var = torch.ones_like(mu) + return torch.distributions.Normal(mu, var) diff --git a/pypots/imputation/usgan/__init__.py b/pypots/imputation/usgan/__init__.py new file mode 100644 index 00000000..fb388d94 --- /dev/null +++ b/pypots/imputation/usgan/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method USGAN. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from .model import USGAN + +__all__ = [ + "USGAN", +] diff --git a/pypots/imputation/usgan/data.py b/pypots/imputation/usgan/data.py new file mode 100644 index 00000000..2a92ee27 --- /dev/null +++ b/pypots/imputation/usgan/data.py @@ -0,0 +1,168 @@ +""" +Dataset class for model USGAN. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from typing import Union, Iterable + +import torch + +from ...data.base import BaseDataset +from ...data.utils import torch_parse_delta + + +class DatasetForUSGAN(BaseDataset): + """Dataset class for USGAN. + + Parameters + ---------- + data : dict or str, + The dataset for model input, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for input, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + return_labels : bool, default = True, + Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, + during training of classification models, the Dataset class will return labels in __getitem__() for model input. + Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we + need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 + files, they already have both X and y saved. But we don't read labels from the file for validating and testing + with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for + distinction. + + file_type : str, default = "h5py" + The type of the given file if train_set and val_set are path strings. + """ + + def __init__( + self, + data: Union[dict, str], + return_labels: bool = True, + file_type: str = "h5py", + ): + super().__init__(data, return_labels, file_type) + + if not isinstance(self.data, str): + # calculate all delta here. + forward_missing_mask = (~torch.isnan(self.X)).type(torch.float32) + forward_X = torch.nan_to_num(self.X) + forward_delta = torch_parse_delta(forward_missing_mask) + backward_X = torch.flip(forward_X, dims=[1]) + backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) + backward_delta = torch_parse_delta(backward_missing_mask) + + self.processed_data = { + "forward": { + "X": forward_X, + "missing_mask": forward_missing_mask, + "delta": forward_delta, + }, + "backward": { + "X": backward_X, + "missing_mask": backward_missing_mask, + "delta": backward_delta, + }, + } + + def _fetch_data_from_array(self, idx: int) -> Iterable: + """Fetch data from self.X if it is given. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + A list contains + + index : int tensor, + The index of the sample. + + X : tensor, + The feature vector for model input. + + missing_mask : tensor, + The mask indicates all missing values in X. + + delta : tensor, + The delta matrix contains time gaps of missing values. + + label (optional) : tensor, + The target label of the time-series sample. + """ + sample = [ + torch.tensor(idx), + # for forward + self.processed_data["forward"]["X"][idx].to(torch.float32), + self.processed_data["forward"]["missing_mask"][idx].to(torch.float32), + self.processed_data["forward"]["delta"][idx].to(torch.float32), + # for backward + self.processed_data["backward"]["X"][idx].to(torch.float32), + self.processed_data["backward"]["missing_mask"][idx].to(torch.float32), + self.processed_data["backward"]["delta"][idx].to(torch.float32), + ] + + if self.y is not None and self.return_labels: + sample.append(self.y[idx].to(torch.long)) + + return sample + + def _fetch_data_from_file(self, idx: int) -> Iterable: + """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. + Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + The collated data sample, a list including all necessary sample info. + """ + + if self.file_handle is None: + self.file_handle = self._open_file_handle() + + X = torch.from_numpy(self.file_handle["X"][idx]) + missing_mask = (~torch.isnan(X)).to(torch.float32) + X = torch.nan_to_num(X) + + forward = { + "X": X, + "missing_mask": missing_mask, + "deltas": torch_parse_delta(missing_mask), + } + + backward = { + "X": torch.flip(forward["X"], dims=[0]), + "missing_mask": torch.flip(forward["missing_mask"], dims=[0]), + } + backward["deltas"] = torch_parse_delta(backward["missing_mask"]) + + sample = [ + torch.tensor(idx), + # for forward + forward["X"], + forward["missing_mask"], + forward["deltas"], + # for backward + backward["X"], + backward["missing_mask"], + backward["deltas"], + ] + + # if the dataset has labels and is for training, then fetch it from the file + if "y" in self.file_handle.keys() and self.return_labels: + sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long)) + + return sample diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py new file mode 100644 index 00000000..cc4788f2 --- /dev/null +++ b/pypots/imputation/usgan/model.py @@ -0,0 +1,832 @@ +""" +The implementation of USGAN for the partially-observed time-series imputation task. + +Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). +Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." + +Notes +----- +Partial implementation uses code from https://github.com/zjuwuyy-DL/Generative-Semi-supervised-Learning-for-Multivariate-Time-Series-Imputation. The bugs in the original implementation +are fixed here. + +""" + +# Created by Jun Wang +# License: GPL-v3 + +from typing import Tuple, Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from .data import DatasetForUSGAN +from .modules import TemporalDecay, FeatureRegression + +# from ..brits.model import RITS +from ..base import BaseNNImputer +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.metrics import cal_mae, cal_mse +from ...utils.logging import logger + + +class RITS(nn.Module): + """model RITS: Recurrent Imputation for Time Series + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + device : + specify running the model on which device, CPU/GPU + + rnn_cell : + the LSTM cell to model temporal data + + temp_decay_h : + the temporal decay module to decay RNN hidden state + + temp_decay_x : + the temporal decay module to decay data in the raw feature space + + hist_reg : + the temporal-regression module to project RNN hidden state into the raw feature space + + feat_reg : + the feature-regression module + + combining_weight : + the module used to generate the weight to combine history regression and feature regression + + Parameters + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + device : + specify running the model on which device, CPU/GPU + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + device: Union[str, torch.device], + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.device = device + + self.rnn_cell = nn.LSTMCell(self.n_features * 2, self.rnn_hidden_size) + self.temp_decay_h = TemporalDecay( + input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False + ) + self.temp_decay_x = TemporalDecay( + input_size=self.n_features, output_size=self.n_features, diag=True + ) + self.hist_reg = nn.Linear(self.rnn_hidden_size, self.n_features) + self.feat_reg = FeatureRegression(self.n_features) + self.combining_weight = nn.Linear(self.n_features * 2, self.n_features) + + def impute( + self, inputs: dict, direction: str + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """The imputation function. + Parameters + ---------- + inputs : + Input data, a dictionary includes feature values, missing masks, and time-gap values. + + direction : + A keyword to extract data from parameter `data`. + + Returns + ------- + imputed_data : + [batch size, sequence length, feature number] + + hidden_states: tensor, + [batch size, RNN hidden size] + + reconstruction_loss : + reconstruction loss + + """ + values = inputs[direction]["X"] # feature values + masks = inputs[direction]["missing_mask"] # missing masks + deltas = inputs[direction]["deltas"] # time-gap values + + # create hidden states and cell states for the lstm cell + hidden_states = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=values.device + ) + cell_states = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=values.device + ) + + estimations = [] + reconstruction_loss = torch.tensor(0.0).to(values.device) + + # imputation period + for t in range(self.n_steps): + # data shape: [batch, time, features] + x = values[:, t, :] # values + m = masks[:, t, :] # mask + d = deltas[:, t, :] # delta, time gap + + gamma_h = self.temp_decay_h(d) + gamma_x = self.temp_decay_x(d) + + hidden_states = hidden_states * gamma_h # decay hidden states + x_h = self.hist_reg(hidden_states) + reconstruction_loss += cal_mae(x_h, x, m) + + x_c = m * x + (1 - m) * x_h + + z_h = self.feat_reg(x_c) + reconstruction_loss += cal_mae(z_h, x, m) + + alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) + + c_h = alpha * z_h + (1 - alpha) * x_h + reconstruction_loss += cal_mae(c_h, x, m) + + c_c = m * x + (1 - m) * c_h + estimations.append(c_h.unsqueeze(dim=1)) + + inputs = torch.cat([c_c, m], dim=1) + hidden_states, cell_states = self.rnn_cell( + inputs, (hidden_states, cell_states) + ) + + estimations = torch.cat(estimations, dim=1) + imputed_data = masks * values + (1 - masks) * estimations + return imputed_data, hidden_states, reconstruction_loss, estimations + + def forward(self, inputs: dict, direction: str = "forward") -> dict: + """Forward processing of the NN module. + Parameters + ---------- + inputs : + The input data. + + direction : + A keyword to extract data from parameter `data`. + + Returns + ------- + dict, + A dictionary includes all results. + + """ + imputed_data, hidden_state, reconstruction_loss, estimations = self.impute( + inputs, direction + ) + # for each iteration, reconstruction_loss increases its value for 3 times + reconstruction_loss /= self.n_steps * 3 + + ret_dict = { + "consistency_loss": torch.tensor( + 0.0, device=imputed_data.device + ), # single direction, has no consistency loss + "reconstruction_loss": reconstruction_loss, + "imputed_data": imputed_data, + "final_hidden_state": hidden_state, + "estimations": estimations, + } + return ret_dict + + +class Discriminator(nn.Module): + def __init__( + self, + n_features: int, + rnn_hidden_size: int, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.hint_rate = hint_rate + self.device = device + self.birnn = nn.GRU( + 2 * n_features, rnn_hidden_size, bidirectional=True, batch_first=True + ).to(device) + self.dropout = nn.Dropout(dropout_rate).to(device) + self.read_out = nn.Linear(2 * rnn_hidden_size, n_features).to(device) + + def forward(self, inputs: dict, training: bool = True) -> dict: + x = inputs["imputed_data"] + m = inputs["forward"]["missing_mask"] + + hint = ( + torch.rand_like(m, dtype=torch.float, device=self.device) < self.hint_rate + ) + hint = hint.byte() + h = hint * m + (1 - hint) * 0.5 + x_in = torch.cat([x, h], dim=-1) + + out, _ = self.birnn(x_in) + logits = self.read_out(self.dropout(out)) + return logits + + +class Generator(nn.Module): + """model Generator: + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + rits_f: RITS object + the forward RITS model + + rits_b: RITS object + the backward RITS model + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + device: Union[str, torch.device], + ): + super().__init__() + # data settings + self.n_steps = n_steps + self.n_features = n_features + # imputer settings + self.rnn_hidden_size = rnn_hidden_size + # create models + self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, device) + self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, device) + + @staticmethod + def _get_consistency_loss( + pred_f: torch.Tensor, pred_b: torch.Tensor + ) -> torch.Tensor: + """Calculate the consistency loss between the imputation from two RITS models. + + Parameters + ---------- + pred_f : + The imputation from the forward RITS. + + pred_b : + The imputation from the backward RITS (already gets reverted). + + Returns + ------- + float tensor, + The consistency loss. + + """ + loss = torch.abs(pred_f - pred_b).mean() * 1e-1 + return loss + + @staticmethod + def _reverse(ret: dict) -> dict: + """Reverse the array values on the time dimension in the given dictionary. + + Parameters + ---------- + ret : + + Returns + ------- + dict, + A dictionary contains values reversed on the time dimension from the given dict. + + """ + + def reverse_tensor(tensor_): + if tensor_.dim() <= 1: + return tensor_ + indices = range(tensor_.size()[1])[::-1] + indices = torch.tensor( + indices, dtype=torch.long, device=tensor_.device, requires_grad=False + ) + return tensor_.index_select(1, indices) + + for key in ret: + ret[key] = reverse_tensor(ret[key]) + + return ret + + def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of BRITS. + + Parameters + ---------- + inputs : + The input data. + + Returns + ------- + dict, A dictionary includes all results. + """ + # Results from the forward RITS. + ret_f = self.rits_f(inputs, "forward") + # Results from the backward RITS. + ret_b = self._reverse(self.rits_b(inputs, "backward")) + + imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 + estimation = (ret_f["estimations"] + ret_b["estimations"]) / 2 + + if not training: + # if not in training mode, return the classification result only + # return { + # "imputed_data": imputed_data, + # } + return imputed_data, estimation + + consistency_loss = self._get_consistency_loss( + ret_f["imputed_data"], ret_b["imputed_data"] + ) + + # `loss` is always the item for backward propagating to update the model + loss = ( + consistency_loss + + ret_f["reconstruction_loss"] + + ret_b["reconstruction_loss"] + ) + + results = { + "imputed_data": imputed_data, + "consistency_loss": consistency_loss, + "loss": loss, # will be used for backward propagating to update the model + } + + return results + + +class _USGAN(nn.Module): + """model USGAN: + USGAN consists of a generator, a discriminator, which are all built on bidirectional recurrent neural networks. + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + rits_f: RITS object + the forward RITS model + + rits_b: RITS object + the backward RITS model + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + lambda_mse: float, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.generator = Generator(n_steps, n_features, rnn_hidden_size, device) + self.discriminator = Discriminator( + n_features, + rnn_hidden_size, + hint_rate=hint_rate, + dropout_rate=dropout_rate, + device=device, + ) + + self.lambda_mse = lambda_mse + self.device = device + + def forward( + self, + inputs: dict, + training_object: str = "generator", + training: bool = True, + ) -> dict: + assert training_object in [ + "generator", + "discriminator", + ], 'training_object should be "generator" or "discriminator"' + + X = inputs["forward"]["X"] + missing_mask = inputs["forward"]["missing_mask"] + batch_size, n_steps, n_features = X.shape + losses = {} + inputs["imputed_data"], inputs["reconstruction"] = self.generator( + inputs, training=False + ) + inputs["discrimination"] = self.discriminator(inputs, training=False) + if not training: + # if only run clustering, then no need to calculate loss + return inputs + + if training_object == "discriminator": + l_D = F.binary_cross_entropy_with_logits( + inputs["discrimination"], missing_mask + ) + losses["discrimination_loss"] = l_D + else: + inputs["discrimination"] = inputs["discrimination"].detach() + l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask + ) + l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) + loss_gene = l_G + self.lambda_mse * l_rec + losses["generation_loss"] = loss_gene + losses["imputed_data"] = inputs["imputed_data"] + return losses + + +class USGAN(BaseNNImputer): + """The PyTorch implementation of the CRLI model :cite:`ma2021CRLI`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_clusters : + The number of clusters in the clustering task. + + n_generator_layers : + The number of layers in the generator. + + rnn_hidden_size : + The size of the RNN hidden state, also the number of hidden units in the RNN cell. + + rnn_cell_type : + The type of RNN cell to use. Can be either "GRU" or "LSTM". + + decoder_fcn_output_dims : + The output dimensions of each layer in the FCN (fully-connected network) of the decoder. + + lambda_kmeans : + The weight of the k-means loss, + i.e. the item :math:`\\lambda` ahead of :math:`\\mathcal{L}_{k-means}` in Eq.13 of the original paper. + + G_steps : + The number of steps to train the generator in each iteration. + + D_steps : + The number of steps to train the discriminator in each iteration. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + G_optimizer : + The optimizer for the generator training. + If not given, will use a default Adam optimizer. + + D_optimizer : + The optimizer for the discriminator training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying CRLI model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + lambda_mse: float = 1, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + G_steps: int = 1, + D_steps: int = 5, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, + G_optimizer: Optional[Optimizer] = Adam(), + D_optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: Optional[str] = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + assert G_steps > 0 and D_steps > 0, "G_steps and D_steps should both >0" + + self.n_steps = n_steps + self.n_features = n_features + self.G_steps = G_steps + self.D_steps = D_steps + + # set up the model + self.model = _USGAN( + n_steps, + n_features, + rnn_hidden_size, + lambda_mse, + hint_rate, + dropout_rate, + self.device, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.G_optimizer = G_optimizer + self.G_optimizer.init_optimizer(self.model.generator.parameters()) + self.D_optimizer = D_optimizer + self.D_optimizer.init_optimizer(self.model.discriminator.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + # fetch data + ( + indices, + X, + missing_mask, + deltas, + back_X, + back_missing_mask, + back_deltas, + ) = self._send_data_to_given_device(data) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + }, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def _train_model( + self, + training_loader: DataLoader, + val_loader: DataLoader = None, + ) -> None: + # each training starts from the very beginning, so reset the loss and model dict here + self.best_loss = float("inf") + self.best_model_dict = None + + try: + training_step = 0 + epoch_train_loss_G_collector = [] + epoch_train_loss_D_collector = [] + for epoch in range(self.epochs): + self.model.train() + for idx, data in enumerate(training_loader): + training_step += 1 + inputs = self._assemble_input_for_training(data) + + step_train_loss_G_collector = [] + step_train_loss_D_collector = [] + # for _ in range(self.G_steps): + if idx % self.G_steps == 0: + self.G_optimizer.zero_grad() + results = self.model.forward( + inputs, training_object="generator" + ) + results["generation_loss"].backward() + self.G_optimizer.step() + step_train_loss_G_collector.append( + results["generation_loss"].item() + ) + + # for _ in range(self.D_steps): + if idx % self.D_steps == 0: + self.D_optimizer.zero_grad() + results = self.model.forward( + inputs, training_object="discriminator" + ) + results["discrimination_loss"].backward(retain_graph=True) + self.D_optimizer.step() + step_train_loss_D_collector.append( + results["discrimination_loss"].item() + ) + + mean_step_train_D_loss = np.mean(step_train_loss_D_collector) + mean_step_train_G_loss = np.mean(step_train_loss_G_collector) + + epoch_train_loss_D_collector.append(mean_step_train_D_loss) + epoch_train_loss_G_collector.append(mean_step_train_G_loss) + + # save training loss logs into the tensorboard file for every step if in need + # Note: the `training_step` is not the actual number of steps that Discriminator and Generator get + # trained, the actual number should be D_steps*training_step and G_steps*training_step accordingly + if self.summary_writer is not None: + loss_results = { + "generation_loss": mean_step_train_G_loss, + "discrimination_loss": mean_step_train_D_loss, + } + self._save_log_into_tb_file( + training_step, "training", loss_results + ) + mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) + mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"train loss_discriminator {mean_epoch_train_D_loss:.4f}" + ) + mean_loss = mean_epoch_train_G_loss + + if mean_loss < self.best_loss: + self.best_loss = mean_loss + self.best_model_dict = self.model.state_dict() + self.patience = self.original_patience + # save the model if necessary + self._auto_save_model_if_necessary( + training_finished=False, + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + ) + else: + self.patience -= 1 + if self.patience == 0: + logger.info( + "Exceeded the training patience. Terminating the training procedure..." + ) + break + except Exception as e: + logger.error(f"Exception: {e}") + if self.best_model_dict is None: + raise RuntimeError( + "Training got interrupted. Model was not trained. Please investigate the error printed above." + ) + else: + RuntimeWarning( + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" + "If you don't want it, please try fit() again." + ) + + if np.equal(self.best_loss, float("inf")): + raise ValueError("Something is wrong. best_loss is Nan after training.") + + logger.info("Finished training.") + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForUSGAN( + train_set, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if isinstance(val_set, str): + with h5py.File(val_set, "r") as hf: + # Here we read the whole validation set from the file to mask a portion for validation. + # In PyPOTS, using a file usually because the data is too big. However, the validation set is + # generally shouldn't be too large. For example, we have 1 billion samples for model training. + # We won't take 20% of them as the validation set because we want as much as possible data for the + # training stage to enhance the model's generalization ability. Therefore, 100,000 representative + # samples will be enough to validate the model. + val_set = { + "X": hf["X"][:], + "X_intact": hf["X_intact"][:], + "indicating_mask": hf["indicating_mask"][:], + } + val_set = DatasetForUSGAN(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(training_finished=True) + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForUSGAN(X, return_labels=False, file_type=file_type) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() diff --git a/pypots/imputation/usgan/modules.py b/pypots/imputation/usgan/modules.py new file mode 100644 index 00000000..c242dee4 --- /dev/null +++ b/pypots/imputation/usgan/modules.py @@ -0,0 +1,140 @@ +""" +The implementation of SSGAN for the partially-observed time-series imputation task. + +Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). +Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." + +Notes +----- +Partial implementation uses code from https://github.com/zjuwuyy-DL/Generative-Semi-supervised-Learning-for-Multivariate-Time-Series-Imputation. The bugs in the original implementation +are fixed here. + +""" + +# Created by Jun Wang +# License: GPL-v3 + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.nn.parameter import Parameter + + +class FeatureRegression(nn.Module): + """The module used to capture the correlation between features for imputation. + + Attributes + ---------- + W : tensor + The weights (parameters) of the module. + + b : tensor + The bias of the module. + + m (buffer) : tensor + The mask matrix, a squire matrix with diagonal entries all zeroes while left parts all ones. + It is applied to the weight matrix to mask out the estimation contributions from features themselves. + It is used to help enhance the imputation performance of the network. + + Parameters + ---------- + input_size : the feature dimension of the input + """ + + def __init__(self, input_size: int): + super().__init__() + self.W = Parameter(torch.Tensor(input_size, input_size)) + self.b = Parameter(torch.Tensor(input_size)) + + m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) + self.register_buffer("m", m) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + std_dev = 1.0 / math.sqrt(self.W.size(0)) + self.W.data.uniform_(-std_dev, std_dev) + if self.b is not None: + self.b.data.uniform_(-std_dev, std_dev) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward processing of the NN module. + + Parameters + ---------- + x : tensor, + the input for processing + + Returns + ------- + output: tensor, + the processed result containing imputation from feature regression + + """ + output = F.linear(x, self.W * Variable(self.m), self.b) + return output + + +class TemporalDecay(nn.Module): + """The module used to generate the temporal decay factor gamma in the original paper. + + Attributes + ---------- + W: tensor, + The weights (parameters) of the module. + b: tensor, + The bias of the module. + + Parameters + ---------- + input_size : int, + the feature dimension of the input + + output_size : int, + the feature dimension of the output + + diag : bool, + whether to product the weight with an identity matrix before forward processing + """ + + def __init__(self, input_size: int, output_size: int, diag: bool = False): + super().__init__() + self.diag = diag + self.W = Parameter(torch.Tensor(output_size, input_size)) + self.b = Parameter(torch.Tensor(output_size)) + + if self.diag: + assert input_size == output_size + m = torch.eye(input_size, input_size) + self.register_buffer("m", m) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + std_dev = 1.0 / math.sqrt(self.W.size(0)) + self.W.data.uniform_(-std_dev, std_dev) + if self.b is not None: + self.b.data.uniform_(-std_dev, std_dev) + + def forward(self, delta: torch.Tensor) -> torch.Tensor: + """Forward processing of the NN module. + + Parameters + ---------- + delta : tensor, shape [batch size, sequence length, feature number] + The time gaps. + + Returns + ------- + gamma : array-like, same shape with parameter `delta`, values in (0,1] + The temporal decay factor. + """ + if self.diag: + gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b)) + else: + gamma = F.relu(F.linear(delta, self.W, self.b)) + gamma = torch.exp(-gamma) + return gamma From 7b2c413f50579c1a5f03d5033d5d529f7db41f48 Mon Sep 17 00:00:00 2001 From: Jun Wang Date: Tue, 5 Sep 2023 15:43:38 +0800 Subject: [PATCH 02/17] add parameter descriptions --- pypots/__init__.py | 2 +- pypots/base.py | 2 +- pypots/cli/doc.py | 2 +- pypots/imputation/gpvae/model.py | 81 ++++++++++++++++++++++-- pypots/imputation/gpvae/modules.py | 88 ++++++++++++++++++-------- pypots/imputation/saits/model.py | 69 ++++++++------------ pypots/imputation/transformer/model.py | 5 +- pypots/imputation/usgan/model.py | 78 +++++++++++++++-------- pypots/utils/__init__.py | 2 +- pypots/utils/metrics.py | 8 +-- 10 files changed, 227 insertions(+), 110 deletions(-) diff --git a/pypots/__init__.py b/pypots/__init__.py index 72846c15..b7a737b6 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.1.2" +__version__ = "0.1.1" __all__ = [ diff --git a/pypots/base.py b/pypots/base.py index 7a12fe94..472f338c 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -13,7 +13,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from .utils.file import create_dir_if_not_exist +from .utils.files import create_dir_if_not_exist from .utils.logging import logger diff --git a/pypots/cli/doc.py b/pypots/cli/doc.py index 2e0e6b5a..a5985497 100644 --- a/pypots/cli/doc.py +++ b/pypots/cli/doc.py @@ -9,7 +9,7 @@ import shutil from argparse import Namespace -from tsdb.utils.downloading import _download_and_extract +from tsdb.data_processing import _download_and_extract from ..cli.base import BaseCommand from ..utils.logging import logger diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 3bf5a866..6e456010 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -30,6 +30,57 @@ class _GPVAE(nn.Module): + """model GPVAE with Gaussian Process prior + + Attributes + ---------- + Encoder : + the encoder in GPVAE + + Decoder : + the decoder in GPVAE + + Parameters + ---------- + input_dim : int, + the feature dimension of the input + + time_length : int, + the length of each time series + + latent_dim : int, + the feature dimension of the latent embedding + + device : str, + specify running the model on which device, CPU/GPU + + encoder_sizes : tuple, + the tuple of the network size in encoder + + decoder_sizes : tuple, + the tuple of the network size in decoder + + beta : float, + the weight of the KL divergence + + M : int, + the number of Monte Carlo samples for ELBO estimation + + K : int, + the number of importance weights for IWAE model + + kernel : str, + the Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] + + sigma : float, + the scale parameter for a kernel function + + length_scale : float, + the length scale parameter for a kernel function + + kernel_scales : int, + the number of different length scales over latent space dimensions + """ def __init__( self, input_dim, @@ -48,12 +99,6 @@ def __init__( length_scale=7.0, kernel_scales=1, ): - """GPVAE model with Gaussian Process prior - :param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] - :param sigma: scale parameter for a kernel function - :param length_scale: length scale parameter for a kernel function - :param kernel_scales: number of different length scales over latent space dimensions - """ super(_GPVAE, self).__init__() self.kernel = kernel self.sigma = sigma @@ -256,8 +301,15 @@ def __init__( n_steps: int, n_features: int, latent_size: int, + encoder_sizes: tuple = (64, 64), + decoder_sizes: tuple = (64, 64), kernel: str = "cauchy", beta: float = 0.2, + M: int = 1, + K: int = 1, + sigma: float = 1.0, + length_scale: float = 7.0, + kernel_scales: int = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -281,7 +333,15 @@ def __init__( self.n_features = n_features self.latent_size = latent_size self.kernel = kernel + self.encoder_sizes = encoder_sizes + self.decoder_sizes = decoder_sizes self.beta = beta + self.M = M + self.K = K + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales + # set up the model self.model = _GPVAE( @@ -289,8 +349,15 @@ def __init__( time_length=self.n_steps, latent_dim=self.latent_size, kernel=self.kernel, - device=self.device, + encoder_sizes=self.encoder_sizes, + decoder_sizes=self.decoder_sizes, beta=self.beta, + M=self.M, + K=self.K, + sigma=self.sigma, + length_scale=self.length_scale, + kernel_scales=self.kernel_scales, + device=self.device, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py index 41e74c4b..69f68eda 100644 --- a/pypots/imputation/gpvae/modules.py +++ b/pypots/imputation/gpvae/modules.py @@ -72,10 +72,23 @@ def cauchy_kernel(T, sigma, length_scale): def make_nn(input_size, output_size, hidden_sizes): - """Creates fully connected neural network - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. + """This function used to creates fully connected neural network. + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers + + Returns + ------- + output: tensor + the processing embeddings """ layers = [] for i in range(len(hidden_sizes)): @@ -115,13 +128,28 @@ def forward(self, x): def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): - """Construct neural network consisting of - one 1d-convolutional layer that utilizes temporal dependences, - fully connected network - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. - :param kernel_size: kernel size for convolutional layer + """This function used to construct neural network consisting of + one 1d-convolutional layer that utilizes temporal dependences, + fully connected network + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers, + + kernel_size : int + kernel size for convolutional layer + + Returns + ------- + output: tensor + the processing embeddings """ padding = kernel_size // 2 @@ -142,15 +170,21 @@ def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): class Encoder(nn.Module): def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): - """Encoder with 1d-convolutional network and multivariate Normal posterior - Used by GP-VAE with proposed banded covariance matrix - :param z_size: latent space dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. - :param window_size: kernel size for Conv1D layer - :param data_type: needed for some data specific modifications, e.g: - tf.nn.softplus is a more common and correct choice, however - tf.nn.sigmoid provides more stable performance on Physionet dataset + """This moudule is an encoder with 1d-convolutional network and multivariate Normal posterior used by GP-VAE with proposed banded covariance matrix + + Parameters + ---------- + input_size : int, + the feature dimension of the input + + z_size : int, + the feature dimension of the output latent embedding + + hidden_sizes : tuple, + the tuple of the hidden layer sizes, and the tuple length sets the number of hidden layers + + window_size : int + the kernel size for the Conv1D layer """ super(Encoder, self).__init__() self.z_size = int(z_size) @@ -164,7 +198,6 @@ def __call__(self, x): batch_size = mapped.size(0) time_length = mapped.size(1) - # Obtain mean and precision matrix components num_dim = len(mapped.shape) mu = self.mu_layer(mapped) logvar = self.logvar_layer(mapped) @@ -219,10 +252,15 @@ def __call__(self, x): class Decoder(nn.Module): def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): - """Decoder with Gaussian output distribution - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. + """This module is a decoder with Gaussian output distribution + + Parameters + ---------- + output_size : int, + the feature dimension of the output + + idden_sizes: tuple + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers. """ super(Decoder, self).__init__() self.output_size = int(output_size) diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 233fc780..d00ab610 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -24,8 +24,8 @@ from .data import DatasetForSAITS from ..base import BaseNNImputer +from ..transformer.modules import EncoderLayer, PositionalEncoding from ...data.base import BaseDataset -from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.metrics import cal_mae @@ -35,8 +35,8 @@ class _SAITS(nn.Module): def __init__( self, n_layers: int, - n_steps: int, - n_features: int, + d_time: int, + d_feature: int, d_model: int, d_inner: int, n_heads: int, @@ -50,16 +50,15 @@ def __init__( ): super().__init__() self.n_layers = n_layers - self.n_steps = n_steps - # concatenate the feature vector and missing mask, hence double the number of features - actual_n_features = n_features * 2 - self.diagonal_attention_mask = diagonal_attention_mask + actual_d_feature = d_feature * 2 self.ORT_weight = ORT_weight self.MIT_weight = MIT_weight self.layer_stack_for_first_block = nn.ModuleList( [ EncoderLayer( + d_time, + actual_d_feature, d_model, d_inner, n_heads, @@ -67,6 +66,7 @@ def __init__( d_v, dropout, attn_dropout, + diagonal_attention_mask, ) for _ in range(n_layers) ] @@ -74,6 +74,8 @@ def __init__( self.layer_stack_for_second_block = nn.ModuleList( [ EncoderLayer( + d_time, + actual_d_feature, d_model, d_inner, n_heads, @@ -81,30 +83,26 @@ def __init__( d_v, dropout, attn_dropout, + diagonal_attention_mask, ) for _ in range(n_layers) ] ) self.dropout = nn.Dropout(p=dropout) - self.position_enc = PositionalEncoding(d_model, n_position=n_steps) - # for the 1st block - self.embedding_1 = nn.Linear(actual_n_features, d_model) - self.reduce_dim_z = nn.Linear(d_model, n_features) - # for the 2nd block - self.embedding_2 = nn.Linear(actual_n_features, d_model) - self.reduce_dim_beta = nn.Linear(d_model, n_features) - self.reduce_dim_gamma = nn.Linear(n_features, n_features) + self.position_enc = PositionalEncoding(d_model, n_position=d_time) + # for operation on time dim + self.embedding_1 = nn.Linear(actual_d_feature, d_model) + self.reduce_dim_z = nn.Linear(d_model, d_feature) + # for operation on measurement dim + self.embedding_2 = nn.Linear(actual_d_feature, d_model) + self.reduce_dim_beta = nn.Linear(d_model, d_feature) + self.reduce_dim_gamma = nn.Linear(d_feature, d_feature) # for delta decay factor - self.weight_combine = nn.Linear(n_features + n_steps, n_features) + self.weight_combine = nn.Linear(d_feature + d_time, d_feature) - def _process( - self, - inputs: dict, - diagonal_attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, list]: + def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]: X, masks = inputs["X"], inputs["missing_mask"] - # first DMSA block input_X_for_first = torch.cat([X, masks], dim=2) input_X_for_first = self.embedding_1(input_X_for_first) @@ -112,7 +110,7 @@ def _process( self.position_enc(input_X_for_first) ) # namely, term e in the math equation for encoder_layer in self.layer_stack_for_first_block: - enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask) + enc_output, _ = encoder_layer(enc_output) X_tilde_1 = self.reduce_dim_z(enc_output) X_prime = masks * X + (1 - masks) * X_tilde_1 @@ -148,23 +146,9 @@ def _process( return X_c, [X_tilde_1, X_tilde_2, X_tilde_3] - def forward( - self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True - ) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - - if (training and self.diagonal_attention_mask) or ( - (not training) and diagonal_attention_mask - ): - diagonal_attention_mask = torch.eye(self.n_steps).to(X.device) - # then broadcast on the batch axis - diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0) - else: - diagonal_attention_mask = None - - imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process( - inputs, diagonal_attention_mask - ) + imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(inputs) if not training: # if not in training mode, return the classification result only @@ -443,8 +427,7 @@ def fit( def impute( self, X: Union[dict, str], - file_type: str = "h5py", - diagonal_attention_mask: bool = True, + file_type="h5py", ) -> np.ndarray: # Step 1: wrap the input data with classes Dataset and DataLoader self.model.eval() # set the model as eval status to freeze it. @@ -461,9 +444,7 @@ def impute( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward( - inputs, diagonal_attention_mask, training=False - ) + results = self.model.forward(inputs, training=False) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index 7e61f6ae..fd5e103b 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -22,9 +22,9 @@ from torch.utils.data import DataLoader from .data import DatasetForSAITS +from .modules import EncoderLayer, PositionalEncoding from ..base import BaseNNImputer from ...data.base import BaseDataset -from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.metrics import cal_mae @@ -55,6 +55,8 @@ def __init__( self.layer_stack = nn.ModuleList( [ EncoderLayer( + d_time, + actual_d_feature, d_model, d_inner, n_heads, @@ -62,6 +64,7 @@ def __init__( d_v, dropout, attn_dropout, + False, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index cc4788f2..dc631af6 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -219,6 +219,26 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict: class Discriminator(nn.Module): + """model Discriminator: built on BiRNN + + Parameters + ---------- + n_features : + the feature dimension of the input + + rnn_hidden_size : + the hidden size of the RNN cell + + hint_rate : + the hint rate for the input imputed_data + + dropout_rate : + the dropout rate for the output layer + + device : + specify running the model on which device, CPU/GPU + + """ def __init__( self, n_features: int, @@ -237,6 +257,17 @@ def __init__( self.read_out = nn.Linear(2 * rnn_hidden_size, n_features).to(device) def forward(self, inputs: dict, training: bool = True) -> dict: + """Forward processing of Discriminator. + + Parameters + ---------- + inputs : + The input data. + + Returns + ------- + dict, the logits of the probability of being the true value. + """ x = inputs["imputed_data"] m = inputs["forward"]["missing_mask"] @@ -344,7 +375,7 @@ def reverse_tensor(tensor_): return ret def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of BRITS. + """Forward processing of Generator. Parameters ---------- @@ -405,11 +436,17 @@ class _USGAN(nn.Module): rnn_hidden_size : the hidden size of the RNN cell - rits_f: RITS object - the forward RITS model + lambda_mse : + the weigth of the reconstruction loss + + hint_rate : + the hint rate for the discriminator + + dropout_rate : + the dropout rate for the last layer in Discriminator - rits_b: RITS object - the backward RITS model + device : + specify running the model on which device, CPU/GPU """ @@ -449,7 +486,6 @@ def forward( X = inputs["forward"]["X"] missing_mask = inputs["forward"]["missing_mask"] - batch_size, n_steps, n_features = X.shape losses = {} inputs["imputed_data"], inputs["reconstruction"] = self.generator( inputs, training=False @@ -487,24 +523,17 @@ class USGAN(BaseNNImputer): n_features : The number of features in the time-series data sample. - n_clusters : - The number of clusters in the clustering task. - - n_generator_layers : - The number of layers in the generator. - rnn_hidden_size : - The size of the RNN hidden state, also the number of hidden units in the RNN cell. - - rnn_cell_type : - The type of RNN cell to use. Can be either "GRU" or "LSTM". - - decoder_fcn_output_dims : - The output dimensions of each layer in the FCN (fully-connected network) of the decoder. + the hidden size of the RNN cell - lambda_kmeans : - The weight of the k-means loss, - i.e. the item :math:`\\lambda` ahead of :math:`\\mathcal{L}_{k-means}` in Eq.13 of the original paper. + lambda_mse : + the weigth of the reconstruction loss + + hint_rate : + the hint rate for the discriminator + + dropout_rate : + the dropout rate for the last layer in Discriminator G_steps : The number of steps to train the generator in each iteration. @@ -573,7 +602,7 @@ def __init__( hint_rate: float = 0.7, dropout_rate: float = 0.0, G_steps: int = 1, - D_steps: int = 5, + D_steps: int = 1, batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, @@ -675,7 +704,7 @@ def _train_model( step_train_loss_G_collector = [] step_train_loss_D_collector = [] - # for _ in range(self.G_steps): + if idx % self.G_steps == 0: self.G_optimizer.zero_grad() results = self.model.forward( @@ -687,7 +716,6 @@ def _train_model( results["generation_loss"].item() ) - # for _ in range(self.D_steps): if idx % self.D_steps == 0: self.D_optimizer.zero_grad() results = self.model.forward( diff --git a/pypots/utils/__init__.py b/pypots/utils/__init__.py index 5fd28d97..1376524d 100644 --- a/pypots/utils/__init__.py +++ b/pypots/utils/__init__.py @@ -8,7 +8,7 @@ __all__ = [ # content files in this package - "file.py", + "files", "logging", "metrics", "random", diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index ee7d69bb..058cc24a 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -54,7 +54,7 @@ def cal_mae( so the result is 1/2=0.5. """ - assert isinstance(predictions, type(targets)), ( + assert type(predictions) == type(targets), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -110,7 +110,7 @@ def cal_mse( """ - assert isinstance(predictions, type(targets)), ( + assert type(predictions) == type(targets), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -166,7 +166,7 @@ def cal_rmse( so the result is :math:`\\sqrt{1/2}=0.5`. """ - assert isinstance(predictions, type(targets)), ( + assert type(predictions) == type(targets), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -217,7 +217,7 @@ def cal_mre( so the result is :math:`\\sqrt{1/2}=0.5`. """ - assert isinstance(predictions, type(targets)), ( + assert type(predictions) == type(targets), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) From 54aa4e3b393ab72a377ab1b44bb1b6c91230164e Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 6 Sep 2023 16:56:46 +0800 Subject: [PATCH 03/17] Revert "add parameter descriptions" This reverts commit 7b2c413f50579c1a5f03d5033d5d529f7db41f48. --- pypots/__init__.py | 2 +- pypots/base.py | 2 +- pypots/cli/doc.py | 2 +- pypots/imputation/gpvae/model.py | 81 ++---------------------- pypots/imputation/gpvae/modules.py | 88 ++++++++------------------ pypots/imputation/saits/model.py | 69 ++++++++++++-------- pypots/imputation/transformer/model.py | 5 +- pypots/imputation/usgan/model.py | 78 ++++++++--------------- pypots/utils/__init__.py | 2 +- pypots/utils/metrics.py | 8 +-- 10 files changed, 110 insertions(+), 227 deletions(-) diff --git a/pypots/__init__.py b/pypots/__init__.py index b7a737b6..72846c15 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.1.1" +__version__ = "0.1.2" __all__ = [ diff --git a/pypots/base.py b/pypots/base.py index 472f338c..7a12fe94 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -13,7 +13,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from .utils.files import create_dir_if_not_exist +from .utils.file import create_dir_if_not_exist from .utils.logging import logger diff --git a/pypots/cli/doc.py b/pypots/cli/doc.py index a5985497..2e0e6b5a 100644 --- a/pypots/cli/doc.py +++ b/pypots/cli/doc.py @@ -9,7 +9,7 @@ import shutil from argparse import Namespace -from tsdb.data_processing import _download_and_extract +from tsdb.utils.downloading import _download_and_extract from ..cli.base import BaseCommand from ..utils.logging import logger diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 6e456010..3bf5a866 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -30,57 +30,6 @@ class _GPVAE(nn.Module): - """model GPVAE with Gaussian Process prior - - Attributes - ---------- - Encoder : - the encoder in GPVAE - - Decoder : - the decoder in GPVAE - - Parameters - ---------- - input_dim : int, - the feature dimension of the input - - time_length : int, - the length of each time series - - latent_dim : int, - the feature dimension of the latent embedding - - device : str, - specify running the model on which device, CPU/GPU - - encoder_sizes : tuple, - the tuple of the network size in encoder - - decoder_sizes : tuple, - the tuple of the network size in decoder - - beta : float, - the weight of the KL divergence - - M : int, - the number of Monte Carlo samples for ELBO estimation - - K : int, - the number of importance weights for IWAE model - - kernel : str, - the Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] - - sigma : float, - the scale parameter for a kernel function - - length_scale : float, - the length scale parameter for a kernel function - - kernel_scales : int, - the number of different length scales over latent space dimensions - """ def __init__( self, input_dim, @@ -99,6 +48,12 @@ def __init__( length_scale=7.0, kernel_scales=1, ): + """GPVAE model with Gaussian Process prior + :param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] + :param sigma: scale parameter for a kernel function + :param length_scale: length scale parameter for a kernel function + :param kernel_scales: number of different length scales over latent space dimensions + """ super(_GPVAE, self).__init__() self.kernel = kernel self.sigma = sigma @@ -301,15 +256,8 @@ def __init__( n_steps: int, n_features: int, latent_size: int, - encoder_sizes: tuple = (64, 64), - decoder_sizes: tuple = (64, 64), kernel: str = "cauchy", beta: float = 0.2, - M: int = 1, - K: int = 1, - sigma: float = 1.0, - length_scale: float = 7.0, - kernel_scales: int = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -333,15 +281,7 @@ def __init__( self.n_features = n_features self.latent_size = latent_size self.kernel = kernel - self.encoder_sizes = encoder_sizes - self.decoder_sizes = decoder_sizes self.beta = beta - self.M = M - self.K = K - self.sigma = sigma - self.length_scale = length_scale - self.kernel_scales = kernel_scales - # set up the model self.model = _GPVAE( @@ -349,15 +289,8 @@ def __init__( time_length=self.n_steps, latent_dim=self.latent_size, kernel=self.kernel, - encoder_sizes=self.encoder_sizes, - decoder_sizes=self.decoder_sizes, - beta=self.beta, - M=self.M, - K=self.K, - sigma=self.sigma, - length_scale=self.length_scale, - kernel_scales=self.kernel_scales, device=self.device, + beta=self.beta, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py index 69f68eda..41e74c4b 100644 --- a/pypots/imputation/gpvae/modules.py +++ b/pypots/imputation/gpvae/modules.py @@ -72,23 +72,10 @@ def cauchy_kernel(T, sigma, length_scale): def make_nn(input_size, output_size, hidden_sizes): - """This function used to creates fully connected neural network. - - Parameters - ---------- - input_size : int, - the dimension of input embeddings - - output_size : int, - the dimension of out embeddings - - hidden_sizes : tuple, - the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers - - Returns - ------- - output: tensor - the processing embeddings + """Creates fully connected neural network + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. """ layers = [] for i in range(len(hidden_sizes)): @@ -128,28 +115,13 @@ def forward(self, x): def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): - """This function used to construct neural network consisting of - one 1d-convolutional layer that utilizes temporal dependences, - fully connected network - - Parameters - ---------- - input_size : int, - the dimension of input embeddings - - output_size : int, - the dimension of out embeddings - - hidden_sizes : tuple, - the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers, - - kernel_size : int - kernel size for convolutional layer - - Returns - ------- - output: tensor - the processing embeddings + """Construct neural network consisting of + one 1d-convolutional layer that utilizes temporal dependences, + fully connected network + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + :param kernel_size: kernel size for convolutional layer """ padding = kernel_size // 2 @@ -170,21 +142,15 @@ def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): class Encoder(nn.Module): def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): - """This moudule is an encoder with 1d-convolutional network and multivariate Normal posterior used by GP-VAE with proposed banded covariance matrix - - Parameters - ---------- - input_size : int, - the feature dimension of the input - - z_size : int, - the feature dimension of the output latent embedding - - hidden_sizes : tuple, - the tuple of the hidden layer sizes, and the tuple length sets the number of hidden layers - - window_size : int - the kernel size for the Conv1D layer + """Encoder with 1d-convolutional network and multivariate Normal posterior + Used by GP-VAE with proposed banded covariance matrix + :param z_size: latent space dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. + :param window_size: kernel size for Conv1D layer + :param data_type: needed for some data specific modifications, e.g: + tf.nn.softplus is a more common and correct choice, however + tf.nn.sigmoid provides more stable performance on Physionet dataset """ super(Encoder, self).__init__() self.z_size = int(z_size) @@ -198,6 +164,7 @@ def __call__(self, x): batch_size = mapped.size(0) time_length = mapped.size(1) + # Obtain mean and precision matrix components num_dim = len(mapped.shape) mu = self.mu_layer(mapped) logvar = self.logvar_layer(mapped) @@ -252,15 +219,10 @@ def __call__(self, x): class Decoder(nn.Module): def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): - """This module is a decoder with Gaussian output distribution - - Parameters - ---------- - output_size : int, - the feature dimension of the output - - idden_sizes: tuple - the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers. + """Decoder with Gaussian output distribution + :param output_size: output dimensionality + :param hidden_sizes: tuple of hidden layer sizes. + The tuple length sets the number of hidden layers. """ super(Decoder, self).__init__() self.output_size = int(output_size) diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index d00ab610..233fc780 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -24,8 +24,8 @@ from .data import DatasetForSAITS from ..base import BaseNNImputer -from ..transformer.modules import EncoderLayer, PositionalEncoding from ...data.base import BaseDataset +from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.metrics import cal_mae @@ -35,8 +35,8 @@ class _SAITS(nn.Module): def __init__( self, n_layers: int, - d_time: int, - d_feature: int, + n_steps: int, + n_features: int, d_model: int, d_inner: int, n_heads: int, @@ -50,15 +50,16 @@ def __init__( ): super().__init__() self.n_layers = n_layers - actual_d_feature = d_feature * 2 + self.n_steps = n_steps + # concatenate the feature vector and missing mask, hence double the number of features + actual_n_features = n_features * 2 + self.diagonal_attention_mask = diagonal_attention_mask self.ORT_weight = ORT_weight self.MIT_weight = MIT_weight self.layer_stack_for_first_block = nn.ModuleList( [ EncoderLayer( - d_time, - actual_d_feature, d_model, d_inner, n_heads, @@ -66,7 +67,6 @@ def __init__( d_v, dropout, attn_dropout, - diagonal_attention_mask, ) for _ in range(n_layers) ] @@ -74,8 +74,6 @@ def __init__( self.layer_stack_for_second_block = nn.ModuleList( [ EncoderLayer( - d_time, - actual_d_feature, d_model, d_inner, n_heads, @@ -83,26 +81,30 @@ def __init__( d_v, dropout, attn_dropout, - diagonal_attention_mask, ) for _ in range(n_layers) ] ) self.dropout = nn.Dropout(p=dropout) - self.position_enc = PositionalEncoding(d_model, n_position=d_time) - # for operation on time dim - self.embedding_1 = nn.Linear(actual_d_feature, d_model) - self.reduce_dim_z = nn.Linear(d_model, d_feature) - # for operation on measurement dim - self.embedding_2 = nn.Linear(actual_d_feature, d_model) - self.reduce_dim_beta = nn.Linear(d_model, d_feature) - self.reduce_dim_gamma = nn.Linear(d_feature, d_feature) + self.position_enc = PositionalEncoding(d_model, n_position=n_steps) + # for the 1st block + self.embedding_1 = nn.Linear(actual_n_features, d_model) + self.reduce_dim_z = nn.Linear(d_model, n_features) + # for the 2nd block + self.embedding_2 = nn.Linear(actual_n_features, d_model) + self.reduce_dim_beta = nn.Linear(d_model, n_features) + self.reduce_dim_gamma = nn.Linear(n_features, n_features) # for delta decay factor - self.weight_combine = nn.Linear(d_feature + d_time, d_feature) + self.weight_combine = nn.Linear(n_features + n_steps, n_features) - def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]: + def _process( + self, + inputs: dict, + diagonal_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, list]: X, masks = inputs["X"], inputs["missing_mask"] + # first DMSA block input_X_for_first = torch.cat([X, masks], dim=2) input_X_for_first = self.embedding_1(input_X_for_first) @@ -110,7 +112,7 @@ def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]: self.position_enc(input_X_for_first) ) # namely, term e in the math equation for encoder_layer in self.layer_stack_for_first_block: - enc_output, _ = encoder_layer(enc_output) + enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask) X_tilde_1 = self.reduce_dim_z(enc_output) X_prime = masks * X + (1 - masks) * X_tilde_1 @@ -146,9 +148,23 @@ def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]: return X_c, [X_tilde_1, X_tilde_2, X_tilde_3] - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward( + self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True + ) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(inputs) + + if (training and self.diagonal_attention_mask) or ( + (not training) and diagonal_attention_mask + ): + diagonal_attention_mask = torch.eye(self.n_steps).to(X.device) + # then broadcast on the batch axis + diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0) + else: + diagonal_attention_mask = None + + imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process( + inputs, diagonal_attention_mask + ) if not training: # if not in training mode, return the classification result only @@ -427,7 +443,8 @@ def fit( def impute( self, X: Union[dict, str], - file_type="h5py", + file_type: str = "h5py", + diagonal_attention_mask: bool = True, ) -> np.ndarray: # Step 1: wrap the input data with classes Dataset and DataLoader self.model.eval() # set the model as eval status to freeze it. @@ -444,7 +461,9 @@ def impute( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward( + inputs, diagonal_attention_mask, training=False + ) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index fd5e103b..7e61f6ae 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -22,9 +22,9 @@ from torch.utils.data import DataLoader from .data import DatasetForSAITS -from .modules import EncoderLayer, PositionalEncoding from ..base import BaseNNImputer from ...data.base import BaseDataset +from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.metrics import cal_mae @@ -55,8 +55,6 @@ def __init__( self.layer_stack = nn.ModuleList( [ EncoderLayer( - d_time, - actual_d_feature, d_model, d_inner, n_heads, @@ -64,7 +62,6 @@ def __init__( d_v, dropout, attn_dropout, - False, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index dc631af6..cc4788f2 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -219,26 +219,6 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict: class Discriminator(nn.Module): - """model Discriminator: built on BiRNN - - Parameters - ---------- - n_features : - the feature dimension of the input - - rnn_hidden_size : - the hidden size of the RNN cell - - hint_rate : - the hint rate for the input imputed_data - - dropout_rate : - the dropout rate for the output layer - - device : - specify running the model on which device, CPU/GPU - - """ def __init__( self, n_features: int, @@ -257,17 +237,6 @@ def __init__( self.read_out = nn.Linear(2 * rnn_hidden_size, n_features).to(device) def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of Discriminator. - - Parameters - ---------- - inputs : - The input data. - - Returns - ------- - dict, the logits of the probability of being the true value. - """ x = inputs["imputed_data"] m = inputs["forward"]["missing_mask"] @@ -375,7 +344,7 @@ def reverse_tensor(tensor_): return ret def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of Generator. + """Forward processing of BRITS. Parameters ---------- @@ -436,17 +405,11 @@ class _USGAN(nn.Module): rnn_hidden_size : the hidden size of the RNN cell - lambda_mse : - the weigth of the reconstruction loss - - hint_rate : - the hint rate for the discriminator - - dropout_rate : - the dropout rate for the last layer in Discriminator + rits_f: RITS object + the forward RITS model - device : - specify running the model on which device, CPU/GPU + rits_b: RITS object + the backward RITS model """ @@ -486,6 +449,7 @@ def forward( X = inputs["forward"]["X"] missing_mask = inputs["forward"]["missing_mask"] + batch_size, n_steps, n_features = X.shape losses = {} inputs["imputed_data"], inputs["reconstruction"] = self.generator( inputs, training=False @@ -523,17 +487,24 @@ class USGAN(BaseNNImputer): n_features : The number of features in the time-series data sample. + n_clusters : + The number of clusters in the clustering task. + + n_generator_layers : + The number of layers in the generator. + rnn_hidden_size : - the hidden size of the RNN cell + The size of the RNN hidden state, also the number of hidden units in the RNN cell. + + rnn_cell_type : + The type of RNN cell to use. Can be either "GRU" or "LSTM". + + decoder_fcn_output_dims : + The output dimensions of each layer in the FCN (fully-connected network) of the decoder. - lambda_mse : - the weigth of the reconstruction loss - - hint_rate : - the hint rate for the discriminator - - dropout_rate : - the dropout rate for the last layer in Discriminator + lambda_kmeans : + The weight of the k-means loss, + i.e. the item :math:`\\lambda` ahead of :math:`\\mathcal{L}_{k-means}` in Eq.13 of the original paper. G_steps : The number of steps to train the generator in each iteration. @@ -602,7 +573,7 @@ def __init__( hint_rate: float = 0.7, dropout_rate: float = 0.0, G_steps: int = 1, - D_steps: int = 1, + D_steps: int = 5, batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, @@ -704,7 +675,7 @@ def _train_model( step_train_loss_G_collector = [] step_train_loss_D_collector = [] - + # for _ in range(self.G_steps): if idx % self.G_steps == 0: self.G_optimizer.zero_grad() results = self.model.forward( @@ -716,6 +687,7 @@ def _train_model( results["generation_loss"].item() ) + # for _ in range(self.D_steps): if idx % self.D_steps == 0: self.D_optimizer.zero_grad() results = self.model.forward( diff --git a/pypots/utils/__init__.py b/pypots/utils/__init__.py index 1376524d..5fd28d97 100644 --- a/pypots/utils/__init__.py +++ b/pypots/utils/__init__.py @@ -8,7 +8,7 @@ __all__ = [ # content files in this package - "files", + "file.py", "logging", "metrics", "random", diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index 058cc24a..ee7d69bb 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -54,7 +54,7 @@ def cal_mae( so the result is 1/2=0.5. """ - assert type(predictions) == type(targets), ( + assert isinstance(predictions, type(targets)), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -110,7 +110,7 @@ def cal_mse( """ - assert type(predictions) == type(targets), ( + assert isinstance(predictions, type(targets)), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -166,7 +166,7 @@ def cal_rmse( so the result is :math:`\\sqrt{1/2}=0.5`. """ - assert type(predictions) == type(targets), ( + assert isinstance(predictions, type(targets)), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) @@ -217,7 +217,7 @@ def cal_mre( so the result is :math:`\\sqrt{1/2}=0.5`. """ - assert type(predictions) == type(targets), ( + assert isinstance(predictions, type(targets)), ( f"types of inputs and target must match, but got" f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) From c76fd6947d1e2d7de64187766c443fc9e2461f2d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 6 Sep 2023 23:19:04 +0800 Subject: [PATCH 04/17] refactor: USGAN model; --- pypots/imputation/usgan/data.py | 132 +--------- pypots/imputation/usgan/model.py | 409 ++++------------------------- pypots/imputation/usgan/modules.py | 140 ---------- 3 files changed, 61 insertions(+), 620 deletions(-) delete mode 100644 pypots/imputation/usgan/modules.py diff --git a/pypots/imputation/usgan/data.py b/pypots/imputation/usgan/data.py index 2a92ee27..bd012c30 100644 --- a/pypots/imputation/usgan/data.py +++ b/pypots/imputation/usgan/data.py @@ -2,19 +2,16 @@ Dataset class for model USGAN. """ -# Created by Jun Wang +# Created by Jun Wang and Wenjie Du # License: GLP-v3 -from typing import Union, Iterable +from typing import Union -import torch +from ..brits.data import DatasetForBRITS -from ...data.base import BaseDataset -from ...data.utils import torch_parse_delta - -class DatasetForUSGAN(BaseDataset): - """Dataset class for USGAN. +class DatasetForUSGAN(DatasetForBRITS): + """Dataset class for USGAN, the same with the one for BRITS. Parameters ---------- @@ -47,122 +44,3 @@ def __init__( file_type: str = "h5py", ): super().__init__(data, return_labels, file_type) - - if not isinstance(self.data, str): - # calculate all delta here. - forward_missing_mask = (~torch.isnan(self.X)).type(torch.float32) - forward_X = torch.nan_to_num(self.X) - forward_delta = torch_parse_delta(forward_missing_mask) - backward_X = torch.flip(forward_X, dims=[1]) - backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) - backward_delta = torch_parse_delta(backward_missing_mask) - - self.processed_data = { - "forward": { - "X": forward_X, - "missing_mask": forward_missing_mask, - "delta": forward_delta, - }, - "backward": { - "X": backward_X, - "missing_mask": backward_missing_mask, - "delta": backward_delta, - }, - } - - def _fetch_data_from_array(self, idx: int) -> Iterable: - """Fetch data from self.X if it is given. - - Parameters - ---------- - idx : int, - The index of the sample to be return. - - Returns - ------- - sample : list, - A list contains - - index : int tensor, - The index of the sample. - - X : tensor, - The feature vector for model input. - - missing_mask : tensor, - The mask indicates all missing values in X. - - delta : tensor, - The delta matrix contains time gaps of missing values. - - label (optional) : tensor, - The target label of the time-series sample. - """ - sample = [ - torch.tensor(idx), - # for forward - self.processed_data["forward"]["X"][idx].to(torch.float32), - self.processed_data["forward"]["missing_mask"][idx].to(torch.float32), - self.processed_data["forward"]["delta"][idx].to(torch.float32), - # for backward - self.processed_data["backward"]["X"][idx].to(torch.float32), - self.processed_data["backward"]["missing_mask"][idx].to(torch.float32), - self.processed_data["backward"]["delta"][idx].to(torch.float32), - ] - - if self.y is not None and self.return_labels: - sample.append(self.y[idx].to(torch.long)) - - return sample - - def _fetch_data_from_file(self, idx: int) -> Iterable: - """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. - Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. - - Parameters - ---------- - idx : int, - The index of the sample to be return. - - Returns - ------- - sample : list, - The collated data sample, a list including all necessary sample info. - """ - - if self.file_handle is None: - self.file_handle = self._open_file_handle() - - X = torch.from_numpy(self.file_handle["X"][idx]) - missing_mask = (~torch.isnan(X)).to(torch.float32) - X = torch.nan_to_num(X) - - forward = { - "X": X, - "missing_mask": missing_mask, - "deltas": torch_parse_delta(missing_mask), - } - - backward = { - "X": torch.flip(forward["X"], dims=[0]), - "missing_mask": torch.flip(forward["missing_mask"], dims=[0]), - } - backward["deltas"] = torch_parse_delta(backward["missing_mask"]) - - sample = [ - torch.tensor(idx), - # for forward - forward["X"], - forward["missing_mask"], - forward["deltas"], - # for backward - backward["X"], - backward["missing_mask"], - backward["deltas"], - ] - - # if the dataset has labels and is for training, then fetch it from the file - if "y" in self.file_handle.keys() and self.return_labels: - sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long)) - - return sample diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index cc4788f2..365dcd40 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -4,17 +4,12 @@ Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." -Notes ------ -Partial implementation uses code from https://github.com/zjuwuyy-DL/Generative-Semi-supervised-Learning-for-Multivariate-Time-Series-Imputation. The bugs in the original implementation -are fixed here. - """ -# Created by Jun Wang +# Created by Jun Wang and Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional +from typing import Union, Optional import h5py import numpy as np @@ -24,201 +19,35 @@ from torch.utils.data import DataLoader from .data import DatasetForUSGAN -from .modules import TemporalDecay, FeatureRegression - -# from ..brits.model import RITS from ..base import BaseNNImputer +from ..brits.model import _BRITS from ...optim.adam import Adam from ...optim.base import Optimizer -from ...utils.metrics import cal_mae, cal_mse from ...utils.logging import logger -class RITS(nn.Module): - """model RITS: Recurrent Imputation for Time Series +class Discriminator(nn.Module): + """model Discriminator: built on BiRNN - Attributes + Parameters ---------- - n_steps : - sequence length (number of time steps) - n_features : - number of features (input dimensions) + the feature dimension of the input rnn_hidden_size : the hidden size of the RNN cell - device : - specify running the model on which device, CPU/GPU - - rnn_cell : - the LSTM cell to model temporal data - - temp_decay_h : - the temporal decay module to decay RNN hidden state - - temp_decay_x : - the temporal decay module to decay data in the raw feature space - - hist_reg : - the temporal-regression module to project RNN hidden state into the raw feature space + hint_rate : + the hint rate for the input imputed_data - feat_reg : - the feature-regression module - - combining_weight : - the module used to generate the weight to combine history regression and feature regression - - Parameters - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell + dropout_rate : + the dropout rate for the output layer device : specify running the model on which device, CPU/GPU """ - def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - device: Union[str, torch.device], - ): - super().__init__() - self.n_steps = n_steps - self.n_features = n_features - self.rnn_hidden_size = rnn_hidden_size - self.device = device - - self.rnn_cell = nn.LSTMCell(self.n_features * 2, self.rnn_hidden_size) - self.temp_decay_h = TemporalDecay( - input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False - ) - self.temp_decay_x = TemporalDecay( - input_size=self.n_features, output_size=self.n_features, diag=True - ) - self.hist_reg = nn.Linear(self.rnn_hidden_size, self.n_features) - self.feat_reg = FeatureRegression(self.n_features) - self.combining_weight = nn.Linear(self.n_features * 2, self.n_features) - - def impute( - self, inputs: dict, direction: str - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """The imputation function. - Parameters - ---------- - inputs : - Input data, a dictionary includes feature values, missing masks, and time-gap values. - - direction : - A keyword to extract data from parameter `data`. - - Returns - ------- - imputed_data : - [batch size, sequence length, feature number] - - hidden_states: tensor, - [batch size, RNN hidden size] - - reconstruction_loss : - reconstruction loss - - """ - values = inputs[direction]["X"] # feature values - masks = inputs[direction]["missing_mask"] # missing masks - deltas = inputs[direction]["deltas"] # time-gap values - - # create hidden states and cell states for the lstm cell - hidden_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=values.device - ) - cell_states = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=values.device - ) - - estimations = [] - reconstruction_loss = torch.tensor(0.0).to(values.device) - - # imputation period - for t in range(self.n_steps): - # data shape: [batch, time, features] - x = values[:, t, :] # values - m = masks[:, t, :] # mask - d = deltas[:, t, :] # delta, time gap - - gamma_h = self.temp_decay_h(d) - gamma_x = self.temp_decay_x(d) - - hidden_states = hidden_states * gamma_h # decay hidden states - x_h = self.hist_reg(hidden_states) - reconstruction_loss += cal_mae(x_h, x, m) - - x_c = m * x + (1 - m) * x_h - - z_h = self.feat_reg(x_c) - reconstruction_loss += cal_mae(z_h, x, m) - - alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) - - c_h = alpha * z_h + (1 - alpha) * x_h - reconstruction_loss += cal_mae(c_h, x, m) - - c_c = m * x + (1 - m) * c_h - estimations.append(c_h.unsqueeze(dim=1)) - - inputs = torch.cat([c_c, m], dim=1) - hidden_states, cell_states = self.rnn_cell( - inputs, (hidden_states, cell_states) - ) - - estimations = torch.cat(estimations, dim=1) - imputed_data = masks * values + (1 - masks) * estimations - return imputed_data, hidden_states, reconstruction_loss, estimations - - def forward(self, inputs: dict, direction: str = "forward") -> dict: - """Forward processing of the NN module. - Parameters - ---------- - inputs : - The input data. - - direction : - A keyword to extract data from parameter `data`. - - Returns - ------- - dict, - A dictionary includes all results. - - """ - imputed_data, hidden_state, reconstruction_loss, estimations = self.impute( - inputs, direction - ) - # for each iteration, reconstruction_loss increases its value for 3 times - reconstruction_loss /= self.n_steps * 3 - - ret_dict = { - "consistency_loss": torch.tensor( - 0.0, device=imputed_data.device - ), # single direction, has no consistency loss - "reconstruction_loss": reconstruction_loss, - "imputed_data": imputed_data, - "final_hidden_state": hidden_state, - "estimations": estimations, - } - return ret_dict - - -class Discriminator(nn.Module): def __init__( self, n_features: int, @@ -230,164 +59,45 @@ def __init__( super().__init__() self.hint_rate = hint_rate self.device = device - self.birnn = nn.GRU( - 2 * n_features, rnn_hidden_size, bidirectional=True, batch_first=True + self.biRNN = nn.GRU( + n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True ).to(device) self.dropout = nn.Dropout(dropout_rate).to(device) - self.read_out = nn.Linear(2 * rnn_hidden_size, n_features).to(device) - - def forward(self, inputs: dict, training: bool = True) -> dict: - x = inputs["imputed_data"] - m = inputs["forward"]["missing_mask"] - - hint = ( - torch.rand_like(m, dtype=torch.float, device=self.device) < self.hint_rate - ) - hint = hint.byte() - h = hint * m + (1 - hint) * 0.5 - x_in = torch.cat([x, h], dim=-1) - - out, _ = self.birnn(x_in) - logits = self.read_out(self.dropout(out)) - return logits - - -class Generator(nn.Module): - """model Generator: - - Attributes - ---------- - n_steps : - sequence length (number of time steps) - - n_features : - number of features (input dimensions) - - rnn_hidden_size : - the hidden size of the RNN cell + self.read_out = nn.Linear(rnn_hidden_size * 2, n_features).to(device) - rits_f: RITS object - the forward RITS model - - rits_b: RITS object - the backward RITS model - - """ - - def __init__( + def forward( self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - device: Union[str, torch.device], - ): - super().__init__() - # data settings - self.n_steps = n_steps - self.n_features = n_features - # imputer settings - self.rnn_hidden_size = rnn_hidden_size - # create models - self.rits_f = RITS(n_steps, n_features, rnn_hidden_size, device) - self.rits_b = RITS(n_steps, n_features, rnn_hidden_size, device) - - @staticmethod - def _get_consistency_loss( - pred_f: torch.Tensor, pred_b: torch.Tensor + imputed_X: torch.Tensor, + missing_mask: torch.Tensor, ) -> torch.Tensor: - """Calculate the consistency loss between the imputation from two RITS models. + """Forward processing of USGAN Discriminator. Parameters ---------- - pred_f : - The imputation from the forward RITS. - - pred_b : - The imputation from the backward RITS (already gets reverted). + imputed_X : torch.Tensor, + The original X with missing parts already imputed. - Returns - ------- - float tensor, - The consistency loss. - - """ - loss = torch.abs(pred_f - pred_b).mean() * 1e-1 - return loss - - @staticmethod - def _reverse(ret: dict) -> dict: - """Reverse the array values on the time dimension in the given dictionary. - - Parameters - ---------- - ret : + missing_mask : torch.Tensor, + The missing mask of X. Returns ------- - dict, - A dictionary contains values reversed on the time dimension from the given dict. - - """ + logits : torch.Tensor, + the logits of the probability of being the true value. - def reverse_tensor(tensor_): - if tensor_.dim() <= 1: - return tensor_ - indices = range(tensor_.size()[1])[::-1] - indices = torch.tensor( - indices, dtype=torch.long, device=tensor_.device, requires_grad=False - ) - return tensor_.index_select(1, indices) - - for key in ret: - ret[key] = reverse_tensor(ret[key]) - - return ret - - def forward(self, inputs: dict, training: bool = True) -> dict: - """Forward processing of BRITS. - - Parameters - ---------- - inputs : - The input data. - - Returns - ------- - dict, A dictionary includes all results. """ - # Results from the forward RITS. - ret_f = self.rits_f(inputs, "forward") - # Results from the backward RITS. - ret_b = self._reverse(self.rits_b(inputs, "backward")) - - imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 - estimation = (ret_f["estimations"] + ret_b["estimations"]) / 2 - - if not training: - # if not in training mode, return the classification result only - # return { - # "imputed_data": imputed_data, - # } - return imputed_data, estimation - - consistency_loss = self._get_consistency_loss( - ret_f["imputed_data"], ret_b["imputed_data"] - ) - # `loss` is always the item for backward propagating to update the model - loss = ( - consistency_loss - + ret_f["reconstruction_loss"] - + ret_b["reconstruction_loss"] + hint = ( + torch.rand_like(missing_mask, dtype=torch.float, device=self.device) + < self.hint_rate ) + hint = hint.int() + h = hint * missing_mask + (1 - hint) * 0.5 + x_in = torch.cat([imputed_X, h], dim=-1) - results = { - "imputed_data": imputed_data, - "consistency_loss": consistency_loss, - "loss": loss, # will be used for backward propagating to update the model - } - - return results + out, _ = self.biRNN(x_in) + logits = self.read_out(self.dropout(out)) + return logits class _USGAN(nn.Module): @@ -405,11 +115,17 @@ class _USGAN(nn.Module): rnn_hidden_size : the hidden size of the RNN cell - rits_f: RITS object - the forward RITS model + lambda_mse : + the weigth of the reconstruction loss - rits_b: RITS object - the backward RITS model + hint_rate : + the hint rate for the discriminator + + dropout_rate : + the dropout rate for the last layer in Discriminator + + device : + specify running the model on which device, CPU/GPU """ @@ -424,7 +140,7 @@ def __init__( device: Union[str, torch.device] = "cpu", ): super().__init__() - self.generator = Generator(n_steps, n_features, rnn_hidden_size, device) + self.generator = _BRITS(n_steps, n_features, rnn_hidden_size, device) self.discriminator = Discriminator( n_features, rnn_hidden_size, @@ -447,14 +163,10 @@ def forward( "discriminator", ], 'training_object should be "generator" or "discriminator"' - X = inputs["forward"]["X"] missing_mask = inputs["forward"]["missing_mask"] - batch_size, n_steps, n_features = X.shape losses = {} - inputs["imputed_data"], inputs["reconstruction"] = self.generator( - inputs, training=False - ) - inputs["discrimination"] = self.discriminator(inputs, training=False) + results = self.generator(inputs, training=training) + inputs["discrimination"] = self.discriminator(inputs) if not training: # if only run clustering, then no need to calculate loss return inputs @@ -469,8 +181,7 @@ def forward( l_G = F.binary_cross_entropy_with_logits( inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask ) - l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) - loss_gene = l_G + self.lambda_mse * l_rec + loss_gene = l_G + self.lambda_mse * results["loss"] losses["generation_loss"] = loss_gene losses["imputed_data"] = inputs["imputed_data"] return losses @@ -487,24 +198,17 @@ class USGAN(BaseNNImputer): n_features : The number of features in the time-series data sample. - n_clusters : - The number of clusters in the clustering task. - - n_generator_layers : - The number of layers in the generator. - rnn_hidden_size : - The size of the RNN hidden state, also the number of hidden units in the RNN cell. + the hidden size of the RNN cell - rnn_cell_type : - The type of RNN cell to use. Can be either "GRU" or "LSTM". + lambda_mse : + the weigth of the reconstruction loss - decoder_fcn_output_dims : - The output dimensions of each layer in the FCN (fully-connected network) of the decoder. + hint_rate : + the hint rate for the discriminator - lambda_kmeans : - The weight of the k-means loss, - i.e. the item :math:`\\lambda` ahead of :math:`\\mathcal{L}_{k-means}` in Eq.13 of the original paper. + dropout_rate : + the dropout rate for the last layer in Discriminator G_steps : The number of steps to train the generator in each iteration. @@ -573,7 +277,7 @@ def __init__( hint_rate: float = 0.7, dropout_rate: float = 0.0, G_steps: int = 1, - D_steps: int = 5, + D_steps: int = 1, batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, @@ -675,7 +379,7 @@ def _train_model( step_train_loss_G_collector = [] step_train_loss_D_collector = [] - # for _ in range(self.G_steps): + if idx % self.G_steps == 0: self.G_optimizer.zero_grad() results = self.model.forward( @@ -687,7 +391,6 @@ def _train_model( results["generation_loss"].item() ) - # for _ in range(self.D_steps): if idx % self.D_steps == 0: self.D_optimizer.zero_grad() results = self.model.forward( diff --git a/pypots/imputation/usgan/modules.py b/pypots/imputation/usgan/modules.py deleted file mode 100644 index c242dee4..00000000 --- a/pypots/imputation/usgan/modules.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -The implementation of SSGAN for the partially-observed time-series imputation task. - -Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). -Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." - -Notes ------ -Partial implementation uses code from https://github.com/zjuwuyy-DL/Generative-Semi-supervised-Learning-for-Multivariate-Time-Series-Imputation. The bugs in the original implementation -are fixed here. - -""" - -# Created by Jun Wang -# License: GPL-v3 - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -from torch.nn.parameter import Parameter - - -class FeatureRegression(nn.Module): - """The module used to capture the correlation between features for imputation. - - Attributes - ---------- - W : tensor - The weights (parameters) of the module. - - b : tensor - The bias of the module. - - m (buffer) : tensor - The mask matrix, a squire matrix with diagonal entries all zeroes while left parts all ones. - It is applied to the weight matrix to mask out the estimation contributions from features themselves. - It is used to help enhance the imputation performance of the network. - - Parameters - ---------- - input_size : the feature dimension of the input - """ - - def __init__(self, input_size: int): - super().__init__() - self.W = Parameter(torch.Tensor(input_size, input_size)) - self.b = Parameter(torch.Tensor(input_size)) - - m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) - self.register_buffer("m", m) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - std_dev = 1.0 / math.sqrt(self.W.size(0)) - self.W.data.uniform_(-std_dev, std_dev) - if self.b is not None: - self.b.data.uniform_(-std_dev, std_dev) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward processing of the NN module. - - Parameters - ---------- - x : tensor, - the input for processing - - Returns - ------- - output: tensor, - the processed result containing imputation from feature regression - - """ - output = F.linear(x, self.W * Variable(self.m), self.b) - return output - - -class TemporalDecay(nn.Module): - """The module used to generate the temporal decay factor gamma in the original paper. - - Attributes - ---------- - W: tensor, - The weights (parameters) of the module. - b: tensor, - The bias of the module. - - Parameters - ---------- - input_size : int, - the feature dimension of the input - - output_size : int, - the feature dimension of the output - - diag : bool, - whether to product the weight with an identity matrix before forward processing - """ - - def __init__(self, input_size: int, output_size: int, diag: bool = False): - super().__init__() - self.diag = diag - self.W = Parameter(torch.Tensor(output_size, input_size)) - self.b = Parameter(torch.Tensor(output_size)) - - if self.diag: - assert input_size == output_size - m = torch.eye(input_size, input_size) - self.register_buffer("m", m) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - std_dev = 1.0 / math.sqrt(self.W.size(0)) - self.W.data.uniform_(-std_dev, std_dev) - if self.b is not None: - self.b.data.uniform_(-std_dev, std_dev) - - def forward(self, delta: torch.Tensor) -> torch.Tensor: - """Forward processing of the NN module. - - Parameters - ---------- - delta : tensor, shape [batch size, sequence length, feature number] - The time gaps. - - Returns - ------- - gamma : array-like, same shape with parameter `delta`, values in (0,1] - The temporal decay factor. - """ - if self.diag: - gamma = F.relu(F.linear(delta, self.W * Variable(self.m), self.b)) - else: - gamma = F.relu(F.linear(delta, self.W, self.b)) - gamma = torch.exp(-gamma) - return gamma From bf0ac7b26ca1155149cfee29043cb6c84892ec4e Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 18 Sep 2023 21:33:34 +0800 Subject: [PATCH 05/17] docs: update the link of contributor svg; --- docs/about_us.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/about_us.rst b/docs/about_us.rst index aaaab944..370a3e0d 100644 --- a/docs/about_us.rst +++ b/docs/about_us.rst @@ -33,5 +33,5 @@ PyPOTS exists thanks to all the nice people (sorted by contribution time) who co .. raw:: html - + From 6d163b903355eac27b038aac88b7db507e4317a1 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 13 Sep 2023 18:50:29 +0800 Subject: [PATCH 06/17] fix: make US-GAN runnable now; --- pypots/imputation/__init__.py | 10 +++++++++- pypots/imputation/mrnn/module.py | 2 +- pypots/imputation/usgan/model.py | 20 ++++++++++++-------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 3d513430..c865ae94 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -13,4 +13,12 @@ from .gpvae import GPVAE from .usgan import USGAN -__all__ = ["SAITS", "Transformer", "BRITS", "MRNN", "LOCF", "GPVAE" "USGAN"] +__all__ = [ + "SAITS", + "Transformer", + "BRITS", + "MRNN", + "LOCF", + "GPVAE", + "USGAN", +] diff --git a/pypots/imputation/mrnn/module.py b/pypots/imputation/mrnn/module.py index 873d2d73..a143d121 100644 --- a/pypots/imputation/mrnn/module.py +++ b/pypots/imputation/mrnn/module.py @@ -18,7 +18,7 @@ class FCN_Regression(nn.Module): def __init__(self, feature_num, rnn_hid_size): - super(FCN_Regression, self).__init__() + super().__init__() self.feat_reg = FeatureRegression(rnn_hid_size * 2) self.U = Parameter(torch.Tensor(feature_num, feature_num)) self.V1 = Parameter(torch.Tensor(feature_num, feature_num)) diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index 365dcd40..c171d810 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -163,27 +163,31 @@ def forward( "discriminator", ], 'training_object should be "generator" or "discriminator"' - missing_mask = inputs["forward"]["missing_mask"] + forward_X = inputs["forward"]["X"] + forward_missing_mask = inputs["forward"]["missing_mask"] losses = {} results = self.generator(inputs, training=training) - inputs["discrimination"] = self.discriminator(inputs) + inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask) if not training: - # if only run clustering, then no need to calculate loss - return inputs + # if only run imputation operation, then no need to calculate loss + return results if training_object == "discriminator": l_D = F.binary_cross_entropy_with_logits( - inputs["discrimination"], missing_mask + inputs["discrimination"], forward_missing_mask ) losses["discrimination_loss"] = l_D else: inputs["discrimination"] = inputs["discrimination"].detach() l_G = F.binary_cross_entropy_with_logits( - inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask + inputs["discrimination"], + 1 - forward_missing_mask, + weight=1 - forward_missing_mask, ) loss_gene = l_G + self.lambda_mse * results["loss"] losses["generation_loss"] = loss_gene - losses["imputed_data"] = inputs["imputed_data"] + + losses["imputed_data"] = results["imputed_data"] return losses @@ -202,7 +206,7 @@ class USGAN(BaseNNImputer): the hidden size of the RNN cell lambda_mse : - the weigth of the reconstruction loss + the weight of the reconstruction loss hint_rate : the hint rate for the discriminator From 286506984c173c5eac4a954b82e91a8597eb41ae Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 13 Sep 2023 19:11:32 +0800 Subject: [PATCH 07/17] feat: add the unit test for US-GAN; --- tests/test_imputation.py | 81 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/tests/test_imputation.py b/tests/test_imputation.py index 6094ce62..0db0900c 100644 --- a/tests/test_imputation.py +++ b/tests/test_imputation.py @@ -15,6 +15,7 @@ from pypots.imputation import ( SAITS, Transformer, + USGAN, BRITS, MRNN, LOCF, @@ -194,6 +195,82 @@ def test_3_saving_path(self): self.transformer.load_model(saved_model_path) +class TestUSGAN(unittest.TestCase): + logger.info("Running tests for an imputation model US-GAN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "US-GAN") + model_save_name = "saved_USGAN_model.pypots" + + # initialize an Adam optimizer + G_optimizer = Adam(lr=0.001, weight_decay=1e-5) + D_optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a BRITS model + us_gan = USGAN( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCH, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_0_fit(self): + self.us_gan.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_1_impute(self): + imputed_X = self.us_gan.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"US-GAN test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_2_parameters(self): + assert hasattr(self.us_gan, "model") and self.us_gan.model is not None + + assert ( + hasattr(self.us_gan, "G_optimizer") and self.us_gan.G_optimizer is not None + ) + assert ( + hasattr(self.us_gan, "D_optimizer") and self.us_gan.D_optimizer is not None + ) + + assert hasattr(self.us_gan, "best_loss") + self.assertNotEqual(self.us_gan.best_loss, float("inf")) + + assert ( + hasattr(self.us_gan, "best_model_dict") + and self.us_gan.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.us_gan) + + # save the trained model into file, and check if the path exists + self.us_gan.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.us_gan.load_model(saved_model_path) + + class TestBRITS(unittest.TestCase): logger.info("Running tests for an imputation model BRITS...") @@ -210,7 +287,7 @@ class TestBRITS(unittest.TestCase): DATA["n_features"], 256, epochs=EPOCH, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/BRITS", + saving_path=saving_path, optimizer=optimizer, ) @@ -279,7 +356,7 @@ class TestMRNN(unittest.TestCase): DATA["n_features"], 256, epochs=EPOCH, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/MRNN", + saving_path=saving_path, optimizer=optimizer, ) From 03bb9346a32599016f999c687935a8194bd00fbc Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 17 Sep 2023 23:59:15 +0800 Subject: [PATCH 08/17] fix: make GP-VAE runnable now; --- pypots/imputation/__init__.py | 4 +- pypots/imputation/gpvae/data.py | 19 +---- pypots/imputation/gpvae/model.py | 115 +++++++++++++++++++++-------- pypots/imputation/gpvae/modules.py | 112 +++++++++++++++++----------- 4 files changed, 161 insertions(+), 89 deletions(-) diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index c865ae94..a6c4dcd8 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -6,11 +6,11 @@ # License: GPL-v3 from .brits import BRITS +from .gpvae import GPVAE from .locf import LOCF +from .mrnn import MRNN from .saits import SAITS from .transformer import Transformer -from .mrnn import MRNN -from .gpvae import GPVAE from .usgan import USGAN __all__ = [ diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py index de7d7747..4f8b27c4 100644 --- a/pypots/imputation/gpvae/data.py +++ b/pypots/imputation/gpvae/data.py @@ -2,7 +2,7 @@ Dataset class for model GP-VAE. """ -# Created by Jun Wang +# Created by Jun Wang and Wenjie Du # License: GLP-v3 from typing import Union, Iterable @@ -14,7 +14,7 @@ class DatasetForGPVAE(BaseDataset): - """Dataset class for BRITS. + """Dataset class for GP-VAE. Parameters ---------- @@ -52,12 +52,10 @@ def __init__( # calculate all delta here. missing_mask = (~torch.isnan(self.X)).type(torch.float32) X = torch.nan_to_num(self.X) - delta = torch_parse_delta(missing_mask) self.processed_data = { "X": X, "missing_mask": missing_mask, - "delta": delta, } def _fetch_data_from_array(self, idx: int) -> Iterable: @@ -93,7 +91,6 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: # for forward self.processed_data["X"][idx].to(torch.float32), self.processed_data["missing_mask"][idx].to(torch.float32), - self.processed_data["delta"][idx].to(torch.float32), ] if self.y is not None and self.return_labels: @@ -123,18 +120,10 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) - forward = { - "X": X, - "missing_mask": missing_mask, - "deltas": torch_parse_delta(missing_mask), - } - sample = [ torch.tensor(idx), - # for forward - forward["X"], - forward["missing_mask"], - forward["deltas"], + X, + missing_mask, ] # if the dataset has labels and is for training, then fetch it from the file diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 3bf5a866..aaa59b5d 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -1,19 +1,16 @@ """ The implementation of GP-VAE for the partially-observed time-series imputation task. -Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. Gp-vae: Deep probabilistic time series imputation[C]//International conference on artificial intelligence and statistics. PMLR, 2020: 1651-1661. - -Notes ------ -Pytorch implementation of the code from https://github.com/ratschlab/GP-VAE. +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. +GP-VAE: Deep probabilistic time series imputation. AISTATS. PMLR, 2020: 1651-1661. """ -# Created by Jun Wang +# Created by Jun Wang and Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional +from typing import Union, Optional import h5py import numpy as np @@ -21,15 +18,65 @@ import torch.nn as nn from torch.utils.data import DataLoader -from .modules import * from .data import DatasetForGPVAE +from .modules import ( + Encoder, + rbf_kernel, + diffusion_kernel, + matern_kernel, + cauchy_kernel, + Decoder, +) from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer -from ...utils.metrics import cal_mae class _GPVAE(nn.Module): + """model GPVAE with Gaussian Process prior + + Parameters + ---------- + input_dim : int, + the feature dimension of the input + + time_length : int, + the length of each time series + + latent_dim : int, + the feature dimension of the latent embedding + + device : str, + specify running the model on which device, CPU/GPU + + encoder_sizes : tuple, + the tuple of the network size in encoder + + decoder_sizes : tuple, + the tuple of the network size in decoder + + beta : float, + the weight of the KL divergence + + M : int, + the number of Monte Carlo samples for ELBO estimation + + K : int, + the number of importance weights for IWAE model + + kernel : str, + the Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] + + sigma : float, + the scale parameter for a kernel function + + length_scale : float, + the length scale parameter for a kernel function + + kernel_scales : int, + the number of different length scales over latent space dimensions + """ + def __init__( self, input_dim, @@ -37,9 +84,7 @@ def __init__( latent_dim, device, encoder_sizes=(64, 64), - encoder=Encoder, decoder_sizes=(64, 64), - decoder=Decoder, beta=1, M=1, K=1, @@ -48,13 +93,7 @@ def __init__( length_scale=7.0, kernel_scales=1, ): - """GPVAE model with Gaussian Process prior - :param kernel: Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] - :param sigma: scale parameter for a kernel function - :param length_scale: length scale parameter for a kernel function - :param kernel_scales: number of different length scales over latent space dimensions - """ - super(_GPVAE, self).__init__() + super().__init__() self.kernel = kernel self.sigma = sigma self.length_scale = length_scale @@ -69,8 +108,8 @@ def __init__( self.time_length = time_length self.latent_dim = latent_dim self.beta = beta - self.encoder = encoder(input_dim, latent_dim, encoder_sizes).to(device) - self.decoder = decoder(latent_dim, input_dim, decoder_sizes).to(device) + self.encoder = Encoder(input_dim, latent_dim, encoder_sizes).to(device) + self.decoder = Decoder(latent_dim, input_dim, decoder_sizes).to(device) self.device = device self.M = M self.K = K @@ -89,9 +128,8 @@ def __call__(self, inputs): return self.decoder(self.encode(inputs).sample()).sample() def forward(self, inputs, training=True): - x = inputs["forward"]["X"] - m_mask = inputs["forward"]["missing_mask"] - delta = inputs["forward"]["deltas"] + x = inputs["X"] + m_mask = inputs["missing_mask"] x = x.repeat(self.M * self.K, 1, 1) if m_mask is not None: m_mask = m_mask.repeat(self.M * self.K, 1, 1) @@ -256,8 +294,15 @@ def __init__( n_steps: int, n_features: int, latent_size: int, + encoder_sizes: tuple = (64, 64), + decoder_sizes: tuple = (64, 64), kernel: str = "cauchy", beta: float = 0.2, + M: int = 1, + K: int = 1, + sigma: float = 1.0, + length_scale: float = 7.0, + kernel_scales: int = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -281,7 +326,14 @@ def __init__( self.n_features = n_features self.latent_size = latent_size self.kernel = kernel + self.encoder_sizes = encoder_sizes + self.decoder_sizes = decoder_sizes self.beta = beta + self.M = M + self.K = K + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales # set up the model self.model = _GPVAE( @@ -289,8 +341,15 @@ def __init__( time_length=self.n_steps, latent_dim=self.latent_size, kernel=self.kernel, - device=self.device, + encoder_sizes=self.encoder_sizes, + decoder_sizes=self.decoder_sizes, beta=self.beta, + M=self.M, + K=self.K, + sigma=self.sigma, + length_scale=self.length_scale, + kernel_scales=self.kernel_scales, + device=self.device, ) self._send_model_to_given_device() self._print_model_size() @@ -305,17 +364,13 @@ def _assemble_input_for_training(self, data: list) -> dict: indices, X, missing_mask, - deltas, ) = self._send_data_to_given_device(data) # assemble input data inputs = { "indices": indices, - "forward": { - "X": X, - "missing_mask": missing_mask, - "deltas": deltas, - }, + "X": X, + "missing_mask": missing_mask, } return inputs diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py index 41e74c4b..dea83f68 100644 --- a/pypots/imputation/gpvae/modules.py +++ b/pypots/imputation/gpvae/modules.py @@ -1,26 +1,19 @@ """ The implementation of GP-VAE for the partially-observed time-series imputation task. -Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. Gp-vae: Deep probabilistic time series imputation[C]//International conference on artificial intelligence and statistics. PMLR, 2020: 1651-1661. +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. +GP-VAE: Deep probabilistic time series imputation. AISTATS. PMLR, 2020: 1651-1661. -Notes ------ -Pytorch implementation of the code from https://github.com/ratschlab/GP-VAE. """ -# Created by Jun Wang +# Created by Jun Wang and Wenjie Du # License: GPL-v3 -from typing import Tuple, Union, Optional - -import h5py import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.distributions.multivariate_normal import MultivariateNormal def rbf_kernel(T, length_scale): @@ -72,10 +65,23 @@ def cauchy_kernel(T, sigma, length_scale): def make_nn(input_size, output_size, hidden_sizes): - """Creates fully connected neural network - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. + """This function used to creates fully connected neural network. + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers + + Returns + ------- + output: tensor + the processing embeddings """ layers = [] for i in range(len(hidden_sizes)): @@ -94,11 +100,7 @@ def make_nn(input_size, output_size, hidden_sizes): class CustomConv1d(torch.nn.Conv1d): def __init(self, in_channels, out_channels, kernal_size, padding): - super(CustomConv1d, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernal_size - self.padding = padding + super().__init__(in_channels, out_channels, kernal_size, padding) def forward(self, x): if len(x.shape) > 2: @@ -115,13 +117,28 @@ def forward(self, x): def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): - """Construct neural network consisting of - one 1d-convolutional layer that utilizes temporal dependences, - fully connected network - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. - :param kernel_size: kernel size for convolutional layer + """This function used to construct neural network consisting of + one 1d-convolutional layer that utilizes temporal dependences, + fully connected network + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers, + + kernel_size : int + kernel size for convolutional layer + + Returns + ------- + output: tensor + the processing embeddings """ padding = kernel_size // 2 @@ -142,17 +159,24 @@ def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): class Encoder(nn.Module): def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): - """Encoder with 1d-convolutional network and multivariate Normal posterior - Used by GP-VAE with proposed banded covariance matrix - :param z_size: latent space dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. - :param window_size: kernel size for Conv1D layer - :param data_type: needed for some data specific modifications, e.g: - tf.nn.softplus is a more common and correct choice, however - tf.nn.sigmoid provides more stable performance on Physionet dataset + """This module is an encoder with 1d-convolutional network and multivariate Normal posterior used by GP-VAE with + proposed banded covariance matrix + + Parameters + ---------- + input_size : int, + the feature dimension of the input + + z_size : int, + the feature dimension of the output latent embedding + + hidden_sizes : tuple, + the tuple of the hidden layer sizes, and the tuple length sets the number of hidden layers + + window_size : int + the kernel size for the Conv1D layer """ - super(Encoder, self).__init__() + super().__init__() self.z_size = int(z_size) self.input_size = input_size self.net, self.mu_layer, self.logvar_layer = make_cnn( @@ -164,7 +188,6 @@ def __call__(self, x): batch_size = mapped.size(0) time_length = mapped.size(1) - # Obtain mean and precision matrix components num_dim = len(mapped.shape) mu = self.mu_layer(mapped) logvar = self.logvar_layer(mapped) @@ -219,12 +242,17 @@ def __call__(self, x): class Decoder(nn.Module): def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): - """Decoder with Gaussian output distribution - :param output_size: output dimensionality - :param hidden_sizes: tuple of hidden layer sizes. - The tuple length sets the number of hidden layers. + """This module is a decoder with Gaussian output distribution + + Parameters + ---------- + output_size : int, + the feature dimension of the output + + hidden_sizes: tuple + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers. """ - super(Decoder, self).__init__() + super().__init__() self.output_size = int(output_size) self.net = make_nn(input_size, output_size, hidden_sizes) From 80e6e22b3ea9864c9d9703b44bc0a143f69aad11 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 19 Sep 2023 16:19:44 +0800 Subject: [PATCH 09/17] feat: add the unit test for GP-VAE; --- tests/test_imputation.py | 72 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/test_imputation.py b/tests/test_imputation.py index 0db0900c..64a0b1ff 100644 --- a/tests/test_imputation.py +++ b/tests/test_imputation.py @@ -16,6 +16,7 @@ SAITS, Transformer, USGAN, + GPVAE, BRITS, MRNN, LOCF, @@ -206,7 +207,7 @@ class TestUSGAN(unittest.TestCase): G_optimizer = Adam(lr=0.001, weight_decay=1e-5) D_optimizer = Adam(lr=0.001, weight_decay=1e-5) - # initialize a BRITS model + # initialize a US-GAN model us_gan = USGAN( DATA["n_steps"], DATA["n_features"], @@ -271,6 +272,75 @@ def test_3_saving_path(self): self.us_gan.load_model(saved_model_path) +class TestGPVAE(unittest.TestCase): + logger.info("Running tests for an imputation model GP-VAE...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "GP-VAE") + model_save_name = "saved_GPVAE_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a GP-VAE model + gp_vae = GPVAE( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCH, + saving_path=saving_path, + optimizer=optimizer, + ) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_0_fit(self): + self.gp_vae.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_1_impute(self): + imputed_X = self.gp_vae.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"GP-VAE test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_2_parameters(self): + assert hasattr(self.gp_vae, "model") and self.gp_vae.model is not None + + assert hasattr(self.gp_vae, "optimizer") and self.gp_vae.optimizer is not None + + assert hasattr(self.gp_vae, "best_loss") + self.assertNotEqual(self.gp_vae.best_loss, float("inf")) + + assert ( + hasattr(self.gp_vae, "best_model_dict") + and self.gp_vae.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-GPVAE") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.gp_vae) + + # save the trained model into file, and check if the path exists + self.gp_vae.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.gp_vae.load_model(saved_model_path) + + class TestBRITS(unittest.TestCase): logger.info("Running tests for an imputation model BRITS...") From 778dba0b3cb8105d61c832cb1ec665e043c5b4bf Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 14:49:39 +0800 Subject: [PATCH 10/17] refactor: simplify some code in GP-VAE; --- pypots/imputation/gpvae/model.py | 120 ++++++++++++++--------------- pypots/imputation/gpvae/modules.py | 15 ++-- 2 files changed, 63 insertions(+), 72 deletions(-) diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index aaa59b5d..6b613d4d 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -46,9 +46,6 @@ class _GPVAE(nn.Module): latent_dim : int, the feature dimension of the latent embedding - device : str, - specify running the model on which device, CPU/GPU - encoder_sizes : tuple, the tuple of the network size in encoder @@ -65,7 +62,7 @@ class _GPVAE(nn.Module): the number of importance weights for IWAE model kernel : str, - the Gaussial Process kernel ["cauchy", "diffusion", "rbf", "matern"] + the Gaussian Process kernel ["cauchy", "diffusion", "rbf", "matern"] sigma : float, the scale parameter for a kernel function @@ -82,7 +79,6 @@ def __init__( input_dim, time_length, latent_dim, - device, encoder_sizes=(64, 64), decoder_sizes=(64, 64), beta=1, @@ -92,6 +88,7 @@ def __init__( sigma=1.0, length_scale=7.0, kernel_scales=1, + window_size=24, ): super().__init__() self.kernel = kernel @@ -99,21 +96,20 @@ def __init__( self.length_scale = length_scale self.kernel_scales = kernel_scales - # Precomputed KL components for efficiency - self.pz_scale_inv = None - self.pz_scale_log_abs_determinant = None - self.prior = None - self.input_dim = input_dim self.time_length = time_length self.latent_dim = latent_dim self.beta = beta - self.encoder = Encoder(input_dim, latent_dim, encoder_sizes).to(device) - self.decoder = Decoder(latent_dim, input_dim, decoder_sizes).to(device) - self.device = device + self.encoder = Encoder(input_dim, latent_dim, encoder_sizes, window_size) + self.decoder = Decoder(latent_dim, input_dim, decoder_sizes) self.M = M self.K = K + # Precomputed KL components for efficiency + self.prior = self._init_prior() + # self.pz_scale_inv = None + # self.pz_scale_log_abs_determinant = None + def encode(self, x): return self.encoder(x) @@ -124,9 +120,6 @@ def decode(self, z): assert num_dim > 2 return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) - def __call__(self, inputs): - return self.decoder(self.encode(inputs).sample()).sample() - def forward(self, inputs, training=True): x = inputs["X"] m_mask = inputs["missing_mask"] @@ -135,7 +128,7 @@ def forward(self, inputs, training=True): m_mask = m_mask.repeat(self.M * self.K, 1, 1) m_mask = m_mask.type(torch.bool) - pz = self._get_prior() + # pz = self.prior() qz_x = self.encode(x) z = qz_x.rsample() px_z = self.decode(z) @@ -147,7 +140,7 @@ def forward(self, inputs, training=True): nll = nll.sum(dim=(1, 2)) if self.K > 1: - kl = qz_x.log_prob(z) - pz.log_prob(z) + kl = qz_x.log_prob(z) - self.prior.log_prob(z) kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) kl = kl.sum(1) @@ -157,7 +150,7 @@ def forward(self, inputs, training=True): elbo = torch.logsumexp(weights, dim=1) elbo = elbo.mean() else: - kl = self.kl_divergence(qz_x, pz) + kl = self.kl_divergence(qz_x, self.prior) kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) kl = kl.sum(1) @@ -178,53 +171,54 @@ def forward(self, inputs, training=True): } return results - def kl_divergence(self, a, b): + @staticmethod + def kl_divergence(a, b): + # TODO: different from the author's implementation return torch.distributions.kl.kl_divergence(a, b) - def _get_prior(self): - if self.prior is None: - # Compute kernel matrices for each latent dimension - kernel_matrices = [] - for i in range(self.kernel_scales): - if self.kernel == "rbf": - kernel_matrices.append( - rbf_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "diffusion": - kernel_matrices.append( - diffusion_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "matern": - kernel_matrices.append( - matern_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "cauchy": - kernel_matrices.append( - cauchy_kernel( - self.time_length, self.sigma, self.length_scale / 2**i - ) + def _init_prior(self): + # Compute kernel matrices for each latent dimension + kernel_matrices = [] + for i in range(self.kernel_scales): + if self.kernel == "rbf": + kernel_matrices.append( + rbf_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "diffusion": + kernel_matrices.append( + diffusion_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "matern": + kernel_matrices.append( + matern_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "cauchy": + kernel_matrices.append( + cauchy_kernel( + self.time_length, self.sigma, self.length_scale / 2**i ) - - # Combine kernel matrices for each latent dimension - tiled_matrices = [] - total = 0 - for i in range(self.kernel_scales): - if i == self.kernel_scales - 1: - multiplier = self.latent_dim - total - else: - multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) - total += multiplier - tiled_matrices.append( - torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) ) - kernel_matrix_tiled = torch.cat(tiled_matrices) - assert len(kernel_matrix_tiled) == self.latent_dim - self.prior = torch.distributions.MultivariateNormal( - loc=torch.zeros(self.latent_dim, self.time_length, device=self.device), - covariance_matrix=kernel_matrix_tiled.to(self.device), + + # Combine kernel matrices for each latent dimension + tiled_matrices = [] + total = 0 + for i in range(self.kernel_scales): + if i == self.kernel_scales - 1: + multiplier = self.latent_dim - total + else: + multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) + total += multiplier + tiled_matrices.append( + torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) ) + kernel_matrix_tiled = torch.cat(tiled_matrices) + assert len(kernel_matrix_tiled) == self.latent_dim + prior = torch.distributions.MultivariateNormal( + loc=torch.zeros(self.latent_dim, self.time_length), + covariance_matrix=kernel_matrix_tiled, + ) - return self.prior + return prior class GPVAE(BaseNNImputer): @@ -232,9 +226,6 @@ class GPVAE(BaseNNImputer): Parameters ---------- - latent_dim : - The size of the latent variable. - beta: The weight of KL divergence in EBLO. @@ -303,6 +294,7 @@ def __init__( sigma: float = 1.0, length_scale: float = 7.0, kernel_scales: int = 1, + window_size: int = 3, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -349,7 +341,7 @@ def __init__( sigma=self.sigma, length_scale=self.length_scale, kernel_scales=self.kernel_scales, - device=self.device, + window_size=window_size, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py index dea83f68..5ad81e09 100644 --- a/pypots/imputation/gpvae/modules.py +++ b/pypots/imputation/gpvae/modules.py @@ -99,8 +99,8 @@ def make_nn(input_size, output_size, hidden_sizes): class CustomConv1d(torch.nn.Conv1d): - def __init(self, in_channels, out_channels, kernal_size, padding): - super().__init__(in_channels, out_channels, kernal_size, padding) + def __init(self, in_channels, out_channels, kernel_size, padding): + super().__init__(in_channels, out_channels, kernel_size, padding) def forward(self, x): if len(x.shape) > 2: @@ -118,7 +118,7 @@ def forward(self, x): def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): """This function used to construct neural network consisting of - one 1d-convolutional layer that utilizes temporal dependences, + one 1d-convolutional layer that utilizes temporal dependencies, fully connected network Parameters @@ -183,7 +183,7 @@ def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): input_size, (z_size, z_size * 2), hidden_sizes, window_size ) - def __call__(self, x): + def forward(self, x): mapped = self.net(x) batch_size = mapped.size(0) time_length = mapped.size(1) @@ -235,14 +235,14 @@ def __call__(self, x): cov_tril_lower = torch.transpose(cov_tril, num_dim - 1, num_dim - 2) z_dist = torch.distributions.MultivariateNormal( - loc=mapped_mean, scale_tril=(cov_tril_lower) + loc=mapped_mean, scale_tril=cov_tril_lower ) return z_dist class Decoder(nn.Module): def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): - """This module is a decoder with Gaussian output distribution + """This module is a decoder with Gaussian output distribution. Parameters ---------- @@ -253,10 +253,9 @@ def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers. """ super().__init__() - self.output_size = int(output_size) self.net = make_nn(input_size, output_size, hidden_sizes) - def __call__(self, x): + def forward(self, x): mu = self.net(x) var = torch.ones_like(mu) return torch.distributions.Normal(mu, var) From ca6e2cdab11238b02bfaf893346cc831de069211 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 22:31:08 +0800 Subject: [PATCH 11/17] Refactor testing cases (#189) * refactor: clear up testing cases; * refactor: refactor code in Dataset classes for models; * refactor: adjust testing workflows according to refactored test cases; * fix: turn missing_mask into torch.float; * fix: error in BTTF testing case; * feat: using pip to manage dependencies in CI testing workflow, and using conda in Daily testing workflow; --- .github/workflows/testing_ci.yml | 59 +- .github/workflows/testing_daily.yml | 60 +- docs/pypots.forecasting.rst | 24 +- pypots/classification/grud/data.py | 2 +- .../template/{dataset.py => data.py} | 0 .../template/{dataset.py => data.py} | 0 pypots/clustering/vader/data.py | 12 +- pypots/data/base.py | 16 +- pypots/data/saving.py | 11 +- .../template/{dataset.py => data.py} | 0 pypots/imputation/brits/data.py | 26 +- pypots/imputation/gpvae/data.py | 9 +- pypots/imputation/saits/data.py | 20 +- .../template/{dataset.py => data.py} | 0 tests/classification/__init__.py | 6 + tests/classification/brits.py | 106 +++ tests/classification/config.py | 21 + tests/classification/grud.py | 105 +++ tests/classification/raindrop.py | 110 +++ tests/cli/__init__.py | 6 + tests/cli/config.py | 11 + tests/cli/dev.py | 92 ++ tests/cli/doc.py | 104 +++ tests/cli/env.py | 49 ++ tests/clustering/__init__.py | 6 + tests/clustering/config.py | 22 + tests/clustering/crli.py | 103 +++ .../vader.py} | 93 +-- tests/data/__init__.py | 6 + .../lazy_loading_strategy.py} | 77 +- tests/forecasting/__init__.py | 6 + .../bttf.py} | 14 +- tests/forecasting/config.py | 23 + tests/global_test_config.py | 13 + tests/imputation/__init__.py | 6 + tests/imputation/brits.py | 104 +++ tests/imputation/config.py | 25 + tests/imputation/gpvae.py | 104 +++ tests/imputation/locf.py | 46 + tests/imputation/mrnn.py | 104 +++ tests/imputation/saits.py | 110 +++ tests/imputation/transformer.py | 113 +++ tests/imputation/usgan.py | 111 +++ tests/optim/__init__.py | 6 + tests/optim/adadelta.py | 56 ++ tests/optim/adagrad.py | 56 ++ tests/optim/adam.py | 56 ++ tests/optim/adamw.py | 56 ++ tests/optim/config.py | 19 + tests/optim/rmsprop.py | 56 ++ tests/optim/sgd.py | 56 ++ tests/test_classification.py | 256 ------ tests/test_cli.py | 189 ----- tests/test_imputation.py | 503 ----------- tests/test_optim.py | 244 ------ tests/test_training_on_multi_gpus.py | 783 ------------------ tests/utils/__init__.py | 6 + tests/{test_utils.py => utils/logging.py} | 25 +- tests/utils/random.py | 36 + 59 files changed, 2111 insertions(+), 2227 deletions(-) rename pypots/classification/template/{dataset.py => data.py} (100%) rename pypots/clustering/template/{dataset.py => data.py} (100%) rename pypots/forecasting/template/{dataset.py => data.py} (100%) rename pypots/imputation/template/{dataset.py => data.py} (100%) create mode 100644 tests/classification/__init__.py create mode 100644 tests/classification/brits.py create mode 100644 tests/classification/config.py create mode 100644 tests/classification/grud.py create mode 100644 tests/classification/raindrop.py create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/config.py create mode 100644 tests/cli/dev.py create mode 100644 tests/cli/doc.py create mode 100644 tests/cli/env.py create mode 100644 tests/clustering/__init__.py create mode 100644 tests/clustering/config.py create mode 100644 tests/clustering/crli.py rename tests/{test_clustering.py => clustering/vader.py} (51%) create mode 100644 tests/data/__init__.py rename tests/{test_data.py => data/lazy_loading_strategy.py} (56%) create mode 100644 tests/forecasting/__init__.py rename tests/{test_forecasting.py => forecasting/bttf.py} (78%) create mode 100644 tests/forecasting/config.py create mode 100644 tests/imputation/__init__.py create mode 100644 tests/imputation/brits.py create mode 100644 tests/imputation/config.py create mode 100644 tests/imputation/gpvae.py create mode 100644 tests/imputation/locf.py create mode 100644 tests/imputation/mrnn.py create mode 100644 tests/imputation/saits.py create mode 100644 tests/imputation/transformer.py create mode 100644 tests/imputation/usgan.py create mode 100644 tests/optim/__init__.py create mode 100644 tests/optim/adadelta.py create mode 100644 tests/optim/adagrad.py create mode 100644 tests/optim/adam.py create mode 100644 tests/optim/adamw.py create mode 100644 tests/optim/config.py create mode 100644 tests/optim/rmsprop.py create mode 100644 tests/optim/sgd.py delete mode 100644 tests/test_classification.py delete mode 100644 tests/test_cli.py delete mode 100644 tests/test_imputation.py delete mode 100644 tests/test_optim.py delete mode 100644 tests/test_training_on_multi_gpus.py create mode 100644 tests/utils/__init__.py rename tests/{test_utils.py => utils/logging.py} (64%) create mode 100644 tests/utils/random.py diff --git a/.github/workflows/testing_ci.yml b/.github/workflows/testing_ci.yml index d339afe5..7e5b6780 100644 --- a/.github/workflows/testing_ci.yml +++ b/.github/workflows/testing_ci.yml @@ -15,43 +15,60 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - shell: bash -l {0} + shell: bash {0} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.10"] + torch-version: ["1.13.1"] steps: - name: Check out the repo code uses: actions/checkout@v3 - - name: Set up Conda - uses: conda-incubator/setup-miniconda@v2 + - name: Determine the Python version + uses: haya14busa/action-cond@v1 + id: condval with: - activate-environment: pypots-test - python-version: ${{ matrix.python-version }} - environment-file: tests/environment_for_conda_test.yml - auto-activate-base: false + cond: ${{ matrix.python-version == 3.7 && matrix.os == 'macOS-latest' }} + # Note: the latest 3.7 subversion 3.7.17 for MacOS has "ModuleNotFoundError: No module named '_bz2'" + if_true: "3.7.16" + if_false: ${{ matrix.python-version }} + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ steps.condval.outputs.value }} + check-latest: true + cache: pip + cache-dependency-path: | + setup.cfg + + - name: Install PyTorch ${{ matrix.torch-version }}+cpu + # we have to install torch in advance because torch_sparse needs it for compilation, + # refer to https://github.com/rusty1s/pytorch_sparse/issues/156#issuecomment-1304869772 for details + run: | + which python + which pip + python -m pip install --upgrade pip + pip install torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cpu + python -c "import torch; print('PyTorch:', torch.__version__)" + + - name: Install other dependencies + run: | + pip install pypots + pip install torch-geometric torch-scatter torch-sparse -f "https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html" + pip install -e ".[dev]" - name: Fetch the test environment details run: | which python - conda info - conda list + pip list - name: Test with pytest run: | - # run tests separately here due to Segmentation Fault in test_clustering when run all in - # one command with `pytest` on MacOS. Bugs not caught, so this is a trade-off to avoid SF. - python -m pytest -rA tests/test_classification.py -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_imputation.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_clustering.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_forecasting.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_optim.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_data.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_utils.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - python -m pytest -rA tests/test_cli.py -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + coverage run --source=pypots -m pytest -rA tests/*/* - name: Generate the LCOV report run: | @@ -61,4 +78,4 @@ jobs: uses: coverallsapp/github-action@master with: github-token: ${{ secrets.GITHUB_TOKEN }} - path-to-lcov: 'coverage.lcov' + path-to-lcov: "coverage.lcov" diff --git a/.github/workflows/testing_daily.yml b/.github/workflows/testing_daily.yml index f0b3ba61..5e41630f 100644 --- a/.github/workflows/testing_daily.yml +++ b/.github/workflows/testing_daily.yml @@ -10,61 +10,43 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - shell: bash {0} + shell: bash -l {0} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - python-version: ["3.7", "3.8", "3.9", "3.10"] - torch-version: ["1.13.1"] + python-version: ["3.7", "3.10"] steps: - name: Check out the repo code uses: actions/checkout@v3 - - name: Determine the Python version - uses: haya14busa/action-cond@v1 - id: condval + - name: Set up Conda + uses: conda-incubator/setup-miniconda@v2 with: - cond: ${{ matrix.python-version == 3.7 && matrix.os == 'macOS-latest' }} - # Note: the latest 3.7 subversion 3.7.17 for MacOS has "ModuleNotFoundError: No module named '_bz2'" - if_true: "3.7.16" - if_false: ${{ matrix.python-version }} - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ steps.condval.outputs.value }} - check-latest: true - cache: pip - cache-dependency-path: | - setup.cfg - - - name: Install PyTorch ${{ matrix.torch-version }}+cpu - # we have to install torch in advance because torch_sparse needs it for compilation, - # refer to https://github.com/rusty1s/pytorch_sparse/issues/156#issuecomment-1304869772 for details - run: | - which python - which pip - python -m pip install --upgrade pip - pip install torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cpu - python -c "import torch; print('PyTorch:', torch.__version__)" - - - name: Install other dependencies - run: | - pip install pypots - pip install torch-geometric torch-scatter torch-sparse -f "https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html" - pip install -e ".[dev]" + activate-environment: pypots-test + python-version: ${{ matrix.python-version }} + environment-file: tests/environment_for_conda_test.yml + auto-activate-base: false - name: Fetch the test environment details run: | which python - pip list + conda info + conda list - name: Test with pytest run: | - coverage run --source=pypots -m pytest --ignore tests/test_training_on_multi_gpus.py - # ignore the test_training_on_multi_gpus.py because it requires multiple GPUs which are not available on GitHub Actions + # run tests separately here due to Segmentation Fault in test_clustering when run all in + # one command with `pytest` on MacOS. Bugs not caught, so this is a trade-off to avoid SF. + python -m pytest -rA tests/classification/* -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/imputation/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/clustering/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/forecasting/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/optim/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/data/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/utils/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc + python -m pytest -rA tests/cli/* -n auto --cov=pypots --cov-append --dist=loadgroup --cov-config=.coveragerc - name: Generate the LCOV report run: | @@ -74,4 +56,4 @@ jobs: uses: coverallsapp/github-action@master with: github-token: ${{ secrets.GITHUB_TOKEN }} - path-to-lcov: "coverage.lcov" + path-to-lcov: 'coverage.lcov' diff --git a/docs/pypots.forecasting.rst b/docs/pypots.forecasting.rst index 2ae67b85..c4ac76b7 100644 --- a/docs/pypots.forecasting.rst +++ b/docs/pypots.forecasting.rst @@ -1,11 +1,31 @@ pypots.forecasting package ========================== +Subpackages +----------- -pypots.forecasting.bttf module +.. toctree:: + :maxdepth: 4 + + pypots.forecasting.bttf + pypots.forecasting.template + +Submodules +---------- + +pypots.forecasting.base module ------------------------------ -.. automodule:: pypots.forecasting.bttf +.. automodule:: pypots.forecasting.base + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +Module contents +--------------- + +.. automodule:: pypots.forecasting :members: :undoc-members: :show-inheritance: diff --git a/pypots/classification/grud/data.py b/pypots/classification/grud/data.py index 52186017..edf1d4d0 100644 --- a/pypots/classification/grud/data.py +++ b/pypots/classification/grud/data.py @@ -123,7 +123,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: if self.file_handle is None: self.file_handle = self._open_file_handle() - X = torch.from_numpy(self.file_handle["X"][idx]) + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) missing_mask = (~torch.isnan(X)).to(torch.float32) X_filledLOCF = self.locf._locf_torch(X.unsqueeze(dim=0)).squeeze() X = torch.nan_to_num(X) diff --git a/pypots/classification/template/dataset.py b/pypots/classification/template/data.py similarity index 100% rename from pypots/classification/template/dataset.py rename to pypots/classification/template/data.py diff --git a/pypots/clustering/template/dataset.py b/pypots/clustering/template/data.py similarity index 100% rename from pypots/clustering/template/dataset.py rename to pypots/clustering/template/data.py diff --git a/pypots/clustering/vader/data.py b/pypots/clustering/vader/data.py index a3b2f91d..a8910b44 100644 --- a/pypots/clustering/vader/data.py +++ b/pypots/clustering/vader/data.py @@ -6,12 +6,12 @@ # License: GLP-v3 -from typing import Union +from typing import Union, Iterable -from ..crli.data import DatasetForCRLI +from ...data.base import BaseDataset -class DatasetForVaDER(DatasetForCRLI): +class DatasetForVaDER(BaseDataset): """Dataset class for model VaDER. Parameters @@ -45,3 +45,9 @@ def __init__( file_type: str = "h5py", ): super().__init__(data, return_labels, file_type) + + def _fetch_data_from_array(self, idx: int) -> Iterable: + return super()._fetch_data_from_array(idx) + + def _fetch_data_from_file(self, idx: int) -> Iterable: + return super()._fetch_data_from_file(idx) diff --git a/pypots/data/base.py b/pypots/data/base.py index 86b15fc2..1bef9f9c 100644 --- a/pypots/data/base.py +++ b/pypots/data/base.py @@ -204,13 +204,13 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: The collated data sample, a list including all necessary sample info. """ - X = self.X[idx] - missing_mask = ~torch.isnan(X) + X = self.X[idx].to(torch.float32) + missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) sample = [ torch.tensor(idx), - X.to(torch.float32), - missing_mask.to(torch.float32), + X, + missing_mask, ] if self.y is not None and self.return_labels: @@ -279,13 +279,13 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: if self.file_handle is None: self.file_handle = self._open_file_handle() - X = torch.from_numpy(self.file_handle["X"][idx]) - missing_mask = ~torch.isnan(X) + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) + missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) sample = [ torch.tensor(idx), - X.to(torch.float32), - missing_mask.to(torch.float32), + X, + missing_mask, ] # if the dataset has labels and is for training, then fetch it from the file diff --git a/pypots/data/saving.py b/pypots/data/saving.py index 8581ad50..61138df2 100644 --- a/pypots/data/saving.py +++ b/pypots/data/saving.py @@ -14,7 +14,11 @@ from pypots.utils.logging import logger -def save_dict_into_h5(data_dict: dict, saving_dir: str) -> None: +def save_dict_into_h5( + data_dict: dict, + saving_dir: str, + saving_name: str = "datasets.h5", +) -> None: """Save the given data (in a dictionary) into the given h5 file. Parameters @@ -25,6 +29,9 @@ def save_dict_into_h5(data_dict: dict, saving_dir: str) -> None: saving_dir : str, The h5 file to save the data. + saving_name : str, optional (default="datasets.h5") + The final name of the saved h5 file. + """ def save_set(handle, name, data): @@ -36,7 +43,7 @@ def save_set(handle, name, data): handle.create_dataset(name, data=data) create_dir_if_not_exist(saving_dir) - saving_path = os.path.join(saving_dir, "datasets.h5") + saving_path = os.path.join(saving_dir, saving_name) with h5py.File(saving_path, "w") as hf: for k, v in data_dict.items(): save_set(hf, k, v) diff --git a/pypots/forecasting/template/dataset.py b/pypots/forecasting/template/data.py similarity index 100% rename from pypots/forecasting/template/dataset.py rename to pypots/forecasting/template/data.py diff --git a/pypots/imputation/brits/data.py b/pypots/imputation/brits/data.py index f39e411c..342ede98 100644 --- a/pypots/imputation/brits/data.py +++ b/pypots/imputation/brits/data.py @@ -59,14 +59,14 @@ def __init__( self.processed_data = { "forward": { - "X": forward_X, - "missing_mask": forward_missing_mask, - "delta": forward_delta, + "X": forward_X.to(torch.float32), + "missing_mask": forward_missing_mask.to(torch.float32), + "delta": forward_delta.to(torch.float32), }, "backward": { - "X": backward_X, - "missing_mask": backward_missing_mask, - "delta": backward_delta, + "X": backward_X.to(torch.float32), + "missing_mask": backward_missing_mask.to(torch.float32), + "delta": backward_delta.to(torch.float32), }, } @@ -101,13 +101,13 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: sample = [ torch.tensor(idx), # for forward - self.processed_data["forward"]["X"][idx].to(torch.float32), - self.processed_data["forward"]["missing_mask"][idx].to(torch.float32), - self.processed_data["forward"]["delta"][idx].to(torch.float32), + self.processed_data["forward"]["X"][idx], + self.processed_data["forward"]["missing_mask"][idx], + self.processed_data["forward"]["delta"][idx], # for backward - self.processed_data["backward"]["X"][idx].to(torch.float32), - self.processed_data["backward"]["missing_mask"][idx].to(torch.float32), - self.processed_data["backward"]["delta"][idx].to(torch.float32), + self.processed_data["backward"]["X"][idx], + self.processed_data["backward"]["missing_mask"][idx], + self.processed_data["backward"]["delta"][idx], ] if self.y is not None and self.return_labels: @@ -133,7 +133,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: if self.file_handle is None: self.file_handle = self._open_file_handle() - X = torch.from_numpy(self.file_handle["X"][idx]) + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py index 4f8b27c4..8bb9be8c 100644 --- a/pypots/imputation/gpvae/data.py +++ b/pypots/imputation/gpvae/data.py @@ -10,7 +10,6 @@ import torch from ...data.base import BaseDataset -from ...data.utils import torch_parse_delta class DatasetForGPVAE(BaseDataset): @@ -51,7 +50,7 @@ def __init__( if not isinstance(self.data, str): # calculate all delta here. missing_mask = (~torch.isnan(self.X)).type(torch.float32) - X = torch.nan_to_num(self.X) + X = torch.nan_to_num(self.X).to(torch.float32) self.processed_data = { "X": X, @@ -89,8 +88,8 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: sample = [ torch.tensor(idx), # for forward - self.processed_data["X"][idx].to(torch.float32), - self.processed_data["missing_mask"][idx].to(torch.float32), + self.processed_data["X"][idx], + self.processed_data["missing_mask"][idx], ] if self.y is not None and self.return_labels: @@ -116,7 +115,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: if self.file_handle is None: self.file_handle = self._open_file_handle() - X = torch.from_numpy(self.file_handle["X"][idx]) + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) diff --git a/pypots/imputation/saits/data.py b/pypots/imputation/saits/data.py index 2fb80bc3..5ff679a5 100644 --- a/pypots/imputation/saits/data.py +++ b/pypots/imputation/saits/data.py @@ -88,15 +88,15 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: indicating_mask : tensor. The mask indicates artificially missing values in X. """ - X = self.X[idx] + X = self.X[idx].to(torch.float32) X_intact, X, missing_mask, indicating_mask = mcar(X, rate=self.rate) sample = [ torch.tensor(idx), - X_intact.to(torch.float32), - X.to(torch.float32), - missing_mask.to(torch.float32), - indicating_mask.to(torch.float32), + X_intact, + X, + missing_mask, + indicating_mask, ] if self.y is not None and self.return_labels: @@ -122,15 +122,15 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: if self.file_handle is None: self.file_handle = self._open_file_handle() - X = torch.from_numpy(self.file_handle["X"][idx]) + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) X_intact, X, missing_mask, indicating_mask = mcar(X, rate=self.rate) sample = [ torch.tensor(idx), - X_intact.to(torch.float32), - X.to(torch.float32), - missing_mask.to(torch.float32), - indicating_mask.to(torch.float32), + X_intact, + X, + missing_mask, + indicating_mask, ] # if the dataset has labels and is for training, then fetch it from the file diff --git a/pypots/imputation/template/dataset.py b/pypots/imputation/template/data.py similarity index 100% rename from pypots/imputation/template/dataset.py rename to pypots/imputation/template/data.py diff --git a/tests/classification/__init__.py b/tests/classification/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/classification/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/classification/brits.py b/tests/classification/brits.py new file mode 100644 index 00000000..b1905c39 --- /dev/null +++ b/tests/classification/brits.py @@ -0,0 +1,106 @@ +""" +Test cases for BRITS classification model. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import unittest + +import pytest + +from pypots.classification import BRITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_binary_classification_metrics +from tests.classification.config import ( + EPOCHS, + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_CLASSIFICATION, +) +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) + + +class TestBRITS(unittest.TestCase): + logger.info("Running tests for a classification model BRITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "BRITS") + model_save_name = "saved_BRITS_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a BRITS model + brits = BRITS( + DATA["n_steps"], + DATA["n_features"], + n_classes=DATA["n_classes"], + rnn_hidden_size=256, + epochs=EPOCHS, + saving_path=saving_path, + model_saving_strategy="better", + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="classification-brits") + def test_0_fit(self): + self.brits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-brits") + def test_1_classify(self): + predictions = self.brits.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-brits") + def test_2_parameters(self): + assert hasattr(self.brits, "model") and self.brits.model is not None + + assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None + + assert hasattr(self.brits, "best_loss") + self.assertNotEqual(self.brits.best_loss, float("inf")) + + assert ( + hasattr(self.brits, "best_model_dict") + and self.brits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-brits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.brits) + + # save the trained model into file, and check if the path exists + self.brits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.brits.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/classification/config.py b/tests/classification/config.py new file mode 100644 index 00000000..35b17029 --- /dev/null +++ b/tests/classification/config.py @@ -0,0 +1,21 @@ +""" +Test configs for classification models. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os + +from tests.global_test_config import ( + DATA, + RESULT_SAVING_DIR, +) + +EPOCHS = 5 + +TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]} +VAL_SET = {"X": DATA["val_X"], "y": DATA["val_y"]} +TEST_SET = {"X": DATA["test_X"]} + +RESULT_SAVING_DIR_FOR_CLASSIFICATION = os.path.join(RESULT_SAVING_DIR, "classification") diff --git a/tests/classification/grud.py b/tests/classification/grud.py new file mode 100644 index 00000000..a662cb70 --- /dev/null +++ b/tests/classification/grud.py @@ -0,0 +1,105 @@ +""" +Test cases for GRUD classification model. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import unittest + +import pytest + +from pypots.classification import GRUD +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_binary_classification_metrics +from tests.classification.config import ( + EPOCHS, + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_CLASSIFICATION, +) +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) + + +class TestGRUD(unittest.TestCase): + logger.info("Running tests for a classification model GRUD...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "GRUD") + model_save_name = "saved_GRUD_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a GRUD model + grud = GRUD( + DATA["n_steps"], + DATA["n_features"], + n_classes=DATA["n_classes"], + rnn_hidden_size=256, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="classification-grud") + def test_0_fit(self): + self.grud.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-grud") + def test_1_classify(self): + predictions = self.grud.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-grud") + def test_2_parameters(self): + assert hasattr(self.grud, "model") and self.grud.model is not None + + assert hasattr(self.grud, "optimizer") and self.grud.optimizer is not None + + assert hasattr(self.grud, "best_loss") + self.assertNotEqual(self.grud.best_loss, float("inf")) + + assert ( + hasattr(self.grud, "best_model_dict") + and self.grud.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-grud") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.grud) + + # save the trained model into file, and check if the path exists + self.grud.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.grud.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/classification/raindrop.py b/tests/classification/raindrop.py new file mode 100644 index 00000000..277164dc --- /dev/null +++ b/tests/classification/raindrop.py @@ -0,0 +1,110 @@ +""" +Test cases for Raindrop classification model. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import unittest + +import pytest + +from pypots.classification import Raindrop +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_binary_classification_metrics +from tests.classification.config import ( + EPOCHS, + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_CLASSIFICATION, +) +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) + + +class TestRaindrop(unittest.TestCase): + logger.info("Running tests for a classification model Raindrop...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "Raindrop") + model_save_name = "saved_Raindrop_model.pypots" + + # initialize a Raindrop model + raindrop = Raindrop( + DATA["n_steps"], + DATA["n_features"], + DATA["n_classes"], + n_layers=2, + d_model=DATA["n_features"] * 4, + d_inner=256, + n_heads=2, + dropout=0.3, + d_static=0, + aggregation="mean", + sensor_wise_mask=False, + static=False, + epochs=EPOCHS, + saving_path=saving_path, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_0_fit(self): + self.raindrop.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_1_classify(self): + predictions = self.raindrop.classify(TEST_SET) + metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + logger.info( + f'ROC_AUC: {metrics["roc_auc"]}, \n' + f'PR_AUC: {metrics["pr_auc"]},\n' + f'F1: {metrics["f1"]},\n' + f'Precision: {metrics["precision"]},\n' + f'Recall: {metrics["recall"]},\n' + ) + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_2_parameters(self): + assert hasattr(self.raindrop, "model") and self.raindrop.model is not None + + assert ( + hasattr(self.raindrop, "optimizer") and self.raindrop.optimizer is not None + ) + + assert hasattr(self.raindrop, "best_loss") + self.assertNotEqual(self.raindrop.best_loss, float("inf")) + + assert ( + hasattr(self.raindrop, "best_model_dict") + and self.raindrop.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="classification-raindrop") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.raindrop) + + # save the trained model into file, and check if the path exists + self.raindrop.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.raindrop.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/cli/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/cli/config.py b/tests/cli/config.py new file mode 100644 index 00000000..defdb211 --- /dev/null +++ b/tests/cli/config.py @@ -0,0 +1,11 @@ +""" +Test configs for CLI tools. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os + + +PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../..")) diff --git a/tests/cli/dev.py b/tests/cli/dev.py new file mode 100644 index 00000000..4387be29 --- /dev/null +++ b/tests/cli/dev.py @@ -0,0 +1,92 @@ +""" +Test cases for the functions and classes in package `pypots.cli.dev`. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import threading +import unittest +from argparse import Namespace +from copy import copy + +import pytest + +from pypots.cli.dev import dev_command_factory +from tests.cli.config import PROJECT_ROOT_DIR + + +def callback_func(): + raise TimeoutError("Time out.") + + +def time_out(interval, callback): + def decorator(func): + def wrapper(*args, **kwargs): + t = threading.Thread(target=func, args=args, kwargs=kwargs) + t.setDaemon(True) + t.start() + t.join(interval) # wait for interval seconds + if t.is_alive(): + return threading.Timer(0, callback).start() # invoke callback() + else: + return + + return wrapper + + return decorator + + +@pytest.mark.xfail(reason="Allow tests for CLI to fail") +class TestPyPOTSCLIDev(unittest.TestCase): + # set up the default arguments + default_arguments = { + "build": False, + "cleanup": False, + "run_tests": False, + "k": None, + "show_coverage": False, + "lint_code": False, + } + # `pypots-cli dev` must run under the project root dir + os.chdir(PROJECT_ROOT_DIR) + + @pytest.mark.xdist_group(name="cli-dev") + def test_0_build(self): + arguments = copy(self.default_arguments) + arguments["build"] = True + args = Namespace(**arguments) + dev_command_factory(args).run() + + @pytest.mark.xdist_group(name="cli-dev") + def test_1_run_tests(self): + arguments = copy(self.default_arguments) + arguments["run_tests"] = True + arguments["k"] = "try_to_find_a_non_existing_test_case" + args = Namespace(**arguments) + try: + dev_command_factory(args).run() + except RuntimeError: # try to find a non-existing test case, so RuntimeError will be raised + pass + except Exception as e: # other exceptions will cause an error and result in failed testing + raise e + + # Don't test --lint-code because Black will reformat the code and cause error when generating the coverage report + # @pytest.mark.xdist_group(name="cli-dev") + # def test_2_lint_code(self): + # arguments = copy(self.default_arguments) + # arguments["lint_code"] = True + # args = Namespace(**arguments) + # dev_command_factory(args).run() + + @pytest.mark.xdist_group(name="cli-dev") + def test_3_cleanup(self): + arguments = copy(self.default_arguments) + arguments["cleanup"] = True + args = Namespace(**arguments) + dev_command_factory(args).run() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cli/doc.py b/tests/cli/doc.py new file mode 100644 index 00000000..85e4e190 --- /dev/null +++ b/tests/cli/doc.py @@ -0,0 +1,104 @@ +""" +Test cases for the functions and classes in package `pypots.cli.doc`. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import threading +import unittest +from argparse import Namespace +from copy import copy + +import pytest + +from pypots.cli.doc import doc_command_factory +from pypots.utils.logging import logger +from tests.cli.config import PROJECT_ROOT_DIR + + +def callback_func(): + raise TimeoutError("Time out.") + + +def time_out(interval, callback): + def decorator(func): + def wrapper(*args, **kwargs): + t = threading.Thread(target=func, args=args, kwargs=kwargs) + t.setDaemon(True) + t.start() + t.join(interval) # wait for interval seconds + if t.is_alive(): + return threading.Timer(0, callback).start() # invoke callback() + else: + return + + return wrapper + + return decorator + + +@pytest.mark.xfail(reason="Allow tests for CLI to fail") +class TestPyPOTSCLIDoc(unittest.TestCase): + # set up the default arguments + default_arguments = { + "gene_rst": False, + "branch": "main", + "gene_html": False, + "view_doc": False, + "port": 9075, + "cleanup": False, + } + # `pypots-cli doc` must run under the project root dir + os.chdir(PROJECT_ROOT_DIR) + + @pytest.mark.xdist_group(name="cli-doc") + def test_0_gene_rst(self): + arguments = copy(self.default_arguments) + arguments["gene_rst"] = True + args = Namespace(**arguments) + doc_command_factory(args).run() + + logger.info("run again under a non-root dir") + try: + os.chdir(os.path.abspath(os.path.join(PROJECT_ROOT_DIR, "pypots"))) + doc_command_factory(args).run() + except RuntimeError: # try to run under a non-root dir, so RuntimeError will be raised + pass + except Exception as e: # other exceptions will cause an error and result in failed testing + raise e + finally: + os.chdir(PROJECT_ROOT_DIR) + + @pytest.mark.xdist_group(name="cli-doc") + def test_1_gene_html(self): + arguments = copy(self.default_arguments) + arguments["gene_html"] = True + args = Namespace(**arguments) + try: + doc_command_factory(args).run() + except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below + logger.error(e) + + @pytest.mark.xdist_group(name="cli-doc") + @time_out(2, callback_func) # wait for two seconds + def test_2_view_doc(self): + arguments = copy(self.default_arguments) + arguments["view_doc"] = True + args = Namespace(**arguments) + try: + doc_command_factory(args).run() + except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below + logger.error(e) + + @pytest.mark.xdist_group(name="cli-doc") + def test_3_cleanup(self): + arguments = copy(self.default_arguments) + arguments["cleanup"] = True + args = Namespace(**arguments) + doc_command_factory(args).run() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cli/env.py b/tests/cli/env.py new file mode 100644 index 00000000..36b5b20e --- /dev/null +++ b/tests/cli/env.py @@ -0,0 +1,49 @@ +""" +Test cases for the functions and classes in package `pypots.cli.env`. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os +import unittest +from argparse import Namespace +from copy import copy + +import pytest + +from pypots.cli.env import env_command_factory +from pypots.utils.logging import logger +from tests.cli.config import PROJECT_ROOT_DIR + + +@pytest.mark.xfail(reason="Allow tests for CLI to fail") +class TestPyPOTSCLIEnv(unittest.TestCase): + # set up the default arguments + default_arguments = { + "install": "optional", + "tool": "conda", + } + + # `pypots-cli env` must run under the project root dir + os.chdir(PROJECT_ROOT_DIR) + + @pytest.mark.xdist_group(name="cli-env") + def test_0_install_with_conda(self): + arguments = copy(self.default_arguments) + arguments["tool"] = "conda" + args = Namespace(**arguments) + try: + env_command_factory(args).run() + except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below + logger.error(e) + + @pytest.mark.xdist_group(name="cli-env") + def test_1_install_with_pip(self): + arguments = copy(self.default_arguments) + arguments["tool"] = "pip" + args = Namespace(**arguments) + try: + env_command_factory(args).run() + except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below + logger.error(e) diff --git a/tests/clustering/__init__.py b/tests/clustering/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/clustering/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/clustering/config.py b/tests/clustering/config.py new file mode 100644 index 00000000..aa43d7dd --- /dev/null +++ b/tests/clustering/config.py @@ -0,0 +1,22 @@ +""" +Test configs for clustering models. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os + +from tests.global_test_config import ( + DATA, + RESULT_SAVING_DIR, +) + + +EPOCHS = 5 + +TRAIN_SET = {"X": DATA["train_X"]} +VAL_SET = {"X": DATA["val_X"]} +TEST_SET = {"X": DATA["test_X"]} + +RESULT_SAVING_DIR_FOR_CLUSTERING = os.path.join(RESULT_SAVING_DIR, "clustering") diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py new file mode 100644 index 00000000..923911fd --- /dev/null +++ b/tests/clustering/crli.py @@ -0,0 +1,103 @@ +""" +Test cases for CRLI clustering model. +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +import os +import unittest + +import pytest + +from pypots.clustering import CRLI +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_rand_index, cal_cluster_purity +from tests.clustering.config import ( + EPOCHS, + TRAIN_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_CLUSTERING, +) +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) + + +class TestCRLI(unittest.TestCase): + logger.info("Running tests for a clustering model CRLI...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "CRLI") + model_save_name = "saved_CRLI_model.pypots" + + # initialize an Adam optimizer + G_optimizer = Adam(lr=0.001, weight_decay=1e-5) + D_optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a CRLI model + crli = CRLI( + n_steps=DATA["n_steps"], + n_features=DATA["n_features"], + n_clusters=DATA["n_classes"], + n_generator_layers=2, + rnn_hidden_size=128, + epochs=EPOCHS, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_0_fit(self): + self.crli.fit(TRAIN_SET) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_1_parameters(self): + assert hasattr(self.crli, "model") and self.crli.model is not None + + assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None + assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None + + assert hasattr(self.crli, "best_loss") + self.assertNotEqual(self.crli.best_loss, float("inf")) + + assert ( + hasattr(self.crli, "best_model_dict") + and self.crli.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="clustering-crli") + def test_2_cluster(self): + clustering = self.crli.cluster(TEST_SET) + RI = cal_rand_index(clustering, DATA["test_y"]) + CP = cal_cluster_purity(clustering, DATA["test_y"]) + logger.info(f"RI: {RI}\nCP: {CP}") + + @pytest.mark.xdist_group(name="clustering-crli") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.crli) + + # save the trained model into file, and check if the path exists + self.crli.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.crli.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_clustering.py b/tests/clustering/vader.py similarity index 51% rename from tests/test_clustering.py rename to tests/clustering/vader.py index bbd4d014..71a6a91d 100644 --- a/tests/test_clustering.py +++ b/tests/clustering/vader.py @@ -1,5 +1,5 @@ """ -Test cases for clustering models. +Test cases for VaDER clustering model. """ # Created by Wenjie Du @@ -12,94 +12,22 @@ import numpy as np import pytest -from pypots.clustering import VaDER, CRLI +from pypots.clustering import VaDER from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import cal_rand_index, cal_cluster_purity +from tests.clustering.config import ( + EPOCHS, + TRAIN_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_CLUSTERING, +) from tests.global_test_config import ( DATA, - RESULT_SAVING_DIR, + DEVICE, check_tb_and_model_checkpoints_existence, ) -EPOCHS = 5 - -TRAIN_SET = {"X": DATA["train_X"]} -VAL_SET = {"X": DATA["val_X"]} -TEST_SET = {"X": DATA["test_X"]} - -RESULT_SAVING_DIR_FOR_CLUSTERING = os.path.join(RESULT_SAVING_DIR, "clustering") - - -class TestCRLI(unittest.TestCase): - logger.info("Running tests for a clustering model CRLI...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "CRLI") - model_save_name = "saved_CRLI_model.pypots" - - # initialize an Adam optimizer - G_optimizer = Adam(lr=0.001, weight_decay=1e-5) - D_optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a CRLI model - crli = CRLI( - n_steps=DATA["n_steps"], - n_features=DATA["n_features"], - n_clusters=DATA["n_classes"], - n_generator_layers=2, - rnn_hidden_size=128, - epochs=EPOCHS, - saving_path=saving_path, - G_optimizer=G_optimizer, - D_optimizer=D_optimizer, - ) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_0_fit(self): - self.crli.fit(TRAIN_SET) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_1_parameters(self): - assert hasattr(self.crli, "model") and self.crli.model is not None - - assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None - assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None - - assert hasattr(self.crli, "best_loss") - self.assertNotEqual(self.crli.best_loss, float("inf")) - - assert ( - hasattr(self.crli, "best_model_dict") - and self.crli.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_2_cluster(self): - clustering = self.crli.cluster(TEST_SET) - RI = cal_rand_index(clustering, DATA["test_y"]) - CP = cal_cluster_purity(clustering, DATA["test_y"]) - logger.info(f"RI: {RI}\nCP: {CP}") - - @pytest.mark.xdist_group(name="clustering-crli") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.crli) - - # save the trained model into file, and check if the path exists - self.crli.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.crli.load_model(saved_model_path) - class TestVaDER(unittest.TestCase): logger.info("Running tests for a clustering model Transformer...") @@ -120,8 +48,9 @@ class TestVaDER(unittest.TestCase): d_mu_stddev=5, pretrain_epochs=20, epochs=EPOCHS, - saving_path=saving_path, optimizer=optimizer, + saving_path=saving_path, + device=DEVICE, ) @pytest.mark.xdist_group(name="clustering-vader") diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/data/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/test_data.py b/tests/data/lazy_loading_strategy.py similarity index 56% rename from tests/test_data.py rename to tests/data/lazy_loading_strategy.py index 27531098..8db1080c 100644 --- a/tests/test_data.py +++ b/tests/data/lazy_loading_strategy.py @@ -8,31 +8,28 @@ import os import unittest -import h5py import pytest from pypots.classification import BRITS, GRUD +from pypots.data.saving import save_dict_into_h5 from pypots.imputation import SAITS -from tests.global_test_config import DATA, DATA_SAVING_DIR from pypots.utils.logging import logger +from tests.global_test_config import DATA, DATA_SAVING_DIR - -TRAIN_SET = f"{DATA_SAVING_DIR}/train_set.h5" -VAL_SET = f"{DATA_SAVING_DIR}/val_set.h5" -TEST_SET = f"{DATA_SAVING_DIR}/test_set.h5" -IMPUTATION_TRAIN_SET = f"{DATA_SAVING_DIR}/imputation_train_set.h5" -IMPUTATION_VAL_SET = f"{DATA_SAVING_DIR}/imputation_val_set.h5" +TRAIN_SET_NAME = "train_set.h5" +TRAIN_SET_PATH = f"{DATA_SAVING_DIR}/{TRAIN_SET_NAME}" +VAL_SET_NAME = "val_set.h5" +VAL_SET_PATH = f"{DATA_SAVING_DIR}/{VAL_SET_NAME}" +TEST_SET_NAME = "test_set.h5" +TEST_SET_PATH = f"{DATA_SAVING_DIR}/{TEST_SET_NAME}" +IMPUTATION_TRAIN_SET_NAME = "imputation_train_set.h5" +IMPUTATION_TRAIN_SET_PATH = f"{DATA_SAVING_DIR}/{IMPUTATION_TRAIN_SET_NAME}" +IMPUTATION_VAL_SET_NAME = "imputation_val_set.h5" +IMPUTATION_VAL_SET_PATH = f"{DATA_SAVING_DIR}/{IMPUTATION_VAL_SET_NAME}" EPOCHS = 1 -def save_data_set_into_h5(data, path): - with h5py.File(path, "w") as hf: - for i in data.keys(): - tp = int if i == "y" else "float32" - hf.create_dataset(i, data=data[i].astype(tp)) - - class TestLazyLoadingClasses(unittest.TestCase): logger.info("Running tests for Dataset classes with lazy-loading strategy...") @@ -73,53 +70,63 @@ def test_0_save_datasets_into_files(self): # create the dir for saving files os.makedirs(DATA_SAVING_DIR, exist_ok=True) - if not os.path.exists(TRAIN_SET): - save_data_set_into_h5( - {"X": DATA["train_X"], "y": DATA["train_y"].astype(int)}, TRAIN_SET + if not os.path.exists(TRAIN_SET_PATH): + save_dict_into_h5( + {"X": DATA["train_X"], "y": DATA["train_y"].astype(float)}, + DATA_SAVING_DIR, + TRAIN_SET_NAME, ) - if not os.path.exists(VAL_SET): - save_data_set_into_h5( - {"X": DATA["val_X"], "y": DATA["val_y"].astype(int)}, VAL_SET + if not os.path.exists(VAL_SET_PATH): + save_dict_into_h5( + {"X": DATA["val_X"], "y": DATA["val_y"].astype(float)}, + DATA_SAVING_DIR, + VAL_SET_NAME, ) - if not os.path.exists(IMPUTATION_TRAIN_SET): - save_data_set_into_h5({"X": DATA["train_X"]}, IMPUTATION_TRAIN_SET) + if not os.path.exists(IMPUTATION_TRAIN_SET_PATH): + save_dict_into_h5( + {"X": DATA["train_X"]}, DATA_SAVING_DIR, IMPUTATION_TRAIN_SET_NAME + ) - if not os.path.exists(IMPUTATION_VAL_SET): - save_data_set_into_h5( + if not os.path.exists(IMPUTATION_VAL_SET_PATH): + save_dict_into_h5( { "X": DATA["val_X"], "X_intact": DATA["val_X_intact"], "indicating_mask": DATA["val_X_indicating_mask"], }, - IMPUTATION_VAL_SET, + DATA_SAVING_DIR, + IMPUTATION_VAL_SET_NAME, ) - if not os.path.exists(TEST_SET): - save_data_set_into_h5( + if not os.path.exists(TEST_SET_PATH): + save_dict_into_h5( { "X": DATA["test_X"], "X_intact": DATA["test_X_intact"], "indicating_mask": DATA["test_X_indicating_mask"], }, - TEST_SET, + DATA_SAVING_DIR, + TEST_SET_NAME, ) @pytest.mark.xdist_group(name="data-lazy-loading") def test_1_DatasetForMIT_BaseDataset(self): - self.saits.fit(train_set=IMPUTATION_TRAIN_SET, val_set=IMPUTATION_VAL_SET) - _ = self.saits.impute(X=TEST_SET) + self.saits.fit( + train_set=IMPUTATION_TRAIN_SET_PATH, val_set=IMPUTATION_VAL_SET_PATH + ) + _ = self.saits.impute(X=TEST_SET_PATH) @pytest.mark.xdist_group(name="data-lazy-loading") def test_2_DatasetForBRITS(self): - self.brits.fit(train_set=TRAIN_SET, val_set=VAL_SET) - _ = self.brits.classify(X=TEST_SET) + self.brits.fit(train_set=TRAIN_SET_PATH, val_set=VAL_SET_PATH) + _ = self.brits.classify(X=TEST_SET_PATH) @pytest.mark.xdist_group(name="data-lazy-loading") def test_3_DatasetForGRUD(self): - self.grud.fit(train_set=TRAIN_SET, val_set=VAL_SET) - _ = self.grud.classify(X=TEST_SET) + self.grud.fit(train_set=TRAIN_SET_PATH, val_set=VAL_SET_PATH) + _ = self.grud.classify(X=TEST_SET_PATH) if __name__ == "__main__": diff --git a/tests/forecasting/__init__.py b/tests/forecasting/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/forecasting/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/test_forecasting.py b/tests/forecasting/bttf.py similarity index 78% rename from tests/test_forecasting.py rename to tests/forecasting/bttf.py index d2e8e14b..8e6946e7 100644 --- a/tests/test_forecasting.py +++ b/tests/forecasting/bttf.py @@ -1,5 +1,5 @@ """ -Test cases for forecasting models. +Test cases for BTTF forecasting model. """ # Created by Wenjie Du @@ -12,12 +12,13 @@ from pypots.forecasting import BTTF from pypots.utils.logging import logger from pypots.utils.metrics import cal_mae +from tests.forecasting.config import ( + TEST_SET, + TEST_SET_INTACT, + N_PRED_STEP, +) from tests.global_test_config import DATA -EPOCHS = 5 -N_PRED_STEP = 4 -TEST_SET = {"X": DATA["test_X"][:, :-N_PRED_STEP]} - class TestBTTF(unittest.TestCase): logger.info("Running tests for a forecasting model BTTF...") @@ -37,8 +38,7 @@ class TestBTTF(unittest.TestCase): @pytest.mark.xdist_group(name="forecasting-bttf") def test_0_forecasting(self): predictions = self.bttf.forecast(TEST_SET) - logger.info(f"prediction shape: {predictions.shape}") - mae = cal_mae(predictions, DATA["test_X_intact"][:, -N_PRED_STEP:]) + mae = cal_mae(predictions, TEST_SET_INTACT["X"][:, -N_PRED_STEP:]) logger.info(f"prediction MAE: {mae}") diff --git a/tests/forecasting/config.py b/tests/forecasting/config.py new file mode 100644 index 00000000..0a2a9e78 --- /dev/null +++ b/tests/forecasting/config.py @@ -0,0 +1,23 @@ +""" +Test configs for forecasting models. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os + +from tests.global_test_config import ( + DATA, + RESULT_SAVING_DIR, +) + +EPOCHS = 5 +N_PRED_STEP = 4 + +TRAIN_SET = {"X": DATA["train_X"]} +VAL_SET = {"X": DATA["val_X"]} +TEST_SET = {"X": DATA["test_X"][:, :-N_PRED_STEP]} +TEST_SET_INTACT = {"X": DATA["test_X_intact"]} + +RESULT_SAVING_DIR_FOR_CLASSIFICATION = os.path.join(RESULT_SAVING_DIR, "forecasting") diff --git a/tests/global_test_config.py b/tests/global_test_config.py index f3349483..5e152734 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -7,7 +7,10 @@ import os +import torch + from pypots.data.generating import gene_incomplete_random_walk_dataset +from pypots.utils.logging import logger # Generate the unified data for testing and cache it first, DATA here is a singleton # Otherwise, file lock will cause bug if running test parallely with pytest-xdist. @@ -20,6 +23,16 @@ RESULT_SAVING_DIR = "testing_results" +# set DEVICES to None if no cuda device is available, to avoid initialization failed while importing test classes +cuda_devices = [torch.device(i) for i in range(torch.cuda.device_count())] +if len(cuda_devices) > 2: + logger.info("❗️Detected multiple cuda devices, using all of them to run testing.") + DEVICE = cuda_devices +else: + # if having no multiple cuda devices, leave it as None to use the default device + DEVICE = None + + def check_tb_and_model_checkpoints_existence(model): # check the tensorboard file existence saved_files = os.listdir(model.saving_path) diff --git a/tests/imputation/__init__.py b/tests/imputation/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/imputation/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/imputation/brits.py b/tests/imputation/brits.py new file mode 100644 index 00000000..bf0a70c3 --- /dev/null +++ b/tests/imputation/brits.py @@ -0,0 +1,104 @@ +""" +Test cases for BRITS imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import BRITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestBRITS(unittest.TestCase): + logger.info("Running tests for an imputation model BRITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "BRITS") + model_save_name = "saved_BRITS_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a BRITS model + brits = BRITS( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_0_fit(self): + self.brits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_1_impute(self): + imputed_X = self.brits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"BRITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-brits") + def test_2_parameters(self): + assert hasattr(self.brits, "model") and self.brits.model is not None + + assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None + + assert hasattr(self.brits, "best_loss") + self.assertNotEqual(self.brits.best_loss, float("inf")) + + assert ( + hasattr(self.brits, "best_model_dict") + and self.brits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-brits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.brits) + + # save the trained model into file, and check if the path exists + self.brits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.brits.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/config.py b/tests/imputation/config.py new file mode 100644 index 00000000..c225598b --- /dev/null +++ b/tests/imputation/config.py @@ -0,0 +1,25 @@ +""" +Test configs for imputation models. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import os + +from tests.global_test_config import ( + DATA, + RESULT_SAVING_DIR, +) + +EPOCHS = 5 + +TRAIN_SET = {"X": DATA["train_X"]} +VAL_SET = { + "X": DATA["val_X"], + "X_intact": DATA["val_X_intact"], + "indicating_mask": DATA["val_X_indicating_mask"], +} +TEST_SET = {"X": DATA["test_X"]} + +RESULT_SAVING_DIR_FOR_IMPUTATION = os.path.join(RESULT_SAVING_DIR, "imputation") diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py new file mode 100644 index 00000000..9c59c5b2 --- /dev/null +++ b/tests/imputation/gpvae.py @@ -0,0 +1,104 @@ +""" +Test cases for GP-VAE imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import GPVAE +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestGPVAE(unittest.TestCase): + logger.info("Running tests for an imputation model GP-VAE...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "GP-VAE") + model_save_name = "saved_GPVAE_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a GP-VAE model + gp_vae = GPVAE( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_0_fit(self): + self.gp_vae.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_1_impute(self): + imputed_X = self.gp_vae.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"GP-VAE test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_2_parameters(self): + assert hasattr(self.gp_vae, "model") and self.gp_vae.model is not None + + assert hasattr(self.gp_vae, "optimizer") and self.gp_vae.optimizer is not None + + assert hasattr(self.gp_vae, "best_loss") + self.assertNotEqual(self.gp_vae.best_loss, float("inf")) + + assert ( + hasattr(self.gp_vae, "best_model_dict") + and self.gp_vae.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.gp_vae) + + # save the trained model into file, and check if the path exists + self.gp_vae.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.gp_vae.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/locf.py b/tests/imputation/locf.py new file mode 100644 index 00000000..8e54fbe0 --- /dev/null +++ b/tests/imputation/locf.py @@ -0,0 +1,46 @@ +""" +Test cases for LOCF imputation method. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import LOCF +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, +) +from tests.imputation.config import ( + TEST_SET, +) + + +class TestLOCF(unittest.TestCase): + logger.info("Running tests for an imputation model LOCF...") + locf = LOCF(nan=0) + + @pytest.mark.xdist_group(name="imputation-locf") + def test_0_impute(self): + test_X_imputed = self.locf.impute(TEST_SET) + assert not np.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + test_X_imputed, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"LOCF test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-locf") + def test_1_parameters(self): + assert hasattr(self.locf, "nan") and self.locf.nan is not None + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/mrnn.py b/tests/imputation/mrnn.py new file mode 100644 index 00000000..681a9121 --- /dev/null +++ b/tests/imputation/mrnn.py @@ -0,0 +1,104 @@ +""" +Test cases for MRNN imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import MRNN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestMRNN(unittest.TestCase): + logger.info("Running tests for an imputation model MRNN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MRNN") + model_save_name = "saved_MRNN_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a MRNN model + mrnn = MRNN( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_0_fit(self): + self.mrnn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_1_impute(self): + imputed_X = self.mrnn.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"MRNN test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_2_parameters(self): + assert hasattr(self.mrnn, "model") and self.mrnn.model is not None + + assert hasattr(self.mrnn, "optimizer") and self.mrnn.optimizer is not None + + assert hasattr(self.mrnn, "best_loss") + self.assertNotEqual(self.mrnn.best_loss, float("inf")) + + assert ( + hasattr(self.mrnn, "best_model_dict") + and self.mrnn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.mrnn) + + # save the trained model into file, and check if the path exists + self.mrnn.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.mrnn.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/saits.py b/tests/imputation/saits.py new file mode 100644 index 00000000..647e8657 --- /dev/null +++ b/tests/imputation/saits.py @@ -0,0 +1,110 @@ +""" +Test cases for SAITS imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestSAITS(unittest.TestCase): + logger.info("Running tests for an imputation model SAITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "SAITS") + model_save_name = "saved_saits_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=256, + d_inner=128, + n_heads=4, + d_k=64, + d_v=64, + dropout=0.1, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_1_impute(self): + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-saits") + def test_2_parameters(self): + assert hasattr(self.saits, "model") and self.saits.model is not None + + assert hasattr(self.saits, "optimizer") and self.saits.optimizer is not None + + assert hasattr(self.saits, "best_loss") + self.assertNotEqual(self.saits.best_loss, float("inf")) + + assert ( + hasattr(self.saits, "best_model_dict") + and self.saits.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-saits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.saits) + + # save the trained model into file, and check if the path exists + self.saits.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.saits.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/transformer.py b/tests/imputation/transformer.py new file mode 100644 index 00000000..965b2cf7 --- /dev/null +++ b/tests/imputation/transformer.py @@ -0,0 +1,113 @@ +""" +Test cases for Transformer imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import Transformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestTransformer(unittest.TestCase): + logger.info("Running tests for an imputation model Transformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Transformer") + model_save_name = "saved_transformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Transformer model + transformer = Transformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=256, + d_inner=128, + n_heads=4, + d_k=64, + d_v=64, + dropout=0.1, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_0_fit(self): + self.transformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_1_impute(self): + imputed_X = self.transformer.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"Transformer test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_2_parameters(self): + assert hasattr(self.transformer, "model") and self.transformer.model is not None + + assert ( + hasattr(self.transformer, "optimizer") + and self.transformer.optimizer is not None + ) + + assert hasattr(self.transformer, "best_loss") + self.assertNotEqual(self.transformer.best_loss, float("inf")) + + assert ( + hasattr(self.transformer, "best_model_dict") + and self.transformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-transformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.transformer) + + # save the trained model into file, and check if the path exists + self.transformer.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.transformer.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/usgan.py b/tests/imputation/usgan.py new file mode 100644 index 00000000..c91a17a1 --- /dev/null +++ b/tests/imputation/usgan.py @@ -0,0 +1,111 @@ +""" +Test cases for US-GAN imputation model. +""" + +# Created by Wenjie Du +# License: GPL-v3 + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import USGAN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import ( + DATA, + DEVICE, + check_tb_and_model_checkpoints_existence, +) +from tests.imputation.config import ( + TRAIN_SET, + VAL_SET, + TEST_SET, + RESULT_SAVING_DIR_FOR_IMPUTATION, + EPOCHS, +) + + +class TestUSGAN(unittest.TestCase): + logger.info("Running tests for an imputation model US-GAN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "US-GAN") + model_save_name = "saved_USGAN_model.pypots" + + # initialize an Adam optimizer + G_optimizer = Adam(lr=0.001, weight_decay=1e-5) + D_optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a US-GAN model + us_gan = USGAN( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCHS, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_0_fit(self): + self.us_gan.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_1_impute(self): + imputed_X = self.us_gan.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"US-GAN test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_2_parameters(self): + assert hasattr(self.us_gan, "model") and self.us_gan.model is not None + + assert ( + hasattr(self.us_gan, "G_optimizer") and self.us_gan.G_optimizer is not None + ) + assert ( + hasattr(self.us_gan, "D_optimizer") and self.us_gan.D_optimizer is not None + ) + + assert hasattr(self.us_gan, "best_loss") + self.assertNotEqual(self.us_gan.best_loss, float("inf")) + + assert ( + hasattr(self.us_gan, "best_model_dict") + and self.us_gan.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.us_gan) + + # save the trained model into file, and check if the path exists + self.us_gan.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.us_gan.load_model(saved_model_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/__init__.py b/tests/optim/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/optim/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/optim/adadelta.py b/tests/optim/adadelta.py new file mode 100644 index 00000000..b69e5ea4 --- /dev/null +++ b/tests/optim/adadelta.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer Adadelta. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adadelta +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestAdadelta(unittest.TestCase): + logger.info("Running tests for Adadelta...") + + # initialize an Adadelta optimizer + adadelta = Adadelta(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adadelta, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-adadelta") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/adagrad.py b/tests/optim/adagrad.py new file mode 100644 index 00000000..21b4696a --- /dev/null +++ b/tests/optim/adagrad.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer Adagrad. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adagrad +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestAdagrad(unittest.TestCase): + logger.info("Running tests for Adagrad...") + + # initialize an Adagrad optimizer + adagrad = Adagrad(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adagrad, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-adagrad") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/adam.py b/tests/optim/adam.py new file mode 100644 index 00000000..448f92b9 --- /dev/null +++ b/tests/optim/adam.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer Adam. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestAdam(unittest.TestCase): + logger.info("Running tests for Adam...") + + # initialize an Adam optimizer + adam = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adam, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-adam") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/adamw.py b/tests/optim/adamw.py new file mode 100644 index 00000000..a7941f43 --- /dev/null +++ b/tests/optim/adamw.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer AdamW. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import AdamW +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestAdamW(unittest.TestCase): + logger.info("Running tests for AdamW...") + + # initialize an AdamW optimizer + adamw = AdamW(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-adamw") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/config.py b/tests/optim/config.py new file mode 100644 index 00000000..a0391027 --- /dev/null +++ b/tests/optim/config.py @@ -0,0 +1,19 @@ +""" +Test configs for optimizers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from tests.global_test_config import DATA + +TRAIN_SET = {"X": DATA["train_X"]} +VAL_SET = { + "X": DATA["val_X"], + "X_intact": DATA["val_X_intact"], + "indicating_mask": DATA["val_X_indicating_mask"], +} +TEST_SET = {"X": DATA["test_X"]} + + +EPOCHS = 1 diff --git a/tests/optim/rmsprop.py b/tests/optim/rmsprop.py new file mode 100644 index 00000000..1fe61a0d --- /dev/null +++ b/tests/optim/rmsprop.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer RMSprop. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import RMSprop +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestRMSprop(unittest.TestCase): + logger.info("Running tests for RMSprop...") + + # initialize a RMSprop optimizer + rmsprop = RMSprop(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=rmsprop, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-rmsprop") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/optim/sgd.py b/tests/optim/sgd.py new file mode 100644 index 00000000..4b1c1998 --- /dev/null +++ b/tests/optim/sgd.py @@ -0,0 +1,56 @@ +""" +Test cases for the optimizer SGD. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import SGD +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestSGD(unittest.TestCase): + logger.info("Running tests for SGD...") + + # initialize a SGD optimizer + sgd = SGD(lr=0.001, weight_decay=1e-5) + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + + @pytest.mark.xdist_group(name="optim-sgd") + def test_0_fit(self): + self.saits.fit(TRAIN_SET, VAL_SET) + imputed_X = self.saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_classification.py b/tests/test_classification.py deleted file mode 100644 index 2ef9c6d1..00000000 --- a/tests/test_classification.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Test cases for classification models. -""" - -# Created by Wenjie Du -# License: GLP-v3 - -import os -import unittest - -import pytest - -from pypots.classification import BRITS, GRUD, Raindrop -from pypots.optim import Adam -from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics -from tests.global_test_config import ( - DATA, - RESULT_SAVING_DIR, - check_tb_and_model_checkpoints_existence, -) - -EPOCHS = 5 - -TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]} -VAL_SET = {"X": DATA["val_X"], "y": DATA["val_y"]} -TEST_SET = {"X": DATA["test_X"]} - -RESULT_SAVING_DIR_FOR_CLASSIFICATION = os.path.join(RESULT_SAVING_DIR, "classification") - - -class TestBRITS(unittest.TestCase): - logger.info("Running tests for a classification model BRITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "BRITS") - model_save_name = "saved_BRITS_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a BRITS model - brits = BRITS( - DATA["n_steps"], - DATA["n_features"], - n_classes=DATA["n_classes"], - rnn_hidden_size=256, - epochs=EPOCHS, - saving_path=saving_path, - model_saving_strategy="better", - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="classification-brits") - def test_0_fit(self): - self.brits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-brits") - def test_1_classify(self): - predictions = self.brits.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-brits") - def test_2_parameters(self): - assert hasattr(self.brits, "model") and self.brits.model is not None - - assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None - - assert hasattr(self.brits, "best_loss") - self.assertNotEqual(self.brits.best_loss, float("inf")) - - assert ( - hasattr(self.brits, "best_model_dict") - and self.brits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-brits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.brits) - - # save the trained model into file, and check if the path exists - self.brits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.brits.load_model(saved_model_path) - - -class TestGRUD(unittest.TestCase): - logger.info("Running tests for a classification model GRUD...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "GRUD") - model_save_name = "saved_GRUD_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a GRUD model - grud = GRUD( - DATA["n_steps"], - DATA["n_features"], - n_classes=DATA["n_classes"], - rnn_hidden_size=256, - epochs=EPOCHS, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="classification-grud") - def test_0_fit(self): - self.grud.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-grud") - def test_1_classify(self): - predictions = self.grud.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-grud") - def test_2_parameters(self): - assert hasattr(self.grud, "model") and self.grud.model is not None - - assert hasattr(self.grud, "optimizer") and self.grud.optimizer is not None - - assert hasattr(self.grud, "best_loss") - self.assertNotEqual(self.grud.best_loss, float("inf")) - - assert ( - hasattr(self.grud, "best_model_dict") - and self.grud.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-grud") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.grud) - - # save the trained model into file, and check if the path exists - self.grud.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.grud.load_model(saved_model_path) - - -class TestRaindrop(unittest.TestCase): - logger.info("Running tests for a classification model Raindrop...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "Raindrop") - model_save_name = "saved_Raindrop_model.pypots" - - # initialize a Raindrop model - raindrop = Raindrop( - DATA["n_steps"], - DATA["n_features"], - DATA["n_classes"], - n_layers=2, - d_model=DATA["n_features"] * 4, - d_inner=256, - n_heads=2, - dropout=0.3, - d_static=0, - aggregation="mean", - sensor_wise_mask=False, - static=False, - epochs=EPOCHS, - saving_path=saving_path, - ) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_0_fit(self): - self.raindrop.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_1_classify(self): - predictions = self.raindrop.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_2_parameters(self): - assert hasattr(self.raindrop, "model") and self.raindrop.model is not None - - assert ( - hasattr(self.raindrop, "optimizer") and self.raindrop.optimizer is not None - ) - - assert hasattr(self.raindrop, "best_loss") - self.assertNotEqual(self.raindrop.best_loss, float("inf")) - - assert ( - hasattr(self.raindrop, "best_model_dict") - and self.raindrop.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.raindrop) - - # save the trained model into file, and check if the path exists - self.raindrop.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.raindrop.load_model(saved_model_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index 4e9e9927..00000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -Test cases for the functions and classes in package `pypots.cli`. -""" - -# Created by Wenjie Du -# License: GLP-v3 - -import os -import threading -import unittest -from argparse import Namespace -from copy import copy - -import pytest - -from pypots.cli.dev import dev_command_factory -from pypots.cli.doc import doc_command_factory -from pypots.cli.env import env_command_factory -from pypots.utils.logging import logger - -PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../..")) - - -def callback_func(): - raise TimeoutError("Time out.") - - -def time_out(interval, callback): - def decorator(func): - def wrapper(*args, **kwargs): - t = threading.Thread(target=func, args=args, kwargs=kwargs) - t.setDaemon(True) - t.start() - t.join(interval) # wait for interval seconds - if t.is_alive(): - return threading.Timer(0, callback).start() # invoke callback() - else: - return - - return wrapper - - return decorator - - -@pytest.mark.xfail(reason="Allow tests for CLI to fail") -class TestPyPOTSCLIDev(unittest.TestCase): - # set up the default arguments - default_arguments = { - "build": False, - "cleanup": False, - "run_tests": False, - "k": None, - "show_coverage": False, - "lint_code": False, - } - # `pypots-cli dev` must run under the project root dir - os.chdir(PROJECT_ROOT_DIR) - - @pytest.mark.xdist_group(name="cli-dev") - def test_0_build(self): - arguments = copy(self.default_arguments) - arguments["build"] = True - args = Namespace(**arguments) - dev_command_factory(args).run() - - @pytest.mark.xdist_group(name="cli-dev") - def test_1_run_tests(self): - arguments = copy(self.default_arguments) - arguments["run_tests"] = True - arguments["k"] = "try_to_find_a_non_existing_test_case" - args = Namespace(**arguments) - try: - dev_command_factory(args).run() - except RuntimeError: # try to find a non-existing test case, so RuntimeError will be raised - pass - except Exception as e: # other exceptions will cause an error and result in failed testing - raise e - - # Don't test --lint-code because Black will reformat the code and cause error when generating the coverage report - # @pytest.mark.xdist_group(name="cli-dev") - # def test_2_lint_code(self): - # arguments = copy(self.default_arguments) - # arguments["lint_code"] = True - # args = Namespace(**arguments) - # dev_command_factory(args).run() - - @pytest.mark.xdist_group(name="cli-dev") - def test_3_cleanup(self): - arguments = copy(self.default_arguments) - arguments["cleanup"] = True - args = Namespace(**arguments) - dev_command_factory(args).run() - - -@pytest.mark.xfail(reason="Allow tests for CLI to fail") -class TestPyPOTSCLIDoc(unittest.TestCase): - # set up the default arguments - default_arguments = { - "gene_rst": False, - "branch": "main", - "gene_html": False, - "view_doc": False, - "port": 9075, - "cleanup": False, - } - # `pypots-cli doc` must run under the project root dir - os.chdir(PROJECT_ROOT_DIR) - - @pytest.mark.xdist_group(name="cli-doc") - def test_0_gene_rst(self): - arguments = copy(self.default_arguments) - arguments["gene_rst"] = True - args = Namespace(**arguments) - doc_command_factory(args).run() - - logger.info("run again under a non-root dir") - try: - os.chdir(os.path.abspath(os.path.join(PROJECT_ROOT_DIR, "pypots"))) - doc_command_factory(args).run() - except RuntimeError: # try to run under a non-root dir, so RuntimeError will be raised - pass - except Exception as e: # other exceptions will cause an error and result in failed testing - raise e - finally: - os.chdir(PROJECT_ROOT_DIR) - - @pytest.mark.xdist_group(name="cli-doc") - def test_1_gene_html(self): - arguments = copy(self.default_arguments) - arguments["gene_html"] = True - args = Namespace(**arguments) - try: - doc_command_factory(args).run() - except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below - logger.error(e) - - @pytest.mark.xdist_group(name="cli-doc") - @time_out(2, callback_func) # wait for two seconds - def test_2_view_doc(self): - arguments = copy(self.default_arguments) - arguments["view_doc"] = True - args = Namespace(**arguments) - try: - doc_command_factory(args).run() - except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below - logger.error(e) - - @pytest.mark.xdist_group(name="cli-doc") - def test_3_cleanup(self): - arguments = copy(self.default_arguments) - arguments["cleanup"] = True - args = Namespace(**arguments) - doc_command_factory(args).run() - - -@pytest.mark.xfail(reason="Allow tests for CLI to fail") -class TestPyPOTSCLIEnv(unittest.TestCase): - # set up the default arguments - default_arguments = { - "install": "optional", - "tool": "conda", - } - - # `pypots-cli env` must run under the project root dir - os.chdir(PROJECT_ROOT_DIR) - - @pytest.mark.xdist_group(name="cli-env") - def test_0_install_with_conda(self): - arguments = copy(self.default_arguments) - arguments["tool"] = "conda" - args = Namespace(**arguments) - try: - env_command_factory(args).run() - except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below - logger.error(e) - - @pytest.mark.xdist_group(name="cli-env") - def test_1_install_with_pip(self): - arguments = copy(self.default_arguments) - arguments["tool"] = "pip" - args = Namespace(**arguments) - try: - env_command_factory(args).run() - except Exception as e: # somehow we have some error when testing on Windows, so just print and pass below - logger.error(e) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_imputation.py b/tests/test_imputation.py deleted file mode 100644 index 64a0b1ff..00000000 --- a/tests/test_imputation.py +++ /dev/null @@ -1,503 +0,0 @@ -""" -Test cases for imputation models. -""" - -# Created by Wenjie Du -# License: GPL-v3 - - -import os.path -import unittest - -import numpy as np -import pytest - -from pypots.imputation import ( - SAITS, - Transformer, - USGAN, - GPVAE, - BRITS, - MRNN, - LOCF, -) -from pypots.optim import Adam -from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae -from tests.global_test_config import ( - DATA, - RESULT_SAVING_DIR, - check_tb_and_model_checkpoints_existence, -) - -EPOCH = 5 - -TRAIN_SET = {"X": DATA["train_X"]} -VAL_SET = { - "X": DATA["val_X"], - "X_intact": DATA["val_X_intact"], - "indicating_mask": DATA["val_X_indicating_mask"], -} -TEST_SET = {"X": DATA["test_X"]} - -RESULT_SAVING_DIR_FOR_IMPUTATION = os.path.join(RESULT_SAVING_DIR, "imputation") - - -class TestSAITS(unittest.TestCase): - logger.info("Running tests for an imputation model SAITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "SAITS") - model_save_name = "saved_saits_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=2, - d_model=256, - d_inner=128, - n_heads=4, - d_k=64, - d_v=64, - dropout=0.1, - epochs=EPOCH, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_1_impute(self): - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-saits") - def test_2_parameters(self): - assert hasattr(self.saits, "model") and self.saits.model is not None - - assert hasattr(self.saits, "optimizer") and self.saits.optimizer is not None - - assert hasattr(self.saits, "best_loss") - self.assertNotEqual(self.saits.best_loss, float("inf")) - - assert ( - hasattr(self.saits, "best_model_dict") - and self.saits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.saits) - - # save the trained model into file, and check if the path exists - self.saits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.saits.load_model(saved_model_path) - - -class TestTransformer(unittest.TestCase): - logger.info("Running tests for an imputation model Transformer...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Transformer") - model_save_name = "saved_transformer_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a Transformer model - transformer = Transformer( - DATA["n_steps"], - DATA["n_features"], - n_layers=2, - d_model=256, - d_inner=128, - n_heads=4, - d_k=64, - d_v=64, - dropout=0.1, - epochs=EPOCH, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_0_fit(self): - self.transformer.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_1_impute(self): - imputed_X = self.transformer.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"Transformer test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_2_parameters(self): - assert hasattr(self.transformer, "model") and self.transformer.model is not None - - assert ( - hasattr(self.transformer, "optimizer") - and self.transformer.optimizer is not None - ) - - assert hasattr(self.transformer, "best_loss") - self.assertNotEqual(self.transformer.best_loss, float("inf")) - - assert ( - hasattr(self.transformer, "best_model_dict") - and self.transformer.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.transformer) - - # save the trained model into file, and check if the path exists - self.transformer.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.transformer.load_model(saved_model_path) - - -class TestUSGAN(unittest.TestCase): - logger.info("Running tests for an imputation model US-GAN...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "US-GAN") - model_save_name = "saved_USGAN_model.pypots" - - # initialize an Adam optimizer - G_optimizer = Adam(lr=0.001, weight_decay=1e-5) - D_optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a US-GAN model - us_gan = USGAN( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCH, - saving_path=saving_path, - G_optimizer=G_optimizer, - D_optimizer=D_optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-usgan") - def test_0_fit(self): - self.us_gan.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-usgan") - def test_1_impute(self): - imputed_X = self.us_gan.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"US-GAN test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-usgan") - def test_2_parameters(self): - assert hasattr(self.us_gan, "model") and self.us_gan.model is not None - - assert ( - hasattr(self.us_gan, "G_optimizer") and self.us_gan.G_optimizer is not None - ) - assert ( - hasattr(self.us_gan, "D_optimizer") and self.us_gan.D_optimizer is not None - ) - - assert hasattr(self.us_gan, "best_loss") - self.assertNotEqual(self.us_gan.best_loss, float("inf")) - - assert ( - hasattr(self.us_gan, "best_model_dict") - and self.us_gan.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-usgan") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.us_gan) - - # save the trained model into file, and check if the path exists - self.us_gan.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.us_gan.load_model(saved_model_path) - - -class TestGPVAE(unittest.TestCase): - logger.info("Running tests for an imputation model GP-VAE...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "GP-VAE") - model_save_name = "saved_GPVAE_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a GP-VAE model - gp_vae = GPVAE( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCH, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-gpvae") - def test_0_fit(self): - self.gp_vae.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-gpvae") - def test_1_impute(self): - imputed_X = self.gp_vae.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"GP-VAE test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-gpvae") - def test_2_parameters(self): - assert hasattr(self.gp_vae, "model") and self.gp_vae.model is not None - - assert hasattr(self.gp_vae, "optimizer") and self.gp_vae.optimizer is not None - - assert hasattr(self.gp_vae, "best_loss") - self.assertNotEqual(self.gp_vae.best_loss, float("inf")) - - assert ( - hasattr(self.gp_vae, "best_model_dict") - and self.gp_vae.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-GPVAE") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.gp_vae) - - # save the trained model into file, and check if the path exists - self.gp_vae.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.gp_vae.load_model(saved_model_path) - - -class TestBRITS(unittest.TestCase): - logger.info("Running tests for an imputation model BRITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "BRITS") - model_save_name = "saved_BRITS_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a BRITS model - brits = BRITS( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCH, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_0_fit(self): - self.brits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_1_impute(self): - imputed_X = self.brits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"BRITS test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-brits") - def test_2_parameters(self): - assert hasattr(self.brits, "model") and self.brits.model is not None - - assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None - - assert hasattr(self.brits, "best_loss") - self.assertNotEqual(self.brits.best_loss, float("inf")) - - assert ( - hasattr(self.brits, "best_model_dict") - and self.brits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.brits) - - # save the trained model into file, and check if the path exists - self.brits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.brits.load_model(saved_model_path) - - -class TestMRNN(unittest.TestCase): - logger.info("Running tests for an imputation model MRNN...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MRNN") - model_save_name = "saved_MRNN_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a MRNN model - mrnn = MRNN( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCH, - saving_path=saving_path, - optimizer=optimizer, - ) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_0_fit(self): - self.mrnn.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_1_impute(self): - imputed_X = self.mrnn.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"MRNN test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_2_parameters(self): - assert hasattr(self.mrnn, "model") and self.mrnn.model is not None - - assert hasattr(self.mrnn, "optimizer") and self.mrnn.optimizer is not None - - assert hasattr(self.mrnn, "best_loss") - self.assertNotEqual(self.mrnn.best_loss, float("inf")) - - assert ( - hasattr(self.mrnn, "best_model_dict") - and self.mrnn.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.mrnn) - - # save the trained model into file, and check if the path exists - self.mrnn.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.mrnn.load_model(saved_model_path) - - -class TestLOCF(unittest.TestCase): - logger.info("Running tests for an imputation model LOCF...") - locf = LOCF(nan=0) - - @pytest.mark.xdist_group(name="imputation-locf") - def test_0_impute(self): - test_X_imputed = self.locf.impute(TEST_SET) - assert not np.isnan( - test_X_imputed - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - test_X_imputed, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"LOCF test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-locf") - def test_1_parameters(self): - assert hasattr(self.locf, "nan") and self.locf.nan is not None - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_optim.py b/tests/test_optim.py deleted file mode 100644 index 9be096fb..00000000 --- a/tests/test_optim.py +++ /dev/null @@ -1,244 +0,0 @@ -""" -Test cases for optimizers. -""" - -# Created by Wenjie Du -# License: GLP-v3 - -import unittest - -import h5py -import numpy as np -import pytest - -from pypots.imputation import SAITS -from pypots.optim import Adam, AdamW, Adagrad, Adadelta, SGD, RMSprop -from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae -from tests.global_test_config import DATA - -TRAIN_SET = {"X": DATA["train_X"]} -VAL_SET = { - "X": DATA["val_X"], - "X_intact": DATA["val_X_intact"], - "indicating_mask": DATA["val_X_indicating_mask"], -} -TEST_SET = {"X": DATA["test_X"]} - - -EPOCHS = 3 - - -def save_data_set_into_h5(data, path): - with h5py.File(path, "w") as hf: - for i in data.keys(): - tp = int if i == "y" else "float32" - hf.create_dataset(i, data=data[i].astype(tp)) - - -class TestAdam(unittest.TestCase): - logger.info("Running tests for Adam...") - - # initialize an Adam optimizer - adam = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=adam, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-adam") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -class TestAdamW(unittest.TestCase): - logger.info("Running tests for AdamW...") - - # initialize an AdamW optimizer - adamw = AdamW(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=adamw, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-adamw") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -class TestAdagrad(unittest.TestCase): - logger.info("Running tests for Adagrad...") - - # initialize an Adagrad optimizer - adagrad = Adagrad(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=adagrad, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-adagrad") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -class TestAdadelta(unittest.TestCase): - logger.info("Running tests for Adadelta...") - - # initialize an Adadelta optimizer - adadelta = Adadelta(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=adadelta, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-adadelta") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -class TestSGD(unittest.TestCase): - logger.info("Running tests for SGD...") - - # initialize a SGD optimizer - sgd = SGD(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=sgd, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-sgd") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -class TestRMSprop(unittest.TestCase): - logger.info("Running tests for RMSprop...") - - # initialize a RMSprop optimizer - rmsprop = RMSprop(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model for testing DatasetForMIT and BaseDataset - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=1, - d_model=128, - d_inner=64, - n_heads=2, - d_k=64, - d_v=64, - dropout=0.1, - optimizer=rmsprop, - epochs=EPOCHS, - ) - - @pytest.mark.xdist_group(name="optim-rmsprop") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_training_on_multi_gpus.py b/tests/test_training_on_multi_gpus.py deleted file mode 100644 index b076cbfe..00000000 --- a/tests/test_training_on_multi_gpus.py +++ /dev/null @@ -1,783 +0,0 @@ -""" -Test cases for running models on multi cuda devices. -""" - -# Created by Wenjie Du -# License: GPL-v3 - - -import os.path -import unittest - -import numpy as np -import pytest -import torch - -from pypots.classification import BRITS, GRUD, Raindrop -from pypots.clustering import VaDER, CRLI -from pypots.forecasting import BTTF -from pypots.imputation import BRITS as ImputationBRITS -from pypots.imputation import ( - SAITS, - Transformer, - MRNN, - LOCF, -) -from pypots.optim import Adam -from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics -from pypots.utils.metrics import cal_mae -from pypots.utils.metrics import cal_rand_index, cal_cluster_purity -from tests.global_test_config import ( - DATA, - RESULT_SAVING_DIR, - check_tb_and_model_checkpoints_existence, -) - -EPOCHS = 5 - -cuda_devices = [torch.device(i) for i in range(torch.cuda.device_count())] - -# set DEVICES to None if no cuda device is available, to avoid initialization failed while importing test classes -DEVICES = None if cuda_devices == [] else cuda_devices - -# global skip test if less than two cuda-enabled devices -LESS_THAN_TWO_DEVICES = len(cuda_devices) < 2 -pytestmark = pytest.mark.skipif( - LESS_THAN_TWO_DEVICES, reason="not enough cuda devices to run tests" -) - - -TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]} - -VAL_SET = { - "X": DATA["val_X"], - "X_intact": DATA["val_X_intact"], - "indicating_mask": DATA["val_X_indicating_mask"], - "y": DATA["val_y"], -} -TEST_SET = {"X": DATA["test_X"]} - -RESULT_SAVING_DIR_FOR_IMPUTATION = os.path.join(RESULT_SAVING_DIR, "imputation") -RESULT_SAVING_DIR_FOR_CLASSIFICATION = os.path.join(RESULT_SAVING_DIR, "classification") -RESULT_SAVING_DIR_FOR_CLUSTERING = os.path.join(RESULT_SAVING_DIR, "clustering") - - -class TestSAITS(unittest.TestCase): - logger.info("Running tests for an imputation model SAITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "SAITS") - model_save_name = "saved_saits_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a SAITS model - saits = SAITS( - DATA["n_steps"], - DATA["n_features"], - n_layers=2, - d_model=256, - d_inner=128, - n_heads=4, - d_k=64, - d_v=64, - dropout=0.1, - epochs=EPOCHS, - saving_path=saving_path, - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_0_fit(self): - self.saits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_1_impute(self): - imputed_X = self.saits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"SAITS test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-saits") - def test_2_parameters(self): - assert hasattr(self.saits, "model") and self.saits.model is not None - - assert hasattr(self.saits, "optimizer") and self.saits.optimizer is not None - - assert hasattr(self.saits, "best_loss") - self.assertNotEqual(self.saits.best_loss, float("inf")) - - assert ( - hasattr(self.saits, "best_model_dict") - and self.saits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-saits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.saits) - - # save the trained model into file, and check if the path exists - self.saits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.saits.load_model(saved_model_path) - - -class TestTransformer(unittest.TestCase): - logger.info("Running tests for an imputation model Transformer...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Transformer") - model_save_name = "saved_transformer_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a Transformer model - transformer = Transformer( - DATA["n_steps"], - DATA["n_features"], - n_layers=2, - d_model=256, - d_inner=128, - n_heads=4, - d_k=64, - d_v=64, - dropout=0.1, - epochs=EPOCHS, - saving_path=saving_path, - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_0_fit(self): - self.transformer.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_1_impute(self): - imputed_X = self.transformer.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"Transformer test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_2_parameters(self): - assert hasattr(self.transformer, "model") and self.transformer.model is not None - - assert ( - hasattr(self.transformer, "optimizer") - and self.transformer.optimizer is not None - ) - - assert hasattr(self.transformer, "best_loss") - self.assertNotEqual(self.transformer.best_loss, float("inf")) - - assert ( - hasattr(self.transformer, "best_model_dict") - and self.transformer.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-transformer") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.transformer) - - # save the trained model into file, and check if the path exists - self.transformer.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.transformer.load_model(saved_model_path) - - -class TestImputationBRITS(unittest.TestCase): - logger.info("Running tests for an imputation model BRITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "BRITS") - model_save_name = "saved_BRITS_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a BRITS model - brits = ImputationBRITS( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCHS, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/BRITS", - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_0_fit(self): - self.brits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_1_impute(self): - imputed_X = self.brits.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"BRITS test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-brits") - def test_2_parameters(self): - assert hasattr(self.brits, "model") and self.brits.model is not None - - assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None - - assert hasattr(self.brits, "best_loss") - self.assertNotEqual(self.brits.best_loss, float("inf")) - - assert ( - hasattr(self.brits, "best_model_dict") - and self.brits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-brits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.brits) - - # save the trained model into file, and check if the path exists - self.brits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.brits.load_model(saved_model_path) - - -class TestMRNN(unittest.TestCase): - logger.info("Running tests for an imputation model MRNN...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MRNN") - model_save_name = "saved_MRNN_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a MRNN model - mrnn = MRNN( - DATA["n_steps"], - DATA["n_features"], - 256, - epochs=EPOCHS, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/MRNN", - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_0_fit(self): - self.mrnn.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_1_impute(self): - imputed_X = self.mrnn.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"MRNN test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_2_parameters(self): - assert hasattr(self.mrnn, "model") and self.mrnn.model is not None - - assert hasattr(self.mrnn, "optimizer") and self.mrnn.optimizer is not None - - assert hasattr(self.mrnn, "best_loss") - self.assertNotEqual(self.mrnn.best_loss, float("inf")) - - assert ( - hasattr(self.mrnn, "best_model_dict") - and self.mrnn.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="imputation-mrnn") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.mrnn) - - # save the trained model into file, and check if the path exists - self.mrnn.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.mrnn.load_model(saved_model_path) - - -class TestLOCF(unittest.TestCase): - logger.info("Running tests for an imputation model LOCF...") - locf = LOCF(nan=0) - - @pytest.mark.xdist_group(name="imputation-locf") - def test_0_impute(self): - test_X_imputed = self.locf.impute(TEST_SET) - assert not np.isnan( - test_X_imputed - ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( - test_X_imputed, DATA["test_X_intact"], DATA["test_X_indicating_mask"] - ) - logger.info(f"LOCF test_MAE: {test_MAE}") - - @pytest.mark.xdist_group(name="imputation-locf") - def test_1_parameters(self): - assert hasattr(self.locf, "nan") and self.locf.nan is not None - - -class TestClassificationBRITS(unittest.TestCase): - logger.info("Running tests for a classification model BRITS...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "BRITS") - model_save_name = "saved_BRITS_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a BRITS model - brits = BRITS( - DATA["n_steps"], - DATA["n_features"], - n_classes=DATA["n_classes"], - rnn_hidden_size=256, - epochs=EPOCHS, - saving_path=saving_path, - model_saving_strategy="better", - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="classification-brits") - def test_0_fit(self): - self.brits.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-brits") - def test_1_classify(self): - predictions = self.brits.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-brits") - def test_2_parameters(self): - assert hasattr(self.brits, "model") and self.brits.model is not None - - assert hasattr(self.brits, "optimizer") and self.brits.optimizer is not None - - assert hasattr(self.brits, "best_loss") - self.assertNotEqual(self.brits.best_loss, float("inf")) - - assert ( - hasattr(self.brits, "best_model_dict") - and self.brits.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-brits") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.brits) - - # save the trained model into file, and check if the path exists - self.brits.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.brits.load_model(saved_model_path) - - -class TestGRUD(unittest.TestCase): - logger.info("Running tests for a classification model GRUD...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "GRUD") - model_save_name = "saved_GRUD_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a GRUD model - grud = GRUD( - DATA["n_steps"], - DATA["n_features"], - n_classes=DATA["n_classes"], - rnn_hidden_size=256, - epochs=EPOCHS, - saving_path=saving_path, - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="classification-grud") - def test_0_fit(self): - self.grud.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-grud") - def test_1_classify(self): - predictions = self.grud.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-grud") - def test_2_parameters(self): - assert hasattr(self.grud, "model") and self.grud.model is not None - - assert hasattr(self.grud, "optimizer") and self.grud.optimizer is not None - - assert hasattr(self.grud, "best_loss") - self.assertNotEqual(self.grud.best_loss, float("inf")) - - assert ( - hasattr(self.grud, "best_model_dict") - and self.grud.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-grud") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.grud) - - # save the trained model into file, and check if the path exists - self.grud.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.grud.load_model(saved_model_path) - - -class TestRaindrop(unittest.TestCase): - logger.info("Running tests for a classification model Raindrop...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "Raindrop") - model_save_name = "saved_Raindrop_model.pypots" - - # initialize a Raindrop model - raindrop = Raindrop( - DATA["n_steps"], - DATA["n_features"], - DATA["n_classes"], - n_layers=2, - d_model=DATA["n_features"] * 4, - d_inner=256, - n_heads=2, - dropout=0.3, - d_static=0, - aggregation="mean", - sensor_wise_mask=False, - static=False, - epochs=EPOCHS, - saving_path=saving_path, - ) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_0_fit(self): - self.raindrop.fit(TRAIN_SET, VAL_SET) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_1_classify(self): - predictions = self.raindrop.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) - logger.info( - f'ROC_AUC: {metrics["roc_auc"]}, \n' - f'PR_AUC: {metrics["pr_auc"]},\n' - f'F1: {metrics["f1"]},\n' - f'Precision: {metrics["precision"]},\n' - f'Recall: {metrics["recall"]},\n' - ) - assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_2_parameters(self): - assert hasattr(self.raindrop, "model") and self.raindrop.model is not None - - assert ( - hasattr(self.raindrop, "optimizer") and self.raindrop.optimizer is not None - ) - - assert hasattr(self.raindrop, "best_loss") - self.assertNotEqual(self.raindrop.best_loss, float("inf")) - - assert ( - hasattr(self.raindrop, "best_model_dict") - and self.raindrop.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="classification-raindrop") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.raindrop) - - # save the trained model into file, and check if the path exists - self.raindrop.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.raindrop.load_model(saved_model_path) - - -class TestCRLI(unittest.TestCase): - logger.info("Running tests for a clustering model CRLI...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "CRLI") - model_save_name = "saved_CRLI_model.pypots" - - # initialize an Adam optimizer - G_optimizer = Adam(lr=0.001, weight_decay=1e-5) - D_optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a CRLI model - crli = CRLI( - n_steps=DATA["n_steps"], - n_features=DATA["n_features"], - n_clusters=DATA["n_classes"], - n_generator_layers=2, - rnn_hidden_size=128, - epochs=EPOCHS, - saving_path=saving_path, - G_optimizer=G_optimizer, - D_optimizer=D_optimizer, - ) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_0_fit(self): - self.crli.fit(TRAIN_SET) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_1_parameters(self): - assert hasattr(self.crli, "model") and self.crli.model is not None - - assert hasattr(self.crli, "G_optimizer") and self.crli.G_optimizer is not None - assert hasattr(self.crli, "D_optimizer") and self.crli.D_optimizer is not None - - assert hasattr(self.crli, "best_loss") - self.assertNotEqual(self.crli.best_loss, float("inf")) - - assert ( - hasattr(self.crli, "best_model_dict") - and self.crli.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="clustering-crli") - def test_2_cluster(self): - clustering = self.crli.cluster(TEST_SET) - RI = cal_rand_index(clustering, DATA["test_y"]) - CP = cal_cluster_purity(clustering, DATA["test_y"]) - logger.info(f"RI: {RI}\nCP: {CP}") - - @pytest.mark.xdist_group(name="clustering-crli") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.crli) - - # save the trained model into file, and check if the path exists - self.crli.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.crli.load_model(saved_model_path) - - -class TestVaDER(unittest.TestCase): - logger.info("Running tests for a clustering model Transformer...") - - # set the log and model saving path - saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLUSTERING, "VaDER") - model_save_name = "saved_VaDER_model.pypots" - - # initialize an Adam optimizer - optimizer = Adam(lr=0.001, weight_decay=1e-5) - - # initialize a VaDER model - vader = VaDER( - n_steps=DATA["n_steps"], - n_features=DATA["n_features"], - n_clusters=DATA["n_classes"], - rnn_hidden_size=64, - d_mu_stddev=5, - pretrain_epochs=20, - epochs=EPOCHS, - saving_path=saving_path, - optimizer=optimizer, - num_workers=2, - device=DEVICES, - ) - - @pytest.mark.xdist_group(name="clustering-vader") - def test_0_fit(self): - self.vader.fit(TRAIN_SET) - - @pytest.mark.xdist_group(name="clustering-vader") - def test_1_cluster(self): - try: - clustering = self.vader.cluster(TEST_SET) - RI = cal_rand_index(clustering, DATA["test_y"]) - CP = cal_cluster_purity(clustering, DATA["test_y"]) - logger.info(f"RI: {RI}\nCP: {CP}") - except np.linalg.LinAlgError as e: - logger.error( - f"{e}\n" - "Got singular matrix, please try to retrain the model to fix this" - ) - - @pytest.mark.xdist_group(name="clustering-vader") - def test_2_parameters(self): - assert hasattr(self.vader, "model") and self.vader.model is not None - - assert hasattr(self.vader, "optimizer") and self.vader.optimizer is not None - - assert hasattr(self.vader, "best_loss") - self.assertNotEqual(self.vader.best_loss, float("inf")) - - assert ( - hasattr(self.vader, "best_model_dict") - and self.vader.best_model_dict is not None - ) - - @pytest.mark.xdist_group(name="clustering-vader") - def test_3_saving_path(self): - # whether the root saving dir exists, which should be created by save_log_into_tb_file - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" - - # check if the tensorboard file and model checkpoints exist - check_tb_and_model_checkpoints_existence(self.vader) - - # save the trained model into file, and check if the path exists - self.vader.save_model( - saving_dir=self.saving_path, file_name=self.model_save_name - ) - - # test loading the saved model, not necessary, but need to test - saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.vader.load_model(saved_model_path) - - -class TestBTTF(unittest.TestCase): - logger.info("Running tests for a forecasting model BTTF...") - - # initialize a BTTF model - pred_step = 4 - bttf = BTTF( - n_steps=DATA["n_steps"] - pred_step, - n_features=10, - pred_step=pred_step, - rank=10, - time_lags=[1, 2, 3, 5, 5 + 1, 5 + 2, 10, 10 + 1, 10 + 2], - burn_iter=5, - gibbs_iter=5, - multi_step=1, - ) - - @pytest.mark.xdist_group(name="forecasting-bttf") - def test_0_forecasting(self): - predictions = self.bttf.forecast({"X": DATA["test_X"][:, : -self.pred_step]}) - logger.info(f"prediction shape: {predictions.shape}") - mae = cal_mae(predictions, DATA["test_X_intact"][:, -self.pred_step :]) - logger.info(f"prediction MAE: {mae}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..f0b4685e --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 diff --git a/tests/test_utils.py b/tests/utils/logging.py similarity index 64% rename from tests/test_utils.py rename to tests/utils/logging.py index 0fd48ec8..113f0dde 100644 --- a/tests/test_utils.py +++ b/tests/utils/logging.py @@ -1,5 +1,5 @@ """ -Test cases for the functions and classes in package `pypots.utils`. +Test cases for the functions and classes in package `pypots.utils.logging`. """ # Created by Wenjie Du @@ -9,10 +9,7 @@ import shutil import unittest -import torch - from pypots.utils.logging import Logger -from pypots.utils.random import set_random_seed class TestLogging(unittest.TestCase): @@ -49,25 +46,5 @@ def test_saving_log_into_file(self): shutil.rmtree("test_log", ignore_errors=True) -class TestRandom(unittest.TestCase): - def test_set_random_seed(self): - random_state1 = torch.get_rng_state() - torch.rand( - 1, 3 - ) # randomly generate something, the random state will be reset, so two states should be varying - random_state2 = torch.get_rng_state() - assert not torch.equal( - random_state1, random_state2 - ), "The random seed hasn't set, so two random states should be different." - - set_random_seed(26) - random_state1 = torch.get_rng_state() - set_random_seed(26) - random_state2 = torch.get_rng_state() - assert torch.equal( - random_state1, random_state2 - ), "The random seed has been set, two random states are not the same." - - if __name__ == "__main__": unittest.main() diff --git a/tests/utils/random.py b/tests/utils/random.py new file mode 100644 index 00000000..0d1a0ca0 --- /dev/null +++ b/tests/utils/random.py @@ -0,0 +1,36 @@ +""" +Test cases for the functions and classes in package `pypots.utils.random`. +""" + +# Created by Wenjie Du +# License: GPL-v3 + +import unittest + +import torch + +from pypots.utils.random import set_random_seed + + +class TestRandom(unittest.TestCase): + def test_set_random_seed(self): + random_state1 = torch.get_rng_state() + torch.rand( + 1, 3 + ) # randomly generate something, the random state will be reset, so two states should be varying + random_state2 = torch.get_rng_state() + assert not torch.equal( + random_state1, random_state2 + ), "The random seed hasn't set, so two random states should be different." + + set_random_seed(26) + random_state1 = torch.get_rng_state() + set_random_seed(26) + random_state2 = torch.get_rng_state() + assert torch.equal( + random_state1, random_state2 + ), "The random seed has been set, two random states are not the same." + + +if __name__ == "__main__": + unittest.main() From 0f1977b6dc18137c257eb91a937752e886d43940 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 22:40:25 +0800 Subject: [PATCH 12/17] docs: update READE with new added models; --- README.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7b591634..9c86f08a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,8 @@ -##

Welcome to PyPOTS

+

Welcome to PyPOTS

+ **

A Python Toolbox for Data Mining on Partially-Observed Time Series

**

@@ -161,6 +162,8 @@ PyPOTS supports imputation, classification, clustering, and forecasting tasks on | **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | | Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | | Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 | +| Neural Net | US-GAN | Generative Semi-supervised Learning for Multivariate Time Series Imputation [^10] | 2021 | +| Neural Net | GP-VAE | GP-VAE: Deep Probabilistic Time Series Imputation [^11] | 2020 | | Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | | Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 | | Naive | LOCF | Last Observation Carried Forward | - | @@ -253,7 +256,7 @@ We care about the feedback from our users, so we're building PyPOTS community on If you have any suggestions or want to contribute ideas or share time-series related papers, join us and tell. PyPOTS community is open, transparent, and surely friendly. Let's work together to build and improve PyPOTS! - +[//]: # (Use APA reference style below) [^1]: Du, W., Cote, D., & Liu, Y. (2023). [SAITS: Self-Attention-based Imputation for Time Series](https://doi.org/10.1016/j.eswa.2023.119619). *Expert systems with applications*. [^2]: Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). [Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html). *NeurIPS 2017*. [^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). [BRITS: Bidirectional Recurrent Imputation for Time Series](https://papers.nips.cc/paper/2018/hash/734e6bfcd358e25ac1db0a4241b95651-Abstract.html). *NeurIPS 2018*. @@ -263,7 +266,8 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together [^7]: Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., Hofmann-Apitius, M., & Fröhlich, H. (2019). [Deep learning for clustering of multivariate clinical patient trajectories with missing values](https://academic.oup.com/gigascience/article/8/11/giz134/5626377). *GigaScience*. [^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). *IEEE transactions on pattern analysis and machine intelligence*. [^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*. - +[^10]: Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). [Generative Semi-supervised Learning for Multivariate Time Series Imputation](https://ojs.aaai.org/index.php/AAAI/article/view/17086). *AAAI 2021*. +[^11]: Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S.. (2020). [GP-VAE: Deep Probabilistic Time Series Imputation](https://proceedings.mlr.press/v108/fortuin20a.html). *AISTATS 2020*.

🏠 Visits @@ -271,4 +275,4 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together PyPOTS visits
-
\ No newline at end of file +
From 03cd4c6e43ca0c13310b1b35a359caaa7679d32c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 23:00:04 +0800 Subject: [PATCH 13/17] feat: run CI testing workflow with pytest-xdist to speed up; --- .github/workflows/testing_ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/testing_ci.yml b/.github/workflows/testing_ci.yml index 7e5b6780..b139d018 100644 --- a/.github/workflows/testing_ci.yml +++ b/.github/workflows/testing_ci.yml @@ -68,7 +68,8 @@ jobs: - name: Test with pytest run: | - coverage run --source=pypots -m pytest -rA tests/*/* + rm -rf tests/__pycache__ + coverage run --source=pypots -m pytest -rA tests/*/* -n auto --dist=loadgroup - name: Generate the LCOV report run: | From e7b92ccde6601430b19a5e0024d5543531c0c481 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 23:30:18 +0800 Subject: [PATCH 14/17] docs: update .gitignore; --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0841fdef..51294f38 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,8 @@ docs/_build .coverage .pytest_cache *__pycache__* -*testing_results* +*test* # ignore specific kinds of files like all PDFs *.pdf +*.ipynb From 8b1efbb532369fb5e9378979c982282eaf60d7e9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 23:30:51 +0800 Subject: [PATCH 15/17] feat: add cal_internal_cluster_validation_metrics(); --- pypots/utils/metrics.py | 89 ++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 36 deletions(-) diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index 85efb54d..cc349b50 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -574,73 +574,90 @@ def cal_cluster_purity( return cluster_purity -def cal_silhouette( - latent_rep: np.ndarray, - class_predictions: np.ndarray -) -> float: +def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the mean Silhouette Coefficient of all samples. Parameters ---------- - latent_rep : - Latent representation learned by a clusterer. + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. - class_predictions : - Clustering results returned by a clusterer. + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. Returns ------- - silhouette : + silhouette_score : float Mean Silhouette Coefficient for all samples. """ - silhouette = metrics.silhouette_score(latent_rep, class_predictions) - return silhouette + silhouette_score = metrics.silhouette_score(X, predicted_labels) + return silhouette_score -def cal_chs( - latent_rep: np.ndarray, - class_predictions: np.ndarray -) -> float: +def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Calinski and Harabasz score (also known as the Variance Ratio Criterion). - Parameters - ---------- - latent_rep : - Latent representation learned by a clusterer. - - class_predictions : - Clustering results returned by a clusterer. + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. Returns ------- - chs : + calinski_harabasz_score : float The resulting Calinski-Harabasz score. """ - chs = metrics.calinski_harabasz_score(latent_rep, class_predictions) - return chs + calinski_harabasz_score = metrics.calinski_harabasz_score(X, predicted_labels) + return calinski_harabasz_score -def cal_dbs( - latent_rep: np.ndarray, - class_predictions: np.ndarray -) -> float: +def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Davies-Bouldin score. Parameters ---------- - latent_rep : - Latent representation learned by a clusterer. + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. - class_predictions : - Clustering results returned by a clusterer. + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. Returns ------- - dbs : + davies_bouldin_score : float The resulting Davies-Bouldin score. """ - dbs = metrics.davies_bouldin_score(latent_rep, class_predictions) - return dbs + davies_bouldin_score = metrics.davies_bouldin_score(X, predicted_labels) + return davies_bouldin_score + + +def cal_internal_cluster_validation_metrics(X, predicted_labels): + """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. + + Parameters + ---------- + X : array-like of shape (n_samples_a, n_features) + A feature array, or learned latent representation, that can be used for clustering. + + predicted_labels : array-like of shape (n_samples) + Predicted labels for each sample. + + Returns + ------- + internal_cluster_validation_metrics : dict + A dictionary contains all internal cluster validation metrics available in PyPOTS. + """ + + silhouette_score = cal_silhouette(X, predicted_labels) + calinski_harabasz_score = cal_chs(X, predicted_labels) + davies_bouldin_score = cal_dbs(X, predicted_labels) + + internal_cluster_validation_metrics = { + "silhouette_score": silhouette_score, + "calinski_harabasz_score": calinski_harabasz_score, + "davies_bouldin_score": davies_bouldin_score, + } + return internal_cluster_validation_metrics From 4dbccec5106e76c80feb80519ad3718e38e5140b Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 23:34:27 +0800 Subject: [PATCH 16/17] docs: update docs; --- docs/pypots.data.rst | 9 +++++++++ docs/pypots.imputation.rst | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/docs/pypots.data.rst b/docs/pypots.data.rst index d792d6aa..fe7c4678 100644 --- a/docs/pypots.data.rst +++ b/docs/pypots.data.rst @@ -10,6 +10,15 @@ pypots.data.base module :show-inheritance: :inherited-members: +pypots.data.saving module +----------------------------- + +.. automodule:: pypots.data.saving + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.data.generating module ----------------------------- diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 0e31f8c8..a33e0fdf 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -19,6 +19,24 @@ pypots.imputation.transformer module :show-inheritance: :inherited-members: +pypots.imputation.usgan module +------------------------------ + +.. automodule:: pypots.imputation.usgan + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.gpvae module +------------------------------ + +.. automodule:: pypots.imputation.gpvae + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.brits module ------------------------------ From f447b274448907375860995d96b8e4db5a773ae8 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 22 Sep 2023 00:25:53 +0800 Subject: [PATCH 17/17] fix: run with `python -m` instead of `coverage -m` to fix the bug of reporting 0% coverage; --- .github/workflows/testing_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing_ci.yml b/.github/workflows/testing_ci.yml index b139d018..4cdfe5bc 100644 --- a/.github/workflows/testing_ci.yml +++ b/.github/workflows/testing_ci.yml @@ -69,7 +69,7 @@ jobs: - name: Test with pytest run: | rm -rf tests/__pycache__ - coverage run --source=pypots -m pytest -rA tests/*/* -n auto --dist=loadgroup + python -m pytest -rA tests/*/* -n auto --cov=pypots --dist=loadgroup --cov-config=.coveragerc - name: Generate the LCOV report run: |