Skip to content

Commit

Permalink
feat: add cal_nmi();
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 23, 2023
1 parent 8475c96 commit 6cf2caa
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions pypots/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,31 @@ def cal_adjusted_rand_index(
return aRI


def cal_nmi(
class_predictions: np.ndarray,
targets: np.ndarray,
) -> float:
"""Calculate Normalized Mutual Information between two clusterings.
Parameters
----------
class_predictions :
Clustering results returned by a clusterer.
targets :
Ground truth (correct) clustering results.
Returns
-------
NMI : float,
Normalized Mutual Information
"""
NMI = metrics.normalized_mutual_info_score(targets, class_predictions)
return NMI


def cal_cluster_purity(
class_predictions: np.ndarray,
targets: np.ndarray,
Expand Down

0 comments on commit 6cf2caa

Please sign in to comment.