Skip to content

Commit

Permalink
Merge pull request #239 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Make SAITS return attention weights in predict()
  • Loading branch information
WenjieDu authored Nov 21, 2023
2 parents c962530 + 545b5ee commit 6e29488
Show file tree
Hide file tree
Showing 23 changed files with 509 additions and 91 deletions.
12 changes: 6 additions & 6 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,17 @@ def fit(
Parameters
----------
train_set : dict or str
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
The dataset for model training, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
val_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
Expand All @@ -364,8 +364,8 @@ def predict(
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
Expand Down
4 changes: 4 additions & 0 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def classify(
) -> np.ndarray:
"""Classify the input data with the trained model.
Warnings
--------
The method classify is deprecated. Please use `predict()` instead.
Parameters
----------
X :
Expand Down
20 changes: 20 additions & 0 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,26 @@ def classify(
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Classify the input data with the trained model.
Warnings
--------
The method classify is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
Expand Down
20 changes: 20 additions & 0 deletions pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,26 @@ def classify(
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Classify the input data with the trained model.
Warnings
--------
The method classify is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
Expand Down
20 changes: 20 additions & 0 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,26 @@ def classify(
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Classify the input data with the trained model.
Warnings
--------
The method classify is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
Expand Down
4 changes: 4 additions & 0 deletions pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def cluster(
) -> np.ndarray:
"""Cluster the input with the trained model.
Warnings
--------
The method cluster is deprecated. Please use `predict()` instead.
Parameters
----------
X :
Expand Down
77 changes: 61 additions & 16 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# License: BSD-3-Clause

import os
from typing import Union, Optional, Tuple
from typing import Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -381,8 +381,35 @@ def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
return_latent_vars: bool = False,
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
return_latent_vars : bool
Whether to return the latent variables in CRLI, e.g. latent representation from the fully connected network
in CRLI, etc.
Returns
-------
result_dict : dict,
The dictionary containing the clustering results and latent variables if necessary.
"""

self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForCRLI(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -399,39 +426,57 @@ def predict(
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, return_loss=False)
clustering_latent_collector.append(inputs["fcn_latent"])
imputation_collector.append(inputs["imputation_latent"])
if return_latent_vars:
imputation_collector.append(inputs["imputation_latent"])

imputation = torch.cat(imputation_collector).cpu().detach().numpy()
clustering_latent = (
torch.cat(clustering_latent_collector).cpu().detach().numpy()
)
clustering = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation_latent": imputation,
}

result_dict = {
"clustering": clustering,
}

if return_latent:
result_dict["latent"] = latent_collector
if return_latent_vars:
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
latent_var_collector = {
"clustering_latent": clustering_latent,
"imputation_latent": imputation,
}
result_dict["latent_vars"] = latent_var_collector

return result_dict

def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
) -> np.ndarray:
"""Cluster the input with the trained model.
Warnings
--------
The method cluster is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like,
Clustering results.
"""
logger.warning(
"🚨DeprecationWarning: The method cluster is deprecated. Please use `predict` instead."
)

result_dict = self.predict(X, file_type, return_latent)
if return_latent:
return result_dict["clustering"], result_dict["latent"]

result_dict = self.predict(X, file_type)
return result_dict["clustering"]
95 changes: 69 additions & 26 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


import os
from typing import Tuple, Union, Optional
from typing import Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -394,8 +394,33 @@ def predict(
self,
test_set: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
return_latent_vars: bool = False,
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
return_latent_vars : bool
Whether to return the latent variables in VaDER, e.g. mu and phi, etc.
Returns
-------
result_dict : dict,
The dictionary containing the clustering results and latent variables if necessary.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForVaDER(test_set, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -420,18 +445,12 @@ def predict(

mu_tilde = results["mu_tilde"].cpu().numpy()
mu_tilde_collector.append(mu_tilde)
stddev_tilde = results["stddev_tilde"].cpu().numpy()
stddev_tilde_collector.append(stddev_tilde)
mu = results["mu"].cpu().numpy()
mu_collector.append(mu)
var = results["var"].cpu().numpy()
var_collector.append(var)
phi = results["phi"].cpu().numpy()
phi_collector.append(phi)
z = results["z"].cpu().numpy()
z_collector.append(z)
imputation_latent = results["imputation_latent"].cpu().numpy()
imputation_latent_collector.append(imputation_latent)

def func_to_apply(
mu_t_: np.ndarray,
Expand All @@ -454,38 +473,62 @@ def func_to_apply(
clustering_results = np.argmax(p, axis=0)
clustering_results_collector.append(clustering_results)

clustering = np.concatenate(clustering_results_collector)
latent_collector = {
"mu_tilde": np.concatenate(mu_tilde_collector),
"stddev_tilde": np.concatenate(stddev_tilde_collector),
"mu": np.concatenate(mu_collector),
"var": np.concatenate(var_collector),
"phi": np.concatenate(phi_collector),
"z": np.concatenate(z_collector),
"imputation_latent": np.concatenate(imputation_latent_collector),
}
if return_latent_vars:
stddev_tilde = results["stddev_tilde"].cpu().numpy()
stddev_tilde_collector.append(stddev_tilde)
z = results["z"].cpu().numpy()
z_collector.append(z)
imputation_latent = results["imputation_latent"].cpu().numpy()
imputation_latent_collector.append(imputation_latent)

clustering = np.concatenate(clustering_results_collector)
result_dict = {
"clustering": clustering,
}

if return_latent:
result_dict["latent"] = latent_collector
if return_latent_vars:
latent_var_collector = {
"mu_tilde": np.concatenate(mu_tilde_collector),
"stddev_tilde": np.concatenate(stddev_tilde_collector),
"mu": np.concatenate(mu_collector),
"var": np.concatenate(var_collector),
"phi": np.concatenate(phi_collector),
"z": np.concatenate(z_collector),
"imputation_latent": np.concatenate(imputation_latent_collector),
}
result_dict["latent_vars"] = latent_var_collector

return result_dict

def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
) -> Union[np.ndarray]:
"""Cluster the input with the trained model.
Warnings
--------
The method cluster is deprecated. Please use `predict()` instead.
Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like,
Clustering results.
"""
logger.warning(
"🚨DeprecationWarning: The method cluster is deprecated. Please use `predict` instead."
)

result_dict = self.predict(X, file_type, return_latent)
if return_latent:
return result_dict["clustering"], result_dict["latent"]

result_dict = self.predict(X, file_type)
return result_dict["clustering"]
Loading

0 comments on commit 6e29488

Please sign in to comment.