Skip to content

Commit

Permalink
feat: add predict() as a unified function for all models;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 2, 2023
1 parent f8e0d4a commit 5557d36
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 87 deletions.
141 changes: 113 additions & 28 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import os
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Optional, Union

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -209,6 +211,33 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None
if ("loss" in item_name) or ("error" in item_name):
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step)

def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
"""Automatically save the current model into a file if in need.
Parameters
----------
training_finished :
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.
saving_name :
The file name of the saved model.
"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return

def save_model(
self,
saving_dir: str,
Expand Down Expand Up @@ -258,33 +287,6 @@ def save_model(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)

def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
"""Automatically save the current model into a file if in need.
Parameters
----------
training_finished :
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.
saving_name :
The file name of the saved model.
"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return

def load_model(self, model_path: str) -> None:
"""Load the saved model from a disk file.
Expand Down Expand Up @@ -317,6 +319,72 @@ def load_model(self, model_path: str) -> None:
raise e
logger.info(f"Model loaded successfully from {model_path}.")

@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the classifier on the given data.
Parameters
----------
train_set : dict or str
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
val_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if train_set and val_set are path strings.
"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
raise NotImplementedError


class BaseNNModel(BaseModel):
"""The abstract class for all neural-network models.
Expand Down Expand Up @@ -400,7 +468,7 @@ def __init__(
else:
assert (
patience <= epochs
), f"patience must be smaller than epoches which is {epochs}, but got patience={patience}"
), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}"

# training hype-parameters
self.batch_size = batch_size
Expand All @@ -421,3 +489,20 @@ def _print_model_size(self) -> None:
logger.info(
f"Model initialized successfully with the number of trainable parameters: {num_params:,}"
)

@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError
34 changes: 10 additions & 24 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def fit(
"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError

@abstractmethod
def impute(
self,
Expand All @@ -117,6 +125,8 @@ def impute(
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError


Expand Down Expand Up @@ -386,27 +396,3 @@ def fit(
"""
raise NotImplementedError

@abstractmethod
def impute(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (time steps), n_features],
Imputed data.
"""
raise NotImplementedError
27 changes: 21 additions & 6 deletions pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.metrics import cal_mae
from ...utils.logging import logger


class RITS(nn.Module):
Expand Down Expand Up @@ -527,11 +528,11 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def impute(
def predict(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -549,5 +550,19 @@ def impute(
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

imputation_collector = torch.cat(imputation_collector)
return imputation_collector.cpu().detach().numpy()
imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation_collector,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
24 changes: 19 additions & 5 deletions pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,11 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def impute(
def predict(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGPVAE(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -441,5 +441,19 @@ def impute(
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

imputation_collector = torch.cat(imputation_collector)
return imputation_collector.cpu().detach().numpy()
imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation_collector,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
22 changes: 18 additions & 4 deletions pypots/imputation/locf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def _locf_torch(self, X: torch.Tensor) -> torch.Tensor:

return X_imputed

def impute(
def predict(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
) -> dict:
assert not isinstance(X, str)
X = X["X"]

Expand All @@ -145,4 +145,18 @@ def impute(
"X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}"
)

return imputed_data
result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
24 changes: 19 additions & 5 deletions pypots/imputation/mrnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def impute(
def predict(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForMRNN(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -308,5 +308,19 @@ def impute(
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

imputation_collector = torch.cat(imputation_collector)
return imputation_collector.cpu().detach().numpy()
imputation_collector = torch.cat(imputation_collector).cpu().detach().numpy()
result_dict = {
"imputation": imputation_collector,
}
return result_dict

def impute(
self,
X: Union[dict, str],
file_type="h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
return results_dict["imputation"]
Loading

0 comments on commit 5557d36

Please sign in to comment.