Skip to content

Commit

Permalink
feat: invoke cal_internal_cluster_validation_metrics() in testing cases;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 26, 2023
1 parent d5ed020 commit 8dfd189
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
12 changes: 9 additions & 3 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from pypots.clustering import CRLI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import cal_rand_index, cal_cluster_purity
from pypots.utils.visualization import plot_clustering_results
from pypots.utils.metrics import (
cal_external_cluster_validation_metrics,
cal_internal_cluster_validation_metrics,
)
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
Expand Down Expand Up @@ -75,11 +77,15 @@ def test_1_parameters(self):

@pytest.mark.xdist_group(name="clustering-crli")
def test_2_cluster(self):
clustering = self.crli.cluster(TEST_SET)
clustering, latent_collector = self.crli.cluster(TEST_SET, return_latent=True)
external_metrics = cal_external_cluster_validation_metrics(
clustering, DATA["test_y"]
)
internal_metrics = cal_internal_cluster_validation_metrics(
latent_collector["clustering_latent"], DATA["test_y"]
)
logger.info(f"{external_metrics}")
logger.info(f"{internal_metrics}")

@pytest.mark.xdist_group(name="clustering-crli")
def test_3_saving_path(self):
Expand Down
13 changes: 11 additions & 2 deletions tests/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from pypots.clustering import VaDER
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import cal_external_cluster_validation_metrics
from pypots.utils.metrics import (
cal_external_cluster_validation_metrics,
cal_internal_cluster_validation_metrics,
)
from tests.clustering.config import (
EPOCHS,
TRAIN_SET,
Expand Down Expand Up @@ -60,11 +63,17 @@ def test_0_fit(self):
@pytest.mark.xdist_group(name="clustering-vader")
def test_1_cluster(self):
try:
clustering = self.vader.cluster(TEST_SET)
clustering, latent_collector = self.vader.cluster(
TEST_SET, return_latent=True
)
external_metrics = cal_external_cluster_validation_metrics(
clustering, DATA["test_y"]
)
internal_metrics = cal_internal_cluster_validation_metrics(
latent_collector["z"], DATA["test_y"]
)
logger.info(f"{external_metrics}")
logger.info(f"{internal_metrics}")
except np.linalg.LinAlgError as e:
logger.error(
f"{e}\n"
Expand Down

0 comments on commit 8dfd189

Please sign in to comment.