diff --git a/pretty_confusion_matrix/pretty_confusion_matrix.py b/pretty_confusion_matrix/pretty_confusion_matrix.py index 75e8b62..14ef372 100644 --- a/pretty_confusion_matrix/pretty_confusion_matrix.py +++ b/pretty_confusion_matrix/pretty_confusion_matrix.py @@ -152,6 +152,7 @@ def pp_matrix( figsize=[8, 8], show_null_values=0, pred_val_axis="y", + categories=[] ): """ print conf matrix with default layout (like matlab) @@ -163,7 +164,8 @@ def pp_matrix( lw linewidth pred_val_axis where to show the prediction values (x or y axis) 'col' or 'x': show predicted values in columns (x axis) instead lines - 'lin' or 'y': show predicted values in lines (y axis) + 'lin' or 'y': show predicted values in lines (y axis + categories: parameter to specify string values of class names instead of just 0, 1, 2, etc. """ if pred_val_axis in ("col", "x"): xlbl = "Predicted" @@ -189,6 +191,8 @@ def pp_matrix( cmap=cmap, linecolor="w", fmt=fmt, + xticklabels=categories, + yticklabels=categories ) # set ticklabels rotation