From 6cf2caaceb1b13035108119dd38a35d2579cf199 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 23 Sep 2023 19:50:38 +0800 Subject: [PATCH] feat: add cal_nmi(); --- pypots/utils/metrics.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index cc349b50..a1972f21 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -543,6 +543,31 @@ def cal_adjusted_rand_index( return aRI +def cal_nmi( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate Normalized Mutual Information between two clusterings. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + NMI : float, + Normalized Mutual Information + + + """ + NMI = metrics.normalized_mutual_info_score(targets, class_predictions) + return NMI + + def cal_cluster_purity( class_predictions: np.ndarray, targets: np.ndarray,