Skip to content

Commit

Permalink
Improve widgets for GTStudy
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Oct 25, 2023
1 parent d6330cc commit 79ac8d1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 10 deletions.
116 changes: 107 additions & 9 deletions src/spikeinterface/widgets/gtstudy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import numpy as np

from .base import BaseWidget, to_attr
from .utils import get_unit_colors

from ..core import ChannelSparsity
from ..core.waveform_extractor import WaveformExtractor
from ..core.basesorting import BaseSorting


class StudyRunTimesWidget(BaseWidget):
Expand Down Expand Up @@ -129,7 +124,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.ax.legend()


# TODO : plot optionally average on some levels using group by
class StudyPerformances(BaseWidget):
"""
Plot performances over case for a study.
Expand All @@ -140,16 +134,17 @@ class StudyPerformances(BaseWidget):
study: GroundTruthStudy
A study object.
mode: str
Which mode in "swarm"
Which mode in "ordered", "snr", "swarm"
performance_names: list
Which performances to plot ("accuracy", "precision", "recall")
case_keys: list or None
A selection of cases to plot, if None, then all.
"""

def __init__(
self,
study,
mode="swarm",
mode="ordered",
performance_names=("accuracy", "precision", "recall"),
case_keys=None,
backend=None,
Expand Down Expand Up @@ -230,3 +225,106 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True)


class StudyAgreementMatrix(BaseWidget):
"""
Plot agreement matrix.
Parameters
----------
study: GroundTruthStudy
A study object.
case_keys: list or None
A selection of cases to plot, if None, then all.
ordered: bool
Order units with best agreement scores.
This enable to see agreement on a diagonal.
count_text: bool
If True counts are displayed as text
"""

def __init__(
self,
study,
ordered=True, count_text=True,
case_keys=None,
backend=None,
**backend_kwargs,
):
if case_keys is None:
case_keys = list(study.cases.keys())

plot_data = dict(
study=study,
case_keys=case_keys,
ordered=ordered,
count_text=count_text,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure
from .comparison import AgreementMatrixWidget

dp = to_attr(data_plot)
study = dp.study

backend_kwargs["num_axes"] = len(dp.case_keys)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

for count, key in enumerate(dp.case_keys):
ax = self.axes.flatten()[count]
comp = study.comparisons[key]
AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax)
label = study.cases[key]["label"]
ax.set_title(label)


class StudySummary(BaseWidget):
"""
Plot a summary of a ground truth study.
Internally do:
plot_study_run_times
plot_study_unit_counts
plot_study_performances
plot_study_agreement_matrix
Parameters
----------
study: GroundTruthStudy
A study object.
case_keys: list or None
A selection of cases to plot, if None, then all.
"""

def __init__(
self,
study,
case_keys=None,
backend=None,
**backend_kwargs,
):
if case_keys is None:
case_keys = list(study.cases.keys())

plot_data = dict(
study=study,
case_keys=case_keys,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

study = data_plot["study"]
case_keys = data_plot["case_keys"]


StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs)
StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs)
StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs)
StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs)
StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs)
6 changes: 5 additions & 1 deletion src/spikeinterface/widgets/widget_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .unit_waveforms_density_map import UnitWaveformDensityMapWidget
from .unit_waveforms import UnitWaveformsWidget
from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget
from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances
from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary
from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget

widget_list = [
Expand Down Expand Up @@ -66,6 +66,8 @@
StudyRunTimesWidget,
StudyUnitCountsWidget,
StudyPerformances,
StudyAgreementMatrix,
StudySummary,
StudyComparisonCollisionBySimilarityWidget,
]

Expand Down Expand Up @@ -135,6 +137,8 @@
plot_study_run_times = StudyRunTimesWidget
plot_study_unit_counts = StudyUnitCountsWidget
plot_study_performances = StudyPerformances
plot_study_agreement_matrix = StudyAgreementMatrix
plot_study_summary = StudySummary
plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget


Expand Down

0 comments on commit 79ac8d1

Please sign in to comment.