From 8dfd1892312471ae8ee407d2e1654c3eec4e53b0 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 26 Sep 2023 13:47:18 +0800 Subject: [PATCH] feat: invoke cal_internal_cluster_validation_metrics() in testing cases; --- tests/clustering/crli.py | 12 +++++++++--- tests/clustering/vader.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 44c07673..191f58c8 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -14,8 +14,10 @@ from pypots.clustering import CRLI from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_rand_index, cal_cluster_purity -from pypots.utils.visualization import plot_clustering_results +from pypots.utils.metrics import ( + cal_external_cluster_validation_metrics, + cal_internal_cluster_validation_metrics, +) from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -75,11 +77,15 @@ def test_1_parameters(self): @pytest.mark.xdist_group(name="clustering-crli") def test_2_cluster(self): - clustering = self.crli.cluster(TEST_SET) + clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["clustering_latent"], DATA["test_y"] + ) logger.info(f"{external_metrics}") + logger.info(f"{internal_metrics}") @pytest.mark.xdist_group(name="clustering-crli") def test_3_saving_path(self): diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index cd7a8f72..a76b61e8 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -15,7 +15,10 @@ from pypots.clustering import VaDER from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_external_cluster_validation_metrics +from pypots.utils.metrics import ( + cal_external_cluster_validation_metrics, + cal_internal_cluster_validation_metrics, +) from tests.clustering.config import ( EPOCHS, TRAIN_SET, @@ -60,11 +63,17 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="clustering-vader") def test_1_cluster(self): try: - clustering = self.vader.cluster(TEST_SET) + clustering, latent_collector = self.vader.cluster( + TEST_SET, return_latent=True + ) external_metrics = cal_external_cluster_validation_metrics( clustering, DATA["test_y"] ) + internal_metrics = cal_internal_cluster_validation_metrics( + latent_collector["z"], DATA["test_y"] + ) logger.info(f"{external_metrics}") + logger.info(f"{internal_metrics}") except np.linalg.LinAlgError as e: logger.error( f"{e}\n"