From e49071e38394c039d70cbc083c8b5a2cbb785b1b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 14:53:01 +0200 Subject: [PATCH] Port plot_confusion_matrix to new API. --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../_legacy_mpl_widgets/confusionmatrix.py | 91 ------------------- .../widgets/tests/test_widgets.py | 9 +- src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 11 insertions(+), 95 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 045b8acc8e..6013512022 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -15,9 +15,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget - from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py deleted file mode 100644 index 942b613fbf..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ConfusionMatrixWidget(BaseWidget): - """ - Plots sorting comparison confusion matrix. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget - """ - - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" - - def plot(self): - self._do_plot() - - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() - - N1 = confusion_matrix.shape[0] - 1 - N2 = confusion_matrix.shape[1] - 1 - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(confusion_matrix.values, cmap="Greens") - - if self._count_text: - for (i, j), z in np.ndenumerate(confusion_matrix.values): - if z != 0: - if z > np.max(confusion_matrix.values) / 2.0: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") - else: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") - - self.ax.axhline(int(N1 - 1) + 0.5, color="black") - self.ax.axvline(int(N2 - 1) + 0.5, color="black") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2 + 1)) - self.ax.set_yticks(np.arange(0, N1 + 1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if self._unit_ticks: - self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) - self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) - else: - self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) - self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 + 0.5) - self.ax.set_ylim( - N1 + 0.5, - -0.5, - ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 2f11e5ee3c..0aa309f748 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -330,6 +330,12 @@ def test_plot_agreement_matrix(self): if backend not in self.skip_backends: sw.plot_agreement_matrix(self.gt_comp) + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + if __name__ == "__main__": @@ -352,7 +358,8 @@ def test_plot_agreement_matrix(self): # mytest.test_quality_metrics() # mytest.test_template_metrics() # mytest.test_amplitudes() - mytest.test_plot_agreement_matrix() + # mytest.test_plot_agreement_matrix() + mytest.test_plot_confusion_matrix() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 22b33e38aa..d02aa7de7a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -6,6 +6,7 @@ from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget from .quality_metrics import QualityMetricsWidget @@ -28,6 +29,7 @@ AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, QualityMetricsWidget, @@ -82,6 +84,7 @@ plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget plot_quality_metrics = QualityMetricsWidget