From 2f40ee33a586fded9017320b5cf02b581b6a9273 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 24 Sep 2023 18:51:30 +0800 Subject: [PATCH 1/3] feat: invoke cal_external_cluster_validation_metrics() in testing cases; --- tests/clustering/crli.py | 8 +++++--- tests/clustering/vader.py | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 923911fd..44c07673 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -15,6 +15,7 @@ from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import cal_rand_index, cal_cluster_purity +from pypots.utils.visualization import plot_clustering_results from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -75,9 +76,10 @@ def test_1_parameters(self): @pytest.mark.xdist_group(name="clustering-crli") def test_2_cluster(self): clustering = self.crli.cluster(TEST_SET) - RI = cal_rand_index(clustering, DATA["test_y"]) - CP = cal_cluster_purity(clustering, DATA["test_y"]) - logger.info(f"RI: {RI}\nCP: {CP}") + external_metrics = cal_external_cluster_validation_metrics( + clustering, DATA["test_y"] + ) + logger.info(f"{external_metrics}") @pytest.mark.xdist_group(name="clustering-crli") def test_3_saving_path(self): diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index 71a6a91d..cd7a8f72 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -15,7 +15,7 @@ from pypots.clustering import VaDER from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_rand_index, cal_cluster_purity +from pypots.utils.metrics import cal_external_cluster_validation_metrics from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -61,9 +61,10 @@ def test_0_fit(self): def test_1_cluster(self): try: clustering = self.vader.cluster(TEST_SET) - RI = cal_rand_index(clustering, DATA["test_y"]) - CP = cal_cluster_purity(clustering, DATA["test_y"]) - logger.info(f"RI: {RI}\nCP: {CP}") + external_metrics = cal_external_cluster_validation_metrics( + clustering, DATA["test_y"] + ) + logger.info(f"{external_metrics}") except np.linalg.LinAlgError as e: logger.error( f"{e}\n" From d5ed0202d7a00bf12102363fa870f679618f32f5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 26 Sep 2023 13:46:45 +0800 Subject: [PATCH 2/3] feat: enable two clustering models to return latent for advanced analysis; --- pypots/clustering/crli/model.py | 27 ++++++++++++++------ pypots/clustering/vader/model.py | 42 +++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index ac64eb47..26f0a769 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -10,7 +10,7 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Union, Optional +from typing import Union, Optional, Tuple import numpy as np import torch @@ -419,7 +419,8 @@ def cluster( self, X: Union[dict, str], file_type: str = "h5py", - ) -> np.ndarray: + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: self.model.eval() # set the model as eval status to freeze it. test_set = DatasetForCRLI(X, return_labels=False, file_type=file_type) test_loader = DataLoader( @@ -428,15 +429,27 @@ def cluster( shuffle=False, num_workers=self.num_workers, ) - latent_collector = [] + clustering_latent_collector = [] + imputation_collector = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) inputs = self.model.forward(inputs, training=False) - latent_collector.append(inputs["fcn_latent"]) + clustering_latent_collector.append(inputs["fcn_latent"]) + imputation_collector.append(inputs["imputation"]) - latent_collector = torch.cat(latent_collector).cpu().detach().numpy() - clustering = self.model.kmeans.fit_predict(latent_collector) + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + clustering_latent = ( + torch.cat(clustering_latent_collector).cpu().detach().numpy() + ) + clustering_results = self.model.kmeans.fit_predict(clustering_latent) + latent_collector = { + "clustering_latent": clustering_latent, + "imputation": imputation, + } + + if return_latent: + return clustering_results, latent_collector - return clustering + return clustering_results diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 2ed8c035..7c85ad13 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -182,13 +182,17 @@ def forward( mu_tilde, stddev_tilde, ) = self.get_results(X, missing_mask) + imputed_X = X_reconstructed * (1 - missing_mask) + X * missing_mask if not training and not pretrain: results = { "mu_tilde": mu_tilde, + "stddev_tilde": stddev_tilde, "mu": mu_c, "var": var_c, "phi": phi_c, + "z": z, + "imputed_X": imputed_X, } # if only run clustering, then no need to calculate loss return results @@ -260,6 +264,7 @@ def forward( results = { "loss": reconstruction_loss + self.alpha * latent_loss, "z": z, + "imputed_X": imputed_X, } return results @@ -597,7 +602,12 @@ def fit( # Step 3: save the model if necessary self._auto_save_model_if_necessary(training_finished=True) - def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: + def cluster( + self, + X: Union[dict, str], + file_type: str = "h5py", + return_latent: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: self.model.eval() # set the model as eval status to freeze it. test_set = DatasetForVaDER(X, return_labels=False, file_type=file_type) test_loader = DataLoader( @@ -606,6 +616,13 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: shuffle=False, num_workers=self.num_workers, ) + mu_tilde_collector = [] + stddev_tilde_collector = [] + mu_collector = [] + var_collector = [] + phi_collector = [] + z_collector = [] + imputed_X_collector = [] clustering_results_collector = [] with torch.no_grad(): @@ -614,9 +631,19 @@ def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray: results = self.model.forward(inputs, training=False) mu_tilde = results["mu_tilde"].cpu().numpy() + mu_tilde_collector.append(mu_tilde) + stddev_tilde = results["stddev_tilde"].cpu().numpy() + stddev_tilde_collector.append(stddev_tilde) mu = results["mu"].cpu().numpy() + mu_collector.append(mu) var = results["var"].cpu().numpy() + var_collector.append(var) phi = results["phi"].cpu().numpy() + phi_collector.append(phi) + z = results["z"].cpu().numpy() + z_collector.append(z) + imputed_X = results["imputed_X"].cpu().numpy() + imputed_X_collector.append(imputed_X) def func_to_apply( mu_t_: np.ndarray, @@ -640,4 +667,17 @@ def func_to_apply( clustering_results_collector.append(clustering_results) clustering_results = np.concatenate(clustering_results_collector) + latent_collector = { + "mu_tilde": np.concatenate(mu_tilde_collector), + "stddev_tilde": np.concatenate(stddev_tilde_collector), + "mu": np.concatenate(mu_collector), + "var": np.concatenate(var_collector), + "phi": np.concatenate(phi_collector), + "z": np.concatenate(z_collector), + "imputation": np.concatenate(imputed_X_collector), + } + + if return_latent: + return clustering_results, latent_collector + return clustering_results From 8dfd1892312471ae8ee407d2e1654c3eec4e53b0 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 26 Sep 2023 13:47:18 +0800 Subject: [PATCH 3/3] feat: invoke cal_internal_cluster_validation_metrics() in testing cases; --- tests/clustering/crli.py | 12 +++++++++--- tests/clustering/vader.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 44c07673..191f58c8 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -14,8 +14,10 @@ from pypots.clustering import CRLI from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_rand_index, cal_cluster_purity -from pypots.utils.visualization import plot_clustering_results +from pypots.utils.metrics import ( + cal_external_cluster_validation_metrics, + cal_internal_cluster_validation_metrics, +) from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -75,11 +77,15 @@ def test_1_parameters(self): @pytest.mark.xdist_group(name="clustering-crli") def test_2_cluster(self): - clustering = self.crli.cluster(TEST_SET) + clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["clustering_latent"], DATA["test_y"] + ) logger.info(f"{external_metrics}") + logger.info(f"{internal_metrics}") @pytest.mark.xdist_group(name="clustering-crli") def test_3_saving_path(self): diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index cd7a8f72..a76b61e8 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -15,7 +15,10 @@ from pypots.clustering import VaDER from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_external_cluster_validation_metrics +from pypots.utils.metrics import ( + cal_external_cluster_validation_metrics, + cal_internal_cluster_validation_metrics, +) from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -60,11 +63,17 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="clustering-vader") def test_1_cluster(self): try: - clustering = self.vader.cluster(TEST_SET) + clustering, latent_collector = self.vader.cluster( + TEST_SET, return_latent=True + ) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["z"], DATA["test_y"] + ) logger.info(f"{external_metrics}") + logger.info(f"{internal_metrics}") except np.linalg.LinAlgError as e: logger.error( f"{e}\n"