From 8c4d527573521be93d77706451c22824ad8d6e83 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 09:45:12 +0200 Subject: [PATCH 01/15] Improve GroundTruthStudy --- src/spikeinterface/comparison/groundtruthstudy.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index e5f4ce8b31..917067fba6 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -288,16 +288,19 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): recording, gt_sorting = self.datasets[dataset_key] we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) - def get_waveform_extractor(self, key): + def get_waveform_extractor(self, case_key=None, dataset_key=None): # some recording are not dumpable to json and the waveforms extactor need it! # so we load it with and put after # this should be fixed in PR 2027 so remove this after - dataset_key = self.cases[key]["dataset"] + if case_key is not None: + dataset_key = self.cases[case_key]["dataset"] + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) - we = load_waveforms(wf_folder, with_recording=False) - recording, _ = self.datasets[dataset_key] - we.set_recording(recording) + we = load_waveforms(wf_folder, with_recording=True) + # we = load_waveforms(wf_folder, with_recording=False) + # recording, _ = self.datasets[dataset_key] + # we.set_recording(recording) return we def get_templates(self, key, mode="average"): From 1b8470f484e353f1ff442dcda1f2422f303bfcd5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 12:10:34 +0200 Subject: [PATCH 02/15] group comparison widget --- .../widgets/agreement_matrix.py | 87 ---- src/spikeinterface/widgets/comparison.py | 459 ++++++++++++++++++ .../widgets/confusion_matrix.py | 79 --- src/spikeinterface/widgets/widget_list.py | 4 +- 4 files changed, 461 insertions(+), 168 deletions(-) delete mode 100644 src/spikeinterface/widgets/agreement_matrix.py create mode 100644 src/spikeinterface/widgets/comparison.py delete mode 100644 src/spikeinterface/widgets/confusion_matrix.py diff --git a/src/spikeinterface/widgets/agreement_matrix.py b/src/spikeinterface/widgets/agreement_matrix.py deleted file mode 100644 index ec6ea1c87c..0000000000 --- a/src/spikeinterface/widgets/agreement_matrix.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -class AgreementMatrixWidget(BaseWidget): - """ - Plots sorting comparison agreement matrix. - - Parameters - ---------- - sorting_comparison: GroundTruthComparison or SymmetricSortingComparison - The sorting comparison object. - Symetric or not. - ordered: bool - Order units with best agreement scores. - This enable to see agreement on a diagonal. - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - - """ - - def __init__( - self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs - ): - plot_data = dict( - sorting_comparison=sorting_comparison, - ordered=ordered, - count_text=count_text, - unit_ticks=unit_ticks, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - comp = dp.sorting_comparison - - if dp.ordered: - scores = comp.get_ordered_agreement_scores() - else: - scores = comp.agreement_scores - - N1 = scores.shape[0] - N2 = scores.shape[1] - - unit_ids1 = scores.index.values - unit_ids2 = scores.columns.values - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(scores.values, cmap="Greens") - - if dp.count_text: - for i, u1 in enumerate(unit_ids1): - u2 = comp.best_match_12[u1] - if u2 != -1: - j = np.where(unit_ids2 == u2)[0][0] - - self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2)) - self.ax.set_yticks(np.arange(0, N1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if dp.unit_ticks: - self.ax.set_yticklabels(scores.index, fontsize=12) - self.ax.set_xticklabels(scores.columns, fontsize=12) - - self.ax.set_xlabel(comp.name_list[1], fontsize=20) - self.ax.set_ylabel(comp.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 - 0.5) - self.ax.set_ylim( - N1 - 0.5, - -0.5, - ) diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py new file mode 100644 index 0000000000..db3de104d2 --- /dev/null +++ b/src/spikeinterface/widgets/comparison.py @@ -0,0 +1,459 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + +class MultiCompGraphWidget(BaseWidget): + """ + Plots multi comparison graph. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + draw_labels: bool + If True unit labels are shown + node_cmap: matplotlib colormap + The colormap to be used for the nodes (default 'viridis') + edge_cmap: matplotlib colormap + The colormap to be used for the edges (default 'hot') + alpha_edges: float + Alpha value for edges + colorbar: bool + If True a colorbar for the edges is plotted + """ + + def __init__( + self, + multi_comparison, + draw_labels=False, + node_cmap="viridis", + edge_cmap="hot", + alpha_edges=0.5, + colorbar=False, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + draw_labels=draw_labels, + node_cmap=node_cmap, + edge_cmap=edge_cmap, + alpha_edges=alpha_edges, + colorbar=colorbar, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + import networkx as nx + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + g = mcmp.graph + edge_col = [] + for e in g.edges(data=True): + n1, n2, d = e + edge_col.append(d["weight"]) + nodes_col_dict = {} + for i, sort_name in enumerate(mcmp.name_list): + nodes_col_dict[sort_name] = i + nodes_col = [] + for node in sorted(g.nodes): + nodes_col.append(nodes_col_dict[node[0]]) + nodes_col = np.array(nodes_col) / len(mcmp.name_list) + + _ = plt.set_cmap(dp.node_cmap) + _ = nx.draw_networkx_nodes( + g, + pos=nx.circular_layout(sorted(g)), + nodelist=sorted(g.nodes), + node_color=nodes_col, + node_size=20, + ax=self.ax, + ) + _ = nx.draw_networkx_edges( + g, + pos=nx.circular_layout((sorted(g))), + nodelist=sorted(g.nodes), + edge_color=edge_col, + alpha=dp.alpha_edges, + edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_vmin=mcmp.match_score, + edge_vmax=1, + ax=self.ax, + ) + if dp.draw_labels: + labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} + pos = nx.circular_layout(sorted(g)) + # extend position radially + pos_extended = {} + for node, pos in pos.items(): + pos_new = pos + 0.1 * pos + pos_extended[node] = pos_new + _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) + + if dp.colorbar: + import matplotlib.pyplot as plt + + norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) + cmap = plt.cm.get_cmap(dp.edge_cmap) + m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) + self.figure.colorbar(m) + + self.ax.axis("off") + + +class MultiCompGlobalAgreementWidget(BaseWidget): + """ + Plots multi comparison agreement as pie or bar plot. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + plot_type: str + 'pie' or 'bar' + cmap: matplotlib colormap, default: 'YlOrRd' + The colormap to be used for the nodes + fontsize: int, default: 9 + The text fontsize + show_legend: bool, default: True + If True a legend is shown + """ + + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() + # fraction of units with agreement > threshold + v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True) + if dp.plot_type == "pie": + p = self.ax.pie(c, colors=colors[v - 1], autopct=lambda pct: _getabs(pct, c), pctdistance=1.25) + self.ax.legend( + p[0], + v, + frameon=False, + title="k=", + handlelength=1, + handletextpad=0.5, + bbox_to_anchor=(1.0, 1.0), + loc=2, + borderaxespad=0.5, + labelspacing=0.15, + fontsize=dp.fontsize, + ) + elif dp.plot_type == "bar": + self.ax.bar(v, c, color=colors[v - 1]) + x_labels = [f"k={vi}" for vi in v] + self.ax.spines["top"].set_visible(False) + self.ax.spines["right"].set_visible(False) + self.ax.set_xticks(v) + self.ax.set_xticklabels(x_labels) + else: + raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") + self.ax.set_title("Units agreed upon\nby k sorters") + + +class MultiCompAgreementBySorterWidget(BaseWidget): + """ + Plots multi comparison agreement as pie or bar plot. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + plot_type: str + 'pie' or 'bar' + cmap: matplotlib colormap + The colormap to be used for the nodes (default 'Reds') + axes: list of matplotlib axes + The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax + and figure parameters are ignored. + show_legend: bool + Show the legend in the last axes (default True). + + Returns + ------- + W: MultiCompGraphWidget + The output widget + """ + + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + mcmp = dp.multi_comparison + name_list = mcmp.name_list + + backend_kwargs["num_axes"] = len(name_list) + backend_kwargs["ncols"] = len(name_list) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() + # fraction of units with agreement > threshold + for i, name in enumerate(name_list): + ax = np.squeeze(self.axes)[i] + v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True) + if dp.plot_type == "pie": + p = ax.pie( + c, + colors=colors[v - 1], + textprops={"color": "k", "fontsize": dp.fontsize}, + autopct=lambda pct: _getabs(pct, c), + pctdistance=1.18, + ) + if (dp.show_legend) and (i == len(name_list) - 1): + plt.legend( + p[0], + v, + frameon=False, + title="k=", + handlelength=1, + handletextpad=0.5, + bbox_to_anchor=(1.15, 1.25), + loc=2, + borderaxespad=0.0, + labelspacing=0.15, + ) + elif dp.plot_type == "bar": + ax.bar(v, c, color=colors[v - 1]) + x_labels = [f"k={vi}" for vi in v] + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.set_xticks(v) + ax.set_xticklabels(x_labels) + else: + raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") + ax.set_title(name) + + if dp.plot_type == "bar": + ylims = [np.max(ax_single.get_ylim()) for ax_single in self.axes] + max_yval = np.max(ylims) + for ax_single in self.axes: + ax_single.set_ylim([0, max_yval]) + + +def _getabs(pct, allvals): + absolute = int(np.round(pct / 100.0 * np.sum(allvals))) + return f"{absolute}" +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + +class ConfusionMatrixWidget(BaseWidget): + """ + Plots sorting comparison confusion matrix. + + Parameters + ---------- + gt_comparison: GroundTruthComparison + The ground truth sorting comparison object + count_text: bool + If True counts are displayed as text + unit_ticks: bool + If True unit tick labels are displayed + + """ + + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.gt_comparison + + confusion_matrix = comp.get_confusion_matrix() + N1 = confusion_matrix.shape[0] - 1 + N2 = confusion_matrix.shape[1] - 1 + + # Using matshow here just because it sets the ticks up nicely. imshow is faster. + self.ax.matshow(confusion_matrix.values, cmap="Greens") + + if dp.count_text: + for (i, j), z in np.ndenumerate(confusion_matrix.values): + if z != 0: + if z > np.max(confusion_matrix.values) / 2.0: + self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") + else: + self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") + + self.ax.axhline(int(N1 - 1) + 0.5, color="black") + self.ax.axvline(int(N2 - 1) + 0.5, color="black") + + # Major ticks + self.ax.set_xticks(np.arange(0, N2 + 1)) + self.ax.set_yticks(np.arange(0, N1 + 1)) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + if dp.unit_ticks: + self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) + self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) + else: + self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) + self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) + + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) + + self.ax.set_xlim(-0.5, N2 + 0.5) + self.ax.set_ylim( + N1 + 0.5, + -0.5, + ) + + + + +class AgreementMatrixWidget(BaseWidget): + """ + Plots sorting comparison agreement matrix. + + Parameters + ---------- + sorting_comparison: GroundTruthComparison or SymmetricSortingComparison + The sorting comparison object. + Symetric or not. + ordered: bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + count_text: bool + If True counts are displayed as text + unit_ticks: bool + If True unit tick labels are displayed + + """ + + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.sorting_comparison + + if dp.ordered: + scores = comp.get_ordered_agreement_scores() + else: + scores = comp.agreement_scores + + N1 = scores.shape[0] + N2 = scores.shape[1] + + unit_ids1 = scores.index.values + unit_ids2 = scores.columns.values + + # Using matshow here just because it sets the ticks up nicely. imshow is faster. + self.ax.matshow(scores.values, cmap="Greens") + + if dp.count_text: + for i, u1 in enumerate(unit_ids1): + u2 = comp.best_match_12[u1] + if u2 != -1: + j = np.where(unit_ids2 == u2)[0][0] + + self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") + + # Major ticks + self.ax.set_xticks(np.arange(0, N2)) + self.ax.set_yticks(np.arange(0, N1)) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + if dp.unit_ticks: + self.ax.set_yticklabels(scores.index, fontsize=12) + self.ax.set_xticklabels(scores.columns, fontsize=12) + + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) + + self.ax.set_xlim(-0.5, N2 - 0.5) + self.ax.set_ylim( + N1 - 0.5, + -0.5, + ) diff --git a/src/spikeinterface/widgets/confusion_matrix.py b/src/spikeinterface/widgets/confusion_matrix.py deleted file mode 100644 index 8eb58f30b2..0000000000 --- a/src/spikeinterface/widgets/confusion_matrix.py +++ /dev/null @@ -1,79 +0,0 @@ -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -class ConfusionMatrixWidget(BaseWidget): - """ - Plots sorting comparison confusion matrix. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - - """ - - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): - plot_data = dict( - gt_comparison=gt_comparison, - count_text=count_text, - unit_ticks=unit_ticks, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - comp = dp.gt_comparison - - confusion_matrix = comp.get_confusion_matrix() - N1 = confusion_matrix.shape[0] - 1 - N2 = confusion_matrix.shape[1] - 1 - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(confusion_matrix.values, cmap="Greens") - - if dp.count_text: - for (i, j), z in np.ndenumerate(confusion_matrix.values): - if z != 0: - if z > np.max(confusion_matrix.values) / 2.0: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") - else: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") - - self.ax.axhline(int(N1 - 1) + 0.5, color="black") - self.ax.axvline(int(N2 - 1) + 0.5, color="black") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2 + 1)) - self.ax.set_yticks(np.arange(0, N1 + 1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if dp.unit_ticks: - self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) - self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) - else: - self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) - self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - - self.ax.set_xlabel(comp.name_list[1], fontsize=20) - self.ax.set_ylabel(comp.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 + 0.5) - self.ax.set_ylim( - N1 + 0.5, - -0.5, - ) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 00d179127d..f18f4cb461 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,11 +2,10 @@ from .base import backend_kwargs_desc -from .agreement_matrix import AgreementMatrixWidget + from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget -from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget @@ -29,6 +28,7 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget +from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget From 22062722769c4577bae63080fd7f9721bc9993a4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 13:06:24 +0200 Subject: [PATCH 03/15] Improve, refactor and simplify some study widgets. --- .../comparison/groundtruthstudy.py | 9 +- src/spikeinterface/widgets/gtstudy.py | 103 +++++++----------- src/spikeinterface/widgets/widget_list.py | 4 +- 3 files changed, 43 insertions(+), 73 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 917067fba6..8cab2afc1b 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -289,18 +289,11 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) def get_waveform_extractor(self, case_key=None, dataset_key=None): - # some recording are not dumpable to json and the waveforms extactor need it! - # so we load it with and put after - # this should be fixed in PR 2027 so remove this after - if case_key is not None: dataset_key = self.cases[case_key]["dataset"] wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) we = load_waveforms(wf_folder, with_recording=True) - # we = load_waveforms(wf_folder, with_recording=False) - # recording, _ = self.datasets[dataset_key] - # we.set_recording(recording) return we def get_templates(self, key, mode="average"): @@ -369,7 +362,7 @@ def get_performance_by_unit(self, case_keys=None): perf_by_unit.append(perf) perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.set_index(self.levels).sort_index() return perf_by_unit def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 6a27b78dec..cd4bdb6cb3 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -150,6 +150,7 @@ def __init__( self, study, mode="swarm", + performance_names=("accuracy", "precision", "recall"), case_keys=None, backend=None, **backend_kwargs, @@ -161,6 +162,7 @@ def __init__( study=study, perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, + performance_names=performance_names, case_keys=case_keys, ) @@ -176,78 +178,55 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) perfs = dp.perfs + study = dp.study + if dp.mode in ("ordered", "snr"): + backend_kwargs["num_axes"] = len(dp.performance_names) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.mode == "swarm": + if dp.mode == "ordered": + for count, performance_name in enumerate(dp.performance_names): + ax = self.axes.flatten()[count] + for key in dp.case_keys: + label = study.cases[key]["label"] + + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label) + ax.set_title(performance_name) + if count == 0: + ax.legend() + + elif dp.mode == "snr": + + metric_name = dp.mode + for count, performance_name in enumerate(dp.performance_names): + ax = self.axes.flatten()[count] + + max_metric = 0 + for key in dp.case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, label=label) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend() + + + elif dp.mode == "swarm": levels = perfs.index.names df = pd.melt( perfs.reset_index(), id_vars=levels, var_name="Metric", value_name="Score", - value_vars=("accuracy", "precision", "recall"), + value_vars=dp.performance_names, ) df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + - -class StudyPerformancesVsMetrics(BaseWidget): - """ - Plot performances vs a metrics (snr for instance) over case for a study. - - - Parameters - ---------- - study: GroundTruthStudy - A study object. - mode: str - Which mode in "swarm" - case_keys: list or None - A selection of cases to plot, if None, then all. - - """ - - def __init__( - self, - study, - metric_name="snr", - performance_name="accuracy", - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - metric_name=metric_name, - performance_name=performance_name, - case_keys=case_keys, - ) - - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - dp = to_attr(data_plot) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - study = dp.study - perfs = study.get_performance_by_unit(case_keys=dp.case_keys) - - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key)[dp.metric_name].values - y = perfs.xs(key)[dp.performance_name].values - label = dp.study.cases[key]["label"] - self.ax.scatter(x, y, label=label) - max_metric = max(max_metric, np.max(x)) - - self.ax.legend() - self.ax.set_xlim(0, max_metric * 1.05) - self.ax.set_ylim(0, 1.05) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f18f4cb461..52aa76165d 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -29,7 +29,7 @@ from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget -from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ @@ -66,7 +66,6 @@ StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, - StudyPerformancesVsMetrics, StudyComparisonCollisionBySimilarityWidget, ] @@ -136,7 +135,6 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances -plot_study_performances_vs_metrics = StudyPerformancesVsMetrics plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget From a9ec3ebd0a00c942577416102922c28de7424af7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 13:09:19 +0200 Subject: [PATCH 04/15] remove a wrong and sad copy/paste --- src/spikeinterface/widgets/comparison.py | 298 ----------------------- 1 file changed, 298 deletions(-) diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py index db3de104d2..e1f51865f9 100644 --- a/src/spikeinterface/widgets/comparison.py +++ b/src/spikeinterface/widgets/comparison.py @@ -1,304 +1,6 @@ import numpy as np -from warnings import warn from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -class MultiCompGraphWidget(BaseWidget): - """ - Plots multi comparison graph. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - draw_labels: bool - If True unit labels are shown - node_cmap: matplotlib colormap - The colormap to be used for the nodes (default 'viridis') - edge_cmap: matplotlib colormap - The colormap to be used for the edges (default 'hot') - alpha_edges: float - Alpha value for edges - colorbar: bool - If True a colorbar for the edges is plotted - """ - - def __init__( - self, - multi_comparison, - draw_labels=False, - node_cmap="viridis", - edge_cmap="hot", - alpha_edges=0.5, - colorbar=False, - backend=None, - **backend_kwargs, - ): - plot_data = dict( - multi_comparison=multi_comparison, - draw_labels=draw_labels, - node_cmap=node_cmap, - edge_cmap=edge_cmap, - alpha_edges=alpha_edges, - colorbar=colorbar, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.colors as mpl_colors - import matplotlib.pyplot as plt - import networkx as nx - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - mcmp = dp.multi_comparison - g = mcmp.graph - edge_col = [] - for e in g.edges(data=True): - n1, n2, d = e - edge_col.append(d["weight"]) - nodes_col_dict = {} - for i, sort_name in enumerate(mcmp.name_list): - nodes_col_dict[sort_name] = i - nodes_col = [] - for node in sorted(g.nodes): - nodes_col.append(nodes_col_dict[node[0]]) - nodes_col = np.array(nodes_col) / len(mcmp.name_list) - - _ = plt.set_cmap(dp.node_cmap) - _ = nx.draw_networkx_nodes( - g, - pos=nx.circular_layout(sorted(g)), - nodelist=sorted(g.nodes), - node_color=nodes_col, - node_size=20, - ax=self.ax, - ) - _ = nx.draw_networkx_edges( - g, - pos=nx.circular_layout((sorted(g))), - nodelist=sorted(g.nodes), - edge_color=edge_col, - alpha=dp.alpha_edges, - edge_cmap=plt.cm.get_cmap(dp.edge_cmap), - edge_vmin=mcmp.match_score, - edge_vmax=1, - ax=self.ax, - ) - if dp.draw_labels: - labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} - pos = nx.circular_layout(sorted(g)) - # extend position radially - pos_extended = {} - for node, pos in pos.items(): - pos_new = pos + 0.1 * pos - pos_extended[node] = pos_new - _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) - - if dp.colorbar: - import matplotlib.pyplot as plt - - norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) - cmap = plt.cm.get_cmap(dp.edge_cmap) - m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) - self.figure.colorbar(m) - - self.ax.axis("off") - - -class MultiCompGlobalAgreementWidget(BaseWidget): - """ - Plots multi comparison agreement as pie or bar plot. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap, default: 'YlOrRd' - The colormap to be used for the nodes - fontsize: int, default: 9 - The text fontsize - show_legend: bool, default: True - If True a legend is shown - """ - - def __init__( - self, - multi_comparison, - plot_type="pie", - cmap="YlOrRd", - fontsize=9, - show_legend=True, - backend=None, - **backend_kwargs, - ): - plot_data = dict( - multi_comparison=multi_comparison, - plot_type=plot_type, - cmap=cmap, - fontsize=fontsize, - show_legend=show_legend, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - mcmp = dp.multi_comparison - cmap = plt.get_cmap(dp.cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) - sg_names, sg_units = mcmp.compute_subgraphs() - # fraction of units with agreement > threshold - v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True) - if dp.plot_type == "pie": - p = self.ax.pie(c, colors=colors[v - 1], autopct=lambda pct: _getabs(pct, c), pctdistance=1.25) - self.ax.legend( - p[0], - v, - frameon=False, - title="k=", - handlelength=1, - handletextpad=0.5, - bbox_to_anchor=(1.0, 1.0), - loc=2, - borderaxespad=0.5, - labelspacing=0.15, - fontsize=dp.fontsize, - ) - elif dp.plot_type == "bar": - self.ax.bar(v, c, color=colors[v - 1]) - x_labels = [f"k={vi}" for vi in v] - self.ax.spines["top"].set_visible(False) - self.ax.spines["right"].set_visible(False) - self.ax.set_xticks(v) - self.ax.set_xticklabels(x_labels) - else: - raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") - self.ax.set_title("Units agreed upon\nby k sorters") - - -class MultiCompAgreementBySorterWidget(BaseWidget): - """ - Plots multi comparison agreement as pie or bar plot. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap - The colormap to be used for the nodes (default 'Reds') - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored. - show_legend: bool - Show the legend in the last axes (default True). - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__( - self, - multi_comparison, - plot_type="pie", - cmap="YlOrRd", - fontsize=9, - show_legend=True, - backend=None, - **backend_kwargs, - ): - plot_data = dict( - multi_comparison=multi_comparison, - plot_type=plot_type, - cmap=cmap, - fontsize=fontsize, - show_legend=show_legend, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.colors as mpl_colors - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - mcmp = dp.multi_comparison - name_list = mcmp.name_list - - backend_kwargs["num_axes"] = len(name_list) - backend_kwargs["ncols"] = len(name_list) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - cmap = plt.get_cmap(dp.cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) - sg_names, sg_units = mcmp.compute_subgraphs() - # fraction of units with agreement > threshold - for i, name in enumerate(name_list): - ax = np.squeeze(self.axes)[i] - v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True) - if dp.plot_type == "pie": - p = ax.pie( - c, - colors=colors[v - 1], - textprops={"color": "k", "fontsize": dp.fontsize}, - autopct=lambda pct: _getabs(pct, c), - pctdistance=1.18, - ) - if (dp.show_legend) and (i == len(name_list) - 1): - plt.legend( - p[0], - v, - frameon=False, - title="k=", - handlelength=1, - handletextpad=0.5, - bbox_to_anchor=(1.15, 1.25), - loc=2, - borderaxespad=0.0, - labelspacing=0.15, - ) - elif dp.plot_type == "bar": - ax.bar(v, c, color=colors[v - 1]) - x_labels = [f"k={vi}" for vi in v] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks(v) - ax.set_xticklabels(x_labels) - else: - raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") - ax.set_title(name) - - if dp.plot_type == "bar": - ylims = [np.max(ax_single.get_ylim()) for ax_single in self.axes] - max_yval = np.max(ylims) - for ax_single in self.axes: - ax_single.set_ylim([0, max_yval]) - - -def _getabs(pct, allvals): - absolute = int(np.round(pct / 100.0 * np.sum(allvals))) - return f"{absolute}" -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): From 4cae0b5a1de456627ccd9251d2a16e880a06e0cf Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 25 Oct 2023 15:04:22 +0200 Subject: [PATCH 05/15] Fix by Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/comparison.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py index e1f51865f9..c45b8bf1db 100644 --- a/src/spikeinterface/widgets/comparison.py +++ b/src/spikeinterface/widgets/comparison.py @@ -87,13 +87,13 @@ class AgreementMatrixWidget(BaseWidget): ---------- sorting_comparison: GroundTruthComparison or SymmetricSortingComparison The sorting comparison object. - Symetric or not. - ordered: bool + Can optionally be symmetric if given a SymmetricSortingComparison + ordered: bool, default: True Order units with best agreement scores. - This enable to see agreement on a diagonal. - count_text: bool + If True, agreement scores can be seen along a diagonal + count_text: bool, default: True If True counts are displayed as text - unit_ticks: bool + unit_ticks: bool, default: True If True unit tick labels are displayed """ From d6330cc4d73866b0995831b66bf80122ce5bb11a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 15:35:39 +0200 Subject: [PATCH 06/15] GroundTruthSTudy : Save comparison object in a folder. --- .../comparison/groundtruthstudy.py | 45 +++++++++++-------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8cab2afc1b..f31ed773a9 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -88,6 +88,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): (study_folder / "sortings").mkdir() (study_folder / "sortings" / "run_logs").mkdir() (study_folder / "metrics").mkdir() + (study_folder / "comparisons").mkdir() for key, (rec, gt_sorting) in datasets.items(): assert "/" not in key, "'/' cannot be in the key name!" @@ -127,16 +128,18 @@ def scan_folder(self): with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) + + self.sortings = {k: None for k in self.cases} self.comparisons = {k: None for k in self.cases} - - self.sortings = {} for key in self.cases: sorting_folder = self.folder / "sortings" / self.key_to_str(key) if sorting_folder.exists(): - sorting = load_extractor(sorting_folder) - else: - sorting = None - self.sortings[key] = sorting + self.sortings[key] = load_extractor(sorting_folder) + + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + '.pickle') + if comparison_file.exists(): + with open(comparison_file, mode='rb') as f : + self.comparisons[key] = pickle.load(f) def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" @@ -155,6 +158,16 @@ def key_to_str(self, key): else: raise ValueError("Keys for cases must str or tuple") + def remove_sorting(self, key): + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + comparison_file = self.folder / "comparisons" / self.key_to_str(key) + if sorting_folder.exists(): + shutil.rmtree(sorting_folder) + for f in (log_file, comparison_file): + if f.exists(): + f.unlink() + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): if case_keys is None: case_keys = self.cases.keys() @@ -177,13 +190,8 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True # save and skip self.copy_sortings(case_keys=[key]) continue - - if sorting_exists: - # delete older sorting + log before running sorters - shutil.rmtree(sorting_folder) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - if log_file.exists(): - log_file.unlink() + + self.remove_sorting(key) if sorter_folder_exists: shutil.rmtree(sorter_folder) @@ -228,10 +236,7 @@ def copy_sortings(self, case_keys=None, force=True): if sorting is not None: if sorting_folder.exists(): if force: - # delete folder + log - shutil.rmtree(sorting_folder) - if log_file.exists(): - log_file.unlink() + self.remove_sorting(key) else: continue @@ -255,6 +260,10 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison comp = comparison_class(gt_sorting, sorting, **kwargs) self.comparisons[key] = comp + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + '.pickle') + with open(comparison_file, mode='wb') as f: + pickle.dump(comp, f) + def get_run_times(self, case_keys=None): import pandas as pd @@ -297,7 +306,7 @@ def get_waveform_extractor(self, case_key=None, dataset_key=None): return we def get_templates(self, key, mode="average"): - we = self.get_waveform_extractor(key) + we = self.get_waveform_extractor(case_key=key) templates = we.get_all_templates(mode=mode) return templates From 79ac8d1c21d843febac666348c0b1a9a4b9b120f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 15:35:57 +0200 Subject: [PATCH 07/15] Improve widgets for GTStudy --- src/spikeinterface/widgets/gtstudy.py | 116 ++++++++++++++++++++-- src/spikeinterface/widgets/widget_list.py | 6 +- 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index cd4bdb6cb3..9867bc6a36 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,11 +1,6 @@ import numpy as np from .base import BaseWidget, to_attr -from .utils import get_unit_colors - -from ..core import ChannelSparsity -from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class StudyRunTimesWidget(BaseWidget): @@ -129,7 +124,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.legend() -# TODO : plot optionally average on some levels using group by class StudyPerformances(BaseWidget): """ Plot performances over case for a study. @@ -140,16 +134,17 @@ class StudyPerformances(BaseWidget): study: GroundTruthStudy A study object. mode: str - Which mode in "swarm" + Which mode in "ordered", "snr", "swarm" + performance_names: list + Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None A selection of cases to plot, if None, then all. - """ def __init__( self, study, - mode="swarm", + mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None, backend=None, @@ -230,3 +225,106 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) +class StudyAgreementMatrix(BaseWidget): + """ + Plot agreement matrix. + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + ordered: bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + count_text: bool + If True counts are displayed as text + """ + + def __init__( + self, + study, + ordered=True, count_text=True, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + case_keys=case_keys, + ordered=ordered, + count_text=count_text, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .comparison import AgreementMatrixWidget + + dp = to_attr(data_plot) + study = dp.study + + backend_kwargs["num_axes"] = len(dp.case_keys) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for count, key in enumerate(dp.case_keys): + ax = self.axes.flatten()[count] + comp = study.comparisons[key] + AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax) + label = study.cases[key]["label"] + ax.set_title(label) + + +class StudySummary(BaseWidget): + """ + Plot a summary of a ground truth study. + Internally do: + plot_study_run_times + plot_study_unit_counts + plot_study_performances + plot_study_agreement_matrix + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + study = data_plot["study"] + case_keys = data_plot["case_keys"] + + + StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) + StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) + StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 52aa76165d..b95e92668a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -29,7 +29,7 @@ from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget -from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ @@ -66,6 +66,8 @@ StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, + StudyAgreementMatrix, + StudySummary, StudyComparisonCollisionBySimilarityWidget, ] @@ -135,6 +137,8 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances +plot_study_agreement_matrix = StudyAgreementMatrix +plot_study_summary = StudySummary plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget From 5e7a57fdeffa2f122e96bb10449fc1ddc9a88829 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Oct 2023 21:31:37 +0200 Subject: [PATCH 08/15] Some more cosmetics. --- src/spikeinterface/widgets/comparison.py | 12 +++++------ src/spikeinterface/widgets/gtstudy.py | 26 +++++++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py index c45b8bf1db..1b14275459 100644 --- a/src/spikeinterface/widgets/comparison.py +++ b/src/spikeinterface/widgets/comparison.py @@ -142,17 +142,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") # Major ticks - self.ax.set_xticks(np.arange(0, N2)) - self.ax.set_yticks(np.arange(0, N1)) self.ax.xaxis.tick_bottom() # Labels for major ticks if dp.unit_ticks: - self.ax.set_yticklabels(scores.index, fontsize=12) - self.ax.set_xticklabels(scores.columns, fontsize=12) + self.ax.set_xticks(np.arange(0, N2)) + self.ax.set_yticks(np.arange(0, N1)) + self.ax.set_yticklabels(scores.index) + self.ax.set_xticklabels(scores.columns) - self.ax.set_xlabel(comp.name_list[1], fontsize=20) - self.ax.set_ylabel(comp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1]) + self.ax.set_ylabel(comp.name_list[0]) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 9867bc6a36..77c557910f 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -190,7 +190,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.plot(val, label=label) ax.set_title(performance_name) if count == 0: - ax.legend() + ax.legend(loc='upper right') elif dp.mode == "snr": @@ -203,13 +203,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x = study.get_metrics(key).loc[:, metric_name].values y = perfs.xs(key).loc[:, performance_name].values label = study.cases[key]["label"] - ax.scatter(x, y, label=label) + ax.scatter(x, y, s=10, label=label) max_metric = max(max_metric, np.max(x)) ax.set_title(performance_name) ax.set_xlim(0, max_metric * 1.05) ax.set_ylim(0, 1.05) if count == 0: - ax.legend() + ax.legend(loc='lower right') elif dp.mode == "swarm": @@ -245,7 +245,7 @@ class StudyAgreementMatrix(BaseWidget): def __init__( self, study, - ordered=True, count_text=True, + ordered=True, case_keys=None, backend=None, **backend_kwargs, @@ -257,7 +257,6 @@ def __init__( study=study, case_keys=case_keys, ordered=ordered, - count_text=count_text, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -276,9 +275,22 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for count, key in enumerate(dp.case_keys): ax = self.axes.flatten()[count] comp = study.comparisons[key] - AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax) + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + + AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend='matplotlib', ax=ax) label = study.cases[key]["label"] - ax.set_title(label) + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + # ax0 = self.axes.flatten()[0] + # for ax in self.axes.flatten()[1:]: + # ax.sharey(ax0) class StudySummary(BaseWidget): From 1d112d52806d6b54a202f705b41f107b5a53562b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:44:13 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/groundtruthstudy.py | 13 ++++++------- src/spikeinterface/widgets/comparison.py | 2 -- src/spikeinterface/widgets/gtstudy.py | 16 +++++++--------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index f31ed773a9..0d08922543 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -128,7 +128,6 @@ def scan_folder(self): with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) - self.sortings = {k: None for k in self.cases} self.comparisons = {k: None for k in self.cases} for key in self.cases: @@ -136,9 +135,9 @@ def scan_folder(self): if sorting_folder.exists(): self.sortings[key] = load_extractor(sorting_folder) - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + '.pickle') + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") if comparison_file.exists(): - with open(comparison_file, mode='rb') as f : + with open(comparison_file, mode="rb") as f: self.comparisons[key] = pickle.load(f) def __repr__(self): @@ -190,7 +189,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True # save and skip self.copy_sortings(case_keys=[key]) continue - + self.remove_sorting(key) if sorter_folder_exists: @@ -260,8 +259,8 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison comp = comparison_class(gt_sorting, sorting, **kwargs) self.comparisons[key] = comp - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + '.pickle') - with open(comparison_file, mode='wb') as f: + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + with open(comparison_file, mode="wb") as f: pickle.dump(comp, f) def get_run_times(self, case_keys=None): @@ -300,7 +299,7 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): def get_waveform_extractor(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] - + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) we = load_waveforms(wf_folder, with_recording=True) return we diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py index 1b14275459..70f98df8b9 100644 --- a/src/spikeinterface/widgets/comparison.py +++ b/src/spikeinterface/widgets/comparison.py @@ -77,8 +77,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) - - class AgreementMatrixWidget(BaseWidget): """ Plots sorting comparison agreement matrix. diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 77c557910f..946ef6cd5f 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -190,10 +190,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.plot(val, label=label) ax.set_title(performance_name) if count == 0: - ax.legend(loc='upper right') + ax.legend(loc="upper right") elif dp.mode == "snr": - metric_name = dp.mode for count, performance_name in enumerate(dp.performance_names): ax = self.axes.flatten()[count] @@ -209,8 +208,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_xlim(0, max_metric * 1.05) ax.set_ylim(0, 1.05) if count == 0: - ax.legend(loc='lower right') - + ax.legend(loc="lower right") elif dp.mode == "swarm": levels = perfs.index.names @@ -223,7 +221,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) - + class StudyAgreementMatrix(BaseWidget): """ @@ -278,8 +276,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_ticks = len(comp.sorting1.unit_ids) <= 16 count_text = len(comp.sorting1.unit_ids) <= 16 - - AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend='matplotlib', ax=ax) + AgreementMatrixWidget( + comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) label = study.cases[key]["label"] ax.set_xlabel(label) @@ -287,7 +286,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylabel(None) ax.set_yticks([]) ax.set_xticks([]) - + # ax0 = self.axes.flatten()[0] # for ax in self.axes.flatten()[1:]: # ax.sharey(ax0) @@ -334,7 +333,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = data_plot["study"] case_keys = data_plot["case_keys"] - StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) From b11ff03dae735772feb90c4613060f5367182ce7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:13:41 +0200 Subject: [PATCH 10/15] Update src/spikeinterface/widgets/gtstudy.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/gtstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 946ef6cd5f..1e51c7d101 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -295,7 +295,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class StudySummary(BaseWidget): """ Plot a summary of a ground truth study. - Internally do: + Internally does: plot_study_run_times plot_study_unit_counts plot_study_performances From a64e08673aee417cf3b5525358eab7ee763172e4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:14:29 +0200 Subject: [PATCH 11/15] Update src/spikeinterface/widgets/gtstudy.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/gtstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 1e51c7d101..0b4a5182a5 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -305,7 +305,7 @@ class StudySummary(BaseWidget): ---------- study: GroundTruthStudy A study object. - case_keys: list or None + case_keys: list or None, default: None A selection of cases to plot, if None, then all. """ From 4fe1d02eac410d46db0491f0b66913a98e66a9ee Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:14:42 +0200 Subject: [PATCH 12/15] Update src/spikeinterface/widgets/gtstudy.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/gtstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 0b4a5182a5..20de41c08a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -135,7 +135,7 @@ class StudyPerformances(BaseWidget): A study object. mode: str Which mode in "ordered", "snr", "swarm" - performance_names: list + performance_names: list or tuple, default: ('accuracy', 'precision', 'recall') Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None A selection of cases to plot, if None, then all. From 8c02176cf2b247a6654384fc296e5e6db2544b87 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:49:41 +0200 Subject: [PATCH 13/15] Remove debug plotsgit add src/spikeinterface/sortingcomponents/clustering/merge.py ! --- .../sortingcomponents/clustering/merge.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4c79383542..1ed51fb04f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -20,6 +20,9 @@ from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse +DEBUG = False + + def merge_clusters( peaks, peak_labels, @@ -81,7 +84,6 @@ def merge_clusters( **job_kwargs, ) - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -224,8 +226,6 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -233,8 +233,6 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -551,15 +549,7 @@ def merge( else: final_shift = 0 - # DEBUG = True - DEBUG = False - - # if DEBUG and is_merge: - # if DEBUG and (overlap > 0.1 and overlap <0.3): if DEBUG: - # if DEBUG and not is_merge: - # if DEBUG and (overlap > 0.05 and overlap <0.25): - # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -674,8 +664,6 @@ def merge( final_shift = 0 merge_value = np.nan - # DEBUG = False - DEBUG = True if DEBUG and normed_diff < 0.2: # if DEBUG: From 213da0af5b346530f881de32b2d82fd97ef05886 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:56:53 +0200 Subject: [PATCH 14/15] Extend docstring --- src/spikeinterface/widgets/gtstudy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 20de41c08a..cb7eaf9df1 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -133,8 +133,12 @@ class StudyPerformances(BaseWidget): ---------- study: GroundTruthStudy A study object. - mode: str - Which mode in "ordered", "snr", "swarm" + mode: "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy (default) + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) performance_names: list or tuple, default: ('accuracy', 'precision', 'recall') Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None From dba32edbc0c4d38066e4e1715115536a3705b48a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Oct 2023 12:59:14 +0200 Subject: [PATCH 15/15] Update src/spikeinterface/widgets/gtstudy.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/widgets/gtstudy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index cb7eaf9df1..6e4433ee60 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -240,8 +240,6 @@ class StudyAgreementMatrix(BaseWidget): ordered: bool Order units with best agreement scores. This enable to see agreement on a diagonal. - count_text: bool - If True counts are displayed as text """ def __init__(