Skip to content

Commit

Permalink
Merge pull request #193 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add internal and external cluster validation funcs, and enable CRLI and VaDER to return latent for advanced analysis
  • Loading branch information
WenjieDu authored Sep 26, 2023
2 parents f7cf793 + 28b9fdc commit 09b494d
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(
self,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
n_classes: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
reconstruction_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
static=False,
batch_size=32,
epochs=100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
n_clusters: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
31 changes: 23 additions & 8 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from typing import Union, Optional
from typing import Union, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -50,7 +50,9 @@ def __init__(
n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims, device
) # fully connected network is included in Decoder
self.kmeans = KMeans(
n_clusters=n_clusters
n_clusters=n_clusters,
n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the
# value of `n_init` explicitly to suppress the warning.
) # TODO: implement KMean with torch for gpu acceleration

self.n_clusters = n_clusters
Expand Down Expand Up @@ -417,7 +419,8 @@ def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -426,15 +429,27 @@ def cluster(
shuffle=False,
num_workers=self.num_workers,
)
latent_collector = []
clustering_latent_collector = []
imputation_collector = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, training=False)
latent_collector.append(inputs["fcn_latent"])
clustering_latent_collector.append(inputs["fcn_latent"])
imputation_collector.append(inputs["imputation"])

latent_collector = torch.cat(latent_collector).cpu().detach().numpy()
clustering = self.model.kmeans.fit_predict(latent_collector)
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
clustering_latent = (
torch.cat(clustering_latent_collector).cpu().detach().numpy()
)
clustering_results = self.model.kmeans.fit_predict(clustering_latent)
latent_collector = {
"clustering_latent": clustering_latent,
"imputation": imputation,
}

if return_latent:
return clustering_results, latent_collector

return clustering
return clustering_results
46 changes: 43 additions & 3 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ def forward(
mu_tilde,
stddev_tilde,
) = self.get_results(X, missing_mask)
imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask

if not training and not pretrain:
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputed_X": imputed_X,
}
# if only run clustering, then no need to calculate loss
return results
Expand Down Expand Up @@ -260,6 +264,7 @@ def forward(
results = {
"loss": reconstruction_loss + self.alpha * latent_loss,
"z": z,
"imputed_X": imputed_X,
}

return results
Expand Down Expand Up @@ -343,7 +348,7 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
pretrain_epochs: int = 10,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down Expand Up @@ -462,7 +467,7 @@ def _train_model(
"Now quit to let you check your model training.\n"
"Please raise an issue https://github.com/WenjieDu/PyPOTS/issues if you have questions."
)
exit()
raise RuntimeError
else:
reg_covar *= 2

Expand Down Expand Up @@ -597,7 +602,12 @@ def fit(
# Step 3: save the model if necessary
self._auto_save_model_if_necessary(training_finished=True)

def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
def cluster(
self,
X: Union[dict, str],
file_type: str = "h5py",
return_latent: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForVaDER(X, return_labels=False, file_type=file_type)
test_loader = DataLoader(
Expand All @@ -606,6 +616,13 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
shuffle=False,
num_workers=self.num_workers,
)
mu_tilde_collector = []
stddev_tilde_collector = []
mu_collector = []
var_collector = []
phi_collector = []
z_collector = []
imputed_X_collector = []
clustering_results_collector = []

with torch.no_grad():
Expand All @@ -614,9 +631,19 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
results = self.model.forward(inputs, training=False)

mu_tilde = results["mu_tilde"].cpu().numpy()
mu_tilde_collector.append(mu_tilde)
stddev_tilde = results["stddev_tilde"].cpu().numpy()
stddev_tilde_collector.append(stddev_tilde)
mu = results["mu"].cpu().numpy()
mu_collector.append(mu)
var = results["var"].cpu().numpy()
var_collector.append(var)
phi = results["phi"].cpu().numpy()
phi_collector.append(phi)
z = results["z"].cpu().numpy()
z_collector.append(z)
imputed_X = results["imputed_X"].cpu().numpy()
imputed_X_collector.append(imputed_X)

def func_to_apply(
mu_t_: np.ndarray,
Expand All @@ -640,4 +667,17 @@ def func_to_apply(
clustering_results_collector.append(clustering_results)

clustering_results = np.concatenate(clustering_results_collector)
latent_collector = {
"mu_tilde": np.concatenate(mu_tilde_collector),
"stddev_tilde": np.concatenate(stddev_tilde_collector),
"mu": np.concatenate(mu_collector),
"var": np.concatenate(var_collector),
"phi": np.concatenate(phi_collector),
"z": np.concatenate(z_collector),
"imputation": np.concatenate(imputed_X_collector),
}

if return_latent:
return clustering_results, latent_collector

return clustering_results
2 changes: 1 addition & 1 deletion pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
# n_forecasting_steps: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
self,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(
window_size: int = 3,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/mrnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
MIT_weight: int = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
56 changes: 56 additions & 0 deletions pypots/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,31 @@ def cal_adjusted_rand_index(
return aRI


def cal_nmi(
class_predictions: np.ndarray,
targets: np.ndarray,
) -> float:
"""Calculate Normalized Mutual Information between two clusterings.
Parameters
----------
class_predictions :
Clustering results returned by a clusterer.
targets :
Ground truth (correct) clustering results.
Returns
-------
NMI : float,
Normalized Mutual Information
"""
NMI = metrics.normalized_mutual_info_score(targets, class_predictions)
return NMI


def cal_cluster_purity(
class_predictions: np.ndarray,
targets: np.ndarray,
Expand Down Expand Up @@ -574,6 +599,37 @@ def cal_cluster_purity(
return cluster_purity


def cal_external_cluster_validation_metrics(class_predictions, targets):
"""Computer all external cluster validation metrics available in PyPOTS and return as a dictionary.
Parameters
----------
class_predictions :
Clustering results returned by a clusterer.
targets :
Ground truth (correct) clustering results.
Returns
-------
external_cluster_validation_metrics : dict
A dictionary contains all external cluster validation metrics available in PyPOTS.
"""

ri = cal_rand_index(class_predictions, targets)
ari = cal_adjusted_rand_index(class_predictions, targets)
nmi = cal_nmi(class_predictions, targets)
cp = cal_cluster_purity(class_predictions, targets)

external_cluster_validation_metrics = {
"rand_index": ri,
"adjusted_rand_index": ari,
"nmi": nmi,
"cluster_purity": cp,
}
return external_cluster_validation_metrics


def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float:
"""Compute the mean Silhouette Coefficient of all samples.
Expand Down
18 changes: 13 additions & 5 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pypots.clustering import CRLI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import cal_rand_index, cal_cluster_purity
from pypots.utils.metrics import (
cal_external_cluster_validation_metrics,
cal_internal_cluster_validation_metrics,
)
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
Expand Down Expand Up @@ -74,10 +77,15 @@ def test_1_parameters(self):

@pytest.mark.xdist_group(name="clustering-crli")
def test_2_cluster(self):
clustering = self.crli.cluster(TEST_SET)
RI = cal_rand_index(clustering, DATA["test_y"])
CP = cal_cluster_purity(clustering, DATA["test_y"])
logger.info(f"RI: {RI}\nCP: {CP}")
clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True)
external_metrics = cal_external_cluster_validation_metrics(
clustering, DATA["test_y"]
)
internal_metrics = cal_internal_cluster_validation_metrics(
latent_collector["clustering_latent"], DATA["test_y"]
)
logger.info(f"{external_metrics}")
logger.info(f"{internal_metrics}")

@pytest.mark.xdist_group(name="clustering-crli")
def test_3_saving_path(self):
Expand Down
Loading

0 comments on commit 09b494d

Please sign in to comment.