From 04c687f326935ad59ecabc1066abcfeee0ed487c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 3 Oct 2023 23:09:20 +0800 Subject: [PATCH] feat: add predict() for all models; --- pypots/classification/base.py | 12 ++++++++ pypots/classification/brits/model.py | 32 +++++++++++++++++---- pypots/classification/grud/model.py | 31 ++++++++++++++++---- pypots/classification/raindrop/model.py | 31 ++++++++++++++++---- pypots/classification/template/model.py | 6 +++- pypots/clustering/base.py | 38 +++++++++++++++++++++++-- pypots/clustering/crli/model.py | 34 +++++++++++++++++----- pypots/clustering/template/model.py | 6 +++- pypots/clustering/vader/model.py | 34 +++++++++++++++++----- pypots/forecasting/base.py | 12 ++++++++ pypots/forecasting/bttf/model.py | 30 ++++++++++++++----- pypots/forecasting/template/model.py | 6 +++- pypots/imputation/base.py | 26 +++++++++++++++++ pypots/imputation/brits/model.py | 8 +++--- pypots/imputation/gpvae/model.py | 7 +++-- pypots/imputation/locf/model.py | 5 ++-- pypots/imputation/mrnn/model.py | 7 +++-- pypots/imputation/saits/model.py | 7 +++-- pypots/imputation/transformer/model.py | 7 +++-- pypots/imputation/usgan/model.py | 6 ++-- 20 files changed, 280 insertions(+), 65 deletions(-) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 34d6743f..b52dfc13 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -95,6 +95,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def classify( self, @@ -117,6 +125,8 @@ def classify( array-like, shape [n_samples], Classification results of the given samples. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -402,4 +412,6 @@ def classify( array-like, shape [n_samples], Classification results of the given samples. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index c9419b4a..af1faec8 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -16,6 +16,7 @@ from typing import Optional, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -29,6 +30,7 @@ ) from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _BRITS(imputation_BRITS, nn.Module): @@ -326,23 +328,41 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py"): + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type) + test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) classification_pred = results["classification_pred"] - prediction_collector.append(classification_pred) + classification_collector.append(classification_pred) + + classification = torch.cat(classification_collector).cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index 7d122d8f..60fd6482 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -23,6 +23,7 @@ from ...imputation.brits.modules import TemporalDecay from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _GRUD(nn.Module): @@ -303,23 +304,41 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type) + test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) prediction = results["classification_pred"] - prediction_collector.append(prediction) + classification_collector.append(prediction) + + classification = torch.cat(classification_collector).cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 06d68e8b..8b9d1670 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -519,9 +519,13 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type) + test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -529,13 +533,28 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: num_workers=self.num_workers, ) - prediction_collector = [] + classification_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) prediction = results["classification_pred"] - prediction_collector.append(prediction) + classification_collector.append(prediction) + + classification = torch.cat(classification_collector).cpu().detach().numpy() + + result_dict = { + "classification": classification, + } + return result_dict - predictions = torch.cat(prediction_collector) - return predictions.cpu().detach().numpy() + def classify( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + return result_dict["classification"] diff --git a/pypots/classification/template/model.py b/pypots/classification/template/model.py index f7e2d15a..9e625ab2 100644 --- a/pypots/classification/template/model.py +++ b/pypots/classification/template/model.py @@ -101,5 +101,9 @@ def fit( ) -> None: raise NotImplementedError - def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index c9a27ad4..43d2e82e 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -64,6 +64,7 @@ def __init__( def fit( self, train_set: Union[dict, str], + val_set: Union[dict, str] = None, file_type: str = "h5py", ) -> None: """Train the cluster. @@ -78,11 +79,29 @@ def fit( If it is a path string, the path should point to a data file, e.g. a h5 file, which contains key-value pairs like a dict, and it has to include the key 'X'. + val_set : + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + file_type : - The type of the given file if train_set is a path string. + The type of the given file if train_set and val_set are path strings. + """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def cluster( self, @@ -105,6 +124,8 @@ def cluster( array-like, Clustering results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -332,6 +353,7 @@ def _train_model( def fit( self, train_set: Union[dict, str], + val_set: Union[dict, str] = None, file_type: str = "h5py", ) -> None: """Train the cluster. @@ -346,8 +368,18 @@ def fit( If it is a path string, the path should point to a data file, e.g. a h5 file, which contains key-value pairs like a dict, and it has to include the key 'X'. + val_set : + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + file_type : - The type of the given file if train_set is a path string. + The type of the given file if train_set and val_set are path strings. + """ raise NotImplementedError @@ -373,4 +405,6 @@ def cluster( array-like, Clustering results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 35be0034..46805c38 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -415,14 +415,14 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def cluster( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", return_latent: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type) + test_set = DatasetForCRLI(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -443,13 +443,33 @@ def cluster( clustering_latent = ( torch.cat(clustering_latent_collector).cpu().detach().numpy() ) - clustering_results = self.model.kmeans.fit_predict(clustering_latent) + clustering = self.model.kmeans.fit_predict(clustering_latent) latent_collector = { "clustering_latent": clustering_latent, "imputation": imputation, } + result_dict = { + "clustering": clustering, + } + + if return_latent: + result_dict["latent"] = latent_collector + + return result_dict + + def cluster( + self, + X: Union[dict, str], + file_type: str = "h5py", + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + result_dict = self.predict(X, file_type, return_latent) if return_latent: - return clustering_results, latent_collector + return result_dict["clustering"], result_dict["latent"] - return clustering_results + return result_dict["clustering"] diff --git a/pypots/clustering/template/model.py b/pypots/clustering/template/model.py index d9b40cff..e51beb6d 100644 --- a/pypots/clustering/template/model.py +++ b/pypots/clustering/template/model.py @@ -101,5 +101,9 @@ def fit( ) -> None: raise NotImplementedError - def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 9e9198f5..d6d910c5 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -604,14 +604,14 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def cluster( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", return_latent: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForVaDER(X, return_labels=False, file_type=file_type) + test_set = DatasetForVaDER(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -668,7 +668,7 @@ def func_to_apply( clustering_results = np.argmax(p, axis=0) clustering_results_collector.append(clustering_results) - clustering_results = np.concatenate(clustering_results_collector) + clustering = np.concatenate(clustering_results_collector) latent_collector = { "mu_tilde": np.concatenate(mu_tilde_collector), "stddev_tilde": np.concatenate(stddev_tilde_collector), @@ -679,7 +679,27 @@ def func_to_apply( "imputation": np.concatenate(imputed_X_collector), } + result_dict = { + "clustering": clustering, + } + + if return_latent: + result_dict["latent"] = latent_collector + + return result_dict + + def cluster( + self, + X: Union[dict, str], + file_type: str = "h5py", + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + result_dict = self.predict(X, file_type, return_latent) if return_latent: - return clustering_results, latent_collector + return result_dict["clustering"], result_dict["latent"] - return clustering_results + return result_dict["clustering"] diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 150aeacf..7189dabb 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -88,6 +88,14 @@ def fit( """ raise NotImplementedError + @abstractmethod + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + raise NotImplementedError + @abstractmethod def forecast( self, @@ -109,6 +117,8 @@ def forecast( array-like, shape [n_samples, prediction_horizon, n_features], Forecasting results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError @@ -380,4 +390,6 @@ def forecast( array-like, shape [n_samples, prediction_horizon, n_features], Forecasting results. """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. raise NotImplementedError diff --git a/pypots/forecasting/bttf/model.py b/pypots/forecasting/bttf/model.py index 57712ee1..3f5aa9e6 100644 --- a/pypots/forecasting/bttf/model.py +++ b/pypots/forecasting/bttf/model.py @@ -34,6 +34,7 @@ ar4cast, ) from ..base import BaseForecaster +from ...utils.logging import logger def _BTTF( @@ -358,16 +359,16 @@ def fit( """ warnings.warn("Please run func forecast(X) directly.") - def forecast( + def predict( self, - X: Union[dict, str], + test_set: Union[dict, str], file_type: str = "h5py", - ) -> np.ndarray: + ) -> dict: assert not isinstance( - X, str + test_set, str ), "BTTF so far does not accept file input. It needs a specified Dataset class." - X = X["X"] + X = test_set["X"] X = X.transpose((0, 2, 1)) pred = BTTF_forecast( @@ -380,5 +381,20 @@ def forecast( self.burn_iter, self.gibbs_iter, ) - pred = pred.transpose((0, 2, 1)) - return pred + forecasting = pred.transpose((0, 2, 1)) + result_dict = { + "forecasting": forecasting, + } + return result_dict + + def forecast( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + result_dict = self.predict(X, file_type=file_type) + forecasting = result_dict["forecasting"] + return forecasting diff --git a/pypots/forecasting/template/model.py b/pypots/forecasting/template/model.py index f761d795..3b665117 100644 --- a/pypots/forecasting/template/model.py +++ b/pypots/forecasting/template/model.py @@ -99,5 +99,9 @@ def fit( ) -> None: raise NotImplementedError - def forecast(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: raise NotImplementedError diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 4519b0de..8ee46619 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -396,3 +396,29 @@ def fit( """ raise NotImplementedError + + @abstractmethod + def impute( + self, + X: Union[dict, str], + file_type: str = "h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + # this is for old API compatibility, will be removed in the future. + # Please implement predict() instead. + raise NotImplementedError diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 7806871d..54a47331 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -27,8 +27,8 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer -from ...utils.metrics import cal_mae from ...utils.logging import logger +from ...utils.metrics import cal_mae class RITS(nn.Module): @@ -534,7 +534,7 @@ def predict( file_type: str = "h5py", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type) + test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -550,9 +550,9 @@ def predict( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 49fc58b5..56602c72 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -30,6 +30,7 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class _GPVAE(nn.Module): @@ -425,7 +426,7 @@ def predict( file_type="h5py", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForGPVAE(X, return_labels=False, file_type=file_type) + test_set = DatasetForGPVAE(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -441,9 +442,9 @@ def predict( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict diff --git a/pypots/imputation/locf/model.py b/pypots/imputation/locf/model.py index 37fea272..286cb0ec 100644 --- a/pypots/imputation/locf/model.py +++ b/pypots/imputation/locf/model.py @@ -13,6 +13,7 @@ import torch from ..base import BaseImputer +from ...utils.logging import logger class LOCF(BaseImputer): @@ -126,8 +127,8 @@ def predict( test_set: Union[dict, str], file_type: str = "h5py", ) -> dict: - assert not isinstance(X, str) - X = X["X"] + assert not isinstance(test_set, str) + X = test_set["X"] assert len(X.shape) == 3, ( f"Input X should have 3 dimensions [n_samples, n_steps, n_features], " diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index 115f0fbc..afdde3b4 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -21,6 +21,7 @@ from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_rmse @@ -292,7 +293,7 @@ def predict( file_type="h5py", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForMRNN(X, return_labels=False, file_type=file_type) + test_set = DatasetForMRNN(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -308,9 +309,9 @@ def predict( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 21346a9d..21e72cff 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -28,6 +28,7 @@ from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_mae @@ -447,7 +448,7 @@ def predict( ) -> dict: # Step 1: wrap the input data with classes Dataset and DataLoader self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X, return_labels=False, file_type=file_type) + test_set = BaseDataset(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -467,9 +468,9 @@ def predict( imputation_collector.append(imputed_data) # Step 3: output collection and return - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index f2ed2ea7..e7f9c02b 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -27,6 +27,7 @@ from ...modules.self_attention import EncoderLayer, PositionalEncoding from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger from ...utils.metrics import cal_mae @@ -369,7 +370,7 @@ def predict( file_type: str = "h5py", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X, return_labels=False, file_type=file_type) + test_set = BaseDataset(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -385,9 +386,9 @@ def predict( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index b0ed2d24..f1a6ab27 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -492,7 +492,7 @@ def predict( file_type="h5py", ) -> dict: self.model.eval() # set the model as eval status to freeze it. - test_set = DatasetForUSGAN(X, return_labels=False, file_type=file_type) + test_set = DatasetForUSGAN(test_set, return_labels=False, file_type=file_type) test_loader = DataLoader( test_set, batch_size=self.batch_size, @@ -508,9 +508,9 @@ def predict( imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) - imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy() + imputation = torch.cat(imputation_collector).cpu().detach().numpy() result_dict = { - "imputation": imputation_collector, + "imputation": imputation, } return result_dict