diff --git a/confusion_matrix_pretty_print.py b/confusion_matrix_pretty_print.py index eb3158b..d6beb3a 100644 --- a/confusion_matrix_pretty_print.py +++ b/confusion_matrix_pretty_print.py @@ -212,20 +212,29 @@ def plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=Tru from sklearn.metrics import confusion_matrix from pandas import DataFrame - #data - if(not columns): - #labels axis integer: - ##columns = range(1, len(np.unique(y_test))+1) - #labels axis string: - from string import ascii_uppercase - columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:len(np.unique(y_test))]] - - confm = confusion_matrix(y_test, predictions) - cmap = 'Oranges'; - fz = 11; - figsize=[9,9]; + # data + if not columns: + generated_columns = [ + "class {0}".format(str(x).upper()) for x in np.unique(y_test) + ] + + confm = confusion_matrix(y_test, predictions, labels=np.unique(y_test)) + cmap = "Oranges" + fz = 11 + figsize = [9, 9] show_null_values = 2 - df_cm = DataFrame(confm, index=columns, columns=columns) + + df_cm = pd.DataFrame( + confm, + index=(np.unique(y_test) if not columns is None else generated_columns), + columns=(np.unique(y_test) if not columns is None else generated_columns), + ) + + # reorders confusion matrix to respect columns passed as arguments + if columns: + df_cm = df_cm.reindex(index=columns, columns=columns) + df_cm = df_cm.fillna(0.0) + pretty_plot_confusion_matrix(df_cm, fz=fz, cmap=cmap, figsize=figsize, show_null_values=show_null_values, pred_val_axis=pred_val_axis) #