diff --git a/pypots/utils/metrics.py b/pypots/utils/metrics.py index cc349b50..a1972f21 100644 --- a/pypots/utils/metrics.py +++ b/pypots/utils/metrics.py @@ -543,6 +543,31 @@ def cal_adjusted_rand_index( return aRI +def cal_nmi( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> float: + """Calculate Normalized Mutual Information between two clusterings. + + Parameters + ---------- + class_predictions : + Clustering results returned by a clusterer. + + targets : + Ground truth (correct) clustering results. + + Returns + ------- + NMI : float, + Normalized Mutual Information + + + """ + NMI = metrics.normalized_mutual_info_score(targets, class_predictions) + return NMI + + def cal_cluster_purity( class_predictions: np.ndarray, targets: np.ndarray,