From 1984734f21f11ac2d2e56128d035e888bb4160f7 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 24 Sep 2023 10:39:17 +0800 Subject: [PATCH] feat: add cal_external_cluster_validation_metrics(); --- pypots/utils/metrics.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index a1972f21..ac239648 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -599,6 +599,37 @@ def cal_cluster_purity( return cluster_purity +def cal_external_cluster_validation_metrics(class_predictions, targets): + """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + external_cluster_validation_metrics : dict + A dictionary contains all external cluster validation metrics available in PyPOTS. + """ + + ri = cal_rand_index(class_predictions, targets) + ari = cal_adjusted_rand_index(class_predictions, targets) + nmi = cal_nmi(class_predictions, targets) + cp = cal_cluster_purity(class_predictions, targets) + + external_cluster_validation_metrics = { + "rand_index": ri, + "adjusted_rand_index": ari, + "nmi": nmi, + "cluster_purity": cp, + } + return external_cluster_validation_metrics + + def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the mean Silhouette Coefficient of all samples.