diff --git a/pretty_confusion_matrix/pretty_confusion_matrix.py b/pretty_confusion_matrix/pretty_confusion_matrix.py index 75e8b62..1c52f04 100644 --- a/pretty_confusion_matrix/pretty_confusion_matrix.py +++ b/pretty_confusion_matrix/pretty_confusion_matrix.py @@ -240,6 +240,8 @@ def pp_matrix( plt.tight_layout() # set layout slim plt.show() + return fig + def pp_matrix_from_data( y_test,