From eb4e1021017da4066ceac3c39f91b231af6ef30d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 12 Oct 2023 16:58:11 +0200 Subject: [PATCH 01/25] Improve generate.py with spatial on generate_template() --- src/spikeinterface/core/generate.py | 50 +++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index dc84d31987..bbb953ccff 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -881,9 +881,10 @@ def generate_single_fake_waveform( depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), - positive_amplitude=(0.05, 0.15), + positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), + propagation_speed=(250., 350.), # ms / um ) @@ -938,6 +939,7 @@ def generate_templates( * 'positive_amplitude': the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) * 'smooth_ms': the gaussian smooth in ms (default range: (0.03-0.07)) * 'decay_power': the decay power (default range: (1.2-1.8)) + * 'propagation_speed': mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values contains vector with same size of num_units. If the key is not in dict then it is generated using unit_params_range unit_params_range: dict of tuple @@ -985,12 +987,17 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - v = rng.random(num_units) + if k in unit_params_range: - lim0, lim1 = unit_params_range[k] + lims = unit_params_range[k] else: - lim0, lim1 = default_unit_params_range[k] - params[k] = v * (lim1 - lim0) + lim0 + lims = default_unit_params_range[k] + if lims is not None: + lim0, lim1 = lims + v = rng.random(num_units) + params[k] = v * (lim1 - lim0) + lim0 + else: + params[k] = [None] * num_units for u in range(num_units): wf = generate_single_fake_waveform( @@ -1006,17 +1013,46 @@ def generate_templates( dtype=dtype, ) + + ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors eps = 1.0 # naive formula for spatial decay pow = params["decay_power"][u] channel_factors = alpha / (distances[u, :] + eps) ** pow + wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + + # This mimic a propagation delay for distant channel + propagation_speed = params["propagation_speed"][u] + if propagation_speed is not None: + # the speed is um/ms + dist = distances[u, :].copy() + dist -= np.min(dist) + delay_s = dist / propagation_speed / 1000. + sample_shifts = delay_s * fs + + # apply the delay with fft transform to get sub sample shift + n = wfs.shape[0] + wfs_f = np.fft.rfft(wfs, axis=0) + if n % 2 == 0: + # n is even sig_f[-1] is nyquist and so pi + omega = np.linspace(0, np.pi, wfs_f.shape[0]) + else: + # n is odd sig_f[-1] is exactly nyquist!! we need (n-1) / n factor!! + omega = np.linspace(0, np.pi * (n - 1) / n, wfs_f.shape[0]) + # broadcast omega and sample_shifts depend the axis + shifts = omega[:, np.newaxis] * sample_shifts[np.newaxis, :] + wfs = np.fft.irfft(wfs_f * np.exp(-1j * shifts), n=n, axis=0) + if upsample_factor is not None: for f in range(upsample_factor): - templates[u, :, :, f] = wf[f::upsample_factor, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :, f] = wfs[f::upsample_factor] else: - templates[u, :, :] = wf[:, np.newaxis] * channel_factors[np.newaxis, :] + templates[u, :, :] = wfs + + + return templates From 0fd12553a70582c3a5cac5b007fe32ac439ddb48 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 16 Oct 2023 19:45:43 +0200 Subject: [PATCH 02/25] Use multi columns probe in generate_ground_truth_recording() --- src/spikeinterface/core/generate.py | 36 +++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbb953ccff..6bb5a384e6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -8,7 +8,7 @@ from .numpyextractors import NumpyRecording, NumpySorting from .basesorting import minimum_spike_dtype -from probeinterface import Probe, generate_linear_probe +from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting @@ -93,7 +93,6 @@ def generate_recording( probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) - probe = generate_linear_probe(num_elec=num_channels) return recording @@ -1358,6 +1357,13 @@ def generate_ground_truth_recording( num_units=10, sorting=None, probe=None, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), templates=None, ms_before=1.0, ms_after=3.0, @@ -1386,7 +1392,9 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of linear probe is generated. + An external Probe object. If not provided of probe is generated using generate_probe_kwargs. + generate_probe_kwargs: dict + A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. @@ -1443,8 +1451,28 @@ def generate_ground_truth_recording( num_spikes = sorting.to_spike_vector().size if probe is None: - probe = generate_linear_probe(num_elec=num_channels) + # probe = generate_linear_probe(num_elec=num_channels) + # probe.set_device_channel_indices(np.arange(num_channels)) + + prb_kwargs = generate_probe_kwargs.copy() + if 'num_contact_per_column' in prb_kwargs: + assert (prb_kwargs['num_contact_per_column'] * prb_kwargs['num_columns']) == num_channels, \ + "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" + elif 'num_contact_per_column' not in prb_kwargs and 'num_columns' in prb_kwargs: + n = num_channels // prb_kwargs['num_columns'] + num_contact_per_column = [n] * prb_kwargs['num_columns'] + mid = prb_kwargs['num_columns'] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs['num_columns'] + prb_kwargs['num_contact_per_column'] = num_contact_per_column + else: + raise ValueError("num_columns should be provided in dict generate_probe_kwargs") + + print(prb_kwargs) + probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) + print(probe) + + else: num_channels = probe.get_contact_count() From 0b57f4dccb6d2561737362b8351993eaac0939d3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 17 Oct 2023 10:29:27 +0200 Subject: [PATCH 03/25] harmonize refactory period in generate.py --- src/spikeinterface/core/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6bb5a384e6..d6924f6f4f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -121,7 +121,7 @@ def generate_sorting( durations=[10.325, 3.5], #  in s for 2 segments firing_rates=3.0, empty_units=None, - refractory_period_ms=3.0, # in ms + refractory_period_ms=4.0, # in ms add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, @@ -142,7 +142,7 @@ def generate_sorting( The firing rate of each unit (in Hz). empty_units : list, default: None List of units that will have no spikes. (used for testing mainly). - refractory_period_ms : float, default: 3.0 + refractory_period_ms : float, default: 4.0 The refractory period in ms add_spikes_on_borders : bool, default: False If True, spikes will be added close to the borders of the segments. @@ -1369,7 +1369,7 @@ def generate_ground_truth_recording( ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=1.5), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), From 737812f22560d2390ea5869f3579dfede3e7c28d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 17 Oct 2023 10:31:56 +0200 Subject: [PATCH 04/25] clean --- src/spikeinterface/core/generate.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d6924f6f4f..de69af85f3 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1467,11 +1467,8 @@ def generate_ground_truth_recording( else: raise ValueError("num_columns should be provided in dict generate_probe_kwargs") - print(prb_kwargs) probe = generate_multi_columns_probe(**prb_kwargs) probe.set_device_channel_indices(np.arange(num_channels)) - print(probe) - else: num_channels = probe.get_contact_count() From bf670504d57ba9c3af7f4bcd75398f140d1d6627 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 18 Oct 2023 15:40:32 +0200 Subject: [PATCH 05/25] Move UnitPresenceWidget in the new widgets api. --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../widgets/_legacy_mpl_widgets/presence.py | 135 ------------------ .../widgets/tests/test_widgets.py | 11 +- src/spikeinterface/widgets/unit_presence.py | 108 ++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 5 files changed, 120 insertions(+), 140 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py create mode 100644 src/spikeinterface/widgets/unit_presence.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 53c2a5c79e..8dca67dce0 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -36,9 +36,6 @@ ) -# unit presence -from .presence import plot_presence, PresenceWidget - # correlogram comparison from .correlogramcomp import ( StudyComparisonCorrelogramBySimilarityWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py deleted file mode 100644 index 863b37eda6..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/presence.py +++ /dev/null @@ -1,135 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class PresenceWidget(BaseWidget): - """ - Estimates of the probability density function for each unit using Gaussian kernels, - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - time_pixels: int - Number of samples calculated for each density function - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: PresenceWidget - The output widget - """ - - def __init__( - self, sorting, segment_index=None, unit_ids=None, time_range=None, figure=None, time_pixels=200, ax=None - ): - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - self._time_pixels = time_pixels - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Presence" - - def plot(self): - self._do_plot() - - def _do_plot(self): - import matplotlib.pyplot as plt - from scipy.stats import gaussian_kde - - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - - time_grid = np.linspace(visible_start_frame, visible_end_frame, self._time_pixels) - time_den = [] - - self.ax.grid("both") - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - - if spiketimes[0] != spiketimes[-1]: # not always the same value - time_den.append(gaussian_kde(spiketimes).pdf(time_grid)) - else: - aux = np.zeros_like(time_grid) - aux[np.argmin(np.abs(time_grid - spiketimes))] = 1 - time_den.append(aux) - - self.ax.matshow(np.vstack(time_den), cmap=plt.cm.inferno, aspect="auto") - - self.ax.hlines(np.arange(len(units_ids)) + 0.5, 0, len(time_den[0]), color="k", linewidth=4) - - self.ax.tick_params(axis="y", which="both", grid_linestyle="None") - - self.ax.set_xlim(0, self._time_pixels) - new_labels = [] - self.ax.xaxis.set_ticks_position("bottom") - - for xt in self.ax.get_xticks(): - if xt < self._time_pixels: - new_labels.append("{:.1f}".format(time_grid[int(xt)])) - else: - new_labels.append("{:.1f}".format(visible_end_frame)) - self.ax.set_xticks(self.ax.get_xticks()) - self.ax.set_xticklabels(new_labels) - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlabel("time (s)") - self.ax.set_ylabel("Unit ID") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - trange[1] = self._max_frame - if trange[0] < 0: - trange[0] = 0 - return trange - - -def plot_presence(*args, **kwargs): - W = PresenceWidget(*args, **kwargs) - W.plot() - return W - - -plot_presence.__doc__ = PresenceWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 93c97f5913..9c3179cd0a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -376,6 +376,12 @@ def test_plot_unit_probe_map(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) + + def test_plot_unit_presence(self): + possible_backends = list(sw.UnitPresenceWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_unit_presence(self.sorting) if __name__ == "__main__": @@ -395,15 +401,16 @@ def test_plot_unit_probe_map(self): # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() # mytest.test_crosscorrelogram() - mytest.test_isi_distribution() + # mytest.test_isi_distribution() # mytest.test_unit_locations() # mytest.test_quality_metrics() - mytest.test_template_metrics() + # mytest.test_template_metrics() # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() + mytest.test_plot_unit_presence() plt.show() diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py new file mode 100644 index 0000000000..2b39faeb24 --- /dev/null +++ b/src/spikeinterface/widgets/unit_presence.py @@ -0,0 +1,108 @@ +import numpy as np +from typing import Union + +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors +from ..core.waveform_extractor import WaveformExtractor + + +class UnitPresenceWidget(BaseWidget): + """ + Estimates of the probability density function for each unit using Gaussian kernels, + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + time_range: list + List with start time and end time + bin_duration_s: float, default 0.5 + Bin size (in seconds) for the heat map time axis. + smooth_sigma: float or None + + """ + + def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, + bin_duration_s=0.05, smooth_sigma=4.5, + backend=None, + **backend_kwargs, + ): + + if segment_index is None: + nseg = sorting.get_num_segments() + if nseg != 1: + raise ValueError("You must provide segment_index=...") + else: + segment_index = 0 + + + data_plot = dict( + sorting=sorting, + segment_index=segment_index, + time_range=time_range, + bin_duration_s=bin_duration_s, + smooth_sigma=smooth_sigma, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting = dp.sorting + + spikes = sorting.to_spike_vector(concatenated=False, use_cache=True) + spikes = spikes[dp.segment_index] + + fs = sorting.get_sampling_frequency() + + if dp.time_range is not None: + t0, t1 = dp.time_range + ind0 = int(t0 * fs) + ind1 = int(t1 * fs) + mask = (spikes['sample_index'] >= ind0) & (spikes['sample_index'] <= ind1) + spikes = spikes[mask] + + + + if spikes.size == 0: + return + + last = spikes['sample_index'][-1] + max_time = last / fs + + num_units = len(sorting.unit_ids) + num_time_bins = int(max_time / dp.bin_duration_s) + 1 + map = np.zeros((num_units, num_time_bins)) + ind0 = spikes['unit_index'] + ind1 = spikes['sample_index'] // int(dp.bin_duration_s * fs) + map[ind0, ind1] += 1 + + if dp.smooth_sigma is not None: + import scipy.signal + n = int(dp.smooth_sigma * 5) + bins = np.arange(-n, n + 1) + smooth_kernel = np.exp(-(bins**2) / (2 * dp.smooth_sigma**2)) + smooth_kernel /= np.sum(smooth_kernel) + smooth_kernel = smooth_kernel[np.newaxis, :] + map = scipy.signal.oaconvolve(map, smooth_kernel, mode='same', axes=1) + + + + im = self.ax.matshow(map, cmap='inferno', aspect="auto") + self.ax.set_xlabel("Time (s)") + self.ax.set_ylabel("Units") + + self.figure.colorbar(im) + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ff3d1436ba..706b71967b 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -21,6 +21,7 @@ from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget +from .unit_presence import UnitPresenceWidget from .unit_probe_map import UnitProbeMapWidget from .unit_summary import UnitSummaryWidget from .unit_templates import UnitTemplatesWidget @@ -49,6 +50,7 @@ TracesWidget, UnitDepthsWidget, UnitLocationsWidget, + UnitPresenceWidget, UnitProbeMapWidget, UnitSummaryWidget, UnitTemplatesWidget, @@ -112,6 +114,7 @@ plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget +plot_unit_presence = UnitPresenceWidget plot_unit_probe_map = UnitProbeMapWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget From 3be5d67c4f12208f5283ab02e04121cd5b4e5b0c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 18 Oct 2023 16:18:28 +0200 Subject: [PATCH 06/25] Move PeakActivityMapWidget to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../activity.py => peak_activity.py} | 99 +++++++++---------- .../widgets/tests/test_widgets.py | 16 ++- src/spikeinterface/widgets/unit_presence.py | 5 - src/spikeinterface/widgets/widget_list.py | 3 + 5 files changed, 63 insertions(+), 63 deletions(-) rename src/spikeinterface/widgets/{_legacy_mpl_widgets/activity.py => peak_activity.py} (59%) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 8dca67dce0..6459987ce5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,6 +1,3 @@ -# peak activity -from .activity import plot_peak_activity_map, PeakActivityMapWidget - from .multicompgraph import ( plot_multicomp_graph, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/peak_activity.py similarity index 59% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py rename to src/spikeinterface/widgets/peak_activity.py index 9715b7ea87..a80b6db6eb 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -1,5 +1,11 @@ import numpy as np -from .basewidget import BaseWidget +from typing import Union + +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors +from ..core.waveform_extractor import WaveformExtractor class PeakActivityMapWidget(BaseWidget): @@ -17,8 +23,6 @@ class PeakActivityMapWidget(BaseWidget): to avoid multiple computation. detect_peaks_kwargs: None or dict If peaks is None here the kwargs for detect_peak function. - weight_with_amplitudes: bool False by default - Peak are weighted by amplitude bin_duration_s: None or float If None then static image If not None then it is an animation per bin. @@ -28,55 +32,48 @@ class PeakActivityMapWidget(BaseWidget): Plot rates with interpolated map with_channel_ids: bool False default Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ProbeMapWidget - The output widget + + """ - def __init__( - self, + def __init__(self, recording, - peaks=None, - detect_peaks_kwargs={}, - weight_with_amplitudes=True, + peaks, bin_duration_s=None, with_contact_color=True, with_interpolated_map=True, with_channel_ids=False, with_color_bar=True, - figure=None, - ax=None, + backend=None, + **backend_kwargs, ): - import matplotlib.pylab as plt - from matplotlib.animation import FuncAnimation - from probeinterface.plotting import plot_probe - BaseWidget.__init__(self, figure, ax) + data_plot = dict( + recording=recording, + peaks=peaks, + bin_duration_s=bin_duration_s, + with_contact_color=with_contact_color, + with_interpolated_map=with_interpolated_map, + with_channel_ids=with_channel_ids, + with_color_bar=with_color_bar, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - assert recording.get_num_segments() == 1, "Handle only one segment" + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.recording = recording - self.peaks = peaks - self.detect_peaks_kwargs = detect_peaks_kwargs - self.weight_with_amplitudes = weight_with_amplitudes - self.bin_duration_s = bin_duration_s - self.with_contact_color = with_contact_color - self.with_interpolated_map = with_interpolated_map - self.with_channel_ids = with_channel_ids + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def plot(self): - rec = self.recording - peaks = self.peaks - if peaks is None: - from spikeinterface.sortingcomponents.peak_detection import detect_peaks - peaks = detect_peaks(rec, **self.detect_peaks_kwargs) + rec = dp.recording + peaks = dp.peaks fs = rec.get_sampling_frequency() duration = rec.get_total_duration() @@ -88,24 +85,24 @@ def plot(self): ) probe = probes[0] - if self.bin_duration_s is None: - self._plot_one_bin(rec, probe, peaks, duration) + if dp.bin_duration_s is None: + self._plot_one_bin(rec, probe, peaks, duration, dp.with_channel_ids, dp.with_contact_color, dp.with_interpolated_map) else: - bin_size = int(self.bin_duration_s * fs) - num_frames = int(duration / self.bin_duration_s) + bin_size = int(dp.bin_duration_s * fs) + num_frames = int(duration / dp.bin_duration_s) def animate_func(i): i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] - artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) + artists = self._plot_one_bin(rec, probe, local_peaks, dp.with_channel_ids, dp.bin_duration_s, dp.with_contact_color, dp.with_interpolated_map) return artists from matplotlib.animation import FuncAnimation self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=100, blit=True) - def _plot_one_bin(self, rec, probe, peaks, duration): - # TODO: @alessio weight_with_amplitudes is not implemented yet + + def _plot_one_bin(self, rec, probe, peaks, duration, with_channel_ids, with_contact_color, with_interpolated_map): rates = np.zeros(rec.get_num_channels(), dtype="float64") for chan_ind, chan_id in enumerate(rec.channel_ids): mask = peaks["channel_index"] == chan_ind @@ -113,9 +110,9 @@ def _plot_one_bin(self, rec, probe, peaks, duration): rates[chan_ind] = num_spike / duration artists = () - if self.with_contact_color: + if with_contact_color: text_on_contact = None - if self.with_channel_ids: + if with_channel_ids: text_on_contact = self.recording.channel_ids from probeinterface.plotting import plot_probe @@ -130,7 +127,7 @@ def _plot_one_bin(self, rec, probe, peaks, duration): ) artists = artists + (poly, poly_contour) - if self.with_interpolated_map: + if with_interpolated_map: image, xlims, ylims = probe.to_image( rates, pixel_size=0.5, num_pixel=None, method="linear", xlims=None, ylims=None ) @@ -140,10 +137,4 @@ def _plot_one_bin(self, rec, probe, peaks, duration): return artists -def plot_peak_activity_map(*args, **kwargs): - W = PeakActivityMapWidget(*args, **kwargs) - W.plot() - return W - -plot_peak_activity_map.__doc__ = PeakActivityMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 9c3179cd0a..24fa740ed9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -84,6 +84,9 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + cls.peaks = detect_peaks(cls.recording, method="locally_exclusive") + def test_plot_traces(self): possible_backends = list(sw.TracesWidget.get_possible_backends()) for backend in possible_backends: @@ -382,6 +385,16 @@ def test_plot_unit_presence(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_presence(self.sorting) + + def test_plot_peak_activity(self): + possible_backends = list(sw.PeakActivityMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_peak_activity(self.recording, self.peaks) + + + + if __name__ == "__main__": @@ -411,6 +424,7 @@ def test_plot_unit_presence(self): # mytest.test_plot_probe_map() # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() - mytest.test_plot_unit_presence() + # mytest.test_plot_unit_presence() + mytest.test_plot_peak_activity() plt.show() diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py index 2b39faeb24..1b60a731dd 100644 --- a/src/spikeinterface/widgets/unit_presence.py +++ b/src/spikeinterface/widgets/unit_presence.py @@ -1,11 +1,6 @@ import numpy as np -from typing import Union - -from probeinterface import ProbeGroup from .base import BaseWidget, to_attr -from .utils import get_unit_colors -from ..core.waveform_extractor import WaveformExtractor class UnitPresenceWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 706b71967b..f72e0fc3bd 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,6 +10,7 @@ from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget +from .peak_activity import PeakActivityMapWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .rasters import RasterWidget @@ -39,6 +40,7 @@ CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + PeakActivityMapWidget, ProbeMapWidget, QualityMetricsWidget, RasterWidget, @@ -103,6 +105,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_peak_activity = PeakActivityMapWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_rasters = RasterWidget From dd2c98a65a4b2c29190dcbbfa96b7cf995a8687f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 18 Oct 2023 16:55:57 +0200 Subject: [PATCH 07/25] Remove useless plot for unique comparison objects --- .../widgets/_legacy_mpl_widgets/__init__.py | 10 - .../_legacy_mpl_widgets/gtcomparison.py | 192 ------------------ 2 files changed, 202 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 6459987ce5..2db2ce1428 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -22,16 +22,6 @@ from .sortingperformance import plot_sorting_performance -# ground truth comparions (=comparison over sorter) -from .gtcomparison import ( - plot_gt_performances, - plot_gt_performances_averages, - ComparisonPerformancesWidget, - ComparisonPerformancesAveragesWidget, - plot_gt_performances_by_template_similarity, - ComparisonPerformancesByTemplateSimilarity, -) - # correlogram comparison from .correlogramcomp import ( diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py deleted file mode 100644 index a58a8e2e37..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtcomparison.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performancess - * count units -""" -import numpy as np - -from .basewidget import BaseWidget - - -class ComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, gt_comp, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.palette = palette - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - ax = self.ax - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = self.gt_comp.get_performance() - perf_by_units = perf_by_units.reset_index() - - df = pd.melt( - perf_by_units, var_name="Metric", value_name="Score", value_vars=("accuracy", "precision", "recall") - ) - import seaborn as sns - - sns.swarmplot(data=df, x="Metric", y="Score", hue="Metric", dodge=True, s=3, ax=ax) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Performance") - - -class ComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, gt_comp, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - ax = self.ax - - perf_by_units = self.gt_comp.get_performance() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean", "std"] - - data = perf_by_units.agg(to_agg) - - m = data.mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - stds = data.std() - - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = 1 + c / (ncol + 2) - yerr = stds[col] - ax.bar(x, m[col], yerr=yerr, width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - ax.set_ylabel("metric") - # ax.set_xlim(0, 1) - - -class ComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - similarity_matrix: matrix - The similarity between the templates in the gt recording and the ones - found by the sorter - ax: matplotlib ax - The ax to be used. If not given a figure is created - - """ - - def __init__(self, gt_comp, similarity_matrix, ax=None, ylim=(0.6, 1)): - from matplotlib import pyplot as plt - import pandas as pd - - self.gt_comp = gt_comp - self.similarity_matrix = similarity_matrix - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - all_results = {"similarity": [], "accuracy": []} - comp = self.gt_comp - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results["similarity"] += [ - self.similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results["accuracy"] += [comp.agreement_scores.at[u1, u2]] - - all_results["similarity"] = np.array(all_results["similarity"]) - all_results["accuracy"] = np.array(all_results["accuracy"]) - - self.ax.plot(all_results["similarity"], all_results["accuracy"], ".") - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_gt_performances(*args, **kwargs): - W = ComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances.__doc__ = ComparisonPerformancesWidget.__doc__ - - -def plot_gt_performances_averages(*args, **kwargs): - W = ComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances_averages.__doc__ = ComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_performances_by_template_similarity(*args, **kwargs): - W = ComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_performances_by_template_similarity.__doc__ = ComparisonPerformancesByTemplateSimilarity.__doc__ From da462ea2f2e482128cc29cce617d762029cd60dc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 15:02:38 +0200 Subject: [PATCH 08/25] remove old sorting performence widgets --- .../_legacy_mpl_widgets/sortingperformance.py | 85 ------------------- 1 file changed, 85 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py deleted file mode 100644 index eb5258fa05..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/sortingperformance.py +++ /dev/null @@ -1,85 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - -from ...comparison import GroundTruthComparison - - -class SortingPerformanceWidget(BaseWidget): - """ - Plots sorting performance for each ground-truth unit. - - Parameters - ---------- - gt_sorting_comparison: GroundTruthComparison - The ground truth sorting comparison object - property_name: str - The property of the sorting extractor to use as x-axis (e.g. snr). - If None, no property is used. - metric: str - The performance metric. 'accuracy' (default), 'precision', 'recall', 'miss rate', etc. - markersize: int - The size of the marker - marker: str - The matplotlib marker to use (default '.') - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: SortingPerformanceWidget - The output widget - """ - - def __init__( - self, - sorting_comparison, - metrics, - performance_name="accuracy", - metric_name="snr", - color="b", - markersize=10, - marker=".", - figure=None, - ax=None, - ): - from matplotlib import pyplot as plt - - assert isinstance( - sorting_comparison, GroundTruthComparison - ), "The 'sorting_comparison' object should be a GroundTruthComparison instance" - BaseWidget.__init__(self, figure, ax) - self.sorting_comparison = sorting_comparison - self.metrics = metrics - self.performance_name = performance_name - self.metric_name = metric_name - self.color = color - self.markersize = markersize - self.marker = marker - - def plot(self): - self._do_plot() - - def _do_plot(self): - comp = self.sorting_comparison - unit_ids = comp.sorting1.get_unit_ids() - perf = comp.get_performance()[self.performance_name] - metric = self.metrics[self.metric_name] - - ax = self.ax - - ax.plot(metric, perf, marker=self.marker, markersize=int(self.markersize), ls="", color=self.color) - ax.set_xlabel(self.metric_name) - ax.set_ylabel(self.performance_name) - ax.set_ylim(0, 1.05) - - -def plot_sorting_performance(*args, **kwargs): - W = SortingPerformanceWidget(*args, **kwargs) - W.plot() - return W - - -plot_sorting_performance.__doc__ = SortingPerformanceWidget.__doc__ From e7f9f8f05ebb968d73629fcead3e3877d050edf3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 15:03:20 +0200 Subject: [PATCH 09/25] remove old correlogram comp widget --- .../_legacy_mpl_widgets/correlogramcomp.py | 154 ------------------ 1 file changed, 154 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py deleted file mode 100644 index d224686b3a..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlogramcomp.py +++ /dev/null @@ -1,154 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class StudyComparisonCorrelogramBySimilarityWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_bins=np.linspace(-0.4, 1, 8), - show_legend=False, - ncols=3, - axes=None, - cmap="winter", - ylim=(0, 0.5), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - if axes is None: - num_axes = len(study.sorter_names) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - self.ncols = ncols - self.cmap = cmap - self.study = study - self.metric = metric - self.similarity_bins = np.asarray(similarity_bins) - self.show_legend = show_legend - self.ylim = ylim - - def plot(self): - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - - self.study.precompute_scores_by_similarities() - time_bins = self.study.time_bins - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - result = self.study.get_error_profile_over_similarity_bins(self.similarity_bins, sorter_name) - - # plot by similarity bins - ax = self.axes.flatten()[sorter_ind] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot(time_bins, result[(cmin, cmax)], label="CS in [%g,%g]" % (cmin, cmax), c=colorVal) - - if np.mod(sorter_ind, self.ncols) == 0: - ax.set_ylabel("cc error") - - if sorter_ind >= (len(self.study.sorter_names) // self.ncols): - ax.set_xlabel("lags (ms)") - - ax.set_title(sorter_name) - if self.show_legend: - ax.legend() - - if self.ylim is not None: - ax.set_ylim(self.ylim) - - -class StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_ranges=np.linspace(-0.4, 1, 8), - show_legend=False, - ax=None, - show_std=False, - ylim=(0, 0.5), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.show_std = show_std - self.ylim = None - self.similarity_ranges = np.asarray(similarity_ranges) - self.show_legend = show_legend - - def plot(self): - self.study.precompute_scores_by_similarities() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - all_similarities = self.study.all_similarities[sorter_name] - all_errors = self.study.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - mean_rerrors = [] - std_errors = [] - for i in range(self.similarity_ranges.size - 1): - cmin, cmax = self.similarity_ranges[i], self.similarity_ranges[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_rerrors += [np.nanmean(all_errors[amin:amax])] - std_errors += [np.nanstd(all_errors[amin:amax])] - - xaxis = np.diff(self.similarity_ranges) / 2 + self.similarity_ranges[:-1] - - if not self.show_std: - self.ax.plot(xaxis, mean_rerrors, label=sorter_name, c="C%d" % sorter_ind) - else: - self.ax.errorbar(xaxis, mean_rerrors, yerr=std_errors, label=sorter_name, c="C%d" % sorter_ind) - - self.ax.set_ylabel("cc error") - self.ax.set_xlabel("similarity") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_study_comparison_correlogram_by_similarity(*args, **kwargs): - W = StudyComparisonCorrelogramBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_correlogram_by_similarity.__doc__ = StudyComparisonCorrelogramBySimilarityWidget.__doc__ - -# def plot_study_comparison_Correlogram_by_similarity_range(*args, **kwargs): -# W = StudyComparisonCorrelogramBySimilarityRangeWidget(*args, **kwargs) -# W.plot() -# return W -# plot_study_comparison_Correlogram_by_similarity_range.__doc__ = StudyComparisonCorrelogramBySimilarityRangeWidget.__doc__ - - -def plot_study_comparison_correlogram_by_similarity_ranges_mean_error(*args, **kwargs): - W = StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_correlogram_by_similarity_ranges_mean_error.__doc__ = ( - StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget.__doc__ -) From 657437ff9821310604157ed4206102674e0d4497 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 15:03:33 +0200 Subject: [PATCH 10/25] remove old correlogram comp widget --- src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 2db2ce1428..afcb04ba76 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -20,7 +20,6 @@ plot_study_comparison_collision_by_similarity_ranges, ) -from .sortingperformance import plot_sorting_performance # correlogram comparison From 3b5296d1f588cca0bb6401ccf0e1c9b5b2f1d244 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 17:47:45 +0200 Subject: [PATCH 11/25] Move collision widgets to new widget API --- .../widgets/_legacy_mpl_widgets/__init__.py | 21 - .../_legacy_mpl_widgets/collisioncomp.py | 503 ------------------ src/spikeinterface/widgets/collision.py | 296 +++++++++++ .../widgets/utils_matplotlib.py | 2 +- src/spikeinterface/widgets/widget_list.py | 6 +- 5 files changed, 302 insertions(+), 526 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py create mode 100644 src/spikeinterface/widgets/collision.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index afcb04ba76..e9cdbe848f 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -7,25 +7,4 @@ plot_multicomp_agreement_by_sorter, MultiCompAgreementBySorterWidget, ) -from .collisioncomp import ( - plot_comparison_collision_pair_by_pair, - ComparisonCollisionPairByPairWidget, - plot_comparison_collision_by_similarity, - ComparisonCollisionBySimilarityWidget, - plot_study_comparison_collision_by_similarity, - StudyComparisonCollisionBySimilarityWidget, - plot_study_comparison_collision_by_similarity_range, - StudyComparisonCollisionBySimilarityRangeWidget, - StudyComparisonCollisionBySimilarityRangesWidget, - plot_study_comparison_collision_by_similarity_ranges, -) - - -# correlogram comparison -from .correlogramcomp import ( - StudyComparisonCorrelogramBySimilarityWidget, - plot_study_comparison_correlogram_by_similarity, - StudyComparisonCorrelogramBySimilarityRangesMeanErrorWidget, - plot_study_comparison_correlogram_by_similarity_ranges_mean_error, -) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py deleted file mode 100644 index 468b96ff3b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ /dev/null @@ -1,503 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ComparisonCollisionPairByPairWidget(BaseWidget): - """ - Plots CollisionGTComparison pair by pair. - - Parameters - ---------- - comp: CollisionGTComparison - The collision ground truth comparison object - unit_ids: list - List of considered units - nbins: int - Number of bins - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__(self, comp, unit_ids=None, figure=None, ax=None): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, figure, ax) - if unit_ids is None: - # take all units - unit_ids = comp.sorting1.get_unit_ids() - - self.comp = comp - self.unit_ids = unit_ids - - def plot(self): - self._do_plot() - - def _do_plot(self): - from matplotlib import pyplot as plt - - fig = self.figure - - for ax in fig.axes: - ax.remove() - - n = len(self.unit_ids) - gs = gridspec.GridSpec(ncols=n, nrows=n, figure=fig) - - axs = np.empty((n, n), dtype=object) - ax = None - for r in range(n): - for c in range(n): - ax = fig.add_subplot(gs[r, c], sharex=ax, sharey=ax) - if c > 0: - plt.setp(ax.get_yticklabels(), visible=False) - if r < n - 1: - plt.setp(ax.get_xticklabels(), visible=False) - axs[r, c] = ax - - fs = self.comp.sorting1.get_sampling_frequency() - - lags = self.comp.bins / fs * 1000 - width = lags[1] - lags[0] - - for r in range(n): - for c in range(r + 1, n): - ax = axs[r, c] - - u1 = self.unit_ids[r] - u2 = self.unit_ids[c] - ind1 = self.comp.sorting1.id_to_index(u1) - ind2 = self.comp.sorting1.id_to_index(u2) - - tp = self.comp.all_tp[ind1, ind2, :] - fn = self.comp.all_fn[ind1, ind2, :] - ax.bar(lags[:-1], tp, width=width, color="g", align="edge") - ax.bar(lags[:-1], fn, width=width, bottom=tp, color="r", align="edge") - - ax = axs[c, r] - tp = self.comp.all_tp[ind2, ind1, :] - fn = self.comp.all_fn[ind2, ind1, :] - ax.bar(lags[:-1], tp, width=width, color="g", align="edge") - ax.bar(lags[:-1], fn, width=width, bottom=tp, color="r", align="edge") - - for r in range(n): - ax = axs[r, 0] - u1 = self.unit_ids[r] - ax.set_ylabel(f"gt id{u1}") - - for c in range(n): - ax = axs[0, c] - u2 = self.unit_ids[c] - ax.set_title(f"collision with \ngt id{u2}") - - ax = axs[-1, 0] - ax.set_xlabel("collision lag [ms]") - - -class ComparisonCollisionBySimilarityWidget(BaseWidget): - """ - Plots CollisionGTComparison pair by pair orderer by cosine_similarity - - Parameters - ---------- - comp: CollisionGTComparison - The collision ground truth comparison object - templates: array - template of units - mode: 'heatmap' or 'lines' - to see collision curves for every pairs ('heatmap') or as lines averaged over pairs. - similarity_bins: array - if mode is 'lines', the bins used to average the pairs - cmap: string - colormap used to show averages if mode is 'lines' - metric: 'cosine_similarity' - metric for ordering - good_only: True - keep only the pairs with a non zero accuracy (found templates) - min_accuracy: float - If good only, the minimum accuracy every cell should have, individually, to be - considered in a putative pair - unit_ids: list - List of considered units - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - """ - - def __init__( - self, - comp, - templates, - unit_ids=None, - metric="cosine_similarity", - figure=None, - ax=None, - mode="heatmap", - similarity_bins=np.linspace(-0.4, 1, 8), - cmap="winter", - good_only=True, - min_accuracy=0.9, - show_legend=False, - ylim=(0, 1), - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, figure, ax) - - assert mode in ["heatmap", "lines"] - - if unit_ids is None: - # take all units - unit_ids = comp.sorting1.get_unit_ids() - - self.comp = comp - self.cmap = cmap - self.mode = mode - self.ylim = ylim - self.show_legend = show_legend - self.similarity_bins = similarity_bins - self.templates = templates - self.unit_ids = unit_ids - self.metric = metric - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - self._do_plot() - - def _do_plot(self): - import sklearn - import matplotlib.pyplot as plt - import matplotlib - - # compute similarity - # take index of template (respect unit_ids order) - all_unit_ids = list(self.comp.sorting1.get_unit_ids()) - template_inds = [all_unit_ids.index(u) for u in self.unit_ids] - - templates = self.templates[template_inds, :, :].copy() - flat_templates = templates.reshape(templates.shape[0], -1) - if self.metric == "cosine_similarity": - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - else: - raise NotImplementedError("metric=...") - - fs = self.comp.sorting1.get_sampling_frequency() - lags = self.comp.bins / fs * 1000 - - n = len(self.unit_ids) - - similarities, recall_scores, pair_names = self.comp.compute_collision_by_similarity( - similarity_matrix, unit_ids=self.unit_ids, good_only=self.good_only, min_accuracy=self.min_accuracy - ) - - if self.mode == "heatmap": - fig = self.figure - for ax in fig.axes: - ax.remove() - - n_pair = len(similarities) - - ax0 = fig.add_axes([0.1, 0.1, 0.25, 0.8]) - ax1 = fig.add_axes([0.4, 0.1, 0.5, 0.8], sharey=ax0) - - plt.setp(ax1.get_yticklabels(), visible=False) - - im = ax1.imshow( - recall_scores[::-1, :], - cmap="viridis", - aspect="auto", - interpolation="none", - extent=(lags[0], lags[-1], -0.5, n_pair - 0.5), - ) - im.set_clim(0, 1) - - ax0.plot(similarities, np.arange(n_pair), color="k") - - ax0.set_yticks(np.arange(n_pair)) - ax0.set_yticklabels(pair_names) - # ax0.set_xlim(0,1) - - ax0.set_xlabel(self.metric) - ax0.set_ylabel("pairs") - - ax1.set_xlabel("lag (ms)") - elif self.mode == "lines": - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - - # plot by similarity bins - if self.ax is None: - fig, ax = plt.subplots() - else: - ax = self.ax - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - order = np.argsort(similarities) - similarities = similarities[order] - recall_scores = recall_scores[order, :] - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - - amin, amax = np.searchsorted(similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(recall_scores[amin:amax], axis=0) - - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, - mean_recall_scores, - label="CS in [%g,%g]" % (cmin, cmax), - c=colorVal, - ) - - if self.show_legend: - ax.legend() - ax.set_ylim(self.ylim) - ax.set_xlabel("lags (ms)") - ax.set_ylabel("collision accuracy") - - -class StudyComparisonCollisionBySimilarityWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_bins=np.linspace(-0.4, 1, 8), - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ncols=3, - axes=None, - cmap="winter", - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - if axes is None: - num_axes = len(study.sorter_names) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - self.ncols = ncols - self.study = study - self.metric = metric - self.cmap = cmap - self.similarity_bins = np.asarray(similarity_bins) - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - my_cmap = plt.get_cmap(self.cmap) - cNorm = matplotlib.colors.Normalize(vmin=self.similarity_bins.min(), vmax=self.similarity_bins.max()) - scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - curves = self.study.get_lag_profile_over_similarity_bins(self.similarity_bins, sorter_name) - - # plot by similarity bins - ax = self.axes.flatten()[sorter_ind] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - for i in range(self.similarity_bins.size - 1): - cmin, cmax = self.similarity_bins[i], self.similarity_bins[i + 1] - colorVal = scalarMap.to_rgba((cmin + cmax) / 2) - ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, - curves[(cmin, cmax)], - label="CS in [%g,%g]" % (cmin, cmax), - c=colorVal, - ) - - if np.mod(sorter_ind, self.ncols) == 0: - ax.set_ylabel("collision accuracy") - - if sorter_ind > (len(self.study.sorter_names) // self.ncols): - ax.set_xlabel("lags (ms)") - - ax.set_title(sorter_name) - if self.show_legend: - ax.legend() - - if self.ylim is not None: - ax.set_ylim(self.ylim) - - -class StudyComparisonCollisionBySimilarityRangeWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_range=[0, 1], - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ax=None, - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.similarity_range = similarity_range - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.min_accuracy = min_accuracy - - def plot(self): - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - mean_recall_scores = self.study.get_mean_over_similarity_range(self.similarity_range, sorter_name) - self.ax.plot( - lags[:-1] + (lags[1] - lags[0]) / 2, mean_recall_scores, label=sorter_name, c="C%d" % sorter_ind - ) - - self.ax.set_ylabel("collision accuracy") - self.ax.set_xlabel("lags (ms)") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -class StudyComparisonCollisionBySimilarityRangesWidget(BaseWidget): - def __init__( - self, - study, - metric="cosine_similarity", - similarity_ranges=np.linspace(-0.4, 1, 8), - show_legend=False, - ylim=(0.5, 1), - good_only=True, - min_accuracy=0.9, - ax=None, - show_std=False, - ): - from matplotlib import pyplot as plt - import matplotlib.gridspec as gridspec - import matplotlib.colors - - BaseWidget.__init__(self, None, ax) - - self.study = study - self.metric = metric - self.similarity_ranges = similarity_ranges - self.show_legend = show_legend - self.ylim = ylim - self.good_only = good_only - self.show_std = show_std - self.min_accuracy = min_accuracy - - def plot(self): - self.study.precompute_scores_by_similarities(self.good_only, min_accuracy=self.min_accuracy) - lags = self.study.get_lags() - - for sorter_ind, sorter_name in enumerate(self.study.sorter_names): - all_similarities = self.study.all_similarities[sorter_name] - all_recall_scores = self.study.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = [] - std_recall_scores = [] - for i in range(self.similarity_ranges.size - 1): - cmin, cmax = self.similarity_ranges[i], self.similarity_ranges[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores += [np.nanmean(all_recall_scores[amin:amax])] - std_recall_scores += [np.nanstd(all_recall_scores[amin:amax])] - - xaxis = np.diff(self.similarity_ranges) / 2 + self.similarity_ranges[:-1] - - if not self.show_std: - self.ax.plot(xaxis, mean_recall_scores, label=sorter_name, c="C%d" % sorter_ind) - else: - self.ax.errorbar( - xaxis, mean_recall_scores, yerr=std_recall_scores, label=sorter_name, c="C%d" % sorter_ind - ) - - self.ax.set_ylabel("collision accuracy") - self.ax.set_xlabel("similarity") - - if self.show_legend: - self.ax.legend() - - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - -def plot_comparison_collision_pair_by_pair(*args, **kwargs): - W = ComparisonCollisionPairByPairWidget(*args, **kwargs) - W.plot() - return W - - -plot_comparison_collision_pair_by_pair.__doc__ = ComparisonCollisionPairByPairWidget.__doc__ - - -def plot_comparison_collision_by_similarity(*args, **kwargs): - W = ComparisonCollisionBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_comparison_collision_by_similarity.__doc__ = ComparisonCollisionBySimilarityWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity.__doc__ = StudyComparisonCollisionBySimilarityWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity_range(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityRangeWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity_range.__doc__ = StudyComparisonCollisionBySimilarityRangeWidget.__doc__ - - -def plot_study_comparison_collision_by_similarity_ranges(*args, **kwargs): - W = StudyComparisonCollisionBySimilarityRangesWidget(*args, **kwargs) - W.plot() - return W - - -plot_study_comparison_collision_by_similarity_ranges.__doc__ = StudyComparisonCollisionBySimilarityRangesWidget.__doc__ diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py new file mode 100644 index 0000000000..3932226231 --- /dev/null +++ b/src/spikeinterface/widgets/collision.py @@ -0,0 +1,296 @@ + + + +import numpy as np + +from .base import BaseWidget, to_attr + + +class ComparisonCollisionBySimilarityWidget(BaseWidget): + """ + Plots CollisionGTComparison pair by pair orderer by cosine_similarity + + Parameters + ---------- + comp: CollisionGTComparison + The collision ground truth comparison object + templates: array + template of units + mode: 'heatmap' or 'lines' + to see collision curves for every pairs ('heatmap') or as lines averaged over pairs. + similarity_bins: array + if mode is 'lines', the bins used to average the pairs + cmap: string + colormap used to show averages if mode is 'lines' + metric: 'cosine_similarity' + metric for ordering + good_only: True + keep only the pairs with a non zero accuracy (found templates) + min_accuracy: float + If good only, the minimum accuracy every cell should have, individually, to be + considered in a putative pair + unit_ids: list + List of considered units + """ + + def __init__( + self, + comp, + templates, + unit_ids=None, + metric="cosine_similarity", + figure=None, + ax=None, + mode="heatmap", + similarity_bins=np.linspace(-0.4, 1, 8), + cmap="winter", + good_only=False, + min_accuracy=0.9, + show_legend=False, + ylim=(0, 1), + + backend=None, + **backend_kwargs, + ): + + if unit_ids is None: + unit_ids = comp.sorting1.get_unit_ids() + + data_plot = dict( + comp=comp, + templates=templates, + unit_ids=unit_ids, + metric=metric, + mode=mode, + similarity_bins=similarity_bins, + cmap=cmap, + good_only=good_only, + min_accuracy=min_accuracy, + show_legend=show_legend, + ylim=ylim, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import sklearn + import matplotlib + + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.comp + + # compute similarity + # take index of template (respect unit_ids order) + all_unit_ids = list(comp.sorting1.get_unit_ids()) + template_inds = [all_unit_ids.index(u) for u in dp.unit_ids] + + templates = dp.templates[template_inds, :, :].copy() + flat_templates = templates.reshape(templates.shape[0], -1) + if dp.metric == "cosine_similarity": + similarity_matrix = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + else: + raise NotImplementedError("metric=...") + + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000 + + n = len(dp.unit_ids) + + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity_matrix, unit_ids=dp.unit_ids, good_only=dp.good_only, min_accuracy=dp.min_accuracy + ) + + if dp.mode == "heatmap": + fig = self.figure + for ax in fig.axes: + ax.remove() + + n_pair = len(similarities) + + ax0 = fig.add_axes([0.1, 0.1, 0.25, 0.8]) + ax1 = fig.add_axes([0.4, 0.1, 0.5, 0.8], sharey=ax0) + + plt.setp(ax1.get_yticklabels(), visible=False) + + im = ax1.imshow( + recall_scores[::-1, :], + cmap="viridis", + aspect="auto", + interpolation="none", + extent=(lags[0], lags[-1], -0.5, n_pair - 0.5), + ) + im.set_clim(0, 1) + + ax0.plot(similarities, np.arange(n_pair), color="k") + + ax0.set_yticks(np.arange(n_pair)) + ax0.set_yticklabels(pair_names) + # ax0.set_xlim(0,1) + + ax0.set_xlabel(dp.metric) + ax0.set_ylabel("pairs") + + ax1.set_xlabel("lag (ms)") + elif dp.mode == "lines": + my_cmap = plt.get_cmap(dp.cmap) + cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) + scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) + + # plot by similarity bins + if self.ax is None: + fig, ax = plt.subplots() + else: + ax = self.ax + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + order = np.argsort(similarities) + similarities = similarities[order] + recall_scores = recall_scores[order, :] + + for i in range(dp.similarity_bins.size - 1): + cmin, cmax = dp.similarity_bins[i], dp.similarity_bins[i + 1] + + amin, amax = np.searchsorted(similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(recall_scores[amin:amax], axis=0) + + colorVal = scalarMap.to_rgba((cmin + cmax) / 2) + ax.plot( + lags[:-1] + (lags[1] - lags[0]) / 2, + mean_recall_scores, + label="CS in [%g,%g]" % (cmin, cmax), + c=colorVal, + ) + + if dp.show_legend: + ax.legend() + ax.set_ylim(dp.ylim) + ax.set_xlabel("lags (ms)") + ax.set_ylabel("collision recall") + + + + +class StudyComparisonCollisionBySimilarityWidget(BaseWidget): + """ + Plots CollisionGTComparison pair by pair orderer by cosine_similarity for all + cases in a study. + + Parameters + ---------- + study: CollisionGTStudy + The collision study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + metric: 'cosine_similarity' + metric for ordering + similarity_bins: array + if mode is 'lines', the bins used to average the pairs + cmap: string + colormap used to show averages if mode is 'lines' + good_only: False + keep only the pairs with a non zero accuracy (found templates) + min_accuracy: float + If good only, the minimum accuracy every cell should have, individually, to be + considered in a putative pair + """ + + def __init__( + self, + study, + case_keys=None, + metric="cosine_similarity", + similarity_bins=np.linspace(-0.4, 1, 8), + show_legend=False, + ylim=(0.5, 1), + good_only=False, + min_accuracy=0.9, + cmap="winter", + + backend=None, + **backend_kwargs, + ): + + if case_keys is None: + case_keys = list(study.cases.keys()) + + data_plot = dict( + study=study, + case_keys=case_keys, + metric=metric, + similarity_bins=similarity_bins, + show_legend=show_legend, + ylim=ylim, + good_only=good_only, + min_accuracy=min_accuracy, + cmap=cmap, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import sklearn + import matplotlib + + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + num_axes = len(dp.case_keys) + backend_kwargs["num_axes"] = num_axes + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + study = dp.study + + + my_cmap = plt.get_cmap(dp.cmap) + cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) + scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) + study.precompute_scores_by_similarities(case_keys=dp.case_keys, good_only=dp.good_only, min_accuracy=dp.min_accuracy, ) + + + for count, key in enumerate(dp.case_keys): + + lags = study.get_lags(key) + + curves = study.get_lag_profile_over_similarity_bins(dp.similarity_bins, key) + + # plot by similarity bins + ax = self.axes.flatten()[count] + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + for i in range(dp.similarity_bins.size - 1): + cmin, cmax = dp.similarity_bins[i], dp.similarity_bins[i + 1] + colorVal = scalarMap.to_rgba((cmin + cmax) / 2) + ax.plot( + lags[:-1] + (lags[1] - lags[0]) / 2, + curves[(cmin, cmax)], + label="CS in [%g,%g]" % (cmin, cmax), + c=colorVal, + ) + + if count % self.axes.shape[1] == 0: + ax.set_ylabel("collision recall") + + if count > (len(dp.case_keys) // self.axes.shape[1]): + ax.set_xlabel("lags (ms)") + + label = study.cases[key]["label"] + ax.set_title(label) + if dp.show_legend: + ax.legend() + + if dp.ylim is not None: + ax.set_ylim(dp.ylim) + \ No newline at end of file diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index a9128d7b66..6e52efcd84 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -50,7 +50,7 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, if num_axes < ncols: ncols = num_axes nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) + figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, squeeze=False) ax = None # remove extra axes if ncols * nrows > num_axes: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f72e0fc3bd..113be0c53b 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -29,7 +29,7 @@ from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics - +from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ AgreementMatrixWidget, @@ -37,6 +37,7 @@ AmplitudesWidget, AutoCorrelogramsWidget, ConfusionMatrixWidget, + ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, @@ -62,6 +63,7 @@ StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics, + StudyComparisonCollisionBySimilarityWidget, ] @@ -102,6 +104,7 @@ plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_confusion_matrix = ConfusionMatrixWidget +plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget @@ -127,6 +130,7 @@ plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances plot_study_performances_vs_metrics = StudyPerformancesVsMetrics +plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget def plot_timeseries(*args, **kwargs): From 5215582c728d32eeaf42bc33f3be3ef2fc6694a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:05:32 +0000 Subject: [PATCH 12/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/_legacy_mpl_widgets/__init__.py | 2 -- src/spikeinterface/widgets/collision.py | 23 ++++--------- src/spikeinterface/widgets/peak_activity.py | 24 +++++++------ .../widgets/tests/test_widgets.py | 9 ++--- src/spikeinterface/widgets/unit_presence.py | 34 +++++++++---------- 5 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index e9cdbe848f..1ee56a7d4c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,4 +1,3 @@ - from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, @@ -7,4 +6,3 @@ plot_multicomp_agreement_by_sorter, MultiCompAgreementBySorterWidget, ) - diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index 3932226231..2b86a2af2d 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -1,6 +1,3 @@ - - - import numpy as np from .base import BaseWidget, to_attr @@ -30,7 +27,7 @@ class ComparisonCollisionBySimilarityWidget(BaseWidget): If good only, the minimum accuracy every cell should have, individually, to be considered in a putative pair unit_ids: list - List of considered units + List of considered units """ def __init__( @@ -48,11 +45,9 @@ def __init__( min_accuracy=0.9, show_legend=False, ylim=(0, 1), - backend=None, **backend_kwargs, ): - if unit_ids is None: unit_ids = comp.sorting1.get_unit_ids() @@ -176,8 +171,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_ylabel("collision recall") - - class StudyComparisonCollisionBySimilarityWidget(BaseWidget): """ Plots CollisionGTComparison pair by pair orderer by cosine_similarity for all @@ -213,11 +206,9 @@ def __init__( good_only=False, min_accuracy=0.9, cmap="winter", - backend=None, **backend_kwargs, ): - if case_keys is None: case_keys = list(study.cases.keys()) @@ -252,19 +243,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = dp.study - my_cmap = plt.get_cmap(dp.cmap) cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) - study.precompute_scores_by_similarities(case_keys=dp.case_keys, good_only=dp.good_only, min_accuracy=dp.min_accuracy, ) - + study.precompute_scores_by_similarities( + case_keys=dp.case_keys, + good_only=dp.good_only, + min_accuracy=dp.min_accuracy, + ) for count, key in enumerate(dp.case_keys): - lags = study.get_lags(key) curves = study.get_lag_profile_over_similarity_bins(dp.similarity_bins, key) - + # plot by similarity bins ax = self.axes.flatten()[count] ax.spines["top"].set_visible(False) @@ -293,4 +285,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.ylim is not None: ax.set_ylim(dp.ylim) - \ No newline at end of file diff --git a/src/spikeinterface/widgets/peak_activity.py b/src/spikeinterface/widgets/peak_activity.py index a80b6db6eb..24d4dc0df9 100644 --- a/src/spikeinterface/widgets/peak_activity.py +++ b/src/spikeinterface/widgets/peak_activity.py @@ -36,7 +36,8 @@ class PeakActivityMapWidget(BaseWidget): """ - def __init__(self, + def __init__( + self, recording, peaks, bin_duration_s=None, @@ -47,7 +48,6 @@ def __init__(self, backend=None, **backend_kwargs, ): - data_plot = dict( recording=recording, peaks=peaks, @@ -60,7 +60,6 @@ def __init__(self, BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) - def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure @@ -71,7 +70,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - rec = dp.recording peaks = dp.peaks @@ -86,7 +84,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): probe = probes[0] if dp.bin_duration_s is None: - self._plot_one_bin(rec, probe, peaks, duration, dp.with_channel_ids, dp.with_contact_color, dp.with_interpolated_map) + self._plot_one_bin( + rec, probe, peaks, duration, dp.with_channel_ids, dp.with_contact_color, dp.with_interpolated_map + ) else: bin_size = int(dp.bin_duration_s * fs) num_frames = int(duration / dp.bin_duration_s) @@ -94,14 +94,21 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def animate_func(i): i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] - artists = self._plot_one_bin(rec, probe, local_peaks, dp.with_channel_ids, dp.bin_duration_s, dp.with_contact_color, dp.with_interpolated_map) + artists = self._plot_one_bin( + rec, + probe, + local_peaks, + dp.with_channel_ids, + dp.bin_duration_s, + dp.with_contact_color, + dp.with_interpolated_map, + ) return artists from matplotlib.animation import FuncAnimation self.animation = FuncAnimation(self.figure, animate_func, frames=num_frames, interval=100, blit=True) - def _plot_one_bin(self, rec, probe, peaks, duration, with_channel_ids, with_contact_color, with_interpolated_map): rates = np.zeros(rec.get_num_channels(), dtype="float64") for chan_ind, chan_id in enumerate(rec.channel_ids): @@ -135,6 +142,3 @@ def _plot_one_bin(self, rec, probe, peaks, duration, with_channel_ids, with_cont artists = artists + (im,) return artists - - - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 24fa740ed9..efb429a52a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -85,6 +85,7 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) from spikeinterface.sortingcomponents.peak_detection import detect_peaks + cls.peaks = detect_peaks(cls.recording, method="locally_exclusive") def test_plot_traces(self): @@ -379,13 +380,13 @@ def test_plot_unit_probe_map(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) - + def test_plot_unit_presence(self): possible_backends = list(sw.UnitPresenceWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_presence(self.sorting) - + def test_plot_peak_activity(self): possible_backends = list(sw.PeakActivityMapWidget.get_possible_backends()) for backend in possible_backends: @@ -393,10 +394,6 @@ def test_plot_peak_activity(self): sw.plot_peak_activity(self.recording, self.peaks) - - - - if __name__ == "__main__": # unittest.main() diff --git a/src/spikeinterface/widgets/unit_presence.py b/src/spikeinterface/widgets/unit_presence.py index 1b60a731dd..3d605936a2 100644 --- a/src/spikeinterface/widgets/unit_presence.py +++ b/src/spikeinterface/widgets/unit_presence.py @@ -21,12 +21,17 @@ class UnitPresenceWidget(BaseWidget): """ - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, - bin_duration_s=0.05, smooth_sigma=4.5, + def __init__( + self, + sorting, + segment_index=None, + unit_ids=None, + time_range=None, + bin_duration_s=0.05, + smooth_sigma=4.5, backend=None, **backend_kwargs, ): - if segment_index is None: nseg = sorting.get_num_segments() if nseg != 1: @@ -34,7 +39,6 @@ def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, else: segment_index = 0 - data_plot = dict( sorting=sorting, segment_index=segment_index, @@ -57,47 +61,43 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = dp.sorting - spikes = sorting.to_spike_vector(concatenated=False, use_cache=True) + spikes = sorting.to_spike_vector(concatenated=False, use_cache=True) spikes = spikes[dp.segment_index] - + fs = sorting.get_sampling_frequency() if dp.time_range is not None: t0, t1 = dp.time_range ind0 = int(t0 * fs) ind1 = int(t1 * fs) - mask = (spikes['sample_index'] >= ind0) & (spikes['sample_index'] <= ind1) + mask = (spikes["sample_index"] >= ind0) & (spikes["sample_index"] <= ind1) spikes = spikes[mask] - - if spikes.size == 0: return - last = spikes['sample_index'][-1] + last = spikes["sample_index"][-1] max_time = last / fs num_units = len(sorting.unit_ids) num_time_bins = int(max_time / dp.bin_duration_s) + 1 map = np.zeros((num_units, num_time_bins)) - ind0 = spikes['unit_index'] - ind1 = spikes['sample_index'] // int(dp.bin_duration_s * fs) + ind0 = spikes["unit_index"] + ind1 = spikes["sample_index"] // int(dp.bin_duration_s * fs) map[ind0, ind1] += 1 if dp.smooth_sigma is not None: import scipy.signal + n = int(dp.smooth_sigma * 5) bins = np.arange(-n, n + 1) smooth_kernel = np.exp(-(bins**2) / (2 * dp.smooth_sigma**2)) smooth_kernel /= np.sum(smooth_kernel) smooth_kernel = smooth_kernel[np.newaxis, :] - map = scipy.signal.oaconvolve(map, smooth_kernel, mode='same', axes=1) + map = scipy.signal.oaconvolve(map, smooth_kernel, mode="same", axes=1) - - - im = self.ax.matshow(map, cmap='inferno', aspect="auto") + im = self.ax.matshow(map, cmap="inferno", aspect="auto") self.ax.set_xlabel("Time (s)") self.ax.set_ylabel("Units") self.figure.colorbar(im) - From 8a73530ffa4edf25d9dfa91418afe550526c3e9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 17:38:40 +0000 Subject: [PATCH 13/25] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.9.1 → 23.10.0](https://github.com/psf/black/compare/23.9.1...23.10.0) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7153a7dfc0..e907976163 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black files: ^src/ From 296751249af7e19ebdbec5586d3d08f31838ca74 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 23 Oct 2023 20:45:55 +0200 Subject: [PATCH 14/25] Update src/spikeinterface/core/generate.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index de69af85f3..7ac6f0cd36 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1392,7 +1392,7 @@ def generate_ground_truth_recording( sorting: Sorting or None An external sorting object. If not provide, one is genrated. probe: Probe or None - An external Probe object. If not provided of probe is generated using generate_probe_kwargs. + An external Probe object. If not provided a probe is generated using generate_probe_kwargs. generate_probe_kwargs: dict A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None From 383ee2ac4f91f88df41c12ddf87f5a745dfcc87b Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 23 Oct 2023 20:46:07 +0200 Subject: [PATCH 15/25] Update src/spikeinterface/core/generate.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 7ac6f0cd36..e876888186 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1394,7 +1394,7 @@ def generate_ground_truth_recording( probe: Probe or None An external Probe object. If not provided a probe is generated using generate_probe_kwargs. generate_probe_kwargs: dict - A dict to constuct the Probe using :pyp:func:`probeinterface.generate_multi_columns_probe()`. + A dict to constuct the Probe using :py:func:`probeinterface.generate_multi_columns_probe()`. templates: np.array or None The templates of units. If None they are generated. From 439d2fed55b376ec0e011c6903d0211d9421d695 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 18:46:14 +0000 Subject: [PATCH 16/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 30 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e876888186..0285d60f4f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -883,7 +883,7 @@ def generate_single_fake_waveform( positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), - propagation_speed=(250., 350.), # ms / um + propagation_speed=(250.0, 350.0), # ms / um ) @@ -986,7 +986,6 @@ def generate_templates( assert unit_params[k].size == num_units params[k] = unit_params[k] else: - if k in unit_params_range: lims = unit_params_range[k] else: @@ -1012,7 +1011,6 @@ def generate_templates( dtype=dtype, ) - ## Add a spatial decay depend on distance from unit to each channel alpha = params["alpha"][u] # the espilon avoid enormous factors @@ -1028,7 +1026,7 @@ def generate_templates( # the speed is um/ms dist = distances[u, :].copy() dist -= np.min(dist) - delay_s = dist / propagation_speed / 1000. + delay_s = dist / propagation_speed / 1000.0 sample_shifts = delay_s * fs # apply the delay with fft transform to get sub sample shift @@ -1050,9 +1048,6 @@ def generate_templates( else: templates[u, :, :] = wfs - - - return templates @@ -1369,7 +1364,7 @@ def generate_ground_truth_recording( ms_after=3.0, upsample_factor=None, upsample_vector=None, - generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.), + generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), generate_templates_kwargs=dict(), @@ -1455,15 +1450,18 @@ def generate_ground_truth_recording( # probe.set_device_channel_indices(np.arange(num_channels)) prb_kwargs = generate_probe_kwargs.copy() - if 'num_contact_per_column' in prb_kwargs: - assert (prb_kwargs['num_contact_per_column'] * prb_kwargs['num_columns']) == num_channels, \ + if "num_contact_per_column" in prb_kwargs: + assert ( + prb_kwargs["num_contact_per_column"] * prb_kwargs["num_columns"] + ) == num_channels, ( "generate_multi_columns_probe : num_channels do not match num_contact_per_column x num_columns" - elif 'num_contact_per_column' not in prb_kwargs and 'num_columns' in prb_kwargs: - n = num_channels // prb_kwargs['num_columns'] - num_contact_per_column = [n] * prb_kwargs['num_columns'] - mid = prb_kwargs['num_columns'] // 2 - num_contact_per_column[mid] += num_channels % prb_kwargs['num_columns'] - prb_kwargs['num_contact_per_column'] = num_contact_per_column + ) + elif "num_contact_per_column" not in prb_kwargs and "num_columns" in prb_kwargs: + n = num_channels // prb_kwargs["num_columns"] + num_contact_per_column = [n] * prb_kwargs["num_columns"] + mid = prb_kwargs["num_columns"] // 2 + num_contact_per_column[mid] += num_channels % prb_kwargs["num_columns"] + prb_kwargs["num_contact_per_column"] = num_contact_per_column else: raise ValueError("num_columns should be provided in dict generate_probe_kwargs") From 2aedd47ec47edc048e312a38434f90200ef98d22 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 20:49:09 +0200 Subject: [PATCH 17/25] generate alpha doc --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0285d60f4f..e879651ae7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -931,7 +931,7 @@ def generate_templates( An optional dict containing parameters per units. Keys are parameter names: - * 'alpha': amplitude of the action potential in a.u. (default range: (5'000-15'000)) + * 'alpha': amplitude of the action potential in a.u. (default range: (6'000-9'000)) * 'depolarization_ms': the depolarization interval in ms (default range: (0.09-0.14)) * 'repolarization_ms': the repolarization interval in ms (default range: (0.5-0.8)) * 'recovery_ms': the recovery interval in ms (default range: (1.0-1.5)) From 82230c875773ef3b8143e7c52081c4ff3a5a9f3f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 20:59:22 +0200 Subject: [PATCH 18/25] clean plot in components benchmarks --- .../benchmark/benchmark_matching.py | 16 ++++++---------- .../benchmark/benchmark_motion_interpolation.py | 1 - 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 4efabbc9c5..eda451e29d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -6,12 +6,9 @@ from spikeinterface.qualitymetrics import compute_quality_metrics from spikeinterface.comparison import CollisionGTComparison, compare_sorter_to_ground_truth from spikeinterface.widgets import ( - plot_sorting_performance, plot_agreement_matrix, plot_comparison_collision_by_similarity, - plot_unit_templates, plot_unit_waveforms, - plot_gt_performances, ) import time @@ -474,13 +471,12 @@ def plot(self, comp, title=None): ax = axs[1, 0] ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) - plot_sorting_performance(comp, self.metrics, performance_name="accuracy", metric_name="snr", ax=ax, color="r") - plot_sorting_performance(comp, self.metrics, performance_name="recall", metric_name="snr", ax=ax, color="g") - plot_sorting_performance(comp, self.metrics, performance_name="precision", metric_name="snr", ax=ax, color="b") - ax.legend(["accuracy", "recall", "precision"]) - - ax = axs[1, 1] - plot_gt_performances(comp, ax=ax) + + for k in ("accuracy", "recall", "precision"): + x = comp.get_performance()[k] + y = self.metrics["snr"] + ax.scatter(x, y, markersize=10, marker=".", label=k) + ax.legend() ax = axs[0, 1] if self.exhaustive_gt: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index b28b29f17c..d0b3f387b1 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -10,7 +10,6 @@ from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten from spikeinterface.sorters import run_sorter, read_sorter_folder -from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances from spikeinterface.comparison import GroundTruthComparison from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording From 08f86acb6d88a3ba8be0442cbb4cfc73f1922f4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:00:02 +0000 Subject: [PATCH 19/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/benchmark/benchmark_matching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index eda451e29d..d961bdbc07 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -471,7 +471,7 @@ def plot(self, comp, title=None): ax = axs[1, 0] ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) - + for k in ("accuracy", "recall", "precision"): x = comp.get_performance()[k] y = self.metrics["snr"] From 50424eb888e35d268b487013d2c9ef5282d53ed0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 23 Oct 2023 21:30:02 +0200 Subject: [PATCH 20/25] oups --- .../_legacy_mpl_widgets/tests/test_widgets_legacy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 9cd321db3c..fda2e75138 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -75,9 +75,9 @@ def tearDown(self): # fig, axes = plt.subplots(self.num_units, 1) # sw.plot_isi_distribution(self._sorting, axes=axes) - def test_plot_peak_activity_map(self): - sw.plot_peak_activity_map(self._rec, with_channel_ids=True) - sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) + # def test_plot_peak_activity_map(self): + # sw.plot_peak_activity_map(self._rec, with_channel_ids=True) + # sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) @@ -87,9 +87,9 @@ def test_multicomp_graph(self): fig, axes = plt.subplots(len(msc.object_list), 1) sw.plot_multicomp_agreement_by_sorter(msc, axes=axes) - def test_sorting_performance(self): - metrics = compute_quality_metrics(self._we, metric_names=["snr"]) - sw.plot_sorting_performance(self._gt_comp, metrics, performance_name="accuracy", metric_name="snr") + # def test_sorting_performance(self): + # metrics = compute_quality_metrics(self._we, metric_names=["snr"]) + # sw.plot_sorting_performance(self._gt_comp, metrics, performance_name="accuracy", metric_name="snr") # ~ def test_plot_unit_summary(self): # ~ unit_id = self._sorting.unit_ids[4] From 66280e4354fa3e8a8f1c1b923005bf882c78d5a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 07:14:21 +0200 Subject: [PATCH 21/25] Update src/spikeinterface/core/generate.py --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e879651ae7..44ea02d32c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -883,7 +883,7 @@ def generate_single_fake_waveform( positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), decay_power=(1.4, 1.8), - propagation_speed=(250.0, 350.0), # ms / um + propagation_speed=(250.0, 350.0), # um / ms ) From 9e32c74069069b6d35bfa8e3a816f85f160b4f66 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 07:32:18 +0200 Subject: [PATCH 22/25] Update docs with new function names --- .../modules_gallery/widgets/plot_2_sort_gallery.py | 10 +++++----- .../modules_gallery/widgets/plot_4_peaks_gallery.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/modules_gallery/widgets/plot_2_sort_gallery.py b/examples/modules_gallery/widgets/plot_2_sort_gallery.py index 2a5ef30e28..c07070927b 100644 --- a/examples/modules_gallery/widgets/plot_2_sort_gallery.py +++ b/examples/modules_gallery/widgets/plot_2_sort_gallery.py @@ -16,13 +16,13 @@ ############################################################################## # plot_rasters() -# ~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~ w_rs = sw.plot_rasters(sorting) ############################################################################## # plot_isi_distribution() -# ~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~ w_isi = sw.plot_isi_distribution(sorting, window_ms=150.0, bin_ms=5.0) @@ -43,10 +43,10 @@ ############################################################################## -# plot_presence() -# ~~~~~~~~~~~~~~~~~~~~~~~~ +# plot_unit_presence() +# ~~~~~~~~~~~~~~~~~~~~ -w_pr = sw.plot_presence(sorting) +w_pr = sw.plot_unit_presence(sorting) plt.show() diff --git a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py index addd87c065..60733afb1d 100644 --- a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py +++ b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py @@ -44,14 +44,14 @@ ############################################################################## # This "peaks" vector can be used in several widgets, for instance -# plot_peak_activity_map() +# plot_peak_activity() -si.plot_peak_activity_map(rec_filtred, peaks=peaks) +si.plot_peak_activity(rec_filtred, peaks=peaks) ############################################################################## # can be also animated with bin_duration_s=1. -si.plot_peak_activity_map(rec_filtred, bin_duration_s=1.) +si.plot_peak_activity(rec_filtred, bin_duration_s=1.) plt.show() From 9cfaaa9d9782e4f7a26866563dd886439df59e6c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 07:53:53 +0200 Subject: [PATCH 23/25] wip: port multi-comparison widgets --- src/spikeinterface/widgets/multicomparison_graph.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/spikeinterface/widgets/multicomparison_graph.py diff --git a/src/spikeinterface/widgets/multicomparison_graph.py b/src/spikeinterface/widgets/multicomparison_graph.py new file mode 100644 index 0000000000..e69de29bb2 From 13c6934ad5e93b6283411ff2e96cca3132524e84 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 08:31:59 +0200 Subject: [PATCH 24/25] Port last legacy widgets and remove _legacy folder --- examples/how_to/get_started.py | 4 +- .../tests/test_multisortingcomparison.py | 7 - src/spikeinterface/widgets/__init__.py | 6 - .../widgets/_legacy_mpl_widgets/__init__.py | 8 - .../widgets/_legacy_mpl_widgets/basewidget.py | 81 ----- .../_legacy_mpl_widgets/multicompgraph.py | 312 ------------------ .../tests/test_legacy_widgets_utils.py | 21 -- .../tests/test_widgets_legacy.py | 128 ------- .../widgets/_legacy_mpl_widgets/utils.py | 37 --- .../widgets/multicomparison_agreement.py | 188 +++++++++++ .../widgets/multicomparison_graph.py | 111 +++++++ .../widgets/tests/test_widgets.py | 39 ++- src/spikeinterface/widgets/widget_list.py | 8 + 13 files changed, 337 insertions(+), 613 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py create mode 100644 src/spikeinterface/widgets/multicomparison_agreement.py diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 7860c605af..329a2b32b0 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -337,8 +337,8 @@ print('Units in agreement between TDC, SC2, and KS2:', sorting_agreement.get_unit_ids()) -w_multi = sw.plot_multicomp_agreement(comp_multi) -w_multi = sw.plot_multicomp_agreement_by_sorter(comp_multi) +w_multi = sw.plot_multicomparison_agreement(comp_multi) +w_multi = sw.plot_multicomparison_agreement_by_sorter(comp_multi) # - # We see that 10 unit were found by all sorters (note that this simulated dataset is a very simple example, and usually sorters do not do such a great job)! diff --git a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py index 0588cca66d..7e4c9ac77b 100644 --- a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py @@ -69,13 +69,6 @@ def test_compare_multiple_sorters(): msc = MultiSortingComparison.load_from_folder(multicomparison_folder) - # import spikeinterface.widgets as sw - # import matplotlib.pyplot as plt - # sw.plot_multicomp_graph(msc) - # sw.plot_multicomp_agreement(msc) - # sw.plot_multicomp_agreement_by_sorter(msc) - # plt.show() - def test_compare_multi_segment(): num_segments = 3 diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index d3066f51fa..d6d181f3fe 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -3,9 +3,3 @@ # general functions from .utils import get_some_colors, get_unit_colors, array_to_image from .base import set_default_plotter_backend, get_default_plotter_backend - - -# we keep this to keep compatibility so we have all previous widgets -# except the one that have been ported that are imported -# with "from .widget_list import *" in the first line -from ._legacy_mpl_widgets import * diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py deleted file mode 100644 index 1ee56a7d4c..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .multicompgraph import ( - plot_multicomp_graph, - MultiCompGraphWidget, - plot_multicomp_agreement, - MultiCompGlobalAgreementWidget, - plot_multicomp_agreement_by_sorter, - MultiCompAgreementBySorterWidget, -) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py deleted file mode 100644 index f69f9a1b0f..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/basewidget.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np - - -# This class replace the old BaseWidget and BaseMultiWidget -class BaseWidget: - def __init__(self, figure=None, ax=None, axes=None, ncols=None, num_axes=None): - """ - figure/ax/axes : only one of then can be not None - """ - import matplotlib.pyplot as plt - - from matplotlib import gridspec - - if figure is not None: - assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - elif ax is not None: - assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" - figure = ax.get_figure() - axes = np.array([[ax]]) - elif axes is not None: - assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" - axes = np.asarray(axes) - figure = axes.flatten()[0].get_figure() - else: - # one fig with one ax - if num_axes is None: - figure, ax = plt.subplots() - axes = np.array([[ax]]) - else: - if num_axes == 0: - # one figure without plots (diffred subplot creation with - figure = plt.figure() - ax = None - axes = None - elif num_axes == 1: - figure = plt.figure() - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - if num_axes < ncols: - ncols = num_axes - nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots( - nrows=nrows, - ncols=ncols, - ) - ax = None - # remove extra axes - if ncols * nrows > num_axes: - for extra_ax in axes.flatten()[num_axes:]: - extra_ax.remove() - - self.figure = figure - self.ax = ax - # axes is a 2D array of ax - self.axes = axes - # self.figure.axes is the flatten of all axes - - -class DataWidget: - def __init__(self) -> None: - self.plotter = None - pass - - def _prepare_data(self): - raise NotImplementedError - - -# keep here just in case it is needed - -# def create_tiled_ax(self, i, nrows, ncols, hspace=0.3, wspace=0.3, is_diag=False): -# gs = gridspec.GridSpecFromSubplotSpec(int(nrows), int(ncols), subplot_spec=self.ax, -# hspace=hspace, wspace=wspace) -# r = int(i // ncols) -# c = int(np.mod(i, ncols)) -# gs_sel = gs[r, c] -# ax = self.figure.add_subplot(gs_sel) -# return ax diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py deleted file mode 100644 index 47e0c026f0..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/multicompgraph.py +++ /dev/null @@ -1,312 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class MultiCompGraphWidget(BaseWidget): - """ - Plots multi comparison graph. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - draw_labels: bool - If True unit labels are shown - node_cmap: matplotlib colormap - The colormap to be used for the nodes (default 'viridis') - edge_cmap: matplotlib colormap - The colormap to be used for the edges (default 'hot') - alpha_edges: float - Alpha value for edges - colorbar: bool - If True a colorbar for the edges is plotted - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__( - self, - multi_comparison, - draw_labels=False, - node_cmap="viridis", - edge_cmap="hot", - alpha_edges=0.5, - colorbar=False, - figure=None, - ax=None, - ): - import matplotlib - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._msc = multi_comparison - self._draw_labels = draw_labels - self._node_cmap = node_cmap - self._edge_cmap = edge_cmap - self._colorbar = colorbar - self._alpha_edges = alpha_edges - self.name = "MultiCompGraph" - - def plot(self): - self._do_plot() - - def _do_plot(self): - import networkx as nx - - g = self._msc.graph - edge_col = [] - for e in g.edges(data=True): - n1, n2, d = e - edge_col.append(d["weight"]) - nodes_col_dict = {} - for i, sort_name in enumerate(self._msc.name_list): - nodes_col_dict[sort_name] = i - nodes_col = [] - for node in sorted(g.nodes): - nodes_col.append(nodes_col_dict[node[0]]) - nodes_col = np.array(nodes_col) / len(self._msc.name_list) - import matplotlib.pyplot as plt - - _ = plt.set_cmap(self._node_cmap) - _ = nx.draw_networkx_nodes( - g, - pos=nx.circular_layout(sorted(g)), - nodelist=sorted(g.nodes), - node_color=nodes_col, - node_size=20, - ax=self.ax, - ) - _ = nx.draw_networkx_edges( - g, - pos=nx.circular_layout((sorted(g))), - nodelist=sorted(g.nodes), - edge_color=edge_col, - alpha=self._alpha_edges, - edge_cmap=plt.cm.get_cmap(self._edge_cmap), - edge_vmin=self._msc.match_score, - edge_vmax=1, - ax=self.ax, - ) - if self._draw_labels: - labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} - pos = nx.circular_layout(sorted(g)) - # extend position radially - pos_extended = {} - for node, pos in pos.items(): - pos_new = pos + 0.1 * pos - pos_extended[node] = pos_new - _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) - - if self._colorbar: - import matplotlib - import matplotlib.pyplot as plt - - norm = matplotlib.colors.Normalize(vmin=self._msc.match_score, vmax=1) - cmap = plt.cm.get_cmap(self._edge_cmap) - m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) - self.figure.colorbar(m) - - self.ax.axis("off") - - -class MultiCompGlobalAgreementWidget(BaseWidget): - """ - Plots multi comparison agreement as pie or bar plot. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap - The colormap to be used for the nodes (default 'Reds') - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__(self, multi_comparison, plot_type="pie", cmap="YlOrRd", fs=10, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - import matplotlib - import matplotlib.pyplot as plt - - self._msc = multi_comparison - self._type = plot_type - self._cmap = cmap - self._fs = fs - self.name = "MultiCompGlobalAgreement" - - def plot(self): - self._do_plot() - - def _do_plot(self): - import matplotlib.pyplot as plt - - cmap = plt.get_cmap(self._cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(self._msc.name_list))]) - sg_names, sg_units = self._msc.compute_subgraphs() - # fraction of units with agreement > threshold - v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True) - if self._type == "pie": - p = self.ax.pie(c, colors=colors[v - 1], autopct=lambda pct: _getabs(pct, c), pctdistance=1.25) - self.ax.legend( - p[0], - v, - frameon=False, - title="k=", - handlelength=1, - handletextpad=0.5, - bbox_to_anchor=(1.0, 1.0), - loc=2, - borderaxespad=0.5, - labelspacing=0.15, - fontsize=self._fs, - ) - elif self._type == "bar": - self.ax.bar(v, c, color=colors[v - 1]) - x_labels = [f"k={vi}" for vi in v] - self.ax.spines["top"].set_visible(False) - self.ax.spines["right"].set_visible(False) - self.ax.set_xticks(v) - self.ax.set_xticklabels(x_labels) - else: - raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") - self.ax.set_title("Units agreed upon\nby k sorters") - - -class MultiCompAgreementBySorterWidget(BaseWidget): - """ - Plots multi comparison agreement as pie or bar plot. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - plot_type: str - 'pie' or 'bar' - cmap: matplotlib colormap - The colormap to be used for the nodes (default 'Reds') - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored. - show_legend: bool - Show the legend in the last axes (default True). - - Returns - ------- - W: MultiCompGraphWidget - The output widget - """ - - def __init__(self, multi_comparison, plot_type="pie", cmap="YlOrRd", fs=9, axes=None, show_legend=True): - import matplotlib.pyplot as plt - - self._msc = multi_comparison - self._type = plot_type - self._cmap = cmap - self._fs = fs - self._show_legend = show_legend - self.name = "MultiCompAgreementBySorterWidget" - - if axes is None: - ncols = len(self._msc.name_list) - fig, axes = plt.subplots(nrows=1, ncols=ncols, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - self._do_plot() - - def _do_plot(self): - name_list = self._msc.name_list - import matplotlib.pyplot as plt - - cmap = plt.get_cmap(self._cmap) - colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(self._msc.name_list))]) - sg_names, sg_units = self._msc.compute_subgraphs() - # fraction of units with agreement > threshold - for i, name in enumerate(name_list): - ax = self.axes[i] - v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True) - if self._type == "pie": - p = ax.pie( - c, - colors=colors[v - 1], - textprops={"color": "k", "fontsize": self._fs}, - autopct=lambda pct: _getabs(pct, c), - pctdistance=1.18, - ) - if (self._show_legend) and (i == len(name_list) - 1): - plt.legend( - p[0], - v, - frameon=False, - title="k=", - handlelength=1, - handletextpad=0.5, - bbox_to_anchor=(1.15, 1.25), - loc=2, - borderaxespad=0.0, - labelspacing=0.15, - ) - elif self._type == "bar": - ax.bar(v, c, color=colors[v - 1]) - x_labels = [f"k={vi}" for vi in v] - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks(v) - ax.set_xticklabels(x_labels) - else: - raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") - ax.set_title(name) - if self._type == "bar": - ylims = [np.max(ax_single.get_ylim()) for ax_single in self.axes] - max_yval = np.max(ylims) - for ax_single in self.axes: - ax_single.set_ylim([0, max_yval]) - - -def _getabs(pct, allvals): - absolute = int(np.round(pct / 100.0 * np.sum(allvals))) - return "{:d}".format(absolute) - - -def plot_multicomp_graph(*args, **kwargs): - W = MultiCompGraphWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_graph.__doc__ = MultiCompGraphWidget.__doc__ - - -def plot_multicomp_agreement(*args, **kwargs): - W = MultiCompGlobalAgreementWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_agreement.__doc__ = MultiCompGlobalAgreementWidget.__doc__ - - -def plot_multicomp_agreement_by_sorter(*args, **kwargs): - W = MultiCompAgreementBySorterWidget(*args, **kwargs) - W.plot() - return W - - -plot_multicomp_agreement_by_sorter.__doc__ = MultiCompAgreementBySorterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py deleted file mode 100644 index 98abe3f4fc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_legacy_widgets_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -if __name__ != "__main__": - import matplotlib - - matplotlib.use("Agg") - -from spikeinterface import download_dataset -import spikeinterface.extractors as se - -from spikeinterface.widgets.utils import get_unit_colors - - -def test_get_unit_colors(): - local_path = download_dataset(remote_path="mearec/mearec_test_10s.h5") - sorting = se.MEArecSortingExtractor(local_path) - - colors = get_unit_colors(sorting) - print(colors) - - -if __name__ == "__main__": - test_get_unit_colors() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py deleted file mode 100644 index fda2e75138..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ /dev/null @@ -1,128 +0,0 @@ -import unittest -import pytest -import sys -from pathlib import Path - -if __name__ != "__main__": - import matplotlib - - matplotlib.use("Agg") -import matplotlib.pyplot as plt - -from spikeinterface import extract_waveforms, load_waveforms, download_dataset -import spikeinterface.extractors as se -import spikeinterface.widgets as sw -import spikeinterface.comparison as sc -from spikeinterface.postprocessing import compute_spike_amplitudes -from spikeinterface.qualitymetrics import compute_quality_metrics - - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "widgets" -else: - cache_folder = Path("cache_folder") / "widgets" - - -class TestWidgets(unittest.TestCase): - def setUp(self): - local_path = download_dataset(remote_path="mearec/mearec_test_10s.h5") - self._rec = se.MEArecRecordingExtractor(local_path) - - self._sorting = se.MEArecSortingExtractor(local_path) - - self.num_units = len(self._sorting.get_unit_ids()) - #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test_old_api").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test_old_api") - else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) - - self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") - self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) - - def tearDown(self): - pass - - # def test_plot_unit_probe_map(self): - # sw.plot_unit_probe_map(self._we, with_channel_ids=True) - # sw.plot_unit_probe_map(self._we, animated=True) - - # def test_plot_units_depth_vs_amplitude(self): - # sw.plot_units_depth_vs_amplitude(self._we) - - # def test_amplitudes_timeseries(self): - # sw.plot_amplitudes_timeseries(self._we) - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_amplitudes_timeseries(self._we, unit_ids=unit_ids) - - # def test_amplitudes_distribution(self): - # sw.plot_amplitudes_distribution(self._we) - - # def test_plot_unit_localization(self): - # sw.plot_unit_localization(self._we, with_channel_ids=True) - # sw.plot_unit_localization(self._we, method='monopolar_triangulation') - - # def test_autocorrelograms(self): - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_autocorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - - # def test_crosscorrelogram(self): - # unit_ids = self._sorting.unit_ids[:4] - # sw.plot_crosscorrelograms(self._sorting, unit_ids=unit_ids, window_ms=500.0, bin_ms=20.0) - - # def test_isi_distribution(self): - # sw.plot_isi_distribution(self._sorting, bin_ms=5.0, window_ms=500.0) - # fig, axes = plt.subplots(self.num_units, 1) - # sw.plot_isi_distribution(self._sorting, axes=axes) - - # def test_plot_peak_activity_map(self): - # sw.plot_peak_activity_map(self._rec, with_channel_ids=True) - # sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - - def test_multicomp_graph(self): - msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) - sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) - sw.plot_multicomp_agreement(msc) - sw.plot_multicomp_agreement_by_sorter(msc) - fig, axes = plt.subplots(len(msc.object_list), 1) - sw.plot_multicomp_agreement_by_sorter(msc, axes=axes) - - # def test_sorting_performance(self): - # metrics = compute_quality_metrics(self._we, metric_names=["snr"]) - # sw.plot_sorting_performance(self._gt_comp, metrics, performance_name="accuracy", metric_name="snr") - - # ~ def test_plot_unit_summary(self): - # ~ unit_id = self._sorting.unit_ids[4] - # ~ sw.plot_unit_summary(self._we, unit_id) - - -if __name__ == "__main__": - # unittest.main() - - mytest = TestWidgets() - mytest.setUp() - - # ~ mytest.test_timeseries() - # ~ mytest.test_unitwaveforms() - # ~ mytest.test_plot_unit_waveform_density_map() - # mytest.test_unittemplates() - # ~ mytest.test_plot_unit_probe_map() - #  mytest.test_plot_units_depth_vs_amplitude() - # ~ mytest.test_amplitudes_timeseries() - # ~ mytest.test_amplitudes_distribution() - # ~ mytest.test_principal_component() - # ~ mytest.test_plot_unit_localization() - - # ~ mytest.test_autocorrelograms() - # ~ mytest.test_crosscorrelogram() - # ~ mytest.test_isi_distribution() - - # ~ mytest.test_plot_drift_over_time() - # ~ mytest.test_plot_peak_activity_map() - - # ~ mytest.test_multicomp_graph() - #  mytest.test_sorting_performance() - - # ~ mytest.test_plot_unit_summary() - - plt.show() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py deleted file mode 100644 index 6872a8b27b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import matplotlib.pyplot as plt - -import random - -try: - import distinctipy - - HAVE_DISTINCTIPY = True -except ImportError: - HAVE_DISTINCTIPY = False - - -def get_unit_colors(sorting, map_name="gist_ncar", format="RGBA", shuffle=False): - """ - Return a dict colors per units. - """ - possible_formats = ("RGBA",) - assert format in possible_formats, f"format must be {possible_formats}" - - unit_ids = sorting.unit_ids - - if HAVE_DISTINCTIPY: - colors = distinctipy.get_colors(unit_ids.size) - # add the alpha - colors = [color + (1.0,) for color in colors] - else: - # some map have black or white at border so +10 - margin = max(4, len(unit_ids) // 20) // 2 - cmap = plt.get_cmap(map_name, len(unit_ids) + 2 * margin) - - colors = [cmap(i + margin) for i, unit_id in enumerate(unit_ids)] - if shuffle: - random.shuffle(colors) - - dict_colors = dict(zip(unit_ids, colors)) - - return dict_colors diff --git a/src/spikeinterface/widgets/multicomparison_agreement.py b/src/spikeinterface/widgets/multicomparison_agreement.py new file mode 100644 index 0000000000..41a4555aba --- /dev/null +++ b/src/spikeinterface/widgets/multicomparison_agreement.py @@ -0,0 +1,188 @@ +import numpy as np + +from .base import BaseWidget, to_attr + + +class MultiCompGlobalAgreementWidget(BaseWidget): + """ + Plots multi comparison agreement as pie or bar plot. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + plot_type: str + 'pie' or 'bar' + cmap: matplotlib colormap, default: 'YlOrRd' + The colormap to be used for the nodes + fontsize: int, default: 9 + The text fontsize + show_legend: bool, default: True + If True a legend is shown + """ + + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() + # fraction of units with agreement > threshold + v, c = np.unique([len(np.unique(s)) for s in sg_names], return_counts=True) + if dp.plot_type == "pie": + p = self.ax.pie(c, colors=colors[v - 1], autopct=lambda pct: _getabs(pct, c), pctdistance=1.25) + self.ax.legend( + p[0], + v, + frameon=False, + title="k=", + handlelength=1, + handletextpad=0.5, + bbox_to_anchor=(1.0, 1.0), + loc=2, + borderaxespad=0.5, + labelspacing=0.15, + fontsize=dp.fontsize, + ) + elif dp.plot_type == "bar": + self.ax.bar(v, c, color=colors[v - 1]) + x_labels = [f"k={vi}" for vi in v] + self.ax.spines["top"].set_visible(False) + self.ax.spines["right"].set_visible(False) + self.ax.set_xticks(v) + self.ax.set_xticklabels(x_labels) + else: + raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") + self.ax.set_title("Units agreed upon\nby k sorters") + + +class MultiCompAgreementBySorterWidget(BaseWidget): + """ + Plots multi comparison agreement as pie or bar plot. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + plot_type: str + 'pie' or 'bar' + cmap: matplotlib colormap + The colormap to be used for the nodes (default 'Reds') + axes: list of matplotlib axes + The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax + and figure parameters are ignored. + show_legend: bool + Show the legend in the last axes (default True). + + Returns + ------- + W: MultiCompGraphWidget + The output widget + """ + + def __init__( + self, + multi_comparison, + plot_type="pie", + cmap="YlOrRd", + fontsize=9, + show_legend=True, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + plot_type=plot_type, + cmap=cmap, + fontsize=fontsize, + show_legend=show_legend, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + mcmp = dp.multi_comparison + name_list = mcmp.name_list + + backend_kwargs["num_axes"] = len(name_list) + backend_kwargs["ncols"] = len(name_list) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + cmap = plt.get_cmap(dp.cmap) + colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) + sg_names, sg_units = mcmp.compute_subgraphs() + # fraction of units with agreement > threshold + for i, name in enumerate(name_list): + ax = np.squeeze(self.axes)[i] + v, c = np.unique([len(np.unique(sn)) for sn in sg_names if name in sn], return_counts=True) + if dp.plot_type == "pie": + p = ax.pie( + c, + colors=colors[v - 1], + textprops={"color": "k", "fontsize": dp.fontsize}, + autopct=lambda pct: _getabs(pct, c), + pctdistance=1.18, + ) + if (dp.show_legend) and (i == len(name_list) - 1): + plt.legend( + p[0], + v, + frameon=False, + title="k=", + handlelength=1, + handletextpad=0.5, + bbox_to_anchor=(1.15, 1.25), + loc=2, + borderaxespad=0.0, + labelspacing=0.15, + ) + elif dp.plot_type == "bar": + ax.bar(v, c, color=colors[v - 1]) + x_labels = [f"k={vi}" for vi in v] + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.set_xticks(v) + ax.set_xticklabels(x_labels) + else: + raise AttributeError("Wrong plot_type. It can be 'pie' or 'bar'") + ax.set_title(name) + + if dp.plot_type == "bar": + ylims = [np.max(ax_single.get_ylim()) for ax_single in self.axes] + max_yval = np.max(ylims) + for ax_single in self.axes: + ax_single.set_ylim([0, max_yval]) + + +def _getabs(pct, allvals): + absolute = int(np.round(pct / 100.0 * np.sum(allvals))) + return f"{absolute}" diff --git a/src/spikeinterface/widgets/multicomparison_graph.py b/src/spikeinterface/widgets/multicomparison_graph.py index e69de29bb2..5de1011194 100644 --- a/src/spikeinterface/widgets/multicomparison_graph.py +++ b/src/spikeinterface/widgets/multicomparison_graph.py @@ -0,0 +1,111 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + +class MultiCompGraphWidget(BaseWidget): + """ + Plots multi comparison graph. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + draw_labels: bool + If True unit labels are shown + node_cmap: matplotlib colormap + The colormap to be used for the nodes (default 'viridis') + edge_cmap: matplotlib colormap + The colormap to be used for the edges (default 'hot') + alpha_edges: float + Alpha value for edges + colorbar: bool + If True a colorbar for the edges is plotted + """ + + def __init__( + self, + multi_comparison, + draw_labels=False, + node_cmap="viridis", + edge_cmap="hot", + alpha_edges=0.5, + colorbar=False, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + draw_labels=draw_labels, + node_cmap=node_cmap, + edge_cmap=edge_cmap, + alpha_edges=alpha_edges, + colorbar=colorbar, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + import networkx as nx + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + g = mcmp.graph + edge_col = [] + for e in g.edges(data=True): + n1, n2, d = e + edge_col.append(d["weight"]) + nodes_col_dict = {} + for i, sort_name in enumerate(mcmp.name_list): + nodes_col_dict[sort_name] = i + nodes_col = [] + for node in sorted(g.nodes): + nodes_col.append(nodes_col_dict[node[0]]) + nodes_col = np.array(nodes_col) / len(mcmp.name_list) + + _ = plt.set_cmap(dp.node_cmap) + _ = nx.draw_networkx_nodes( + g, + pos=nx.circular_layout(sorted(g)), + nodelist=sorted(g.nodes), + node_color=nodes_col, + node_size=20, + ax=self.ax, + ) + _ = nx.draw_networkx_edges( + g, + pos=nx.circular_layout((sorted(g))), + nodelist=sorted(g.nodes), + edge_color=edge_col, + alpha=dp.alpha_edges, + edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_vmin=mcmp.match_score, + edge_vmax=1, + ax=self.ax, + ) + if dp.draw_labels: + labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} + pos = nx.circular_layout(sorted(g)) + # extend position radially + pos_extended = {} + for node, pos in pos.items(): + pos_new = pos + 0.1 * pos + pos_extended[node] = pos_new + _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) + + if dp.colorbar: + import matplotlib.pyplot as plt + + norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) + cmap = plt.cm.get_cmap(dp.edge_cmap) + m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) + self.figure.colorbar(m) + + self.ax.axis("off") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index efb429a52a..052497347d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -207,7 +207,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): **self.backend_kwargs[backend], ) - def test_autocorrelograms(self): + def test_plot_autocorrelograms(self): possible_backends = list(sw.AutoCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -221,7 +221,7 @@ def test_autocorrelograms(self): **self.backend_kwargs[backend], ) - def test_crosscorrelogram(self): + def test_plot_crosscorrelogram(self): possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -235,7 +235,7 @@ def test_crosscorrelogram(self): **self.backend_kwargs[backend], ) - def test_isi_distribution(self): + def test_plot_isi_distribution(self): possible_backends = list(sw.ISIDistributionWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -249,7 +249,7 @@ def test_isi_distribution(self): **self.backend_kwargs[backend], ) - def test_amplitudes(self): + def test_plot_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -283,7 +283,7 @@ def test_plot_all_amplitudes_distributions(self): self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) - def test_unit_locations(self): + def test_plot_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -294,7 +294,7 @@ def test_unit_locations(self): self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - def test_spike_locations(self): + def test_plot_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -305,21 +305,21 @@ def test_spike_locations(self): self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - def test_similarity(self): + def test_plot_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) - def test_quality_metrics(self): + def test_plot_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) - def test_template_metrics(self): + def test_plot_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -344,7 +344,7 @@ def test_plot_unit_summary(self): self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) - def test_sorting_summary(self): + def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: @@ -393,6 +393,23 @@ def test_plot_peak_activity(self): if backend not in self.skip_backends: sw.plot_peak_activity(self.recording, self.peaks) + def test_plot_multicomparison(self): + mcmp = sc.compare_multiple_sorters([self.sorting, self.sorting, self.sorting]) + possible_backends_graph = list(sw.MultiCompGraphWidget.get_possible_backends()) + for backend in possible_backends_graph: + sw.plot_multicomparison_graph( + mcmp, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False, backend=backend + ) + possible_backends_glob = list(sw.MultiCompGlobalAgreementWidget.get_possible_backends()) + for backend in possible_backends_glob: + sw.plot_multicomparison_agreement(mcmp, backend=backend) + possible_backends_by_sorter = list(sw.MultiCompAgreementBySorterWidget.get_possible_backends()) + for backend in possible_backends_by_sorter: + sw.plot_multicomparison_agreement_by_sorter(mcmp) + if backend == "matplotlib": + _, axes = plt.subplots(len(mcmp.object_list), 1) + sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + if __name__ == "__main__": # unittest.main() @@ -422,6 +439,6 @@ def test_plot_peak_activity(self): # mytest.test_plot_rasters() # mytest.test_plot_unit_probe_map() # mytest.test_plot_unit_presence() - mytest.test_plot_peak_activity() + mytest.test_plot_multicomparison() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 113be0c53b..21e00918d0 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,6 +10,8 @@ from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget +from .multicomparison_agreement import MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget +from .multicomparison_graph import MultiCompGraphWidget from .peak_activity import PeakActivityMapWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -41,6 +43,9 @@ CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + MultiCompGlobalAgreementWidget, + MultiCompAgreementBySorterWidget, + MultiCompGraphWidget, PeakActivityMapWidget, ProbeMapWidget, QualityMetricsWidget, @@ -108,6 +113,9 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_multicomparison_agreement = MultiCompGlobalAgreementWidget +plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget +plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 4af4a03e6445f67044c715c3b182eac67c559743 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Oct 2023 09:57:38 +0200 Subject: [PATCH 25/25] Only one multicomparison widget file --- ...arison_agreement.py => multicomparison.py} | 108 +++++++++++++++++ .../widgets/multicomparison_graph.py | 111 ------------------ src/spikeinterface/widgets/widget_list.py | 3 +- 3 files changed, 109 insertions(+), 113 deletions(-) rename src/spikeinterface/widgets/{multicomparison_agreement.py => multicomparison.py} (65%) delete mode 100644 src/spikeinterface/widgets/multicomparison_graph.py diff --git a/src/spikeinterface/widgets/multicomparison_agreement.py b/src/spikeinterface/widgets/multicomparison.py similarity index 65% rename from src/spikeinterface/widgets/multicomparison_agreement.py rename to src/spikeinterface/widgets/multicomparison.py index 41a4555aba..e01a79dfd5 100644 --- a/src/spikeinterface/widgets/multicomparison_agreement.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -1,6 +1,114 @@ import numpy as np +from warnings import warn from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + +class MultiCompGraphWidget(BaseWidget): + """ + Plots multi comparison graph. + + Parameters + ---------- + multi_comparison: BaseMultiComparison + The multi comparison object + draw_labels: bool + If True unit labels are shown + node_cmap: matplotlib colormap + The colormap to be used for the nodes (default 'viridis') + edge_cmap: matplotlib colormap + The colormap to be used for the edges (default 'hot') + alpha_edges: float + Alpha value for edges + colorbar: bool + If True a colorbar for the edges is plotted + """ + + def __init__( + self, + multi_comparison, + draw_labels=False, + node_cmap="viridis", + edge_cmap="hot", + alpha_edges=0.5, + colorbar=False, + backend=None, + **backend_kwargs, + ): + plot_data = dict( + multi_comparison=multi_comparison, + draw_labels=draw_labels, + node_cmap=node_cmap, + edge_cmap=edge_cmap, + alpha_edges=alpha_edges, + colorbar=colorbar, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.colors as mpl_colors + import matplotlib.pyplot as plt + import networkx as nx + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + mcmp = dp.multi_comparison + g = mcmp.graph + edge_col = [] + for e in g.edges(data=True): + n1, n2, d = e + edge_col.append(d["weight"]) + nodes_col_dict = {} + for i, sort_name in enumerate(mcmp.name_list): + nodes_col_dict[sort_name] = i + nodes_col = [] + for node in sorted(g.nodes): + nodes_col.append(nodes_col_dict[node[0]]) + nodes_col = np.array(nodes_col) / len(mcmp.name_list) + + _ = plt.set_cmap(dp.node_cmap) + _ = nx.draw_networkx_nodes( + g, + pos=nx.circular_layout(sorted(g)), + nodelist=sorted(g.nodes), + node_color=nodes_col, + node_size=20, + ax=self.ax, + ) + _ = nx.draw_networkx_edges( + g, + pos=nx.circular_layout((sorted(g))), + nodelist=sorted(g.nodes), + edge_color=edge_col, + alpha=dp.alpha_edges, + edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_vmin=mcmp.match_score, + edge_vmax=1, + ax=self.ax, + ) + if dp.draw_labels: + labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} + pos = nx.circular_layout(sorted(g)) + # extend position radially + pos_extended = {} + for node, pos in pos.items(): + pos_new = pos + 0.1 * pos + pos_extended[node] = pos_new + _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) + + if dp.colorbar: + import matplotlib.pyplot as plt + + norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) + cmap = plt.cm.get_cmap(dp.edge_cmap) + m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) + self.figure.colorbar(m) + + self.ax.axis("off") class MultiCompGlobalAgreementWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/multicomparison_graph.py b/src/spikeinterface/widgets/multicomparison_graph.py deleted file mode 100644 index 5de1011194..0000000000 --- a/src/spikeinterface/widgets/multicomparison_graph.py +++ /dev/null @@ -1,111 +0,0 @@ -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -class MultiCompGraphWidget(BaseWidget): - """ - Plots multi comparison graph. - - Parameters - ---------- - multi_comparison: BaseMultiComparison - The multi comparison object - draw_labels: bool - If True unit labels are shown - node_cmap: matplotlib colormap - The colormap to be used for the nodes (default 'viridis') - edge_cmap: matplotlib colormap - The colormap to be used for the edges (default 'hot') - alpha_edges: float - Alpha value for edges - colorbar: bool - If True a colorbar for the edges is plotted - """ - - def __init__( - self, - multi_comparison, - draw_labels=False, - node_cmap="viridis", - edge_cmap="hot", - alpha_edges=0.5, - colorbar=False, - backend=None, - **backend_kwargs, - ): - plot_data = dict( - multi_comparison=multi_comparison, - draw_labels=draw_labels, - node_cmap=node_cmap, - edge_cmap=edge_cmap, - alpha_edges=alpha_edges, - colorbar=colorbar, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.colors as mpl_colors - import matplotlib.pyplot as plt - import networkx as nx - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - mcmp = dp.multi_comparison - g = mcmp.graph - edge_col = [] - for e in g.edges(data=True): - n1, n2, d = e - edge_col.append(d["weight"]) - nodes_col_dict = {} - for i, sort_name in enumerate(mcmp.name_list): - nodes_col_dict[sort_name] = i - nodes_col = [] - for node in sorted(g.nodes): - nodes_col.append(nodes_col_dict[node[0]]) - nodes_col = np.array(nodes_col) / len(mcmp.name_list) - - _ = plt.set_cmap(dp.node_cmap) - _ = nx.draw_networkx_nodes( - g, - pos=nx.circular_layout(sorted(g)), - nodelist=sorted(g.nodes), - node_color=nodes_col, - node_size=20, - ax=self.ax, - ) - _ = nx.draw_networkx_edges( - g, - pos=nx.circular_layout((sorted(g))), - nodelist=sorted(g.nodes), - edge_color=edge_col, - alpha=dp.alpha_edges, - edge_cmap=plt.cm.get_cmap(dp.edge_cmap), - edge_vmin=mcmp.match_score, - edge_vmax=1, - ax=self.ax, - ) - if dp.draw_labels: - labels = {key: f"{key[0]}_{key[1]}" for key in sorted(g.nodes)} - pos = nx.circular_layout(sorted(g)) - # extend position radially - pos_extended = {} - for node, pos in pos.items(): - pos_new = pos + 0.1 * pos - pos_extended[node] = pos_new - _ = nx.draw_networkx_labels(g, pos=pos_extended, labels=labels, ax=self.ax) - - if dp.colorbar: - import matplotlib.pyplot as plt - - norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) - cmap = plt.cm.get_cmap(dp.edge_cmap) - m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) - self.figure.colorbar(m) - - self.ax.axis("off") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 21e00918d0..00d179127d 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,8 +10,7 @@ from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget -from .multicomparison_agreement import MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget -from .multicomparison_graph import MultiCompGraphWidget +from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget