diff --git a/docs/about_us.rst b/docs/about_us.rst index aaaab944..370a3e0d 100644 --- a/docs/about_us.rst +++ b/docs/about_us.rst @@ -33,5 +33,5 @@ PyPOTS exists thanks to all the nice people (sorted by contribution time) who co .. raw:: html - + diff --git a/pypots/base.py b/pypots/base.py index f55033e3..7a12fe94 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -96,7 +96,9 @@ def _setup_device(self, device: Union[None, str, torch.device, list]): self.device = device elif isinstance(device, list): if len(device) == 0: - raise ValueError("The list of devices should have at least 1 device, but got 0.") + raise ValueError( + "The list of devices should have at least 1 device, but got 0." + ) elif len(device) == 1: return self._setup_device(device[0]) # parallely training on multiple CUDA devices @@ -176,7 +178,6 @@ def _send_data_to_given_device(self, data): if isinstance(self.device, torch.device): # single device data = map(lambda x: x.to(self.device), data) else: # parallely training on multiple devices - # randomly choose one device to balance the workload # device = np.random.choice(self.device) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index a30fd698..a16dbc01 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -256,7 +256,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/classification/raindrop/modules.py b/pypots/classification/raindrop/modules.py index 76a992ef..191ff9c7 100644 --- a/pypots/classification/raindrop/modules.py +++ b/pypots/classification/raindrop/modules.py @@ -174,7 +174,6 @@ def forward( edge_attr: OptTensor = None, return_attention_weights=None, ) -> Tuple[torch.Tensor, Any]: - r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index 324e6718..fd9b7f0d 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -244,7 +244,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - """ Parameters diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index b5e2e14a..8b7a63a1 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -226,7 +226,6 @@ def __init__( saving_path: Optional[str] = None, model_saving_strategy: Optional[str] = "best", ): - super().__init__( n_clusters, batch_size, diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index f2912cce..5a44da85 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -184,7 +184,6 @@ def forward( ) = self.get_results(X, missing_mask) if not training and not pretrain: - results = { "mu_tilde": mu_tilde, "mu": mu_c, @@ -403,7 +402,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 5188999b..079f5925 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -242,7 +242,6 @@ def _train_model( training_loader: DataLoader, val_loader: DataLoader = None, ) -> None: - # each training starts from the very beginning, so reset the loss and model dict here self.best_loss = float("inf") self.best_model_dict = None diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 9de8d0bc..a6c4dcd8 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -6,10 +6,12 @@ # License: GPL-v3 from .brits import BRITS +from .gpvae import GPVAE from .locf import LOCF +from .mrnn import MRNN from .saits import SAITS from .transformer import Transformer -from .mrnn import MRNN +from .usgan import USGAN __all__ = [ "SAITS", @@ -17,4 +19,6 @@ "BRITS", "MRNN", "LOCF", + "GPVAE", + "USGAN", ] diff --git a/pypots/imputation/gpvae/__init__.py b/pypots/imputation/gpvae/__init__.py new file mode 100644 index 00000000..f5ffb05e --- /dev/null +++ b/pypots/imputation/gpvae/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method GP-VAE. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from .model import GPVAE + +__all__ = [ + "GPVAE", +] diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py new file mode 100644 index 00000000..4f8b27c4 --- /dev/null +++ b/pypots/imputation/gpvae/data.py @@ -0,0 +1,133 @@ +""" +Dataset class for model GP-VAE. +""" + +# Created by Jun Wang and Wenjie Du +# License: GLP-v3 + +from typing import Union, Iterable + +import torch + +from ...data.base import BaseDataset +from ...data.utils import torch_parse_delta + + +class DatasetForGPVAE(BaseDataset): + """Dataset class for GP-VAE. + + Parameters + ---------- + data : dict or str, + The dataset for model input, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for input, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + return_labels : bool, default = True, + Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, + during training of classification models, the Dataset class will return labels in __getitem__() for model input. + Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we + need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 + files, they already have both X and y saved. But we don't read labels from the file for validating and testing + with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for + distinction. + + file_type : str, default = "h5py" + The type of the given file if train_set and val_set are path strings. + """ + + def __init__( + self, + data: Union[dict, str], + return_labels: bool = True, + file_type: str = "h5py", + ): + super().__init__(data, return_labels, file_type) + + if not isinstance(self.data, str): + # calculate all delta here. + missing_mask = (~torch.isnan(self.X)).type(torch.float32) + X = torch.nan_to_num(self.X) + + self.processed_data = { + "X": X, + "missing_mask": missing_mask, + } + + def _fetch_data_from_array(self, idx: int) -> Iterable: + """Fetch data from self.X if it is given. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + A list contains + + index : int tensor, + The index of the sample. + + X : tensor, + The feature vector for model input. + + missing_mask : tensor, + The mask indicates all missing values in X. + + delta : tensor, + The delta matrix contains time gaps of missing values. + + label (optional) : tensor, + The target label of the time-series sample. + """ + sample = [ + torch.tensor(idx), + # for forward + self.processed_data["X"][idx].to(torch.float32), + self.processed_data["missing_mask"][idx].to(torch.float32), + ] + + if self.y is not None and self.return_labels: + sample.append(self.y[idx].to(torch.long)) + + return sample + + def _fetch_data_from_file(self, idx: int) -> Iterable: + """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. + Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. + + Parameters + ---------- + idx : int, + The index of the sample to be return. + + Returns + ------- + sample : list, + The collated data sample, a list including all necessary sample info. + """ + + if self.file_handle is None: + self.file_handle = self._open_file_handle() + + X = torch.from_numpy(self.file_handle["X"][idx]) + missing_mask = (~torch.isnan(X)).to(torch.float32) + X = torch.nan_to_num(X) + + sample = [ + torch.tensor(idx), + X, + missing_mask, + ] + + # if the dataset has labels and is for training, then fetch it from the file + if "y" in self.file_handle.keys() and self.return_labels: + sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long)) + + return sample diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py new file mode 100644 index 00000000..6b613d4d --- /dev/null +++ b/pypots/imputation/gpvae/model.py @@ -0,0 +1,446 @@ +""" +The implementation of GP-VAE for the partially-observed time-series imputation task. + +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. +GP-VAE: Deep probabilistic time series imputation. AISTATS. PMLR, 2020: 1651-1661. + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + + +from typing import Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from .data import DatasetForGPVAE +from .modules import ( + Encoder, + rbf_kernel, + diffusion_kernel, + matern_kernel, + cauchy_kernel, + Decoder, +) +from ..base import BaseNNImputer +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class _GPVAE(nn.Module): + """model GPVAE with Gaussian Process prior + + Parameters + ---------- + input_dim : int, + the feature dimension of the input + + time_length : int, + the length of each time series + + latent_dim : int, + the feature dimension of the latent embedding + + encoder_sizes : tuple, + the tuple of the network size in encoder + + decoder_sizes : tuple, + the tuple of the network size in decoder + + beta : float, + the weight of the KL divergence + + M : int, + the number of Monte Carlo samples for ELBO estimation + + K : int, + the number of importance weights for IWAE model + + kernel : str, + the Gaussian Process kernel ["cauchy", "diffusion", "rbf", "matern"] + + sigma : float, + the scale parameter for a kernel function + + length_scale : float, + the length scale parameter for a kernel function + + kernel_scales : int, + the number of different length scales over latent space dimensions + """ + + def __init__( + self, + input_dim, + time_length, + latent_dim, + encoder_sizes=(64, 64), + decoder_sizes=(64, 64), + beta=1, + M=1, + K=1, + kernel="cauchy", + sigma=1.0, + length_scale=7.0, + kernel_scales=1, + window_size=24, + ): + super().__init__() + self.kernel = kernel + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales + + self.input_dim = input_dim + self.time_length = time_length + self.latent_dim = latent_dim + self.beta = beta + self.encoder = Encoder(input_dim, latent_dim, encoder_sizes, window_size) + self.decoder = Decoder(latent_dim, input_dim, decoder_sizes) + self.M = M + self.K = K + + # Precomputed KL components for efficiency + self.prior = self._init_prior() + # self.pz_scale_inv = None + # self.pz_scale_log_abs_determinant = None + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + if not torch.is_tensor(z): + z = torch.tensor(z).float() + num_dim = len(z.shape) + assert num_dim > 2 + return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) + + def forward(self, inputs, training=True): + x = inputs["X"] + m_mask = inputs["missing_mask"] + x = x.repeat(self.M * self.K, 1, 1) + if m_mask is not None: + m_mask = m_mask.repeat(self.M * self.K, 1, 1) + m_mask = m_mask.type(torch.bool) + + # pz = self.prior() + qz_x = self.encode(x) + z = qz_x.rsample() + px_z = self.decode(z) + + nll = -px_z.log_prob(x) + nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) + if m_mask is not None: + nll = torch.where(m_mask, nll, torch.zeros_like(nll)) + nll = nll.sum(dim=(1, 2)) + + if self.K > 1: + kl = qz_x.log_prob(z) - self.prior.log_prob(z) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + weights = -nll - kl + weights = torch.reshape(weights, [self.M, self.K, -1]) + + elbo = torch.logsumexp(weights, dim=1) + elbo = elbo.mean() + else: + kl = self.kl_divergence(qz_x, self.prior) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + elbo = -nll - self.beta * kl + elbo = elbo.mean() + + imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + results = { + "loss": -elbo.mean(), + "imputed_data": imputed_data, + } + return results + + @staticmethod + def kl_divergence(a, b): + # TODO: different from the author's implementation + return torch.distributions.kl.kl_divergence(a, b) + + def _init_prior(self): + # Compute kernel matrices for each latent dimension + kernel_matrices = [] + for i in range(self.kernel_scales): + if self.kernel == "rbf": + kernel_matrices.append( + rbf_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "diffusion": + kernel_matrices.append( + diffusion_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "matern": + kernel_matrices.append( + matern_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "cauchy": + kernel_matrices.append( + cauchy_kernel( + self.time_length, self.sigma, self.length_scale / 2**i + ) + ) + + # Combine kernel matrices for each latent dimension + tiled_matrices = [] + total = 0 + for i in range(self.kernel_scales): + if i == self.kernel_scales - 1: + multiplier = self.latent_dim - total + else: + multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) + total += multiplier + tiled_matrices.append( + torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) + ) + kernel_matrix_tiled = torch.cat(tiled_matrices) + assert len(kernel_matrix_tiled) == self.latent_dim + prior = torch.distributions.MultivariateNormal( + loc=torch.zeros(self.latent_dim, self.time_length), + covariance_matrix=kernel_matrix_tiled, + ) + + return prior + + +class GPVAE(BaseNNImputer): + """The PyTorch implementation of the GPVAE model :cite:``. + + Parameters + ---------- + beta: + The weight of KL divergence in EBLO. + + kernel: + The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"] + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying GPVAE model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + latent_size: int, + encoder_sizes: tuple = (64, 64), + decoder_sizes: tuple = (64, 64), + kernel: str = "cauchy", + beta: float = 0.2, + M: int = 1, + K: int = 1, + sigma: float = 1.0, + length_scale: float = 7.0, + kernel_scales: int = 1, + window_size: int = 3, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.latent_size = latent_size + self.kernel = kernel + self.encoder_sizes = encoder_sizes + self.decoder_sizes = decoder_sizes + self.beta = beta + self.M = M + self.K = K + self.sigma = sigma + self.length_scale = length_scale + self.kernel_scales = kernel_scales + + # set up the model + self.model = _GPVAE( + input_dim=self.n_features, + time_length=self.n_steps, + latent_dim=self.latent_size, + kernel=self.kernel, + encoder_sizes=self.encoder_sizes, + decoder_sizes=self.decoder_sizes, + beta=self.beta, + M=self.M, + K=self.K, + sigma=self.sigma, + length_scale=self.length_scale, + kernel_scales=self.kernel_scales, + window_size=window_size, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + # fetch data + ( + indices, + X, + missing_mask, + ) = self._send_data_to_given_device(data) + + # assemble input data + inputs = { + "indices": indices, + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForGPVAE( + train_set, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if isinstance(val_set, str): + with h5py.File(val_set, "r") as hf: + # Here we read the whole validation set from the file to mask a portion for validation. + # In PyPOTS, using a file usually because the data is too big. However, the validation set is + # generally shouldn't be too large. For example, we have 1 billion samples for model training. + # We won't take 20% of them as the validation set because we want as much as possible data for the + # training stage to enhance the model's generalization ability. Therefore, 100,000 representative + # samples will be enough to validate the model. + val_set = { + "X": hf["X"][:], + "X_intact": hf["X_intact"][:], + "indicating_mask": hf["indicating_mask"][:], + } + val_set = DatasetForGPVAE(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(training_finished=True) + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForGPVAE(X, return_labels=False, file_type=file_type) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() diff --git a/pypots/imputation/gpvae/modules.py b/pypots/imputation/gpvae/modules.py new file mode 100644 index 00000000..5ad81e09 --- /dev/null +++ b/pypots/imputation/gpvae/modules.py @@ -0,0 +1,261 @@ +""" +The implementation of GP-VAE for the partially-observed time-series imputation task. + +Refer to the paper Fortuin V, Baranchuk D, Rätsch G, et al. +GP-VAE: Deep probabilistic time series imputation. AISTATS. PMLR, 2020: 1651-1661. + + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def rbf_kernel(T, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = (xs_in - xs_out) ** 2 + distance_matrix_scaled = distance_matrix / length_scale**2 + kernel_matrix = torch.exp(-distance_matrix_scaled) + return kernel_matrix + + +def diffusion_kernel(T, length_scale): + assert length_scale < 0.5, ( + "length_scale has to be smaller than 0.5 for the " + "kernel matrix to be diagonally dominant" + ) + sigmas = torch.ones(T, T) * length_scale + sigmas_tridiag = torch.diagonal(sigmas, offset=0, dim1=-2, dim2=-1) + sigmas_tridiag += torch.diagonal(sigmas, offset=1, dim1=-2, dim2=-1) + sigmas_tridiag += torch.diagonal(sigmas, offset=-1, dim1=-2, dim2=-1) + kernel_matrix = sigmas_tridiag + torch.eye(T) * (1.0 - length_scale) + return kernel_matrix + + +def matern_kernel(T, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = torch.abs(xs_in - xs_out) + distance_matrix_scaled = distance_matrix / torch.sqrt(length_scale).type( + torch.float32 + ) + kernel_matrix = torch.exp(-distance_matrix_scaled) + return kernel_matrix + + +def cauchy_kernel(T, sigma, length_scale): + xs = torch.arange(T).float() + xs_in = torch.unsqueeze(xs, 0) + xs_out = torch.unsqueeze(xs, 1) + distance_matrix = (xs_in - xs_out) ** 2 + distance_matrix_scaled = distance_matrix / length_scale**2 + kernel_matrix = sigma / (distance_matrix_scaled + 1.0) + + alpha = 0.001 + eye = torch.eye(kernel_matrix.shape[-1]) + return kernel_matrix + alpha * eye + + +def make_nn(input_size, output_size, hidden_sizes): + """This function used to creates fully connected neural network. + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers + + Returns + ------- + output: tensor + the processing embeddings + """ + layers = [] + for i in range(len(hidden_sizes)): + if i == 0: + layers.append( + nn.Linear(in_features=input_size, out_features=hidden_sizes[i]) + ) + else: + layers.append( + nn.Linear(in_features=hidden_sizes[i - 1], out_features=hidden_sizes[i]) + ) + layers.append(nn.ReLU()) + layers.append(nn.Linear(in_features=hidden_sizes[-1], out_features=output_size)) + return nn.Sequential(*layers) + + +class CustomConv1d(torch.nn.Conv1d): + def __init(self, in_channels, out_channels, kernel_size, padding): + super().__init__(in_channels, out_channels, kernel_size, padding) + + def forward(self, x): + if len(x.shape) > 2: + shape = list(np.arange(len(x.shape))) + new_shape = [0, shape[-1]] + shape[1:-1] + out = super(CustomConv1d, self).forward(x.permute(*new_shape)) + shape = list(np.arange(len(out.shape))) + new_shape = [0, shape[-1]] + shape[1:-1] + if self.kernel_size[0] % 2 == 0: + out = F.pad(out, (0, -1), "constant", 0) + return out.permute(new_shape) + + return super(CustomConv1d, self).forward(x) + + +def make_cnn(input_size, output_size, hidden_sizes, kernel_size=3): + """This function used to construct neural network consisting of + one 1d-convolutional layer that utilizes temporal dependencies, + fully connected network + + Parameters + ---------- + input_size : int, + the dimension of input embeddings + + output_size : int, + the dimension of out embeddings + + hidden_sizes : tuple, + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers, + + kernel_size : int + kernel size for convolutional layer + + Returns + ------- + output: tensor + the processing embeddings + """ + padding = kernel_size // 2 + + cnn_layer = CustomConv1d( + input_size, hidden_sizes[0], kernel_size=kernel_size, padding=padding + ) + layers = [cnn_layer] + + for i, h in zip(hidden_sizes, hidden_sizes[1:]): + layers.extend([nn.Linear(i, h), nn.ReLU()]) + if isinstance(output_size, tuple): + net = nn.Sequential(*layers) + return [net] + [nn.Linear(hidden_sizes[-1], o) for o in output_size] + + layers.append(nn.Linear(hidden_sizes[-1], output_size)) + return nn.Sequential(*layers) + + +class Encoder(nn.Module): + def __init__(self, input_size, z_size, hidden_sizes=(128, 128), window_size=24): + """This module is an encoder with 1d-convolutional network and multivariate Normal posterior used by GP-VAE with + proposed banded covariance matrix + + Parameters + ---------- + input_size : int, + the feature dimension of the input + + z_size : int, + the feature dimension of the output latent embedding + + hidden_sizes : tuple, + the tuple of the hidden layer sizes, and the tuple length sets the number of hidden layers + + window_size : int + the kernel size for the Conv1D layer + """ + super().__init__() + self.z_size = int(z_size) + self.input_size = input_size + self.net, self.mu_layer, self.logvar_layer = make_cnn( + input_size, (z_size, z_size * 2), hidden_sizes, window_size + ) + + def forward(self, x): + mapped = self.net(x) + batch_size = mapped.size(0) + time_length = mapped.size(1) + + num_dim = len(mapped.shape) + mu = self.mu_layer(mapped) + logvar = self.logvar_layer(mapped) + mapped_mean = torch.transpose(mu, num_dim - 1, num_dim - 2) + mapped_covar = torch.transpose(logvar, num_dim - 1, num_dim - 2) + mapped_covar = torch.sigmoid(mapped_covar) + mapped_reshaped = mapped_covar.reshape(batch_size, self.z_size, 2 * time_length) + + dense_shape = [batch_size, self.z_size, time_length, time_length] + idxs_1 = np.repeat(np.arange(batch_size), self.z_size * (2 * time_length - 1)) + idxs_2 = np.tile( + np.repeat(np.arange(self.z_size), (2 * time_length - 1)), batch_size + ) + idxs_3 = np.tile( + np.concatenate([np.arange(time_length), np.arange(time_length - 1)]), + batch_size * self.z_size, + ) + idxs_4 = np.tile( + np.concatenate([np.arange(time_length), np.arange(1, time_length)]), + batch_size * self.z_size, + ) + idxs_all = np.stack([idxs_1, idxs_2, idxs_3, idxs_4], axis=1) + + mapped_values = mapped_reshaped[:, :, :-1].reshape(-1) + prec_sparse = torch.sparse_coo_tensor( + torch.LongTensor(idxs_all).t().to(mapped.device), + (mapped_values).to(mapped.device), + (dense_shape), + ) + prec_sparse = prec_sparse.coalesce() + prec_tril = prec_sparse.to_dense() + eye = ( + torch.eye(prec_tril.shape[-1]) + .unsqueeze(0) + .repeat(prec_tril.shape[0], prec_tril.shape[1], 1, 1) + .to(mapped.device) + ) + prec_tril = prec_tril + eye + cov_tril = torch.linalg.solve_triangular(prec_tril, eye, upper=True) + cov_tril = torch.where( + torch.isfinite(cov_tril), cov_tril, torch.zeros_like(cov_tril) + ).to(mapped.device) + + num_dim = len(cov_tril.shape) + cov_tril_lower = torch.transpose(cov_tril, num_dim - 1, num_dim - 2) + + z_dist = torch.distributions.MultivariateNormal( + loc=mapped_mean, scale_tril=cov_tril_lower + ) + return z_dist + + +class Decoder(nn.Module): + def __init__(self, input_size, output_size, hidden_sizes=(256, 256)): + """This module is a decoder with Gaussian output distribution. + + Parameters + ---------- + output_size : int, + the feature dimension of the output + + hidden_sizes: tuple + the tuple of hidden layer sizes, and the tuple length sets the number of hidden layers. + """ + super().__init__() + self.net = make_nn(input_size, output_size, hidden_sizes) + + def forward(self, x): + mu = self.net(x) + var = torch.ones_like(mu) + return torch.distributions.Normal(mu, var) diff --git a/pypots/imputation/mrnn/module.py b/pypots/imputation/mrnn/module.py index 873d2d73..a143d121 100644 --- a/pypots/imputation/mrnn/module.py +++ b/pypots/imputation/mrnn/module.py @@ -18,7 +18,7 @@ class FCN_Regression(nn.Module): def __init__(self, feature_num, rnn_hid_size): - super(FCN_Regression, self).__init__() + super().__init__() self.feat_reg = FeatureRegression(rnn_hid_size * 2) self.U = Parameter(torch.Tensor(feature_num, feature_num)) self.V1 = Parameter(torch.Tensor(feature_num, feature_num)) diff --git a/pypots/imputation/usgan/__init__.py b/pypots/imputation/usgan/__init__.py new file mode 100644 index 00000000..fb388d94 --- /dev/null +++ b/pypots/imputation/usgan/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method USGAN. +""" + +# Created by Jun Wang +# License: GLP-v3 + +from .model import USGAN + +__all__ = [ + "USGAN", +] diff --git a/pypots/imputation/usgan/data.py b/pypots/imputation/usgan/data.py new file mode 100644 index 00000000..bd012c30 --- /dev/null +++ b/pypots/imputation/usgan/data.py @@ -0,0 +1,46 @@ +""" +Dataset class for model USGAN. +""" + +# Created by Jun Wang and Wenjie Du +# License: GLP-v3 + +from typing import Union + +from ..brits.data import DatasetForBRITS + + +class DatasetForUSGAN(DatasetForBRITS): + """Dataset class for USGAN, the same with the one for BRITS. + + Parameters + ---------- + data : dict or str, + The dataset for model input, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for input, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + return_labels : bool, default = True, + Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, + during training of classification models, the Dataset class will return labels in __getitem__() for model input. + Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we + need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 + files, they already have both X and y saved. But we don't read labels from the file for validating and testing + with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for + distinction. + + file_type : str, default = "h5py" + The type of the given file if train_set and val_set are path strings. + """ + + def __init__( + self, + data: Union[dict, str], + return_labels: bool = True, + file_type: str = "h5py", + ): + super().__init__(data, return_labels, file_type) diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py new file mode 100644 index 00000000..c171d810 --- /dev/null +++ b/pypots/imputation/usgan/model.py @@ -0,0 +1,539 @@ +""" +The implementation of USGAN for the partially-observed time-series imputation task. + +Refer to the paper "Miao, X., Wu, Y., Wang, J., Gao, Y., Mao, X., & Yin, J. (2021). +Generative Semi-supervised Learning for Multivariate Time Series Imputation. AAAI 2021." + +""" + +# Created by Jun Wang and Wenjie Du +# License: GPL-v3 + +from typing import Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from .data import DatasetForUSGAN +from ..base import BaseNNImputer +from ..brits.model import _BRITS +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class Discriminator(nn.Module): + """model Discriminator: built on BiRNN + + Parameters + ---------- + n_features : + the feature dimension of the input + + rnn_hidden_size : + the hidden size of the RNN cell + + hint_rate : + the hint rate for the input imputed_data + + dropout_rate : + the dropout rate for the output layer + + device : + specify running the model on which device, CPU/GPU + + """ + + def __init__( + self, + n_features: int, + rnn_hidden_size: int, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.hint_rate = hint_rate + self.device = device + self.biRNN = nn.GRU( + n_features * 2, rnn_hidden_size, bidirectional=True, batch_first=True + ).to(device) + self.dropout = nn.Dropout(dropout_rate).to(device) + self.read_out = nn.Linear(rnn_hidden_size * 2, n_features).to(device) + + def forward( + self, + imputed_X: torch.Tensor, + missing_mask: torch.Tensor, + ) -> torch.Tensor: + """Forward processing of USGAN Discriminator. + + Parameters + ---------- + imputed_X : torch.Tensor, + The original X with missing parts already imputed. + + missing_mask : torch.Tensor, + The missing mask of X. + + Returns + ------- + logits : torch.Tensor, + the logits of the probability of being the true value. + + """ + + hint = ( + torch.rand_like(missing_mask, dtype=torch.float, device=self.device) + < self.hint_rate + ) + hint = hint.int() + h = hint * missing_mask + (1 - hint) * 0.5 + x_in = torch.cat([imputed_X, h], dim=-1) + + out, _ = self.biRNN(x_in) + logits = self.read_out(self.dropout(out)) + return logits + + +class _USGAN(nn.Module): + """model USGAN: + USGAN consists of a generator, a discriminator, which are all built on bidirectional recurrent neural networks. + + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the RNN cell + + lambda_mse : + the weigth of the reconstruction loss + + hint_rate : + the hint rate for the discriminator + + dropout_rate : + the dropout rate for the last layer in Discriminator + + device : + specify running the model on which device, CPU/GPU + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + lambda_mse: float, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + self.generator = _BRITS(n_steps, n_features, rnn_hidden_size, device) + self.discriminator = Discriminator( + n_features, + rnn_hidden_size, + hint_rate=hint_rate, + dropout_rate=dropout_rate, + device=device, + ) + + self.lambda_mse = lambda_mse + self.device = device + + def forward( + self, + inputs: dict, + training_object: str = "generator", + training: bool = True, + ) -> dict: + assert training_object in [ + "generator", + "discriminator", + ], 'training_object should be "generator" or "discriminator"' + + forward_X = inputs["forward"]["X"] + forward_missing_mask = inputs["forward"]["missing_mask"] + losses = {} + results = self.generator(inputs, training=training) + inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask) + if not training: + # if only run imputation operation, then no need to calculate loss + return results + + if training_object == "discriminator": + l_D = F.binary_cross_entropy_with_logits( + inputs["discrimination"], forward_missing_mask + ) + losses["discrimination_loss"] = l_D + else: + inputs["discrimination"] = inputs["discrimination"].detach() + l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"], + 1 - forward_missing_mask, + weight=1 - forward_missing_mask, + ) + loss_gene = l_G + self.lambda_mse * results["loss"] + losses["generation_loss"] = loss_gene + + losses["imputed_data"] = results["imputed_data"] + return losses + + +class USGAN(BaseNNImputer): + """The PyTorch implementation of the CRLI model :cite:`ma2021CRLI`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + rnn_hidden_size : + the hidden size of the RNN cell + + lambda_mse : + the weight of the reconstruction loss + + hint_rate : + the hint rate for the discriminator + + dropout_rate : + the dropout rate for the last layer in Discriminator + + G_steps : + The number of steps to train the generator in each iteration. + + D_steps : + The number of steps to train the discriminator in each iteration. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + G_optimizer : + The optimizer for the generator training. + If not given, will use a default Adam optimizer. + + D_optimizer : + The optimizer for the discriminator training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying CRLI model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + lambda_mse: float = 1, + hint_rate: float = 0.7, + dropout_rate: float = 0.0, + G_steps: int = 1, + D_steps: int = 1, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, + G_optimizer: Optional[Optimizer] = Adam(), + D_optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: Optional[str] = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + assert G_steps > 0 and D_steps > 0, "G_steps and D_steps should both >0" + + self.n_steps = n_steps + self.n_features = n_features + self.G_steps = G_steps + self.D_steps = D_steps + + # set up the model + self.model = _USGAN( + n_steps, + n_features, + rnn_hidden_size, + lambda_mse, + hint_rate, + dropout_rate, + self.device, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.G_optimizer = G_optimizer + self.G_optimizer.init_optimizer(self.model.generator.parameters()) + self.D_optimizer = D_optimizer + self.D_optimizer.init_optimizer(self.model.discriminator.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + # fetch data + ( + indices, + X, + missing_mask, + deltas, + back_X, + back_missing_mask, + back_deltas, + ) = self._send_data_to_given_device(data) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + }, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def _train_model( + self, + training_loader: DataLoader, + val_loader: DataLoader = None, + ) -> None: + # each training starts from the very beginning, so reset the loss and model dict here + self.best_loss = float("inf") + self.best_model_dict = None + + try: + training_step = 0 + epoch_train_loss_G_collector = [] + epoch_train_loss_D_collector = [] + for epoch in range(self.epochs): + self.model.train() + for idx, data in enumerate(training_loader): + training_step += 1 + inputs = self._assemble_input_for_training(data) + + step_train_loss_G_collector = [] + step_train_loss_D_collector = [] + + if idx % self.G_steps == 0: + self.G_optimizer.zero_grad() + results = self.model.forward( + inputs, training_object="generator" + ) + results["generation_loss"].backward() + self.G_optimizer.step() + step_train_loss_G_collector.append( + results["generation_loss"].item() + ) + + if idx % self.D_steps == 0: + self.D_optimizer.zero_grad() + results = self.model.forward( + inputs, training_object="discriminator" + ) + results["discrimination_loss"].backward(retain_graph=True) + self.D_optimizer.step() + step_train_loss_D_collector.append( + results["discrimination_loss"].item() + ) + + mean_step_train_D_loss = np.mean(step_train_loss_D_collector) + mean_step_train_G_loss = np.mean(step_train_loss_G_collector) + + epoch_train_loss_D_collector.append(mean_step_train_D_loss) + epoch_train_loss_G_collector.append(mean_step_train_G_loss) + + # save training loss logs into the tensorboard file for every step if in need + # Note: the `training_step` is not the actual number of steps that Discriminator and Generator get + # trained, the actual number should be D_steps*training_step and G_steps*training_step accordingly + if self.summary_writer is not None: + loss_results = { + "generation_loss": mean_step_train_G_loss, + "discrimination_loss": mean_step_train_D_loss, + } + self._save_log_into_tb_file( + training_step, "training", loss_results + ) + mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) + mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) + logger.info( + f"epoch {epoch}: " + f"training loss_generator {mean_epoch_train_G_loss:.4f}, " + f"train loss_discriminator {mean_epoch_train_D_loss:.4f}" + ) + mean_loss = mean_epoch_train_G_loss + + if mean_loss < self.best_loss: + self.best_loss = mean_loss + self.best_model_dict = self.model.state_dict() + self.patience = self.original_patience + # save the model if necessary + self._auto_save_model_if_necessary( + training_finished=False, + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + ) + else: + self.patience -= 1 + if self.patience == 0: + logger.info( + "Exceeded the training patience. Terminating the training procedure..." + ) + break + except Exception as e: + logger.error(f"Exception: {e}") + if self.best_model_dict is None: + raise RuntimeError( + "Training got interrupted. Model was not trained. Please investigate the error printed above." + ) + else: + RuntimeWarning( + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" + "If you don't want it, please try fit() again." + ) + + if np.equal(self.best_loss, float("inf")): + raise ValueError("Something is wrong. best_loss is Nan after training.") + + logger.info("Finished training.") + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForUSGAN( + train_set, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if isinstance(val_set, str): + with h5py.File(val_set, "r") as hf: + # Here we read the whole validation set from the file to mask a portion for validation. + # In PyPOTS, using a file usually because the data is too big. However, the validation set is + # generally shouldn't be too large. For example, we have 1 billion samples for model training. + # We won't take 20% of them as the validation set because we want as much as possible data for the + # training stage to enhance the model's generalization ability. Therefore, 100,000 representative + # samples will be enough to validate the model. + val_set = { + "X": hf["X"][:], + "X_intact": hf["X_intact"][:], + "indicating_mask": hf["indicating_mask"][:], + } + val_set = DatasetForUSGAN(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(training_finished=True) + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForUSGAN(X, return_labels=False, file_type=file_type) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() diff --git a/tests/test_imputation.py b/tests/test_imputation.py index 6094ce62..64a0b1ff 100644 --- a/tests/test_imputation.py +++ b/tests/test_imputation.py @@ -15,6 +15,8 @@ from pypots.imputation import ( SAITS, Transformer, + USGAN, + GPVAE, BRITS, MRNN, LOCF, @@ -194,6 +196,151 @@ def test_3_saving_path(self): self.transformer.load_model(saved_model_path) +class TestUSGAN(unittest.TestCase): + logger.info("Running tests for an imputation model US-GAN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "US-GAN") + model_save_name = "saved_USGAN_model.pypots" + + # initialize an Adam optimizer + G_optimizer = Adam(lr=0.001, weight_decay=1e-5) + D_optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a US-GAN model + us_gan = USGAN( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCH, + saving_path=saving_path, + G_optimizer=G_optimizer, + D_optimizer=D_optimizer, + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_0_fit(self): + self.us_gan.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_1_impute(self): + imputed_X = self.us_gan.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"US-GAN test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_2_parameters(self): + assert hasattr(self.us_gan, "model") and self.us_gan.model is not None + + assert ( + hasattr(self.us_gan, "G_optimizer") and self.us_gan.G_optimizer is not None + ) + assert ( + hasattr(self.us_gan, "D_optimizer") and self.us_gan.D_optimizer is not None + ) + + assert hasattr(self.us_gan, "best_loss") + self.assertNotEqual(self.us_gan.best_loss, float("inf")) + + assert ( + hasattr(self.us_gan, "best_model_dict") + and self.us_gan.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-usgan") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.us_gan) + + # save the trained model into file, and check if the path exists + self.us_gan.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.us_gan.load_model(saved_model_path) + + +class TestGPVAE(unittest.TestCase): + logger.info("Running tests for an imputation model GP-VAE...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "GP-VAE") + model_save_name = "saved_GPVAE_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a GP-VAE model + gp_vae = GPVAE( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCH, + saving_path=saving_path, + optimizer=optimizer, + ) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_0_fit(self): + self.gp_vae.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_1_impute(self): + imputed_X = self.gp_vae.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"GP-VAE test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-gpvae") + def test_2_parameters(self): + assert hasattr(self.gp_vae, "model") and self.gp_vae.model is not None + + assert hasattr(self.gp_vae, "optimizer") and self.gp_vae.optimizer is not None + + assert hasattr(self.gp_vae, "best_loss") + self.assertNotEqual(self.gp_vae.best_loss, float("inf")) + + assert ( + hasattr(self.gp_vae, "best_model_dict") + and self.gp_vae.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-GPVAE") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.gp_vae) + + # save the trained model into file, and check if the path exists + self.gp_vae.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.gp_vae.load_model(saved_model_path) + + class TestBRITS(unittest.TestCase): logger.info("Running tests for an imputation model BRITS...") @@ -210,7 +357,7 @@ class TestBRITS(unittest.TestCase): DATA["n_features"], 256, epochs=EPOCH, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/BRITS", + saving_path=saving_path, optimizer=optimizer, ) @@ -279,7 +426,7 @@ class TestMRNN(unittest.TestCase): DATA["n_features"], 256, epochs=EPOCH, - saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/MRNN", + saving_path=saving_path, optimizer=optimizer, )