diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..92ec9b3 Binary files /dev/null and b/.DS_Store differ diff --git a/pretty_confusion_matrix/pretty_confusion_matrix.py b/pretty_confusion_matrix/pretty_confusion_matrix.py index 5993172..a99ac74 100644 --- a/pretty_confusion_matrix/pretty_confusion_matrix.py +++ b/pretty_confusion_matrix/pretty_confusion_matrix.py @@ -152,6 +152,7 @@ def pp_matrix( figsize=[8, 8], show_null_values=0, pred_val_axis="y", + title='Confusion Matrix', ): """ print conf matrix with default layout (like matlab) @@ -234,7 +235,7 @@ def pp_matrix( ax.text(item["x"], item["y"], item["text"], **item["kw"]) # titles and legends - ax.set_title("Confusion matrix") + ax.set_title(title) ax.set_xlabel(xlbl) ax.set_ylabel(ylbl) plt.tight_layout() # set layout slim @@ -254,6 +255,7 @@ def pp_matrix_from_data( figsize=[8, 8], show_null_values=0, pred_val_axis="lin", + title="Confusion Matrix", ): """ plot confusion matrix function with y_test (actual values) and predictions (predic), @@ -268,7 +270,7 @@ def pp_matrix_from_data( columns = [ "class %s" % (i) - for i in list(ascii_uppercase)[0 : len(np.unique(y_test))] + for i in list(ascii_uppercase)[0: len(np.unique(y_test))] ] confm = confusion_matrix(y_test, predictions) @@ -280,4 +282,5 @@ def pp_matrix_from_data( figsize=figsize, show_null_values=show_null_values, pred_val_axis=pred_val_axis, + title=title, )