From 6b9fbc2ab8c824f8c83fa34430ca056a904bdebc Mon Sep 17 00:00:00 2001 From: Otto von Sperling Date: Tue, 23 Jun 2020 08:38:37 -0300 Subject: [PATCH] (Refactored) Guarantee that labels are aligned Minor addition of keyword argument 'labels' to calling scikit-learn's confusion_matrix will guarantee that classes are correctly aligned with the axes' labels. It fixes the issue of using columns generated in line 216 as indexers for the y_test and predictions sets. Instead, it creates the confusion matrix with labels from the input sets and later reorders the matrix if columns were passed as arguments. --- confusion_matrix_pretty_print.py | 35 ++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/confusion_matrix_pretty_print.py b/confusion_matrix_pretty_print.py index eb3158b..d6beb3a 100644 --- a/confusion_matrix_pretty_print.py +++ b/confusion_matrix_pretty_print.py @@ -212,20 +212,29 @@ def plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=Tru from sklearn.metrics import confusion_matrix from pandas import DataFrame - #data - if(not columns): - #labels axis integer: - ##columns = range(1, len(np.unique(y_test))+1) - #labels axis string: - from string import ascii_uppercase - columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:len(np.unique(y_test))]] - - confm = confusion_matrix(y_test, predictions) - cmap = 'Oranges'; - fz = 11; - figsize=[9,9]; + # data + if not columns: + generated_columns = [ + "class {0}".format(str(x).upper()) for x in np.unique(y_test) + ] + + confm = confusion_matrix(y_test, predictions, labels=np.unique(y_test)) + cmap = "Oranges" + fz = 11 + figsize = [9, 9] show_null_values = 2 - df_cm = DataFrame(confm, index=columns, columns=columns) + + df_cm = pd.DataFrame( + confm, + index=(np.unique(y_test) if not columns is None else generated_columns), + columns=(np.unique(y_test) if not columns is None else generated_columns), + ) + + # reorders confusion matrix to respect columns passed as arguments + if columns: + df_cm = df_cm.reindex(index=columns, columns=columns) + df_cm = df_cm.fillna(0.0) + pretty_plot_confusion_matrix(df_cm, fz=fz, cmap=cmap, figsize=figsize, show_null_values=show_null_values, pred_val_axis=pred_val_axis) #