diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7153a7dfc0..e907976163 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black files: ^src/ diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 7860c605af..329a2b32b0 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -337,8 +337,8 @@ print('Units in agreement between TDC, SC2, and KS2:', sorting_agreement.get_unit_ids()) -w_multi = sw.plot_multicomp_agreement(comp_multi) -w_multi = sw.plot_multicomp_agreement_by_sorter(comp_multi) +w_multi = sw.plot_multicomparison_agreement(comp_multi) +w_multi = sw.plot_multicomparison_agreement_by_sorter(comp_multi) # - # We see that 10 unit were found by all sorters (note that this simulated dataset is a very simple example, and usually sorters do not do such a great job)! diff --git a/examples/modules_gallery/widgets/plot_2_sort_gallery.py b/examples/modules_gallery/widgets/plot_2_sort_gallery.py index 2a5ef30e28..c07070927b 100644 --- a/examples/modules_gallery/widgets/plot_2_sort_gallery.py +++ b/examples/modules_gallery/widgets/plot_2_sort_gallery.py @@ -16,13 +16,13 @@ ############################################################################## # plot_rasters() -# ~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~ w_rs = sw.plot_rasters(sorting) ############################################################################## # plot_isi_distribution() -# ~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~ w_isi = sw.plot_isi_distribution(sorting, window_ms=150.0, bin_ms=5.0) @@ -43,10 +43,10 @@ ############################################################################## -# plot_presence() -# ~~~~~~~~~~~~~~~~~~~~~~~~ +# plot_unit_presence() +# ~~~~~~~~~~~~~~~~~~~~ -w_pr = sw.plot_presence(sorting) +w_pr = sw.plot_unit_presence(sorting) plt.show() diff --git a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py index addd87c065..60733afb1d 100644 --- a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py +++ b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py @@ -44,14 +44,14 @@ ############################################################################## # This "peaks" vector can be used in several widgets, for instance -# plot_peak_activity_map() +# plot_peak_activity() -si.plot_peak_activity_map(rec_filtred, peaks=peaks) +si.plot_peak_activity(rec_filtred, peaks=peaks) ############################################################################## # can be also animated with bin_duration_s=1. -si.plot_peak_activity_map(rec_filtred, bin_duration_s=1.) +si.plot_peak_activity(rec_filtred, bin_duration_s=1.) plt.show() diff --git a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py index 0588cca66d..7e4c9ac77b 100644 --- a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py @@ -69,13 +69,6 @@ def test_compare_multiple_sorters(): msc = MultiSortingComparison.load_from_folder(multicomparison_folder) - # import spikeinterface.widgets as sw - # import matplotlib.pyplot as plt - # sw.plot_multicomp_graph(msc) - # sw.plot_multicomp_agreement(msc) - # sw.plot_multicomp_agreement_by_sorter(msc) - # plt.show() - def test_compare_multi_segment(): num_segments = 3 diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index dc84d31987..44ea02d32c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -8,7 +8,7 @@ from .numpyextractors import NumpyRecording, NumpySorting from .basesorting import minimum_spike_dtype -from probeinterface import Probe, generate_linear_probe +from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting @@ -93,7 +93,6 @@ def generate_recording( probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) - probe = generate_linear_probe(num_elec=num_channels) return recording @@ -122,7 +121,7 @@ def generate_sorting( durations=[10.325, 3.5], #  in s for 2 segments firing_rates=3.0, empty_units=None, - refractory_period_ms=3.0, # in ms + refractory_period_ms=4.0, # in ms add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, @@ -143,7 +142,7 @@ def generate_sorting( The firing rate of each unit (in Hz). empty_units : list, default: None List of units that will have no spikes. (used for testing mainly). - refractory_period_ms : float, default: 3.0 + refractory_period_ms : float, default: 4.0 The refractory period in ms add_spikes_on_borders : bool, default: False If True, spikes will be added close to the borders of the segments. @@ -881,9 +880,10 @@ def generate_single_fake_waveform( depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), - positive_amplitude=(0.05, 0.15), + positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), + propagation_speed=(250.0, 350.0), # um / ms ) @@ -931,13 +931,14 @@ def generate_templates( An optional dict containing parameters per units. Keys are parameter names: - * 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000)) + * 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000)) * 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14)) * 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8)) * 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5)) * 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) * 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07)) * 'decay_power': the decay power (default range: (1.2-1.8)) + * 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values contains vector with same size of num_units. If the key is not in dict then it is generated using unit_params_range unit_params_range: dict of tuple @@ -985,12 +986,16 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - v = rng.random(num_units) if k in unit_params_range: - lim0, lim1 = unit_params_range[k] + lims = unit_params_range[k] + else: + lims = default_unit_params_range[k] + if lims is not None: + lim0, lim1 = lims + v = rng.random(num_units) + params[k] = v * (lim1 - lim0) + lim0 else: - lim0, lim1 = default_unit_params_range[k] - params[k] = v * (lim1 - lim0) + lim0 + params[k] = [None] * num_units for u in range(num_units): wf = generate_single_fake_waveform( @@ -1006,17 +1011,42 @@ def generate_templates( dtype=dtype, ) + ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors eps = 1.0 # naive formula for spatial decay pow = params["decay_power"][u] channel_factors = alpha / (distances[u, :] + eps) ** pow + wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + + # This mimic a propagation delay for distant channel + propagation_speed = params["propagation_speed"][u] + if propagation_speed is not None: + # the speed is um/ms + dist = distances[u, :].copy() + dist -= np.min(dist) + delay_s = dist / propagation_speed / 1000.0 + sample_shifts = delay_s * fs + + # apply the delay with fft transform to get sub sample shift + n = wfs.shape[0] + wfs_f = np.fft.rfft(wfs, axis=0) + if n % 2 == 0: + # n is even sig_f[-1] is nyquist and so pi + omega = np.linspace(0, np.pi, wfs_f.shape[0]) + else: + # n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!! + omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0]) + # broadcast omega and sample_shifts depend the axis + shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :] + wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0) + if upsample_factor is not None: for f in range(upsample_factor): - templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :, f] = wfs[f::upsample_factor] else: - templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :] = wfs return templates @@ -1322,12 +1352,19 @@ def generate_ground_truth_recording( num_units=10, sorting=None, probe=None, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), templates=None, ms_before=1.0, ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), @@ -1350,7 +1387,9 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of linear probe is generated. + An external Probe object. If not provided a probe is generated using generate_probe_kwargs. + generate_probe_kwargs: dict + A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. @@ -1407,8 +1446,28 @@ def generate_ground_truth_recording( num_spikes = sorting.to_spike_vector().size if probe is None: - probe = generate_linear_probe(num_elec=num_channels) + # probe = generate_linear_probe(num_elec=num_channels) + # probe.set_device_channel_indices(np.arange(num_channels)) + + prb_kwargs = generate_probe_kwargs.copy() + if "num_contact_per_column" in prb_kwargs: + assert ( + prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] + ) == num_channels, ( + "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" + ) + elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: + n = num_channels // prb_kwargs["num_columns"] + num_contact_per_column = [n] * prb_kwargs["num_columns"] + mid = prb_kwargs["num_columns"] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] + prb_kwargs["num_contact_per_column"] = num_contact_per_column + else: + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + + probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) + else: num_channels = probe.get_contact_count() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 4efabbc9c5..d961bdbc07 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -6,12 +6,9 @@ from spikeinterface.qualitymetrics import compute_quality_metrics from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth from spikeinterface.widgets import ( - plot_sorting_performance, plot_agreement_matrix, plot_comparison_collision_by_similarity, - plot_unit_templates, plot_unit_waveforms, - plot_gt_performances, ) import time @@ -474,13 +471,12 @@ def plot(self, comp, title=None): ax = axs[1, 0] ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) - plot_sorting_performance(comp, self.metrics, performance_name="accuracy", metric_name="snr", ax=ax, color="r") - plot_sorting_performance(comp, self.metrics, performance_name="recall", metric_name="snr", ax=ax, color="g") - plot_sorting_performance(comp, self.metrics, performance_name="precision", metric_name="snr", ax=ax, color="b") - ax.legend(["accuracy", "recall", "precision"]) - ax = axs[1, 1] - plot_gt_performances(comp, ax=ax) + for k in ("accuracy", "recall", "precision"): + x = comp.get_performance()[k] + y = self.metrics["snr"] + ax.scatter(x, y, markersize=10, marker=".", label=k) + ax.legend() ax = axs[0, 1] if self.exhaustive_gt: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index b28b29f17c..d0b3f387b1 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -10,7 +10,6 @@ from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten from spikeinterface.sorters import run_sorter, read_sorter_folder -from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances from spikeinterface.comparison import GroundTruthComparison from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index d3066f51fa..d6d181f3fe 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -3,9 +3,3 @@ # general functions from .utils import get_some_colors, get_unit_colors, array_to_image from .base import set_default_plotter_backend, get_default_plotter_backend - - -# we keep this to keep compatibility so we have all previous widgets -# except the one that have been ported that are imported -# with "from .widget_list import *" in the first line -from ._legacy_mpl_widgets import * diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py deleted file mode 100644 index 53c2a5c79e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# peak activity -from .activity import plot_peak_activity_map, PeakActivityMapWidget - - -from .multicompgraph import ( - plot_multicomp_graph, - MultiCompGraphWidget, - plot_multicomp_agreement, - MultiCompGlobalAgreementWidget, - plot_multicomp_agreement_by_sorter, - MultiCompAgreementBySorterWidget, -) -from .collisioncomp import ( - plot_comparison_collision_pair_by_pair, - ComparisonCollisionPairByPairWidget, - plot_comparison_collision_by_similarity, - ComparisonCollisionBySimilarityWidget, - plot_study_comparison_collision_by_similarity, - StudyComparisonCollisionBySimilarityWidget, - plot_study_comparison_collision_by_similarity_range, - StudyComparisonCollisionBySimilarityRangeWidget, - StudyComparisonCollisionBySimilarityRangesWidget, - plot_study_comparison_collision_by_similarity_ranges, -) - -from .sortingperformance import plot_sorting_performance - -# ground truth comparions (=comparison over sorter) -from .gtcomparison import ( - plot_gt_performances, - plot_gt_performances_averages, - ComparisonPerformancesWidget, - ComparisonPerformancesAveragesWidget, - plot_gt_performances_by_template_similarity, - ComparisonPerformancesByTemplateSimilarity, -) - - -# unit presence -from .presence import plot_presence, PresenceWidget - -# correlogram comparison -from .correlogramcomp import ( - StudyComparisonCorrelogramBySimilarityWidget, - plot_study_comparison_correlogram_by_similarity, - StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget, - plot_study_comparison_correlogram_by_similarity_ranges_mean_error, -) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py deleted file mode 100644 index f69f9a1b0f..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np - - -# This class replace the old BaseWidget and BaseMultiWidget -class BaseWidget: - def __init__(self, figure=None, ax=None, axes=None, ncols=None, num_axes=None): - """ - figure/ax/axes : only one of then can be not None - """ - import matplotlib.pyplot as plt - - from matplotlib import gridspec - - if figure is not None: - assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - elif ax is not None: - assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" - figure = ax.get_figure() - axes = np.array([[ax]]) - elif axes is not None: - assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" - axes = np.asarray(axes) - figure = axes.flatten()[0].get_figure() - else: - # one fig with one ax - if num_axes is None: - figure, ax = plt.subplots() - axes = np.array([[ax]]) - else: - if num_axes == 0: - # one figure without plots (diffred subplot creation with - figure = plt.figure() - ax = None - axes = None - elif num_axes == 1: - figure = plt.figure() - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - if num_axes < ncols: - ncols = num_axes - nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots( - nrows=nrows, - ncols=ncols, - ) - ax = None - # remove extra axes - if ncols * nrows > num_axes: - for extra_ax in axes.flatten()[num_axes:]: - extra_ax.remove() - - self.figure = figure - self.ax = ax - # axes is a 2D array of ax - self.axes = axes - # self.figure.axes is the flatten of all axes - - -class DataWidget: - def __init__(self) -> None: - self.plotter = None - pass - - def _prepare_data(self): - raise NotImplementedError - - -# keep here just in case it is needed - -# def create_tiled_ax(self, i, nrows, ncols, hspace=0.3, wspace=0.3, is_diag=False): -# gs = gridspec.GridSpecFromSubplotSpec(int(nrows), int(ncols), subplot_spec=self.ax, -# hspace=hspace, wspace=wspace) -# r = int(i // ncols) -# c = int(np.mod(i, ncols)) -# gs_sel = gs[r, c] -# ax = self.figure.add_subplot(gs_sel) -# return ax diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py deleted file mode 100644 index 468b96ff3b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ /dev/null @@ -1,503 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ComparisonCollisionPairByPairWidget(BaseWidget): - """ - Plots CollisionGTComparison pair by pair. - - Parameters - ---------- - comp: CollisionGTComparison - The collision ground truth comparison object - unit_ids: list - List of considered units - nbins: int - Number of bins - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__(self, comp, unit_ids=None, figure=None, ax=None): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, figure, ax) - if unit_ids is None: - # take all units - unit_ids = comp.sorting1.get_unit_ids() - - self.comp = comp - self.unit_ids = unit_ids - - def plot(self): - self._do_plot() - - def _do_plot(self): - from matplotlib import pyplot as plt - - fig = self.figure - - for ax in fig.axes: - ax.remove() - - n = len(self.unit_ids) - gs = gridspec.GridSpec(ncols=n, nrows=n, figure=fig) - - axs = np.empty((n, n), dtype=object) - ax = None - for r in range(n): - for c in range(n): - ax = fig.add_subplot(gs[r, c], sharex=ax, sharey=ax) - if c > 0: - plt.setp(ax.get_yticklabels(), visible=False) - if r < n - 1: - plt.setp(ax.get_xticklabels(), visible=False) - axs[r, c] = ax - - fs = self.comp.sorting1.get_sampling_frequency() - - lags = self.comp.bins / fs * 1000 - width = lags[1] - lags[0] - - for r in range(n): - for c in range(r + 1, n): - ax = axs[r, c] - - u1 = self.unit_ids[r] - u2 = self.unit_ids[c] - ind1 = self.comp.sorting1.id_to_index(u1) - ind2 = self.comp.sorting1.id_to_index(u2) - - tp = self.comp.all_tp[ind1, ind2, :] - fn = self.comp.all_fn[ind1, ind2, :] - ax.bar(lags[:-1], tp, width=width, color="g", align="edge") - ax.bar(lags[:-1], fn, width=width, bottom=tp, color="r", align="edge") - - ax = axs[c, r] - tp = self.comp.all_tp[ind2, ind1, :] - fn = self.comp.all_fn[ind2, ind1, :] - ax.bar(lags[:-1], tp, width=width, color="g", align="edge") - ax.bar(lags[:-1], fn, width=width, bottom=tp, color="r", align="edge") - - for r in range(n): - ax = axs[r, 0] - u1 = self.unit_ids[r] - ax.set_ylabel(f"gt id{u1}") - - for c in range(n): - ax = axs[0, c] - u2 = self.unit_ids[c] - ax.set_title(f"collision with \ngt id{u2}") - - ax = axs[-1, 0] - ax.set_xlabel("collision lag [ms]") - - -class ComparisonCollisionBySimilarityWidget(BaseWidget): - """ - Plots CollisionGTComparison pair by pair orderer by cosine_similarity - - Parameters - ---------- - comp: CollisionGTComparison - The collision ground truth comparison object - templates: array - template of units - mode: 'heatmap' or 'lines' - to see collision curves for every pairs ('heatmap') or as lines averaged over pairs. - similarity_bins: array - if mode is 'lines', the bins used to average the pairs - cmap: string - colormap used to show averages if mode is 'lines' - metric: 'cosine_similarity' - metric for ordering - good_only: True - keep only the pairs with a non zero accuracy (found templates) - min_accuracy: float - If good only, the minimum accuracy every cell should have, individually, to be - considered in a putative pair - unit_ids: list - List of considered units - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - """ - - def __init__( - self, - comp, - templates, - unit_ids=None, - metric="cosine_similarity", - figure=None, - ax=None, - mode="heatmap", - similarity_bins=np.linspace(-0.4, 1, 8), - cmap="winter", - good_only=True, - min_accuracy=0.9, - show_legend=False, - ylim=(0, 1), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, figure, ax) - - assert mode in ["heatmap", "lines"] - - if unit_ids is None: - # take all units - unit_ids = comp.sorting1.get_unit_ids() - - self.comp = comp - self.cmap = cmap - self.mode = mode - self.ylim = ylim - self.show_legend = show_legend - self.similarity_bins = similarity_bins - self.templates = templates - self.unit_ids = unit_ids - self.metric = metric - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - self._do_plot() - - def _do_plot(self): - import sklearn - import matplotlib.pyplot as plt - import matplotlib - - # compute similarity - # take index of template (respect unit_ids order) - all_unit_ids = list(self.comp.sorting1.get_unit_ids()) - template_inds = [all_unit_ids.index(u) for u in self.unit_ids] - - templates = self.templates[template_inds, :, :].copy() - flat_templates = templates.reshape(templates.shape[0], -1) - if self.metric == "cosine_similarity": - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - else: - raise NotImplementedError("metric=...") - - fs = self.comp.sorting1.get_sampling_frequency() - lags = self.comp.bins / fs * 1000 - - n = len(self.unit_ids) - - similarities, recall_scores, pair_names = self.comp.compute_collision_by_similarity( - similarity_matrix, unit_ids=self.unit_ids, good_only=self.good_only, min_accuracy=self.min_accuracy - ) - - if self.mode == "heatmap": - fig = self.figure - for ax in fig.axes: - ax.remove() - - n_pair = len(similarities) - - ax0 = fig.add_axes([0.1, 0.1, 0.25, 0.8]) - ax1 = fig.add_axes([0.4, 0.1, 0.5, 0.8], sharey=ax0) - - plt.setp(ax1.get_yticklabels(), visible=False) - - im = ax1.imshow( - recall_scores[::-1, :], - cmap="viridis", - aspect="auto", - interpolation="none", - extent=(lags[0], lags[-1], -0.5, n_pair - 0.5), - ) - im.set_clim(0, 1) - - ax0.plot(similarities, np.arange(n_pair), color="k") - - ax0.set_yticks(np.arange(n_pair)) - ax0.set_yticklabels(pair_names) - # ax0.set_xlim(0,1) - - ax0.set_xlabel(self.metric) - ax0.set_ylabel("pairs") - - ax1.set_xlabel("lag (ms)") - elif self.mode == "lines": - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - - # plot by similarity bins - if self.ax is None: - fig, ax = plt.subplots() - else: - ax = self.ax - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - order = np.argsort(similarities) - similarities = similarities[order] - recall_scores = recall_scores[order, :] - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - - amin, amax = np.searchsorted(similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(recall_scores[amin:amax], axis=0) - - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, - mean_recall_scores, - label="CS in [%g,%g]" % (cmin, cmax), - c=colorVal, - ) - - if self.show_legend: - ax.legend() - ax.set_ylim(self.ylim) - ax.set_xlabel("lags (ms)") - ax.set_ylabel("collision accuracy") - - -class StudyComparisonCollisionBySimilarityWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_bins=np.linspace(-0.4, 1, 8), - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ncols=3, - axes=None, - cmap="winter", - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - if axes is None: - num_axes = len(study.sorter_names) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - self.ncols = ncols - self.study = study - self.metric = metric - self.cmap = cmap - self.similarity_bins = np.asarray(similarity_bins) - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - curves = self.study.get_lag_profile_over_similarity_bins(self.similarity_bins, sorter_name) - - # plot by similarity bins - ax = self.axes.flatten()[sorter_ind] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, - curves[(cmin, cmax)], - label="CS in [%g,%g]" % (cmin, cmax), - c=colorVal, - ) - - if np.mod(sorter_ind, self.ncols) == 0: - ax.set_ylabel("collision accuracy") - - if sorter_ind > (len(self.study.sorter_names) // self.ncols): - ax.set_xlabel("lags (ms)") - - ax.set_title(sorter_name) - if self.show_legend: - ax.legend() - - if self.ylim is not None: - ax.set_ylim(self.ylim) - - -class StudyComparisonCollisionBySimilarityRangeWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_range=[0, 1], - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ax=None, - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.similarity_range = similarity_range - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - mean_recall_scores = self.study.get_mean_over_similarity_range(self.similarity_range, sorter_name) - self.ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, mean_recall_scores, label=sorter_name, c="C%d" % sorter_ind - ) - - self.ax.set_ylabel("collision accuracy") - self.ax.set_xlabel("lags (ms)") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -class StudyComparisonCollisionBySimilarityRangesWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_ranges=np.linspace(-0.4, 1, 8), - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ax=None, - show_std=False, - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.similarity_ranges = similarity_ranges - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.show_std = show_std - self.min_accuracy = min_accuracy - - def plot(self): - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - all_similarities = self.study.all_similarities[sorter_name] - all_recall_scores = self.study.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = [] - std_recall_scores = [] - for i in range(self.similarity_ranges.size - 1): - cmin, cmax = self.similarity_ranges[i], self.similarity_ranges[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores += [np.nanmean(all_recall_scores[amin:amax])] - std_recall_scores += [np.nanstd(all_recall_scores[amin:amax])] - - xaxis = np.diff(self.similarity_ranges) / 2 + self.similarity_ranges[:-1] - - if not self.show_std: - self.ax.plot(xaxis, mean_recall_scores, label=sorter_name, c="C%d" % sorter_ind) - else: - self.ax.errorbar( - xaxis, mean_recall_scores, yerr=std_recall_scores, label=sorter_name, c="C%d" % sorter_ind - ) - - self.ax.set_ylabel("collision accuracy") - self.ax.set_xlabel("similarity") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_comparison_collision_pair_by_pair(*args, **kwargs): - W = ComparisonCollisionPairByPairWidget(*args, **kwargs) - W.plot() - return W - - -plot_comparison_collision_pair_by_pair.__doc__ = ComparisonCollisionPairByPairWidget.__doc__ - - -def plot_comparison_collision_by_similarity(*args, **kwargs): - W = ComparisonCollisionBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_comparison_collision_by_similarity.__doc__ = ComparisonCollisionBySimilarityWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity.__doc__ = StudyComparisonCollisionBySimilarityWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity_range(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityRangeWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity_range.__doc__ = StudyComparisonCollisionBySimilarityRangeWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity_ranges(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityRangesWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity_ranges.__doc__ = StudyComparisonCollisionBySimilarityRangesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py deleted file mode 100644 index d224686b3a..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py +++ /dev/null @@ -1,154 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class StudyComparisonCorrelogramBySimilarityWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_bins=np.linspace(-0.4, 1, 8), - show_legend=False, - ncols=3, - axes=None, - cmap="winter", - ylim=(0, 0.5), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - if axes is None: - num_axes = len(study.sorter_names) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - self.ncols = ncols - self.cmap = cmap - self.study = study - self.metric = metric - self.similarity_bins = np.asarray(similarity_bins) - self.show_legend = show_legend - self.ylim = ylim - - def plot(self): - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - - self.study.precompute_scores_by_similarities() - time_bins = self.study.time_bins - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - result = self.study.get_error_profile_over_similarity_bins(self.similarity_bins, sorter_name) - - # plot by similarity bins - ax = self.axes.flatten()[sorter_ind] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot(time_bins, result[(cmin, cmax)], label="CS in [%g,%g]" % (cmin, cmax), c=colorVal) - - if np.mod(sorter_ind, self.ncols) == 0: - ax.set_ylabel("cc error") - - if sorter_ind >= (len(self.study.sorter_names) // self.ncols): - ax.set_xlabel("lags (ms)") - - ax.set_title(sorter_name) - if self.show_legend: - ax.legend() - - if self.ylim is not None: - ax.set_ylim(self.ylim) - - -class StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_ranges=np.linspace(-0.4, 1, 8), - show_legend=False, - ax=None, - show_std=False, - ylim=(0, 0.5), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.show_std = show_std - self.ylim = None - self.similarity_ranges = np.asarray(similarity_ranges) - self.show_legend = show_legend - - def plot(self): - self.study.precompute_scores_by_similarities() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - all_similarities = self.study.all_similarities[sorter_name] - all_errors = self.study.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - mean_rerrors = [] - std_errors = [] - for i in range(self.similarity_ranges.size - 1): - cmin, cmax = self.similarity_ranges[i], self.similarity_ranges[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_rerrors += [np.nanmean(all_errors[amin:amax])] - std_errors += [np.nanstd(all_errors[amin:amax])] - - xaxis = np.diff(self.similarity_ranges) / 2 + self.similarity_ranges[:-1] - - if not self.show_std: - self.ax.plot(xaxis, mean_rerrors, label=sorter_name, c="C%d" % sorter_ind) - else: - self.ax.errorbar(xaxis, mean_rerrors, yerr=std_errors, label=sorter_name, c="C%d" % sorter_ind) - - self.ax.set_ylabel("cc error") - self.ax.set_xlabel("similarity") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_study_comparison_correlogram_by_similarity(*args, **kwargs): - W = StudyComparisonCorrelogramBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_correlogram_by_similarity.__doc__ = StudyComparisonCorrelogramBySimilarityWidget.__doc__ - -# def plot_study_comparison_Correlogram_by_similarity_range(*args, **kwargs): -# W = StudyComparisonCorrelogramBySimilarityRangeWidget(*args, **kwargs) -# W.plot() -# return W -# plot_study_comparison_Correlogram_by_similarity_range.__doc__ = StudyComparisonCorrelogramBySimilarityRangeWidget.__doc__ - - -def plot_study_comparison_correlogram_by_similarity_ranges_mean_error(*args, **kwargs): - W = StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_correlogram_by_similarity_ranges_mean_error.__doc__ = ( - StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget.__doc__ -) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py deleted file mode 100644 index a58a8e2e37..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performancess - * count units -""" -import numpy as np - -from .basewidget import BaseWidget - - -class ComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, gt_comp, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.palette = palette - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - ax = self.ax - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = self.gt_comp.get_performance() - perf_by_units = perf_by_units.reset_index() - - df = pd.melt( - perf_by_units, var_name="Metric", value_name="Score", value_vars=("accuracy", "precision", "recall") - ) - import seaborn as sns - - sns.swarmplot(data=df, x="Metric", y="Score", hue="Metric", dodge=True, s=3, ax=ax) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Performance") - - -class ComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, gt_comp, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - ax = self.ax - - perf_by_units = self.gt_comp.get_performance() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean", "std"] - - data = perf_by_units.agg(to_agg) - - m = data.mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - stds = data.std() - - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = 1 + c / (ncol + 2) - yerr = stds[col] - ax.bar(x, m[col], yerr=yerr, width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - ax.set_ylabel("metric") - # ax.set_xlim(0, 1) - - -class ComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - similarity_matrix: matrix - The similarity between the templates in the gt recording and the ones - found by the sorter - ax: matplotlib ax - The ax to be used. If not given a figure is created - - """ - - def __init__(self, gt_comp, similarity_matrix, ax=None, ylim=(0.6, 1)): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.similarity_matrix = similarity_matrix - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - all_results = {"similarity": [], "accuracy": []} - comp = self.gt_comp - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results["similarity"] += [ - self.similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results["accuracy"] += [comp.agreement_scores.at[u1, u2]] - - all_results["similarity"] = np.array(all_results["similarity"]) - all_results["accuracy"] = np.array(all_results["accuracy"]) - - self.ax.plot(all_results["similarity"], all_results["accuracy"], ".") - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_gt_performances(*args, **kwargs): - W = ComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances.__doc__ = ComparisonPerformancesWidget.__doc__ - - -def plot_gt_performances_averages(*args, **kwargs): - W = ComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances_averages.__doc__ = ComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_performances_by_template_similarity(*args, **kwargs): - W = ComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances_by_template_similarity.__doc__ = ComparisonPerformancesByTemplateSimilarity.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py deleted file mode 100644 index 863b37eda6..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py +++ /dev/null @@ -1,135 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class PresenceWidget(BaseWidget): - """ - Estimates of the probability density function for each unit using Gaussian kernels, - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - time_pixels: int - Number of samples calculated for each density function - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: PresenceWidget - The output widget - """ - - def __init__( - self, sorting, segment_index=None, unit_ids=None, time_range=None, figure=None, time_pixels=200, ax=None - ): - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - self._time_pixels = time_pixels - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Presence" - - def plot(self): - self._do_plot() - - def _do_plot(self): - import matplotlib.pyplot as plt - from scipy.stats import gaussian_kde - - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - - time_grid = np.linspace(visible_start_frame, visible_end_frame, self._time_pixels) - time_den = [] - - self.ax.grid("both") - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - - if spiketimes[0] != spiketimes[-1]: # not always the same value - time_den.append(gaussian_kde(spiketimes).pdf(time_grid)) - else: - aux = np.zeros_like(time_grid) - aux[np.argmin(np.abs(time_grid - spiketimes))] = 1 - time_den.append(aux) - - self.ax.matshow(np.vstack(time_den), cmap=plt.cm.inferno, aspect="auto") - - self.ax.hlines(np.arange(len(units_ids)) + 0.5, 0, len(time_den[0]), color="k", linewidth=4) - - self.ax.tick_params(axis="y", which="both", grid_linestyle="None") - - self.ax.set_xlim(0, self._time_pixels) - new_labels = [] - self.ax.xaxis.set_ticks_position("bottom") - - for xt in self.ax.get_xticks(): - if xt < self._time_pixels: - new_labels.append("{:.1f}".format(time_grid[int(xt)])) - else: - new_labels.append("{:.1f}".format(visible_end_frame)) - self.ax.set_xticks(self.ax.get_xticks()) - self.ax.set_xticklabels(new_labels) - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlabel("time (s)") - self.ax.set_ylabel("Unit ID") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - trange[1] = self._max_frame - if trange[0] < 0: - trange[0] = 0 - return trange - - -def plot_presence(*args, **kwargs): - W = PresenceWidget(*args, **kwargs) - W.plot() - return W - - -plot_presence.__doc__ = PresenceWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py deleted file mode 100644 index eb5258fa05..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py +++ /dev/null @@ -1,85 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - -from ...comparison import GroundTruthComparison - - -class SortingPerformanceWidget(BaseWidget): - """ - Plots sorting performance for each ground-truth unit. - - Parameters - ---------- - gt_sorting_comparison: GroundTruthComparison - The ground truth sorting comparison object - property_name: str - The property of the sorting extractor to use as x-axis (e.g. snr). - If None, no property is used. - metric: str - The performance metric. 'accuracy' (default), 'precision', 'recall', 'miss rate', etc. - markersize: int - The size of the marker - marker: str - The matplotlib marker to use (default '.') - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: SortingPerformanceWidget - The output widget - """ - - def __init__( - self, - sorting_comparison, - metrics, - performance_name="accuracy", - metric_name="snr", - color="b", - markersize=10, - marker=".", - figure=None, - ax=None, - ): - from matplotlib import pyplot as plt - - assert isinstance( - sorting_comparison, GroundTruthComparison - ), "The 'sorting_comparison' object should be a GroundTruthComparison instance" - BaseWidget.__init__(self, figure, ax) - self.sorting_comparison = sorting_comparison - self.metrics = metrics - self.performance_name = performance_name - self.metric_name = metric_name - self.color = color - self.markersize = markersize - self.marker = marker - - def plot(self): - self._do_plot() - - def _do_plot(self): - comp = self.sorting_comparison - unit_ids = comp.sorting1.get_unit_ids() - perf = comp.get_performance()[self.performance_name] - metric = self.metrics[self.metric_name] - - ax = self.ax - - ax.plot(metric, perf, marker=self.marker, markersize=int(self.markersize), ls="", color=self.color) - ax.set_xlabel(self.metric_name) - ax.set_ylabel(self.performance_name) - ax.set_ylim(0, 1.05) - - -def plot_sorting_performance(*args, **kwargs): - W = SortingPerformanceWidget(*args, **kwargs) - W.plot() - return W - - -plot_sorting_performance.__doc__ = SortingPerformanceWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py deleted file mode 100644 index 98abe3f4fc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -if __name__ != "__main__": - import matplotlib - - matplotlib.use("Agg") - -from spikeinterface import download_dataset -import spikeinterface.extractors as se - -from spikeinterface.widgets.utils import get_unit_colors - - -def test_get_unit_colors(): - local_path = download_dataset(remote_path="mearec/mearec_test_10s.h5") - sorting = se.MEArecSortingExtractor(local_path) - - colors = get_unit_colors(sorting) - print(colors) - - -if __name__ == "__main__": - test_get_unit_colors() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py deleted file mode 100644 index 9cd321db3c..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ /dev/null @@ -1,128 +0,0 @@ -import unittest -import pytest -import sys -from pathlib import Path - -if __name__ != "__main__": - import matplotlib - - matplotlib.use("Agg") -import matplotlib.pyplot as plt - -from spikeinterface import extract_waveforms, load_waveforms, download_dataset -import spikeinterface.extractors as se -import spikeinterface.widgets as sw -import spikeinterface.comparison as sc -from spikeinterface.postprocessing import compute_spike_amplitudes -from spikeinterface.qualitymetrics import compute_quality_metrics - - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "widgets" -else: - cache_folder = Path("cache_folder") / "widgets" - - -class TestWidgets(unittest.TestCase): - def setUp(self): - local_path = download_dataset(remote_path="mearec/mearec_test_10s.h5") - self._rec = se.MEArecRecordingExtractor(local_path) - - self._sorting = se.MEArecSortingExtractor(local_path) - - self.num_units = len(self._sorting.get_unit_ids()) - #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test_old_api").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test_old_api") - else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) - - self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") - self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) - - def tearDown(self): - pass - - # def test_plot_unit_probe_map(self): - # sw.plot_unit_probe_map(self._we, with_channel_ids=True) - # sw.plot_unit_probe_map(self._we, animated=True) - - # def test_plot_units_depth_vs_amplitude(self): - # sw.plot_units_depth_vs_amplitude(self._we) - - # def test_amplitudes_timeseries(self): - # sw.plot_amplitudes_timeseries(self._we) - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_amplitudes_timeseries(self._we, unit_ids=unit_ids) - - # def test_amplitudes_distribution(self): - # sw.plot_amplitudes_distribution(self._we) - - # def test_plot_unit_localization(self): - # sw.plot_unit_localization(self._we, with_channel_ids=True) - # sw.plot_unit_localization(self._we, method='monopolar_triangulation') - - # def test_autocorrelograms(self): - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_autocorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - - # def test_crosscorrelogram(self): - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_crosscorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - - # def test_isi_distribution(self): - # sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) - # fig, axes = plt.subplots(self.num_units, 1) - # sw.plot_isi_distribution(self._sorting, axes=axes) - - def test_plot_peak_activity_map(self): - sw.plot_peak_activity_map(self._rec, with_channel_ids=True) - sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - - def test_multicomp_graph(self): - msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) - sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) - sw.plot_multicomp_agreement(msc) - sw.plot_multicomp_agreement_by_sorter(msc) - fig, axes = plt.subplots(len(msc.object_list), 1) - sw.plot_multicomp_agreement_by_sorter(msc, axes=axes) - - def test_sorting_performance(self): - metrics = compute_quality_metrics(self._we, metric_names=["snr"]) - sw.plot_sorting_performance(self._gt_comp, metrics, performance_name="accuracy", metric_name="snr") - - # ~ def test_plot_unit_summary(self): - # ~ unit_id = self._sorting.unit_ids[4] - # ~ sw.plot_unit_summary(self._we, unit_id) - - -if __name__ == "__main__": - # unittest.main() - - mytest = TestWidgets() - mytest.setUp() - - # ~ mytest.test_timeseries() - # ~ mytest.test_unitwaveforms() - # ~ mytest.test_plot_unit_waveform_density_map() - # mytest.test_unittemplates() - # ~ mytest.test_plot_unit_probe_map() - #  mytest.test_plot_units_depth_vs_amplitude() - # ~ mytest.test_amplitudes_timeseries() - # ~ mytest.test_amplitudes_distribution() - # ~ mytest.test_principal_component() - # ~ mytest.test_plot_unit_localization() - - # ~ mytest.test_autocorrelograms() - # ~ mytest.test_crosscorrelogram() - # ~ mytest.test_isi_distribution() - - # ~ mytest.test_plot_drift_over_time() - # ~ mytest.test_plot_peak_activity_map() - - # ~ mytest.test_multicomp_graph() - #  mytest.test_sorting_performance() - - # ~ mytest.test_plot_unit_summary() - - plt.show() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py deleted file mode 100644 index 6872a8b27b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import matplotlib.pyplot as plt - -import random - -try: - import distinctipy - - HAVE_DISTINCTIPY = True -except ImportError: - HAVE_DISTINCTIPY = False - - -def get_unit_colors(sorting, map_name="gist_ncar", format="RGBA", shuffle=False): - """ - Return a dict colors per units. - """ - possible_formats = ("RGBA",) - assert format in possible_formats, f"format must be {possible_formats}" - - unit_ids = sorting.unit_ids - - if HAVE_DISTINCTIPY: - colors = distinctipy.get_colors(unit_ids.size) - # add the alpha - colors = [color + (1.0,) for color in colors] - else: - # some map have black or white at border so +10 - margin = max(4, len(unit_ids) // 20) // 2 - cmap = plt.get_cmap(map_name, len(unit_ids) + 2 * margin) - - colors = [cmap(i + margin) for i, unit_id in enumerate(unit_ids)] - if shuffle: - random.shuffle(colors) - - dict_colors = dict(zip(unit_ids, colors)) - - return dict_colors diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py new file mode 100644 index 0000000000..2b86a2af2d --- /dev/null +++ b/src/spikeinterface/widgets/collision.py @@ -0,0 +1,287 @@ +import numpy as np + +from .base import BaseWidget, to_attr + + +class ComparisonCollisionBySimilarityWidget(BaseWidget): + """ + Plots CollisionGTComparison pair by pair orderer by cosine_similarity + + Parameters + ---------- + comp: CollisionGTComparison + The collision ground truth comparison object + templates: array + template of units + mode: 'heatmap' or 'lines' + to see collision curves for every pairs ('heatmap') or as lines averaged over pairs. + similarity_bins: array + if mode is 'lines', the bins used to average the pairs + cmap: string + colormap used to show averages if mode is 'lines' + metric: 'cosine_similarity' + metric for ordering + good_only: True + keep only the pairs with a non zero accuracy (found templates) + min_accuracy: float + If good only, the minimum accuracy every cell should have, individually, to be + considered in a putative pair + unit_ids: list + List of considered units + """ + + def __init__( + self, + comp, + templates, + unit_ids=None, + metric="cosine_similarity", + figure=None, + ax=None, + mode="heatmap", + similarity_bins=np.linspace(-0.4, 1, 8), + cmap="winter", + good_only=False, + min_accuracy=0.9, + show_legend=False, + ylim=(0, 1), + backend=None, + **backend_kwargs, + ): + if unit_ids is None: + unit_ids = comp.sorting1.get_unit_ids() + + data_plot = dict( + comp=comp, + templates=templates, + unit_ids=unit_ids, + metric=metric, + mode=mode, + similarity_bins=similarity_bins, + cmap=cmap, + good_only=good_only, + min_accuracy=min_accuracy, + show_legend=show_legend, + ylim=ylim, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import sklearn + import matplotlib + + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.comp + + # compute similarity + # take index of template (respect unit_ids order) + all_unit_ids = list(comp.sorting1.get_unit_ids()) + template_inds = [all_unit_ids.index(u) for u in dp.unit_ids] + + templates = dp.templates[template_inds, :, :].copy() + flat_templates = templates.reshape(templates.shape[0], -1) + if dp.metric == "cosine_similarity": + similarity_matrix = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + else: + raise NotImplementedError("metric=...") + + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000 + + n = len(dp.unit_ids) + + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity_matrix, unit_ids=dp.unit_ids, good_only=dp.good_only, min_accuracy=dp.min_accuracy + ) + + if dp.mode == "heatmap": + fig = self.figure + for ax in fig.axes: + ax.remove() + + n_pair = len(similarities) + + ax0 = fig.add_axes([0.1, 0.1, 0.25, 0.8]) + ax1 = fig.add_axes([0.4, 0.1, 0.5, 0.8], sharey=ax0) + + plt.setp(ax1.get_yticklabels(), visible=False) + + im = ax1.imshow( + recall_scores[::-1, :], + cmap="viridis", + aspect="auto", + interpolation="none", + extent=(lags[0], lags[-1], -0.5, n_pair - 0.5), + ) + im.set_clim(0, 1) + + ax0.plot(similarities, np.arange(n_pair), color="k") + + ax0.set_yticks(np.arange(n_pair)) + ax0.set_yticklabels(pair_names) + # ax0.set_xlim(0,1) + + ax0.set_xlabel(dp.metric) + ax0.set_ylabel("pairs") + + ax1.set_xlabel("lag (ms)") + elif dp.mode == "lines": + my_cmap = plt.get_cmap(dp.cmap) + cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) + scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) + + # plot by similarity bins + if self.ax is None: + fig, ax = plt.subplots() + else: + ax = self.ax + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + order = np.argsort(similarities) + similarities = similarities[order] + recall_scores = recall_scores[order, :] + + for i in range(dp.similarity_bins.size - 1): + cmin, cmax = dp.similarity_bins[i], dp.similarity_bins[i + 1] + + amin, amax = np.searchsorted(similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(recall_scores[amin:amax], axis=0) + + colorVal = scalarMap.to_rgba((cmin + cmax) / 2) + ax.plot( + lags[:-1] + (lags[1] - lags[0]) / 2, + mean_recall_scores, + label="CS in [%g,%g]" % (cmin, cmax), + c=colorVal, + ) + + if dp.show_legend: + ax.legend() + ax.set_ylim(dp.ylim) + ax.set_xlabel("lags (ms)") + ax.set_ylabel("collision recall") + + +class StudyComparisonCollisionBySimilarityWidget(BaseWidget): + """ + Plots CollisionGTComparison pair by pair orderer by cosine_similarity for all + cases in a study. + + Parameters + ---------- + study: CollisionGTStudy + The collision study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + metric: 'cosine_similarity' + metric for ordering + similarity_bins: array + if mode is 'lines', the bins used to average the pairs + cmap: string + colormap used to show averages if mode is 'lines' + good_only: False + keep only the pairs with a non zero accuracy (found templates) + min_accuracy: float + If good only, the minimum accuracy every cell should have, individually, to be + considered in a putative pair + """ + + def __init__( + self, + study, + case_keys=None, + metric="cosine_similarity", + similarity_bins=np.linspace(-0.4, 1, 8), + show_legend=False, + ylim=(0.5, 1), + good_only=False, + min_accuracy=0.9, + cmap="winter", + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + data_plot = dict( + study=study, + case_keys=case_keys, + metric=metric, + similarity_bins=similarity_bins, + show_legend=show_legend, + ylim=ylim, + good_only=good_only, + min_accuracy=min_accuracy, + cmap=cmap, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import sklearn + import matplotlib + + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + num_axes = len(dp.case_keys) + backend_kwargs["num_axes"] = num_axes + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + study = dp.study + + my_cmap = plt.get_cmap(dp.cmap) + cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) + scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) + study.precompute_scores_by_similarities( + case_keys=dp.case_keys, + good_only=dp.good_only, + min_accuracy=dp.min_accuracy, + ) + + for count, key in enumerate(dp.case_keys): + lags = study.get_lags(key) + + curves = study.get_lag_profile_over_similarity_bins(dp.similarity_bins, key) + + # plot by similarity bins + ax = self.axes.flatten()[count] + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + for i in range(dp.similarity_bins.size - 1): + cmin, cmax = dp.similarity_bins[i], dp.similarity_bins[i + 1] + colorVal = scalarMap.to_rgba((cmin + cmax) / 2) + ax.plot( + lags[:-1] + (lags[1] - lags[0]) / 2, + curves[(cmin, cmax)], + label="CS in [%g,%g]" % (cmin, cmax), + c=colorVal, + ) + + if count % self.axes.shape[1] == 0: + ax.set_ylabel("collision recall") + + if count > (len(dp.case_keys) // self.axes.shape[1]): + ax.set_xlabel("lags (ms)") + + label = study.cases[key]["label"] + ax.set_title(label) + if dp.show_legend: + ax.legend() + + if dp.ylim is not None: + ax.set_ylim(dp.ylim) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py b/src/spikeinterface/widgets/multicomparison.py similarity index 59% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py rename to src/spikeinterface/widgets/multicomparison.py index 47e0c026f0..e01a79dfd5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -1,6 +1,8 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class MultiCompGraphWidget(BaseWidget): @@ -21,15 +23,6 @@ class MultiCompGraphWidget(BaseWidget): Alpha value for edges colorbar: bool If True a colorbar for the edges is plotted - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget """ def __init__( @@ -40,42 +33,44 @@ def __init__( edge_cmap="hot", alpha_edges=0.5, colorbar=False, - figure=None, - ax=None, + backend=None, + **backend_kwargs, ): - import matplotlib - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._msc = multi_comparison - self._draw_labels = draw_labels - self._node_cmap = node_cmap - self._edge_cmap = edge_cmap - self._colorbar = colorbar - self._alpha_edges = alpha_edges - self.name = "MultiCompGraph" - - def plot(self): - self._do_plot() - - def _do_plot(self): + plot_data = dict( + multi_comparison=multi_comparison, + draw_labels=draw_labels, + node_cmap=node_cmap, + edge_cmap=edge_cmap, + alpha_edges=alpha_edges, + colorbar=colorbar, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt import networkx as nx + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) - g = self._msc.graph + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + g = mcmp.graph edge_col = [] for e in g.edges(data=True): n1, n2, d = e edge_col.append(d["weight"]) nodes_col_dict = {} - for i, sort_name in enumerate(self._msc.name_list): + for i, sort_name in enumerate(mcmp.name_list): nodes_col_dict[sort_name] = i nodes_col = [] for node in sorted(g.nodes): nodes_col.append(nodes_col_dict[node[0]]) - nodes_col = np.array(nodes_col) / len(self._msc.name_list) - import matplotlib.pyplot as plt + nodes_col = np.array(nodes_col) / len(mcmp.name_list) - _ = plt.set_cmap(self._node_cmap) + _ = plt.set_cmap(dp.node_cmap) _ = nx.draw_networkx_nodes( g, pos=nx.circular_layout(sorted(g)), @@ -89,13 +84,13 @@ def _do_plot(self): pos=nx.circular_layout((sorted(g))), nodelist=sorted(g.nodes), edge_color=edge_col, - alpha=self._alpha_edges, - edge_cmap=plt.cm.get_cmap(self._edge_cmap), - edge_vmin=self._msc.match_score, + alpha=dp.alpha_edges, + edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_vmin=mcmp.match_score, edge_vmax=1, ax=self.ax, ) - if self._draw_labels: + if dp.draw_labels: labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} pos = nx.circular_layout(sorted(g)) # extend position radially @@ -105,12 +100,11 @@ def _do_plot(self): pos_extended[node] = pos_new _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) - if self._colorbar: - import matplotlib + if dp.colorbar: import matplotlib.pyplot as plt - norm = matplotlib.colors.Normalize(vmin=self._msc.match_score, vmax=1) - cmap = plt.cm.get_cmap(self._edge_cmap) + norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) + cmap = plt.cm.get_cmap(dp.edge_cmap) m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) self.figure.colorbar(m) @@ -127,42 +121,48 @@ class MultiCompGlobalAgreementWidget(BaseWidget): The multi comparison object plot_type: str 'pie' or 'bar' - cmap: matplotlib colormap - The colormap to be used for the nodes (default 'Reds') - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget + cmap: matplotlib colormap, default: 'YlOrRd' + The colormap to be used for the nodes + fontsize: int, default: 9 + The text fontsize + show_legend: bool, default: True + If True a legend is shown """ - def __init__(self, multi_comparison, plot_type="pie", cmap="YlOrRd", fs=10, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - import matplotlib - import matplotlib.pyplot as plt + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - self._msc = multi_comparison - self._type = plot_type - self._cmap = cmap - self._fs = fs - self.name = "MultiCompGlobalAgreement" + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - def plot(self): - self._do_plot() + dp = to_attr(data_plot) - def _do_plot(self): - import matplotlib.pyplot as plt + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - cmap = plt.get_cmap(self._cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(self._msc.name_list))]) - sg_names, sg_units = self._msc.compute_subgraphs() + mcmp = dp.multi_comparison + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True) - if self._type == "pie": + if dp.plot_type == "pie": p = self.ax.pie(c, colors=colors[v - 1], autopct=lambda pct: _getabs(pct, c), pctdistance=1.25) self.ax.legend( p[0], @@ -175,9 +175,9 @@ def _do_plot(self): loc=2, borderaxespad=0.5, labelspacing=0.15, - fontsize=self._fs, + fontsize=dp.fontsize, ) - elif self._type == "bar": + elif dp.plot_type == "bar": self.ax.bar(v, c, color=colors[v - 1]) x_labels = [f"k={vi}" for vi in v] self.ax.spines["top"].set_visible(False) @@ -213,44 +213,54 @@ class MultiCompAgreementBySorterWidget(BaseWidget): The output widget """ - def __init__(self, multi_comparison, plot_type="pie", cmap="YlOrRd", fs=9, axes=None, show_legend=True): - import matplotlib.pyplot as plt - - self._msc = multi_comparison - self._type = plot_type - self._cmap = cmap - self._fs = fs - self._show_legend = show_legend - self.name = "MultiCompAgreementBySorterWidget" + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - if axes is None: - ncols = len(self._msc.name_list) - fig, axes = plt.subplots(nrows=1, ncols=ncols, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - def plot(self): - self._do_plot() + dp = to_attr(data_plot) + mcmp = dp.multi_comparison + name_list = mcmp.name_list - def _do_plot(self): - name_list = self._msc.name_list - import matplotlib.pyplot as plt + backend_kwargs["num_axes"] = len(name_list) + backend_kwargs["ncols"] = len(name_list) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - cmap = plt.get_cmap(self._cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(self._msc.name_list))]) - sg_names, sg_units = self._msc.compute_subgraphs() + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold for i, name in enumerate(name_list): - ax = self.axes[i] + ax = np.squeeze(self.axes)[i] v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True) - if self._type == "pie": + if dp.plot_type == "pie": p = ax.pie( c, colors=colors[v - 1], - textprops={"color": "k", "fontsize": self._fs}, + textprops={"color": "k", "fontsize": dp.fontsize}, autopct=lambda pct: _getabs(pct, c), pctdistance=1.18, ) - if (self._show_legend) and (i == len(name_list) - 1): + if (dp.show_legend) and (i == len(name_list) - 1): plt.legend( p[0], v, @@ -263,7 +273,7 @@ def _do_plot(self): borderaxespad=0.0, labelspacing=0.15, ) - elif self._type == "bar": + elif dp.plot_type == "bar": ax.bar(v, c, color=colors[v - 1]) x_labels = [f"k={vi}" for vi in v] ax.spines["top"].set_visible(False) @@ -273,7 +283,8 @@ def _do_plot(self): else: raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") ax.set_title(name) - if self._type == "bar": + + if dp.plot_type == "bar": ylims = [np.max(ax_single.get_ylim()) for ax_single in self.axes] max_yval = np.max(ylims) for ax_single in self.axes: @@ -282,31 +293,4 @@ def _do_plot(self): def _getabs(pct, allvals): absolute = int(np.round(pct / 100.0 * np.sum(allvals))) - return "{:d}".format(absolute) - - -def plot_multicomp_graph(*args, **kwargs): - W = MultiCompGraphWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_graph.__doc__ = MultiCompGraphWidget.__doc__ - - -def plot_multicomp_agreement(*args, **kwargs): - W = MultiCompGlobalAgreementWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_agreement.__doc__ = MultiCompGlobalAgreementWidget.__doc__ - - -def plot_multicomp_agreement_by_sorter(*args, **kwargs): - W = MultiCompAgreementBySorterWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_agreement_by_sorter.__doc__ = MultiCompAgreementBySorterWidget.__doc__ + return f"{absolute}" diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/peak_activity.py similarity index 58% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py rename to src/spikeinterface/widgets/peak_activity.py index 9715b7ea87..24d4dc0df9 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -1,5 +1,11 @@ import numpy as np -from .basewidget import BaseWidget +from typing import Union + +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors +from ..core.waveform_extractor import WaveformExtractor class PeakActivityMapWidget(BaseWidget): @@ -17,8 +23,6 @@ class PeakActivityMapWidget(BaseWidget): to avoid multiple computation. detect_peaks_kwargs: None or dict If peaks is None here the kwargs for detect_peak function. - weight_with_amplitudes: bool False by default - Peak are weighted by amplitude bin_duration_s: None or float If None then static image If not None then it is an animation per bin. @@ -28,55 +32,46 @@ class PeakActivityMapWidget(BaseWidget): Plot rates with interpolated map with_channel_ids: bool False default Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ProbeMapWidget - The output widget + + """ def __init__( self, recording, - peaks=None, - detect_peaks_kwargs={}, - weight_with_amplitudes=True, + peaks, bin_duration_s=None, with_contact_color=True, with_interpolated_map=True, with_channel_ids=False, with_color_bar=True, - figure=None, - ax=None, + backend=None, + **backend_kwargs, ): - import matplotlib.pylab as plt - from matplotlib.animation import FuncAnimation - from probeinterface.plotting import plot_probe + data_plot = dict( + recording=recording, + peaks=peaks, + bin_duration_s=bin_duration_s, + with_contact_color=with_contact_color, + with_interpolated_map=with_interpolated_map, + with_channel_ids=with_channel_ids, + with_color_bar=with_color_bar, + ) - BaseWidget.__init__(self, figure, ax) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) - assert recording.get_num_segments() == 1, "Handle only one segment" + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - self.recording = recording - self.peaks = peaks - self.detect_peaks_kwargs = detect_peaks_kwargs - self.weight_with_amplitudes = weight_with_amplitudes - self.bin_duration_s = bin_duration_s - self.with_contact_color = with_contact_color - self.with_interpolated_map = with_interpolated_map - self.with_channel_ids = with_channel_ids + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - def plot(self): - rec = self.recording - peaks = self.peaks - if peaks is None: - from spikeinterface.sortingcomponents.peak_detection import detect_peaks + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - peaks = detect_peaks(rec, **self.detect_peaks_kwargs) + rec = dp.recording + peaks = dp.peaks fs = rec.get_sampling_frequency() duration = rec.get_total_duration() @@ -88,24 +83,33 @@ def plot(self): ) probe = probes[0] - if self.bin_duration_s is None: - self._plot_one_bin(rec, probe, peaks, duration) + if dp.bin_duration_s is None: + self._plot_one_bin( + rec, probe, peaks, duration, dp.with_channel_ids, dp.with_contact_color, dp.with_interpolated_map + ) else: - bin_size = int(self.bin_duration_s * fs) - num_frames = int(duration / self.bin_duration_s) + bin_size = int(dp.bin_duration_s * fs) + num_frames = int(duration / dp.bin_duration_s) def animate_func(i): i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] - artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) + artists = self._plot_one_bin( + rec, + probe, + local_peaks, + dp.with_channel_ids, + dp.bin_duration_s, + dp.with_contact_color, + dp.with_interpolated_map, + ) return artists from matplotlib.animation import FuncAnimation self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=100, blit=True) - def _plot_one_bin(self, rec, probe, peaks, duration): - # TODO: @alessio weight_with_amplitudes is not implemented yet + def _plot_one_bin(self, rec, probe, peaks, duration, with_channel_ids, with_contact_color, with_interpolated_map): rates = np.zeros(rec.get_num_channels(), dtype="float64") for chan_ind, chan_id in enumerate(rec.channel_ids): mask = peaks["channel_index"] == chan_ind @@ -113,9 +117,9 @@ def _plot_one_bin(self, rec, probe, peaks, duration): rates[chan_ind] = num_spike / duration artists = () - if self.with_contact_color: + if with_contact_color: text_on_contact = None - if self.with_channel_ids: + if with_channel_ids: text_on_contact = self.recording.channel_ids from probeinterface.plotting import plot_probe @@ -130,7 +134,7 @@ def _plot_one_bin(self, rec, probe, peaks, duration): ) artists = artists + (poly, poly_contour) - if self.with_interpolated_map: + if with_interpolated_map: image, xlims, ylims = probe.to_image( rates, pixel_size=0.5, num_pixel=None, method="linear", xlims=None, ylims=None ) @@ -138,12 +142,3 @@ def _plot_one_bin(self, rec, probe, peaks, duration): artists = artists + (im,) return artists - - -def plot_peak_activity_map(*args, **kwargs): - W = PeakActivityMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_peak_activity_map.__doc__ = PeakActivityMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 93c97f5913..052497347d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -84,6 +84,10 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + + cls.peaks = detect_peaks(cls.recording, method="locally_exclusive") + def test_plot_traces(self): possible_backends = list(sw.TracesWidget.get_possible_backends()) for backend in possible_backends: @@ -203,7 +207,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): **self.backend_kwargs[backend], ) - def test_autocorrelograms(self): + def test_plot_autocorrelograms(self): possible_backends = list(sw.AutoCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -217,7 +221,7 @@ def test_autocorrelograms(self): **self.backend_kwargs[backend], ) - def test_crosscorrelogram(self): + def test_plot_crosscorrelogram(self): possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -231,7 +235,7 @@ def test_crosscorrelogram(self): **self.backend_kwargs[backend], ) - def test_isi_distribution(self): + def test_plot_isi_distribution(self): possible_backends = list(sw.ISIDistributionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -245,7 +249,7 @@ def test_isi_distribution(self): **self.backend_kwargs[backend], ) - def test_amplitudes(self): + def test_plot_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -279,7 +283,7 @@ def test_plot_all_amplitudes_distributions(self): self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) - def test_unit_locations(self): + def test_plot_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -290,7 +294,7 @@ def test_unit_locations(self): self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - def test_spike_locations(self): + def test_plot_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -301,21 +305,21 @@ def test_spike_locations(self): self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - def test_similarity(self): + def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) - def test_quality_metrics(self): + def test_plot_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) - def test_template_metrics(self): + def test_plot_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -340,7 +344,7 @@ def test_plot_unit_summary(self): self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) - def test_sorting_summary(self): + def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -377,6 +381,35 @@ def test_plot_unit_probe_map(self): if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) + def test_plot_unit_presence(self): + possible_backends = list(sw.UnitPresenceWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_unit_presence(self.sorting) + + def test_plot_peak_activity(self): + possible_backends = list(sw.PeakActivityMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_peak_activity(self.recording, self.peaks) + + def test_plot_multicomparison(self): + mcmp = sc.compare_multiple_sorters([self.sorting, self.sorting, self.sorting]) + possible_backends_graph = list(sw.MultiCompGraphWidget.get_possible_backends()) + for backend in possible_backends_graph: + sw.plot_multicomparison_graph( + mcmp, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False, backend=backend + ) + possible_backends_glob = list(sw.MultiCompGlobalAgreementWidget.get_possible_backends()) + for backend in possible_backends_glob: + sw.plot_multicomparison_agreement(mcmp, backend=backend) + possible_backends_by_sorter = list(sw.MultiCompAgreementBySorterWidget.get_possible_backends()) + for backend in possible_backends_by_sorter: + sw.plot_multicomparison_agreement_by_sorter(mcmp) + if backend == "matplotlib": + _, axes = plt.subplots(len(mcmp.object_list), 1) + sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + if __name__ == "__main__": # unittest.main() @@ -395,15 +428,17 @@ def test_plot_unit_probe_map(self): # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() # mytest.test_crosscorrelogram() - mytest.test_isi_distribution() + # mytest.test_isi_distribution() # mytest.test_unit_locations() # mytest.test_quality_metrics() - mytest.test_template_metrics() + # mytest.test_template_metrics() # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() + # mytest.test_plot_unit_presence() + mytest.test_plot_multicomparison() plt.show() diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py new file mode 100644 index 0000000000..3d605936a2 --- /dev/null +++ b/src/spikeinterface/widgets/unit_presence.py @@ -0,0 +1,103 @@ +import numpy as np + +from .base import BaseWidget, to_attr + + +class UnitPresenceWidget(BaseWidget): + """ + Estimates of the probability density function for each unit using Gaussian kernels, + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + time_range: list + List with start time and end time + bin_duration_s: float, default 0.5 + Bin size (in seconds) for the heat map time axis. + smooth_sigma: float or None + + """ + + def __init__( + self, + sorting, + segment_index=None, + unit_ids=None, + time_range=None, + bin_duration_s=0.05, + smooth_sigma=4.5, + backend=None, + **backend_kwargs, + ): + if segment_index is None: + nseg = sorting.get_num_segments() + if nseg != 1: + raise ValueError("You must provide segment_index=...") + else: + segment_index = 0 + + data_plot = dict( + sorting=sorting, + segment_index=segment_index, + time_range=time_range, + bin_duration_s=bin_duration_s, + smooth_sigma=smooth_sigma, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting = dp.sorting + + spikes = sorting.to_spike_vector(concatenated=False, use_cache=True) + spikes = spikes[dp.segment_index] + + fs = sorting.get_sampling_frequency() + + if dp.time_range is not None: + t0, t1 = dp.time_range + ind0 = int(t0 * fs) + ind1 = int(t1 * fs) + mask = (spikes["sample_index"] >= ind0) & (spikes["sample_index"] <= ind1) + spikes = spikes[mask] + + if spikes.size == 0: + return + + last = spikes["sample_index"][-1] + max_time = last / fs + + num_units = len(sorting.unit_ids) + num_time_bins = int(max_time / dp.bin_duration_s) + 1 + map = np.zeros((num_units, num_time_bins)) + ind0 = spikes["unit_index"] + ind1 = spikes["sample_index"] // int(dp.bin_duration_s * fs) + map[ind0, ind1] += 1 + + if dp.smooth_sigma is not None: + import scipy.signal + + n = int(dp.smooth_sigma * 5) + bins = np.arange(-n, n + 1) + smooth_kernel = np.exp(-(bins**2) / (2 * dp.smooth_sigma**2)) + smooth_kernel /= np.sum(smooth_kernel) + smooth_kernel = smooth_kernel[np.newaxis, :] + map = scipy.signal.oaconvolve(map, smooth_kernel, mode="same", axes=1) + + im = self.ax.matshow(map, cmap="inferno", aspect="auto") + self.ax.set_xlabel("Time (s)") + self.ax.set_ylabel("Units") + + self.figure.colorbar(im) diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index a9128d7b66..6e52efcd84 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -50,7 +50,7 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, if num_axes < ncols: ncols = num_axes nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) + figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, squeeze=False) ax = None # remove extra axes if ncols * nrows > num_axes: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ff3d1436ba..00d179127d 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,6 +10,8 @@ from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget +from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget +from .peak_activity import PeakActivityMapWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .rasters import RasterWidget @@ -21,13 +23,14 @@ from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget +from .unit_presence import UnitPresenceWidget from .unit_probe_map import UnitProbeMapWidget from .unit_summary import UnitSummaryWidget from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics - +from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ AgreementMatrixWidget, @@ -35,9 +38,14 @@ AmplitudesWidget, AutoCorrelogramsWidget, ConfusionMatrixWidget, + ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + MultiCompGlobalAgreementWidget, + MultiCompAgreementBySorterWidget, + MultiCompGraphWidget, + PeakActivityMapWidget, ProbeMapWidget, QualityMetricsWidget, RasterWidget, @@ -49,6 +57,7 @@ TracesWidget, UnitDepthsWidget, UnitLocationsWidget, + UnitPresenceWidget, UnitProbeMapWidget, UnitSummaryWidget, UnitTemplatesWidget, @@ -58,6 +67,7 @@ StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics, + StudyComparisonCollisionBySimilarityWidget, ] @@ -98,9 +108,14 @@ plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_confusion_matrix = ConfusionMatrixWidget +plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_multicomparison_agreement = MultiCompGlobalAgreementWidget +plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget +plot_multicomparison_graph = MultiCompGraphWidget +plot_peak_activity = PeakActivityMapWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_rasters = RasterWidget @@ -112,6 +127,7 @@ plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget +plot_unit_presence = UnitPresenceWidget plot_unit_probe_map = UnitProbeMapWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget @@ -121,6 +137,7 @@ plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances plot_study_performances_vs_metrics = StudyPerformancesVsMetrics +plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget def plot_timeseries(*args, **kwargs):