diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index cd4bdb6cb3..9867bc6a36 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,11 +1,6 @@ import numpy as np from .base import BaseWidget, to_attr -from .utils import get_unit_colors - -from ..core import ChannelSparsity -from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class StudyRunTimesWidget(BaseWidget): @@ -129,7 +124,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.legend() -# TODO : plot optionally average on some levels using group by class StudyPerformances(BaseWidget): """ Plot performances over case for a study. @@ -140,16 +134,17 @@ class StudyPerformances(BaseWidget): study: GroundTruthStudy A study object. mode: str - Which mode in "swarm" + Which mode in "ordered", "snr", "swarm" + performance_names: list + Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None A selection of cases to plot, if None, then all. - """ def __init__( self, study, - mode="swarm", + mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None, backend=None, @@ -230,3 +225,106 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) +class StudyAgreementMatrix(BaseWidget): + """ + Plot agreement matrix. + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + ordered: bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + count_text: bool + If True counts are displayed as text + """ + + def __init__( + self, + study, + ordered=True, count_text=True, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + case_keys=case_keys, + ordered=ordered, + count_text=count_text, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .comparison import AgreementMatrixWidget + + dp = to_attr(data_plot) + study = dp.study + + backend_kwargs["num_axes"] = len(dp.case_keys) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for count, key in enumerate(dp.case_keys): + ax = self.axes.flatten()[count] + comp = study.comparisons[key] + AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax) + label = study.cases[key]["label"] + ax.set_title(label) + + +class StudySummary(BaseWidget): + """ + Plot a summary of a ground truth study. + Internally do: + plot_study_run_times + plot_study_unit_counts + plot_study_performances + plot_study_agreement_matrix + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + study = data_plot["study"] + case_keys = data_plot["case_keys"] + + + StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) + StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) + StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 52aa76165d..b95e92668a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -29,7 +29,7 @@ from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget -from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ @@ -66,6 +66,8 @@ StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, + StudyAgreementMatrix, + StudySummary, StudyComparisonCollisionBySimilarityWidget, ] @@ -135,6 +137,8 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances +plot_study_agreement_matrix = StudyAgreementMatrix +plot_study_summary = StudySummary plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget