Skip to content

Commit

Permalink
feat: add predict() for all models;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 3, 2023
1 parent 5557d36 commit 04c687f
Show file tree
Hide file tree
Showing 20 changed files with 280 additions and 65 deletions.
12 changes: 12 additions & 0 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def fit(
"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError

@abstractmethod
def classify(
self,
Expand All @@ -117,6 +125,8 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError


Expand Down Expand Up @@ -402,4 +412,6 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError
32 changes: 26 additions & 6 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -29,6 +30,7 @@
)
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class _BRITS(imputation_BRITS, nn.Module):
Expand Down Expand Up @@ -326,23 +328,41 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py"):
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(X, return_labels=False, file_type=file_type)
test_set = DatasetForBRITS(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
prediction_collector = []
classification_collector = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
classification_pred = results["classification_pred"]
prediction_collector.append(classification_pred)
classification_collector.append(classification_pred)

classification = torch.cat(classification_collector).cpu().detach().numpy()
result_dict = {
"classification": classification,
}
return result_dict

predictions = torch.cat(prediction_collector)
return predictions.cpu().detach().numpy()
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)
return result_dict["classification"]
31 changes: 25 additions & 6 deletions pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...imputation.brits.modules import TemporalDecay
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class _GRUD(nn.Module):
Expand Down Expand Up @@ -303,23 +304,41 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
prediction_collector = []
classification_collector = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
prediction = results["classification_pred"]
prediction_collector.append(prediction)
classification_collector.append(prediction)

classification = torch.cat(classification_collector).cpu().detach().numpy()
result_dict = {
"classification": classification,
}
return result_dict

predictions = torch.cat(prediction_collector)
return predictions.cpu().detach().numpy()
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)
return result_dict["classification"]
31 changes: 25 additions & 6 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,23 +519,42 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGRUD(X, return_labels=False, file_type=file_type)
test_set = DatasetForGRUD(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

prediction_collector = []
classification_collector = []
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
prediction = results["classification_pred"]
prediction_collector.append(prediction)
classification_collector.append(prediction)

classification = torch.cat(classification_collector).cpu().detach().numpy()

result_dict = {
"classification": classification,
}
return result_dict

predictions = torch.cat(prediction_collector)
return predictions.cpu().detach().numpy()
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)
return result_dict["classification"]
6 changes: 5 additions & 1 deletion pypots/classification/template/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,9 @@ def fit(
) -> None:
raise NotImplementedError

def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError
38 changes: 36 additions & 2 deletions pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
def fit(
self,
train_set: Union[dict, str],
val_set: Union[dict, str] = None,
file_type: str = "h5py",
) -> None:
"""Train the cluster.
Expand All @@ -78,11 +79,29 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include the key 'X'.
val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type :
The type of the given file if train_set is a path string.
The type of the given file if train_set and val_set are path strings.
"""
raise NotImplementedError

@abstractmethod
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError

@abstractmethod
def cluster(
self,
Expand All @@ -105,6 +124,8 @@ def cluster(
array-like,
Clustering results.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError


Expand Down Expand Up @@ -332,6 +353,7 @@ def _train_model(
def fit(
self,
train_set: Union[dict, str],
val_set: Union[dict, str] = None,
file_type: str = "h5py",
) -> None:
"""Train the cluster.
Expand All @@ -346,8 +368,18 @@ def fit(
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include the key 'X'.
val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type :
The type of the given file if train_set is a path string.
The type of the given file if train_set and val_set are path strings.
"""
raise NotImplementedError

Expand All @@ -373,4 +405,6 @@ def cluster(
array-like,
Clustering results.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.
raise NotImplementedError
34 changes: 27 additions & 7 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,14 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def cluster(
def predict(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
) -> dict:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type)
test_set = DatasetForCRLI(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
test_set,
batch_size=self.batch_size,
Expand All @@ -443,13 +443,33 @@ def cluster(
clustering_latent = (
torch.cat(clustering_latent_collector).cpu().detach().numpy()
)
clustering_results = self.model.kmeans.fit_predict(clustering_latent)
clustering = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation": imputation,
}

result_dict = {
"clustering": clustering,
}

if return_latent:
result_dict["latent"] = latent_collector

return result_dict

def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)

result_dict = self.predict(X, file_type, return_latent)
if return_latent:
return clustering_results, latent_collector
return result_dict["clustering"], result_dict["latent"]

return clustering_results
return result_dict["clustering"]
6 changes: 5 additions & 1 deletion pypots/clustering/template/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,9 @@ def fit(
) -> None:
raise NotImplementedError

def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
) -> dict:
raise NotImplementedError
Loading

0 comments on commit 04c687f

Please sign in to comment.